import sys
import warnings
from typing import Tuple, NamedTuple, List, TypeVar, Generic, Optional, Set
import numpy as np
from attr import attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
from flatland.envs.step_utils.speed_counter import SpeedCounter, _pseudo_fractional
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.timetable_utils import Line
from flatland.envs.timetable_utils import Timetable
[docs]
class Agent(NamedTuple):
initial_position: Tuple[int, int]
initial_direction: Grid4TransitionsEnum
direction: Grid4TransitionsEnum
target: Tuple[int, int]
moving: bool
earliest_departure: int
latest_arrival: int
handle: int
position: Tuple[int, int]
arrival_time: int
old_direction: Grid4TransitionsEnum
old_position: Tuple[int, int]
speed_counter: SpeedCounter
action_saver: ActionSaver
state_machine: TrainStateMachine
malfunction_handler: MalfunctionHandler
waypoints: List[List[Waypoint]] = None
waypoints_earliest_departure: List[int] = None
waypoints_latest_arrival: List[int] = None
[docs]
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_configuration=(agent_tuple.initial_position, agent_tuple.initial_direction),
current_configuration=(agent_tuple.position, agent_tuple.direction) if agent_tuple.position is not None and agent_tuple.direction is not None else None,
old_configuration=(
agent_tuple.old_position, agent_tuple.old_direction) if agent_tuple.old_position is not None and agent_tuple.old_direction is not None else None,
targets={(agent_tuple.target, d) for d in Grid4TransitionsEnum},
moving=agent_tuple.moving,
earliest_departure=agent_tuple.earliest_departure,
latest_arrival=agent_tuple.latest_arrival,
handle=agent_tuple.handle,
arrival_time=agent_tuple.arrival_time,
speed_counter=agent_tuple.speed_counter,
action_saver=agent_tuple.action_saver,
state_machine=agent_tuple.state_machine,
malfunction_handler=agent_tuple.malfunction_handler,
waypoints=agent_tuple.waypoints if agent_tuple.waypoints is not None else [Waypoint(agent_tuple.initial_position, agent_tuple.initial_direction),
Waypoint(agent_tuple.target, None)],
waypoints_earliest_departure=agent_tuple.waypoints_earliest_departure if agent_tuple.waypoints_earliest_departure is not None else [
agent_tuple.earliest_departure, None],
waypoints_latest_arrival=agent_tuple.waypoints_latest_arrival if agent_tuple.waypoints_latest_arrival is not None else [None,
agent_tuple.latest_arrival],
)
ConfigurationType = TypeVar('ConfigurationType')
[docs]
@attrs
class EnvAgent(Generic[ConfigurationType]):
# backwards compatibility
# TODO https://github.com/flatland-association/flatland-rl/issues/366 split grid special cases from general implementation
@property
def initial_position(self):
return self.initial_configuration[0]
@initial_position.setter
def initial_position(self, value):
self.initial_configuration = (value, self.initial_direction)
@property
def initial_direction(self):
return self.initial_configuration[1]
@initial_direction.setter
def initial_direction(self, value):
self.initial_configuration = (self.initial_position, value)
@property
def position(self):
if self.current_configuration is None:
return None
return self.current_configuration[0]
@position.setter
def position(self, value):
self.current_configuration = (value, self.direction)
@property
def direction(self):
if self.current_configuration is None:
return None
return self.current_configuration[1]
@direction.setter
def direction(self, value):
self.current_configuration = (self.position, value)
# used in rendering
@property
def old_position(self):
if self.old_configuration is None:
return None
return self.old_configuration[0]
@old_position.setter
def old_position(self, value):
self.old_configuration = (value, self.old_direction)
@property
def old_direction(self):
if self.old_configuration is None:
return None
return self.old_configuration[1]
@old_direction.setter
def old_direction(self, value):
self.old_configuration = (self.old_position, value)
@property
def target(self):
# assuming same cell for all
return list(self.targets)[0][0]
@target.setter
def target(self, value):
# backwards compatibility to mean any direction into the cell. Will valid directions be post-cleaned in _agents_from_line.
self.targets = {(value, d) for d in Grid4TransitionsEnum}
# INIT FROM HERE IN _from_line()
initial_configuration = attrib(type=ConfigurationType)
current_configuration = attrib(type=Optional[ConfigurationType], default=Factory(lambda: None))
targets = attrib(type=Set[ConfigurationType], default=Factory(lambda: set()))
moving = attrib(default=False, type=bool)
# NEW : EnvAgent - Schedule properties
earliest_departure = attrib(default=0, type=int)
latest_arrival = attrib(default=sys.maxsize, type=int)
# including initial and target, routing flexibility
waypoints = attrib(type=List[List[Waypoint]], default=Factory(lambda: [[]]))
# None at target, same for all in routing flexibility
waypoints_earliest_departure = attrib(type=List[int], default=Factory(lambda: []))
# None at initial, same for all in routing flexibility
waypoints_latest_arrival = attrib(type=List[int], default=Factory(lambda: []))
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter = attrib(default=Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default=Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default=Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)),
type=TrainStateMachine)
malfunction_handler = attrib(default=Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
old_configuration = attrib(type=Optional[ConfigurationType], default=Factory(lambda: None))
[docs]
def reset(self):
"""
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
"""
self.current_configuration = None
self.old_configuration = None
self.moving = False
self.arrival_time = None
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset()
self.state_machine.reset()
[docs]
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
# N.B. not serialized, valid targets post-processed.
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler,
waypoints=self.waypoints,
waypoints_earliest_departure=self.waypoints_earliest_departure,
waypoints_latest_arrival=self.waypoints_latest_arrival,
)
[docs]
def get_shortest_path(self, distance_map) -> List[Waypoint]:
return distance_map.get_shortest_paths(agent_handle=self.handle)[self.handle]
[docs]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_counter.max_speed
return int(np.ceil(distance / speed))
[docs]
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
[docs]
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
[docs]
@classmethod
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
num_agents = len(line.agent_waypoints)
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
agent = EnvAgent(
initial_configuration=(line.agent_waypoints[i_agent][0][0].position, line.agent_waypoints[i_agent][0][0].direction),
# why
current_configuration=(line.agent_waypoints[i_agent][0][0].position, line.agent_waypoints[i_agent][0][0].direction),
old_configuration=None,
targets={(line.agent_waypoints[i_agent][-1][0].position, d) for d in Grid4TransitionsEnum},
waypoints=line.agent_waypoints[i_agent],
moving=False,
earliest_departure=None,
latest_arrival=None,
waypoints_earliest_departure=None,
waypoints_latest_arrival=None,
handle=i_agent,
speed_counter=SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
[docs]
@staticmethod
def to_line(agents: List["EnvAgent"]):
return Line(
agent_waypoints={agent.handle: agent.waypoints for agent in agents},
agent_speeds={agent.handle: agent.speed_counter.max_speed for agent in agents},
)
[docs]
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
speed = static_agent[4]['speed']
speed = _pseudo_fractional(speed)
agent = EnvAgent(
initial_configuration=(static_agent[0], static_agent[1]),
current_configuration=(static_agent[0], static_agent[1]),
old_configuration=None,
# N.B. valid targets cleaned in _agents_from_line
targets={(static_agent[2], d) for d in Grid4TransitionsEnum},
moving=static_agent[3],
speed_counter=SpeedCounter(speed), handle=i,
)
else:
agent = EnvAgent(
initial_configuration=(static_agent[0], static_agent[1]),
current_configuration=(static_agent[0], static_agent[1]),
old_configuration=None,
# N.B. valid targets cleaned in _agents_from_line
targets={(static_agent[2], d) for d in Grid4TransitionsEnum},
moving=False,
speed_counter=SpeedCounter(1.0),
handle=i,
)
agents.append(agent)
return agents
def __str__(self):
return (
f"EnvAgent(\n"
f"\thandle={self.handle},\n"
f"\tinitial_position={self.initial_position},\n"
f"\tinitial_direction={self.initial_direction},\n"
f"\tposition={self.position},\n"
f"\tdirection={self.direction if self.direction is None else Grid4TransitionsEnum(self.direction).value},\n"
f"\ttarget={self.target},\n"
f"\ttargets={self.targets},\n"
f"\told_position={self.old_position},\n"
f"\told_direction={self.old_direction},\n"
f"\tearliest_departure={self.earliest_departure},\n"
f"\tlatest_arrival={self.latest_arrival},\n"
f"\tstate_machine={str(self.state_machine)},\n"
f"\tmalfunction_handler={self.malfunction_handler},\n"
f"\twaypoints={self.waypoints},\n"
f"\twaypoints_earliest_departure={self.waypoints_earliest_departure},\n"
f"\twaypoints_latest_arrival={self.waypoints_latest_arrival},\n"
f")"
)
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")