Source code for flatland.envs.step_utils.malfunction_handler

[docs] def get_number_of_steps_to_break(malfunction_generator, np_random): if hasattr(malfunction_generator, "generate"): malfunction = malfunction_generator.generate(np_random) else: malfunction = malfunction_generator(np_random) return malfunction.num_broken_steps
[docs] class MalfunctionHandler: def __init__(self, malfunction_down_counter: int = 0, num_malfunctions: int = 0): self._malfunction_down_counter = malfunction_down_counter self.num_malfunctions = num_malfunctions
[docs] def reset(self): self._malfunction_down_counter = 0 self.num_malfunctions = 0
@property def in_malfunction(self): return self._malfunction_down_counter > 0 @property def malfunction_counter_complete(self): return self._malfunction_down_counter == 0 @property def malfunction_down_counter(self): return self._malfunction_down_counter @malfunction_down_counter.setter def malfunction_down_counter(self, val): self._set_malfunction_down_counter(val) def _set_malfunction_down_counter(self, val): if val < 0: raise ValueError("Cannot set a negative value to malfunction down counter") # Only set new malfunction value if old malfunction is completed if self._malfunction_down_counter == 0: self._malfunction_down_counter = val if val > 0: self.num_malfunctions += 1
[docs] def generate_malfunction(self, malfunction_generator, np_random): num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random) self._set_malfunction_down_counter(num_broken_steps)
[docs] def update_counter(self): if self._malfunction_down_counter > 0: self._malfunction_down_counter -= 1
def __repr__(self): return ( f"MalfunctionHandler(\n" f"\tmalfunction_down_counter={self._malfunction_down_counter},\n" f"\tnum_malfunctions={self.num_malfunctions},\n" f")" )
[docs] def to_dict(self): return {"malfunction_down_counter": self._malfunction_down_counter, "num_malfunctions": self.num_malfunctions}
[docs] def from_dict(self, load_dict): self._malfunction_down_counter = load_dict['malfunction_down_counter'] self.num_malfunctions = load_dict['num_malfunctions']
def __eq__(self, other): return self._malfunction_down_counter == other._malfunction_down_counter and \ self.num_malfunctions == other.num_malfunctions