Source code for flatland.envs.step_utils.state_machine
from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
[docs]
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = StateTransitionSignals()
self.next_state = None
self.previous_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.movement_action_given and self.st_signals.movement_allowed:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if not self.st_signals.in_malfunction:
if self.st_signals.earliest_departure_reached:
# TODO revise design: should we not go to the READY_TO_DEPART first instead of directly to MOVING and STOPPED?
if self.st_signals.movement_action_given and self.st_signals.movement_allowed:
self.next_state = TrainState.MOVING
elif self.st_signals.stop_action_given and self.st_signals.movement_allowed:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
else:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
def _handle_moving(self):
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.target_reached:
# this branch is never used as target reached is not handled by state_machine.step() but by state_machine.update_if_reached()!
self.next_state = TrainState.DONE
elif (self.st_signals.stop_action_given and self.st_signals.new_speed == 0.0) or not self.st_signals.movement_allowed:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.movement_action_given and self.st_signals.movement_allowed:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if not self.st_signals.in_malfunction:
if self.st_signals.movement_action_given and self.st_signals.movement_allowed:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
def _handle_done(self):
"""" Done state is terminal """
self.next_state = TrainState.DONE
[docs]
def calculate_next_state(self, current_state):
# _Handle the current state
if current_state == TrainState.WAITING:
self._handle_waiting()
elif current_state == TrainState.READY_TO_DEPART:
self._handle_ready_to_depart()
elif current_state == TrainState.MALFUNCTION_OFF_MAP:
self._handle_malfunction_off_map()
elif current_state == TrainState.MOVING:
self._handle_moving()
elif current_state == TrainState.STOPPED:
self._handle_stopped()
elif current_state == TrainState.MALFUNCTION:
self._handle_malfunction()
elif current_state == TrainState.DONE:
self._handle_done()
else:
raise ValueError(f"Got unexpected state {current_state}")
[docs]
def step(self):
""" Steps the state machine to the next state """
current_state = self._state
# Clear next state
self.clear_next_state()
# Handle current state to get next_state
self.calculate_next_state(current_state)
# Set next state
self.set_state(self.next_state)
[docs]
def clear_next_state(self):
self.next_state = None
[docs]
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self.previous_state = self._state
self._state = state
[docs]
def reset(self):
self._state = self._initial_state
self.previous_state = None
self.st_signals = StateTransitionSignals()
self.clear_next_state()
[docs]
def update_if_reached(self, position, target):
# Need to do this hacky fix for now, state machine needed speed related states for proper handling
self.st_signals.target_reached = fast_position_equal(position, target)
if self.st_signals.target_reached:
self.next_state = TrainState.DONE
self.set_state(self.next_state)
[docs]
@staticmethod
def can_get_moving_independent(state: TrainState, in_malfunction: bool, movement_action_given: bool, new_speed: float, stop_action_given: bool):
"""
Incoming transitions to go into state MOVING (for motions to be checked - independently of other agents' position):
- keep MOVING unless (stop action given and reaches new speed is zero) or in malfunction
- from MALFUNCTION: if not in malfunction (on or off map) any more and movement action given
- from STOPPED: if movement action given and not in malfunction
Parameters
----------
state : TrainState
in_malfunction : bool
movement_action_given : bool
new_speed : float
stop_action_given : float
Returns
-------
Whether agents wants to move given its state (independently of other agents' position)
"""
can_get_moving = state == TrainState.MOVING and not (stop_action_given and new_speed == 0.0)
# malfunction ends and (explicit) movement action given
can_get_moving |= state == TrainState.MALFUNCTION and not in_malfunction and movement_action_given
can_get_moving |= state == TrainState.STOPPED and movement_action_given
can_get_moving &= not in_malfunction
return can_get_moving
@property
def state(self):
return self._state
@property
def state_transition_signals(self):
return self.st_signals
[docs]
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals
[docs]
def state_position_sync_check(self, position, i_agent, remove_agents_at_target):
""" Check for whether on map and off map states are matching with position being None """
if self.state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(self.state), str(position)))
elif self.state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(self.state), str(position)))
elif self.state == TrainState.DONE and remove_agents_at_target and position is not None:
raise ValueError("Agent ID {} Agent State {} is not None Agent Position {} if remove_agents_at_target".format(
i_agent, str(self.state), str(position)))
def __repr__(self):
return (
f"TrainStateMachine(\n"
f"\tstate={str(self.state)},\n"
f"\tprevious_state={str(self.previous_state) if self.previous_state is not None else None},\n"
f"\tst_signals={self.st_signals}\n"
f")"
)
[docs]
def to_dict(self):
return {"state": self._state,
"previous_state": self.previous_state}
[docs]
@staticmethod
def from_dict(load_dict) -> "TrainStateMachine":
sm = TrainStateMachine()
sm.set_state(load_dict['state'])
sm.previous_state = load_dict['previous_state']
return sm
def __eq__(self, other):
return self._state == other._state and self.previous_state == other.previous_state