"""
TransitionMap and derived classes.
"""
import traceback
import uuid
import warnings
from functools import lru_cache
from typing import Tuple, Generic, TypeVar, Any
import numpy as np
from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position, get_direction
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
ConfigurationType = TypeVar('ConfigurationType')
UnderlyingTransitionsType = TypeVar('UnderlyingTransitionsType')
UnderlyingTransitionsValidityType = TypeVar('UnderlyingTransitionsValidityType')
ActionsType = TypeVar('ActionsType')
[docs]
class TransitionMap(Generic[ConfigurationType, UnderlyingTransitionsType, UnderlyingTransitionsValidityType, ActionsType]):
"""
Base TransitionMap class.
Generic class that implements a collection of transitions over a set of cells.
"""
[docs]
def get_transitions(self, configuration: ConfigurationType) -> Tuple[UnderlyingTransitionsValidityType]:
"""
Return a tuple of transitions available in a cell specified by
`configuration` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
Parameters
----------
configuration
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
raise NotImplementedError()
[docs]
def set_transitions(self, configuration: ConfigurationType, new_transitions: UnderlyingTransitionsType):
"""
Replaces the available transitions in cell `configuration` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
----------
configuration : [ConfigurationType]
The configuration object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
new_transitions : [TransitionsType]
Tuple of new transitions validitiy for the cell.
"""
raise NotImplementedError()
[docs]
def get_transition(self, configuration: ConfigurationType, transition_index: int) -> UnderlyingTransitionsValidityType:
"""
Return the status of whether an agent in cell `configuration` can perform a
movement along transition `transition_index` (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
----------
configuration : [cell identifier]
The configuration object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
Returns
-------
int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
raise NotImplementedError()
[docs]
def set_transition(self, configuration: ConfigurationType, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index` in cell
`configuration' with the new `new_transition`.
Parameters
----------
configuration : [cell identifier]
The configuration object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
raise NotImplementedError()
[docs]
def check_action_on_agent(self, action: ActionsType, configuration: ConfigurationType) -> Tuple[bool, ConfigurationType, bool, ActionsType]:
"""
Apply the action on the train regardless of locations of other agents.
Checks for valid cells to move and valid rail transitions.
Parameters
----------
action : [ActionsType]
Action to execute
configuration : ConfigurationType
position and orientation
Returns
-------
new_cell_valid: bool
is the new position and direction valid (i.e. is it within bounds and does it have > 0 outgoing transitions)
new_position: [ConfigurationType]
New position after applying the action
transition_valid: bool
Whether the transition from old and direction is defined in the grid.
preprocessed_action: [ActionType]
Corrected action if not transition_valid.
"""
raise NotImplementedError()
[docs]
class GridTransitionMap(TransitionMap[Tuple[Tuple[int, int], int], Grid4Transitions, Tuple[bool], Any], Generic[ActionsType]):
"""
Implements a TransitionMap over a 2D grid.
"""
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), grid: np.ndarray = None):
"""
Parameters
----------
width : int
Width of the grid.
height : int
Height of the grid.
transitions : Transitions
The Transitions object to use to encode/decode transitions over the
grid.
"""
self.width = width
self.height = height
self.transitions = transitions
if grid is None:
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
else:
if grid.dtype != self.transitions.get_type():
warnings.warn(f"Expected dtype {self.transitions.get_type()}, found {grid.dtype}.")
self.grid = grid
self._reset_cache()
def _reset_cache(self):
# use __eq__ and __hash__ to control cache lifecycle of instance methods, see https://docs.python.org/3/faq/programming.html#how-do-i-cache-method-calls.
self.uuid = uuid.uuid4().int
def __eq__(self, __value):
return isinstance(__value, GridTransitionMap) and self.uuid == __value.uuid
def __hash__(self):
return self.uuid
[docs]
@lru_cache(maxsize=1_000_000)
def get_full_transitions(self, row, column) -> UnderlyingTransitionsType:
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
Parameters
----------
row: int
column: int
(row,column) specifies the cell in this transition map.
Returns
-------
self.transitions.get_type()
The cell content int the format of this map's Transitions.
"""
return self.grid[(row, column)]
[docs]
@lru_cache(maxsize=4_000_000)
def get_transitions(self, configuration: Tuple[Tuple[int, int], int]) -> Tuple[Grid4Transitions]:
"""
Return a tuple of transitions available in a cell specified by
`configuration` (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
Parameters
----------
configuration : tuple
The configuration indices a cell as ((column, row), orientation),
where orientation is the direction an agent is facing within a cell.
Alternatively, it can be accessed as (column, row) to return the
full cell content.
Returns
-------
tuple
List of the validity of transitions in the cell as given by the maps transitions.
"""
row_col, orientation = configuration
return self.transitions.get_transitions(self.grid[row_col], orientation)
[docs]
def set_transitions(self, configuration: IntVector2D, new_transitions: Transitions):
"""
Replaces the available transitions in cell `configuration` with the tuple
`new_transitions'. `new_transitions` must have
one element for each possible transition.
Parameters
----------
configuration : tuple
The configuration indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
Alternatively, it can be accessed as (column, row) to replace the
full cell content.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
"""
self._reset_cache()
# assert len(configuration) in (2, 3), \
# 'GridTransitionMap.set_transitions() ERROR: configuration tuple must have length 2 or 3.'
if len(configuration) == 3:
self.grid[configuration[0:2]] = self.transitions.set_transitions(self.grid[configuration[0:2]],
configuration[2],
new_transitions)
elif len(configuration) == 2:
self.grid[configuration] = new_transitions
[docs]
@lru_cache(maxsize=4_000_000)
def get_transition(self, configuration: Tuple[Tuple[int, int], int], transition_index):
row_col, orientation = configuration
return self.transitions.get_transition(self.grid[row_col], orientation, transition_index)
[docs]
def set_transition(self, configuration: Tuple[Tuple[int, int], int], transition_index, new_transition, remove_deadends=False):
"""
Replaces the validity of transition to `transition_index` in cell
`configuration' with the new `new_transition`.
Parameters
----------
configuration : tuple
The configuration indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on Transitions used in the map.)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
self._reset_cache()
# assert len(configuration) == 3, \
# 'GridTransitionMap.set_transition() ERROR: configuration tuple must have length 3.'
nDir = configuration[2]
if type(nDir) == np.ndarray:
# I can't work out how to dump a complete backtrace here
try:
assert type(nDir) == int, "cell direction is not an int"
except Exception as e:
traceback.print_stack()
print("fixing nDir:", configuration, nDir)
nDir = int(nDir[0])
# if type(transition_index) not in (int, np.int64):
if isinstance(transition_index, np.ndarray):
# print("fixing transition_index:", configuration, transition_index)
if type(transition_index) == np.ndarray:
transition_index = int(transition_index.ravel()[0])
else:
# print("transition_index type:", type(transition_index))
transition_index = int(transition_index)
# if type(new_transition) not in (int, bool):
if isinstance(new_transition, np.ndarray):
# print("fixing new_transition:", configuration, new_transition)
new_transition = int(new_transition.ravel()[0])
self.grid[configuration[0]][configuration[1]] = self.transitions.set_transition(
self.grid[configuration[0:2]],
nDir,
transition_index,
new_transition,
remove_deadends)
[docs]
def save_transition_map(self, filename):
"""
Save the transitions grid as `filename`, in npy format.
Parameters
----------
filename : string
Name of the file to which to save the transitions grid.
"""
np.save(filename, self.grid)
[docs]
def load_transition_map(self, package, resource):
"""
Load the transitions grid from `filename` (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions` object anyway.
Parameters
----------
package : string
Name of the package from which to load the transitions grid.
resource : string
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename`. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) )
"""
self._reset_cache()
with path(package, resource) as file_in:
new_grid = np.load(file_in)
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
self.width = new_width
self.height = new_height
self.grid = new_grid
[docs]
@lru_cache(maxsize=1_000_000)
def is_dead_end(self, rcPos: IntVector2DArray):
"""
Check if the cell is a dead-end.
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a dead-end.
"""
cell_transition = self.get_full_transitions(rcPos[0], rcPos[1])
return Grid4Transitions.has_deadend(cell_transition)
[docs]
@lru_cache(maxsize=1_000_000)
def is_simple_turn(self, rcPos: IntVector2DArray):
"""
Check if the cell is a left/right simple turn
Parameters
----------
rcPos: Tuple[int,int]
tuple(row, column) with grid coordinate
Returns
-------
boolean
True if and only if the cell is a left/right simple turn.
"""
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
def is_simple_turn(trans):
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
for _ in range(3):
trans = self.transitions.rotate_transition(trans, rotation=90)
all_simple_turns.add(trans)
return trans in all_simple_turns
return is_simple_turn(tmp)
[docs]
@lru_cache(maxsize=4_000_000)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
node_position = node[0]
node_direction = node[1]
if Vec2d.is_equal(node_position, end):
return True
if node not in visited:
visited.add(node)
moves = self.get_transitions((node_position, node_direction))
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node_position, move_index),
move_index))
return False
[docs]
@lru_cache(maxsize=1_000_000)
def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(((gPos2[0], gPos2[1]), iDirOut))
if any(t4Trans2):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition(((gPos2[0], gPos2[1]), orientation), mirror(iDirOut))
if connected > 0:
return False
return True
[docs]
def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
self._reset_cache()
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions((gPos2, iDirOut))
if any(t4Trans2):
continue
else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False
return True
[docs]
@lru_cache(maxsize=1_000_000)
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
Parameters
----------
prev_pos : IntVector2D
The previous position we were checking
current_pos : IntVector2D
The current position we are checking
new_pos : IntVector2D
Possible child position we move into
end_pos : IntVector2D
End cell of path we are drawing
Returns
----------
True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = self.grid[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if Vec2d.is_equal(new_pos, end_pos):
# need to validate end pos setup as well
new_trans_e = self.grid[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
if not self.transitions.is_valid(new_trans_e):
return False
# is transition is valid?
return self.transitions.is_valid(new_trans)
[docs]
def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)