Source code for flatland.envs.step_utils.state_machine
from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import env_utils
from flatland.envs.fast_methods import fast_position_equal
[docs]
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = StateTransitionSignals()
self.next_state = None
self.previous_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if self.st_signals.malfunction_counter_complete:
if self.st_signals.earliest_departure_reached:
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.stop_action_given:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
else:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
def _handle_moving(self):
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.target_reached:
self.next_state = TrainState.DONE
elif self.st_signals.stop_action_given or self.st_signals.movement_conflict:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if self.st_signals.malfunction_counter_complete:
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
def _handle_done(self):
"""" Done state is terminal """
self.next_state = TrainState.DONE
[docs]
def calculate_next_state(self, current_state):
# _Handle the current state
if current_state == TrainState.WAITING:
self._handle_waiting()
elif current_state == TrainState.READY_TO_DEPART:
self._handle_ready_to_depart()
elif current_state == TrainState.MALFUNCTION_OFF_MAP:
self._handle_malfunction_off_map()
elif current_state == TrainState.MOVING:
self._handle_moving()
elif current_state == TrainState.STOPPED:
self._handle_stopped()
elif current_state == TrainState.MALFUNCTION:
self._handle_malfunction()
elif current_state == TrainState.DONE:
self._handle_done()
else:
raise ValueError(f"Got unexpected state {current_state}")
[docs]
def step(self):
""" Steps the state machine to the next state """
current_state = self._state
# Clear next state
self.clear_next_state()
# Handle current state to get next_state
self.calculate_next_state(current_state)
# Set next state
self.set_state(self.next_state)
[docs]
def clear_next_state(self):
self.next_state = None
[docs]
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self.previous_state = self._state
self._state = state
[docs]
def reset(self):
self._state = self._initial_state
self.previous_state = None
self.st_signals = StateTransitionSignals()
self.clear_next_state()
[docs]
def update_if_reached(self, position, target):
# Need to do this hacky fix for now, state machine needed speed related states for proper handling
self.st_signals.target_reached = fast_position_equal(position, target)
if self.st_signals.target_reached:
self.next_state = TrainState.DONE
self.set_state(self.next_state)
@property
def state(self):
return self._state
@property
def state_transition_signals(self):
return self.st_signals
[docs]
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals
def __repr__(self):
return f"\n \
state: {str(self.state)} previous_state {str(self.previous_state)} \n \
st_signals: {self.st_signals}"
[docs]
def to_dict(self):
return {"state": self._state,
"previous_state": self.previous_state}
[docs]
def from_dict(self, load_dict):
self.set_state(load_dict['state'])
self.previous_state = load_dict['previous_state']
def __eq__(self, other):
return self._state == other._state and self.previous_state == other.previous_state