Source code for flatland.core.transition_map

"""
TransitionMap and derived classes.
"""
import traceback
import uuid
import warnings
from functools import lru_cache
from typing import Tuple, Generic, TypeVar, Any

import numpy as np
from importlib_resources import path
from numpy import array

from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet

ConfigurationType = TypeVar('ConfigurationType')
UnderlyingTransitionsType = TypeVar('UnderlyingTransitionsType')
UnderlyingTransitionsValidityType = TypeVar('UnderlyingTransitionsValidityType')
ActionsType = TypeVar('ActionsType')


[docs] class TransitionMap(Generic[ConfigurationType, UnderlyingTransitionsType, UnderlyingTransitionsValidityType, ActionsType]): """ Base TransitionMap class. Generic class that implements a collection of transitions over a set of cells. """
[docs] def get_transitions(self, configuration: ConfigurationType) -> Tuple[UnderlyingTransitionsValidityType]: """ Return a tuple of transitions available in a cell specified by `configuration` (e.g., a tuple of size of the maximum number of transitions, with values 0 or 1, or potentially in between, for stochastic transitions). Parameters ---------- configuration Returns ------- tuple List of the validity of transitions in the cell. """ raise NotImplementedError()
[docs] def set_transitions(self, configuration: ConfigurationType, new_transitions: UnderlyingTransitionsType): """ Replaces the available transitions in cell `configuration` with the tuple `new_transitions'. `new_transitions` must have one element for each possible transition. Parameters ---------- configuration : [ConfigurationType] The configuration object depends on the specific implementation. It generally is an int (e.g., an index) or a tuple of indices. new_transitions : [TransitionsType] Tuple of new transitions validitiy for the cell. """ raise NotImplementedError()
[docs] def get_transition(self, configuration: ConfigurationType, transition_index: int) -> UnderlyingTransitionsValidityType: """ Return the status of whether an agent in cell `configuration` can perform a movement along transition `transition_index` (e.g., the NESW direction of movement, for agents on a grid). Parameters ---------- configuration : [cell identifier] The configuration object depends on the specific implementation. It generally is an int (e.g., an index) or a tuple of indices. transition_index : int Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. Returns ------- int or float (depending on Transitions used) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ raise NotImplementedError()
[docs] def set_transition(self, configuration: ConfigurationType, transition_index, new_transition): """ Replaces the validity of transition to `transition_index` in cell `configuration' with the new `new_transition`. Parameters ---------- configuration : [cell identifier] The configuration object depends on the specific implementation. It generally is an int (e.g., an index) or a tuple of indices. transition_index : int Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. new_transition : int or float (depending on Transitions used) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ raise NotImplementedError()
[docs] def check_action_on_agent(self, action: ActionsType, configuration: ConfigurationType) -> Tuple[bool, ConfigurationType, bool, ActionsType]: """ Apply the action on the train regardless of locations of other agents. Checks for valid cells to move and valid rail transitions. Parameters ---------- action : [ActionsType] Action to execute configuration : ConfigurationType position and orientation Returns ------- new_cell_valid: bool is the new position and direction valid (i.e. is it within bounds and does it have > 0 outgoing transitions) new_position: [ConfigurationType] New position after applying the action transition_valid: bool Whether the transition from old and direction is defined in the grid. preprocessed_action: [ActionType] Corrected action if not transition_valid. """ raise NotImplementedError()
[docs] class GridTransitionMap(TransitionMap[Tuple[Tuple[int, int], int], Grid4Transitions, Tuple[bool], Any], Generic[ActionsType]): """ Implements a TransitionMap over a 2D grid. """ def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), grid: np.ndarray = None): """ Parameters ---------- width : int Width of the grid. height : int Height of the grid. transitions : Transitions The Transitions object to use to encode/decode transitions over the grid. """ self.width = width self.height = height self.transitions = transitions if grid is None: self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) else: if grid.dtype != self.transitions.get_type(): warnings.warn(f"Expected dtype {self.transitions.get_type()}, found {grid.dtype}.") self.grid = grid self._reset_cache() def _reset_cache(self): # use __eq__ and __hash__ to control cache lifecycle of instance methods, see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls. self.uuid = uuid.uuid4().int def __eq__(self, __value): return isinstance(__value, GridTransitionMap) and self.uuid == __value.uuid def __hash__(self): return self.uuid
[docs] @lru_cache(maxsize=1_000_000) def get_full_transitions(self, row, column) -> UnderlyingTransitionsType: """ Returns the full transitions for the cell at (row, column) in the format transition_map's transitions. Parameters ---------- row: int column: int (row,column) specifies the cell in this transition map. Returns ------- self.transitions.get_type() The cell content int the format of this map's Transitions. """ return self.grid[(row, column)]
[docs] @lru_cache(maxsize=4_000_000) def get_transitions(self, configuration: Tuple[Tuple[int, int], int]) -> Tuple[Grid4Transitions]: """ Return a tuple of transitions available in a cell specified by `configuration` (e.g., a tuple of size of the maximum number of transitions, with values 0 or 1, or potentially in between, for stochastic transitions). Parameters ---------- configuration : tuple The configuration indices a cell as ((column, row), orientation), where orientation is the direction an agent is facing within a cell. Alternatively, it can be accessed as (column, row) to return the full cell content. Returns ------- tuple List of the validity of transitions in the cell as given by the maps transitions. """ row_col, orientation = configuration return self.transitions.get_transitions(self.grid[row_col], orientation)
[docs] def set_transitions(self, configuration: IntVector2D, new_transitions: Transitions): """ Replaces the available transitions in cell `configuration` with the tuple `new_transitions'. `new_transitions` must have one element for each possible transition. Parameters ---------- configuration : tuple The configuration indices a cell as (column, row, orientation), where orientation is the direction an agent is facing within a cell. Alternatively, it can be accessed as (column, row) to replace the full cell content. new_transitions : tuple Tuple of new transitions validitiy for the cell. """ self._reset_cache() # assert len(configuration) in (2, 3), \ # 'GridTransitionMap.set_transitions() ERROR: configuration tuple must have length 2 or 3.' if len(configuration) == 3: self.grid[configuration[0:2]] = self.transitions.set_transitions(self.grid[configuration[0:2]], configuration[2], new_transitions) elif len(configuration) == 2: self.grid[configuration] = new_transitions
[docs] @lru_cache(maxsize=4_000_000) def get_transition(self, configuration: Tuple[Tuple[int, int], int], transition_index): row_col, orientation = configuration return self.transitions.get_transition(self.grid[row_col], orientation, transition_index)
[docs] def set_transition(self, configuration: Tuple[Tuple[int, int], int], transition_index, new_transition, remove_deadends=False): """ Replaces the validity of transition to `transition_index` in cell `configuration' with the new `new_transition`. Parameters ---------- configuration : tuple The configuration indices a cell as (column, row, orientation), where orientation is the direction an agent is facing within a cell. transition_index : int Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. new_transition : int or float (depending on Transitions used in the map.) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ self._reset_cache() # assert len(configuration) == 3, \ # 'GridTransitionMap.set_transition() ERROR: configuration tuple must have length 3.' nDir = configuration[2] if type(nDir) == np.ndarray: # I can't work out how to dump a complete backtrace here try: assert type(nDir) == int, "cell direction is not an int" except Exception as e: traceback.print_stack() print("fixing nDir:", configuration, nDir) nDir = int(nDir[0]) # if type(transition_index) not in (int, np.int64): if isinstance(transition_index, np.ndarray): # print("fixing transition_index:", configuration, transition_index) if type(transition_index) == np.ndarray: transition_index = int(transition_index.ravel()[0]) else: # print("transition_index type:", type(transition_index)) transition_index = int(transition_index) # if type(new_transition) not in (int, bool): if isinstance(new_transition, np.ndarray): # print("fixing new_transition:", configuration, new_transition) new_transition = int(new_transition.ravel()[0]) self.grid[configuration[0]][configuration[1]] = self.transitions.set_transition( self.grid[configuration[0:2]], nDir, transition_index, new_transition, remove_deadends)
[docs] def save_transition_map(self, filename): """ Save the transitions grid as `filename`, in npy format. Parameters ---------- filename : string Name of the file to which to save the transitions grid. """ np.save(filename, self.grid)
[docs] def load_transition_map(self, package, resource): """ Load the transitions grid from `filename` (npy format). The load function only updates the transitions grid, and possibly width and height, but the object has to be initialized with the correct `transitions` object anyway. Parameters ---------- package : string Name of the package from which to load the transitions grid. resource : string Name of the file from which to load the transitions grid within the package. override_gridsize : bool If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than (height,width) ) """ self._reset_cache() with path(package, resource) as file_in: new_grid = np.load(file_in) new_height = new_grid.shape[0] new_width = new_grid.shape[1] self.width = new_width self.height = new_height self.grid = new_grid
[docs] @lru_cache(maxsize=1_000_000) def is_dead_end(self, rcPos: IntVector2DArray): """ Check if the cell is a dead-end. Parameters ---------- rcPos: Tuple[int,int] tuple(row, column) with grid coordinate Returns ------- boolean True if and only if the cell is a dead-end. """ cell_transition = self.get_full_transitions(rcPos[0], rcPos[1]) return Grid4Transitions.has_deadend(cell_transition)
[docs] @lru_cache(maxsize=1_000_000) def is_simple_turn(self, rcPos: IntVector2DArray): """ Check if the cell is a left/right simple turn Parameters ---------- rcPos: Tuple[int,int] tuple(row, column) with grid coordinate Returns ------- boolean True if and only if the cell is a left/right simple turn. """ tmp = self.get_full_transitions(rcPos[0], rcPos[1]) def is_simple_turn(trans): all_simple_turns = OrderedSet() for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2) # Case 1c (9) - simple turn left]: ]: for _ in range(3): trans = self.transitions.rotate_transition(trans, rotation=90) all_simple_turns.add(trans) return trans in all_simple_turns return is_simple_turn(tmp)
[docs] @lru_cache(maxsize=4_000_000) def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray): """ Breath first search for a possible path from one node with a certain orientation to a target node. :param start: Start cell rom where we want to check the path :param direction: Start direction for the path we are testing :param end: Cell that we try to reach from the start cell :return: True if a path exists, False otherwise """ visited = OrderedSet() stack = [(start, direction)] while stack: node = stack.pop() node_position = node[0] node_direction = node[1] if Vec2d.is_equal(node_position, end): return True if node not in visited: visited.add(node) moves = self.get_transitions((node_position, node_direction)) for move_index in range(4): if moves[move_index]: stack.append((get_new_position(node_position, move_index), move_index)) return False
[docs] @lru_cache(maxsize=1_000_000) def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) Checks that: - surrounding cells have inbound transitions for all the outbound transitions of this cell. These are NOT checked - see transition.is_valid: - all transitions have the mirror transitions (N->E <=> W->S) - Reverse transitions (N -> S) only exist for a dead-end - a cell contains either no dead-ends or exactly one Returns: True (valid) or False (invalid) """ cell_transition = self.grid[tuple(rcPos)] if check_this_cell: if not self.transitions.is_valid(cell_transition): return False gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] grcPos = array(rcPos) grcMax = self.grid.shape binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int # loop over available outbound directions (indices) for rcPos for iDirOut in giDirOut: gdRC = gDir2dRC[iDirOut] # row,col increment gPos2 = grcPos + gdRC # next cell in that direction # Check the adjacent cell is within bounds # if not, then this transition is invalid! if np.any(gPos2 < 0): return False if np.any(gPos2 >= grcMax): return False # Get the transitions out of gPos2, using iDirOut as the inbound direction # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid t4Trans2 = self.get_transitions(((gPos2[0], gPos2[1]), iDirOut)) if any(t4Trans2): continue else: return False # If the cell is empty but has incoming connections we return false if binTrans < 1: connected = 0 for iDirOut in np.arange(4): gdRC = gDir2dRC[iDirOut] # row,col increment gPos2 = grcPos + gdRC # next cell in that direction # Check the adjacent cell is within bounds # if not, then ignore it for the count of incoming connections if np.any(gPos2 < 0): continue if np.any(gPos2 >= grcMax): continue # Get the transitions out of gPos2, using iDirOut as the inbound direction # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid for orientation in range(4): connected += self.get_transition(((gPos2[0], gPos2[1]), orientation), mirror(iDirOut)) if connected > 0: return False return True
[docs] def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) Checks that: - surrounding cells have inbound transitions for all the outbound transitions of this cell. These are NOT checked - see transition.is_valid: - all transitions have the mirror transitions (N->E <=> W->S) - Reverse transitions (N -> S) only exist for a dead-end - a cell contains either no dead-ends or exactly one Returns: True (valid) or False (invalid) """ self._reset_cache() cell_transition = self.grid[tuple(rcPos)] if check_this_cell: if not self.transitions.is_valid(cell_transition): return False gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] grcPos = array(rcPos) grcMax = self.grid.shape binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int # loop over available outbound directions (indices) for rcPos for iDirOut in giDirOut: gdRC = gDir2dRC[iDirOut] # row,col increment gPos2 = grcPos + gdRC # next cell in that direction # Check the adjacent cell is within bounds # if not, then this transition is invalid! if np.any(gPos2 < 0): return False if np.any(gPos2 >= grcMax): return False # Get the transitions out of gPos2, using iDirOut as the inbound direction # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid t4Trans2 = self.get_transitions((gPos2, iDirOut)) if any(t4Trans2): continue else: self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1) return False return True
[docs] @lru_cache(maxsize=1_000_000) def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D, new_pos: IntVector2D, end_pos: IntVector2D): """ Utility function to test that a path drawn by a-start algorithm uses valid transition objects. We us this to quide a-star as there are many transition elements that are not allowed in RailEnv Parameters ---------- prev_pos : IntVector2D The previous position we were checking current_pos : IntVector2D The current position we are checking new_pos : IntVector2D Possible child position we move into end_pos : IntVector2D End cell of path we are drawing Returns ---------- True if the transition is valid, False if transition element is illegal """ # start by getting direction used to get to current node # and direction from current node to possible child node new_dir = get_direction(current_pos, new_pos) if prev_pos is not None: current_dir = get_direction(prev_pos, current_pos) else: current_dir = new_dir # create new transition that would go to child new_trans = self.grid[current_pos] if prev_pos is None: if new_trans == 0: # need to flip direction because of how end points are defined new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1) else: # check if matches existing layout new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1) else: # set the forward path new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1) # set the backwards path new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) if Vec2d.is_equal(new_pos, end_pos): # need to validate end pos setup as well new_trans_e = self.grid[end_pos] if new_trans_e == 0: # need to flip direction because of how end points are defined new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) else: # check if matches existing layout new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1) if not self.transitions.is_valid(new_trans_e): return False # is transition is valid? return self.transitions.is_valid(new_trans)
[docs] def mirror(dir): return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)