Source code for flatland.envs.rail_grid_transition_map

from functools import lru_cache
from typing import Set, List, Optional
from typing import Tuple

import numpy as np

from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Transitions
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero
from flatland.envs.grid.rail_env_grid import RailEnvTransitions
from flatland.envs.rail_env_action import RailEnvActions
from flatland.utils.ordered_set import OrderedSet


[docs] class RailGridTransitionMap(GridTransitionMap[RailEnvActions]): def __init__(self, width, height, transitions: Transitions = RailEnvTransitions(), grid: np.ndarray = None): super().__init__(width=width, height=height, transitions=transitions, grid=grid)
[docs] @lru_cache(maxsize=4_000_000) def get_successor_configurations(self, configuration: Tuple[Tuple[int, int], int]) -> Set[Tuple[Tuple[int, int], int]]: position, direction = configuration successors = OrderedSet() for action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT]: t = self.apply_action_independent(action, (position, direction)) if t is not None: new_configuration, _ = t if self.is_valid_configuration(new_configuration): successors.add(new_configuration) return successors
[docs] @lru_cache(maxsize=4_000_000) def get_predecessor_configurations(self, configuration: Tuple[Tuple[int, int], int]) -> Set[Tuple[Tuple[int, int], int]]: position, direction = configuration previous_cell = get_new_position(position, (direction + 2) % 4) if self.check_bounds(previous_cell): # Check all possible transitions from previous cell return set([(previous_cell, agent_orientation) for agent_orientation in range(4) if self.get_transition(((previous_cell[0], previous_cell[1]), agent_orientation), direction)]) return set()
@lru_cache(maxsize=1_000_000) def _check_action_new(self, action: RailEnvActions, position: IntVector2D, direction: int): """ Checks whether action at position and direction leads to a valid new position in the grid. Sets action to MOVE_FORWARD if MOVE_LEFT/MOVE_RIGHT is provided but transition is not possible. Sets action to STOPPED if MOVE_FORWARD or DO_NOTING is provided but going into symmetric switch (facing the switch). Parameters ---------- action : RailEnvActions position: IntVector2D direction: int Returns ------- Tuple[Grid4TransitionsEnum, bool] the new direction and whether the the action was valid """ possible_transitions = self.get_transitions((position, direction)) num_transitions = fast_count_nonzero(possible_transitions) if num_transitions == 1: # - dead-end, straight line or curved line; # new_direction will be the only valid transition # - take only available transition new_direction = fast_argmax(possible_transitions) if action == RailEnvActions.MOVE_LEFT and new_direction != (direction - 1) % 4: action = RailEnvActions.MOVE_FORWARD return new_direction, False, action, True elif action == RailEnvActions.MOVE_RIGHT and new_direction != (direction + 1) % 4: action = RailEnvActions.MOVE_FORWARD return new_direction, False, action, True # straight or dead-end return new_direction, True, action, True if action == RailEnvActions.MOVE_LEFT: new_direction = (direction - 1) % 4 if possible_transitions[new_direction]: return new_direction, True, RailEnvActions.MOVE_LEFT, True elif possible_transitions[direction]: return direction, False, RailEnvActions.MOVE_FORWARD, True elif action == RailEnvActions.MOVE_RIGHT: new_direction = (direction + 1) % 4 if possible_transitions[new_direction]: return new_direction, True, RailEnvActions.MOVE_RIGHT, True elif possible_transitions[direction]: return direction, False, RailEnvActions.MOVE_FORWARD, True elif possible_transitions[direction]: return direction, True, action, True return direction, False, RailEnvActions.STOP_MOVING, False @lru_cache(maxsize=1_000_000) def _check_action_on_agent(self, action: RailEnvActions, configuration: Tuple[Tuple[int, int], int]) -> Tuple[ bool, Tuple[Tuple[int, int], int], bool, RailEnvActions, bool]: """ Returns ------- new_cell_valid: bool is the new position and direction valid (i.e. is it within bounds and does it have > 0 outgoing transitions) new_position: [ConfigurationType] New position after applying the action transition_valid: bool Whether the transition from old and direction is defined in the grid. In other words, can the action be applied directly? False if - MOVE_FORWARD/DO_NOTHING when entering symmetric switch - MOVE_LEFT/MOVE_RIGHT corrected to MOVE_FORWARD in switches and dead-ends However, transition_valid for dead-ends and turns either with the correct MOVE_RIGHT/MOVE_LEFT or MOVE_FORWARD/DO_NOTHING. preprocessed_action: [ActionType] Corrected action if not transition_valid. The preprocessed action has the following semantics: - MOVE_LEFT/MOVE_RIGHT: turn left/right without acceleration - MOVE_FORWARD: move forward with acceleration (swap direction in dead-end, also works in left/right turns or symmetric-switches non-facing) - DO_NOTHING: if already moving, keep moving forward without acceleration (swap direction in dead-end, also works in left/right turns or symmetric-switches non-facing); if stopped, stay stopped. action_valid : bool Whether the action is valid at all - irrespective of transition_valid (action_valid=False implies transition_valid=False, but not inversely). Happens only on symmetric switches. """ position, direction = configuration new_direction, transition_valid, preprocessed_action, action_valid = self._check_action_new(action, position, direction) new_position = get_new_position(position, new_direction) new_cell_valid = self.is_valid_configuration((new_position, new_direction)) return new_cell_valid, (new_position, new_direction), transition_valid, preprocessed_action, action_valid
[docs] @lru_cache(maxsize=1_000_000) def apply_action_independent(self, action: RailEnvActions, configuration: Tuple[Tuple[int, int], int]) -> Optional[ Tuple[Tuple[Tuple[int, int], int], bool]]: position, direction = configuration _, new_configuration, _, preprocessed_action, action_valid = self._check_action_on_agent(action, configuration) if action_valid and self.is_valid_configuration(new_configuration): new_position, new_direction = new_configuration # TODO https://github.com/flatland-association/flatland-rl/issues/280 revise design: allow acceleration in turns? dis-allow in dead-ends? straight = new_direction % 2 == direction % 2 return new_configuration, straight else: return None
[docs] def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: """ Returns directions in which the agent can move """ return self.transitions.get_entry_directions(self.get_full_transitions(row, col))