Source code for flatland.core.distance_map
import math
from typing import Dict, List, Optional, Generic, TypeVar, Callable
from flatland.core.transition_map import TransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_grid_transition_map import RailGridTransitionMap
from flatland.envs.step_utils.states import TrainState
UnderlyingTransitionMapType = TypeVar('UnderlyingTransitionMapType', bound=TransitionMap)
UnderlyingDistanceMapType = TypeVar('UnderlyingDistanceMapType')
UnderlyingConfigurationType = TypeVar('UnderlyingConfigurationType')
UnderlyingWaypointType = TypeVar('UnderlyingWaypointType')
[docs]
class AbstractDistanceMap(Generic[UnderlyingTransitionMapType, UnderlyingDistanceMapType, UnderlyingConfigurationType, UnderlyingWaypointType]):
def __init__(self, agents: List[EnvAgent], waypoint_init: Callable[[UnderlyingConfigurationType], UnderlyingWaypointType]):
self.distance_map = None
self.agents_previous_computation = None
self.reset_was_called = False
self.agents: List[EnvAgent] = agents
self.rail: Optional[RailGridTransitionMap] = None
self.waypoint_init = waypoint_init
[docs]
def set(self, distance_map: UnderlyingDistanceMapType):
"""
Set the distance map
"""
self.distance_map = distance_map
[docs]
def get(self) -> UnderlyingDistanceMapType:
"""
Get the distance map
"""
if self.reset_was_called:
self.reset_was_called = False
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_computation is None and self.distance_map is not None:
compute_distance_map = False
if compute_distance_map:
self._compute(self.agents, self.rail)
elif self.distance_map is None:
self._compute(self.agents, self.rail)
return self.distance_map
[docs]
def reset(self, agents: List[EnvAgent], rail: UnderlyingTransitionMapType):
"""
Reset the distance map
"""
self.reset_was_called = True
self.agents: List[EnvAgent] = agents
self.rail = rail
# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
[docs]
def get_shortest_paths(self, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) -> Dict[int, Optional[List[UnderlyingWaypointType]]]:
"""
Computes the shortest path for each agent to its target and the action to be taken to do so.
The paths are derived from a `DistanceMap`.
If there is no path (rail disconnected), the path is given as None.
The agent state (moving or not) and its speed are not taken into account
example:
agent_fixed_travel_paths = get_shortest_paths(env.distance_map, None, agent.handle)
path = agent_fixed_travel_paths[agent.handle]
Parameters
----------
self : reference to the distance_map
max_depth : max path length, if the shortest path is longer, it will be cut
agent_handle : if set, the shortest path for agent.handle will be returned, otherwise for all agents
Returns
-------
Dict[int, Optional[List[WalkingElement]]]
"""
shortest_paths = dict()
def _shortest_path_for_agent(agent: EnvAgent):
if agent.state.is_off_map_state():
configuration = agent.initial_configuration
elif agent.state.is_on_map_state():
configuration = agent.current_configuration
elif agent.state == TrainState.DONE:
shortest_paths[agent.handle] = None
return
else:
shortest_paths[agent.handle] = None
return
shortest_paths[agent.handle] = []
distance = math.inf
depth = 0
while configuration not in agent.targets and (max_depth is None or depth < max_depth):
best_next_configuration = None
next_configurations = self.rail.get_successor_configurations(configuration)
for next_configuration in next_configurations:
next_action_distance = self._get_distance(next_configuration, agent.handle)
if next_action_distance < distance:
distance = next_action_distance
best_next_configuration = next_configuration
shortest_paths[agent.handle].append(self.waypoint_init(configuration))
depth += 1
# if there is no way to continue, the rail must be disconnected!
# (or distance map is incorrect)
if best_next_configuration is None:
shortest_paths[agent.handle] = None
return
configuration = best_next_configuration
if max_depth is None or depth < max_depth:
shortest_paths[agent.handle].append(self.waypoint_init(configuration))
if agent_handle is not None:
_shortest_path_for_agent(self.agents[agent_handle])
else:
for agent in self.agents:
_shortest_path_for_agent(agent)
return shortest_paths
def _compute(self, agents: List[EnvAgent], rail: UnderlyingTransitionMapType):
raise NotImplementedError()
def _set_distance(self, configuration: UnderlyingConfigurationType, target_nr: int, new_distance: int):
raise NotImplementedError()
def _get_distance(self, configuration: UnderlyingConfigurationType, target_nr: int):
raise NotImplementedError()