Source code for flatland.envs.step_utils.state_machine

from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import env_utils
from flatland.envs.fast_methods import fast_position_equal

[docs] class TrainStateMachine: def __init__(self, initial_state=TrainState.WAITING): self._initial_state = initial_state self._state = initial_state self.st_signals = StateTransitionSignals() self.next_state = None self.previous_state = None def _handle_waiting(self): """" Waiting state goes to ready to depart when earliest departure is reached""" # TODO: Important - The malfunction handling is not like proper state machine # Both transition signals can happen at the same time # Atleast mention it in the diagram if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION_OFF_MAP elif self.st_signals.earliest_departure_reached: self.next_state = TrainState.READY_TO_DEPART else: self.next_state = TrainState.WAITING def _handle_ready_to_depart(self): """ Can only go to MOVING if a valid action is provided """ if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION_OFF_MAP elif self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING else: self.next_state = TrainState.READY_TO_DEPART def _handle_malfunction_off_map(self): if self.st_signals.malfunction_counter_complete: if self.st_signals.earliest_departure_reached: if self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING elif self.st_signals.stop_action_given: self.next_state = TrainState.STOPPED else: self.next_state = TrainState.READY_TO_DEPART else: self.next_state = TrainState.WAITING else: self.next_state = TrainState.MALFUNCTION_OFF_MAP def _handle_moving(self): if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION elif self.st_signals.target_reached: self.next_state = TrainState.DONE elif self.st_signals.stop_action_given or self.st_signals.movement_conflict: self.next_state = TrainState.STOPPED else: self.next_state = TrainState.MOVING def _handle_stopped(self): if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION elif self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING else: self.next_state = TrainState.STOPPED def _handle_malfunction(self): if self.st_signals.malfunction_counter_complete: if self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING else: self.next_state = TrainState.STOPPED else: self.next_state = TrainState.MALFUNCTION def _handle_done(self): """" Done state is terminal """ self.next_state = TrainState.DONE
[docs] def calculate_next_state(self, current_state): # _Handle the current state if current_state == TrainState.WAITING: self._handle_waiting() elif current_state == TrainState.READY_TO_DEPART: self._handle_ready_to_depart() elif current_state == TrainState.MALFUNCTION_OFF_MAP: self._handle_malfunction_off_map() elif current_state == TrainState.MOVING: self._handle_moving() elif current_state == TrainState.STOPPED: self._handle_stopped() elif current_state == TrainState.MALFUNCTION: self._handle_malfunction() elif current_state == TrainState.DONE: self._handle_done() else: raise ValueError(f"Got unexpected state {current_state}")
[docs] def step(self): """ Steps the state machine to the next state """ current_state = self._state # Clear next state self.clear_next_state() # Handle current state to get next_state self.calculate_next_state(current_state) # Set next state self.set_state(self.next_state)
[docs] def clear_next_state(self): self.next_state = None
[docs] def set_state(self, state): if not TrainState.check_valid_state(state): raise ValueError(f"Cannot set invalid state {state}") self.previous_state = self._state self._state = state
[docs] def reset(self): self._state = self._initial_state self.previous_state = None self.st_signals = StateTransitionSignals() self.clear_next_state()
[docs] def update_if_reached(self, position, target): # Need to do this hacky fix for now, state machine needed speed related states for proper handling self.st_signals.target_reached = fast_position_equal(position, target) if self.st_signals.target_reached: self.next_state = TrainState.DONE self.set_state(self.next_state)
@property def state(self): return self._state @property def state_transition_signals(self): return self.st_signals
[docs] def set_transition_signals(self, state_transition_signals): self.st_signals = state_transition_signals
def __repr__(self): return f"\n \ state: {str(self.state)} previous_state {str(self.previous_state)} \n \ st_signals: {self.st_signals}"
[docs] def to_dict(self): return {"state": self._state, "previous_state": self.previous_state}
[docs] def from_dict(self, load_dict): self.set_state(load_dict['state']) self.previous_state = load_dict['previous_state']
def __eq__(self, other): return self._state == other._state and self.previous_state == other.previous_state