Source code for flatland.envs.rail_env_policies

from typing import List

from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_policy import RailEnvPolicy
from flatland.envs.rail_env_shortest_paths import get_k_shortest_paths
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.step_utils.states import TrainState


[docs] class ShortestPathPolicy(RailEnvPolicy[RailEnv, RailEnv, RailEnvActions]): def __init__(self): super().__init__() self._shortest_paths = {} def _act(self, env: RailEnv, agent: EnvAgent): if agent.position is None: return RailEnvActions.MOVE_FORWARD if len(self._shortest_paths[agent.handle]) == 0: return RailEnvActions.DO_NOTHING for a in {RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT}: new_cell_valid, (new_position, new_direction), transition_valid, preprocessed_action = env.rail.check_action_on_agent( RailEnvActions.from_value(a), (agent.position, agent.direction) ) if new_cell_valid and transition_valid and ( new_position == self._shortest_paths[agent.handle][1].position and new_direction == self._shortest_paths[agent.handle][1].direction): return a raise Exception("Invalid state")
[docs] def act_many(self, handles: List[int], observations: List[RailEnv], **kwargs): actions = {} for handle, env in zip(handles, observations): agent = env.agents[handle] self._update_agent(agent, env) actions[handle] = self._act(env, agent) return actions
def _update_agent(self, agent: EnvAgent, env: RailEnv): """ Update `_shortest_paths`. """ if agent.state == TrainState.DONE: self._shortest_paths.pop(agent.handle, None) return if agent.handle not in self._shortest_paths: p = [] for pp1, pp2 in zip(agent.waypoints, agent.waypoints[1:]): p1: Waypoint = pp1[0] p2: Waypoint = pp2[0] if len(p) > 0: assert p[-1] == p1, (p[-1], p1) pp_next = get_k_shortest_paths(None, p1.position, p1.direction, p2.position, rail=env.rail) p_next = None if p2.direction is None: p_next = pp_next[0] else: for _p_next in pp_next: if _p_next[-1].direction == p2.direction: p_next = _p_next break assert p_next is not None, f"Not found next path from {p1} to {p2}" if len(p) > 0: p += p_next[1:] else: p += p_next self._shortest_paths[agent.handle] = p if agent.position is None: return while self._shortest_paths[agent.handle][0].position != agent.position: self._shortest_paths[agent.handle] = self._shortest_paths[agent.handle][1:] assert self._shortest_paths[agent.handle][0].position == agent.position