Source code for flatland.envs.rail_env_policies

from typing import List

from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_policy import RailEnvPolicy
from flatland.envs.rail_env_shortest_paths import get_k_shortest_paths


[docs] class ShortestPathPolicy(RailEnvPolicy[RailEnv, RailEnv, RailEnvActions]): def __init__(self): super().__init__() self._shortest_paths = {} self._remaining_targets = {} def _act(self, env: RailEnv, agent: EnvAgent): if agent.position is None: return RailEnvActions.MOVE_FORWARD if agent.handle not in self._remaining_targets: self._remaining_targets[agent.handle] = agent.waypoints shortest_path = self._shortest_paths[agent.handle] while shortest_path[0].position != agent.position: shortest_path = shortest_path[1:] assert shortest_path[0].position == agent.position if agent.position == self._remaining_targets[agent.handle][0]: self._remaining_targets[agent.handle] = self._remaining_targets[agent.handle][1:] if len(self._remaining_targets[agent.handle]) > 0: self._shortest_paths[agent.handle] = \ get_k_shortest_paths(env, agent.position, agent.direction, self._remaining_targets[agent.handle][0].position)[0] if len(self._remaining_targets[agent.handle]) == 0: return RailEnvActions.DO_NOTHING for a in {RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT}: new_cell_valid, (new_position, new_direction), transition_valid, preprocessed_action = env.rail.check_action_on_agent( RailEnvActions.from_value(a), (agent.position, agent.direction) ) if new_cell_valid and transition_valid and ( new_position == self._remaining_targets[agent.handle][0] or ( new_position == shortest_path[1].position and new_direction == shortest_path[1].direction)): return a raise Exception("Invalid state")
[docs] def act_many(self, handles: List[int], observations: List[RailEnv], **kwargs): actions = {} for handle, env in zip(handles, observations): agent = env.agents[handle] if agent.handle not in self._shortest_paths: self._shortest_paths[agent.handle] = get_k_shortest_paths(env, agent.initial_position, agent.initial_direction, agent.target)[0] actions[handle] = self._act(env, agent) return actions