Source code for flatland.envs.graph.distance_map
from collections import defaultdict
from typing import List, Dict
import numpy as np
from flatland.core.distance_map import AbstractDistanceMap
from flatland.core.distance_map_walker import DistanceMapWalker
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.graph.rail_graph_transition_map import GraphTransitionMap
[docs]
class GraphDistanceMap(AbstractDistanceMap[GraphTransitionMap, Dict[int, Dict[str, int]], str, str]):
def __init__(self, agents: List[EnvAgent]):
super().__init__(agents=agents, waypoint_init=str)
def _compute(self, agents: List[EnvAgent], rail: GraphTransitionMap):
self.agents_previous_computation = self.agents
self.distance_map = defaultdict(lambda: defaultdict(lambda: np.inf))
distance_map_walker = DistanceMapWalker[GraphDistanceMap, GraphTransitionMap, str](self)
computed_targets = []
for i, agent in enumerate(agents):
if agent.targets not in computed_targets:
distance_map_walker._distance_map_walker(rail, agent.handle, agent.targets)
else:
self.distance_map[i] = self.distance_map[computed_targets.index(agent.targets)]
computed_targets.append(agent.targets)
def _set_distance(self, configuration: str, target_nr: int, new_distance: int):
self.distance_map[target_nr][configuration] = new_distance
def _get_distance(self, configuration: str, target_nr: int):
return self.get()[target_nr][configuration]