Source code for flatland.envs.step_utils.env_utils

from dataclasses import dataclass
from typing import Tuple

from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.step_utils import transition_utils
from flatland.envs.rail_env_action import RailEnvActions
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.utils.decorators import enable_infrastructure_lru_cache


[docs] @dataclass(repr=True) class AgentTransitionData: """ Class for keeping track of temporary agent data for position update """ position : Tuple[int, int] direction : Grid4Transitions preprocessed_action : RailEnvActions
[docs] @enable_infrastructure_lru_cache(maxsize=1_000_000) def apply_action_independent(action, rail, position, direction): """ Apply the action on the train regardless of locations of other trains Checks for valid cells to move and valid rail transitions --------------------------------------------------------------------- Parameters: action - Action to execute rail - Flatland env.rail object position - current position of the train direction - current direction of the train --------------------------------------------------------------------- Returns: new_position - New position after applying the action new_direction - New direction after applying the action """ if action.is_moving_action(): new_direction, _ = transition_utils.check_action(action, position, direction, rail) new_position = get_new_position(position, new_direction) else: new_position, new_direction = position, direction return new_position, new_direction
[docs] def state_position_sync_check(state, position, i_agent): """ Check for whether on map and off map states are matching with position """ if state.is_on_map_state() and position is None: raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format( i_agent, str(state), str(position) )) elif state.is_off_map_state() and position is not None: raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format( i_agent, str(state), str(position) ))