"""
Definition of the RailEnv environment.
"""
import random
import warnings
from functools import lru_cache
from typing import List, Optional, Dict, Tuple, Set
import numpy as np
import flatland.envs.timetable_generators as ttg
from flatland.core.effects_generator import EffectsGenerator
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid_utils import Vector2D
from flatland.core.grid.rail_env_grid import RailEnvTransitionsEnum
from flatland.envs import agent_chains as ac
from flatland.envs import line_generators as line_gen
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import persistence
from flatland.envs import rail_generators as rail_gen
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_grid_transition_map import RailGridTransitionMap
from flatland.envs.rewards import Rewards
from flatland.envs.step_utils import env_utils
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.utils import seeding
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
[docs]
class RailEnv(Environment):
"""
RailEnv environment class.
RailEnv is an environment inspired by a (simplified version of) a rail
network, in which agents (trains) have to navigate to their target
locations in the shortest time possible, while at the same time cooperating
to avoid bottlenecks.
The valid actions in the environment are:
- 0: do nothing (continue moving or stay still)
- 1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
- 2: move to the next cell in front of the agent; if the agent was not moving, movement is started
- 3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
- 4: stop moving
Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from.
In order for agents to be able to "understand" the simulation behaviour from the observations,
the execution order of actions should not matter (i.e. not depend on the agent handle).
However, the agent ordering is still used to resolve conflicts between two agents trying to move into the same cell,
for example, head-on collisions, or agents "merging" at junctions.
See `MotionCheck` for more details.
Stochastic malfunctioning of trains:
Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
action or cell is selected.
Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
complexity manageable.
TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
"""
def __init__(self,
width,
height,
rail_generator: "RailGenerator" = None,
line_generator: "LineGenerator" = None,
number_of_agents=2,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
malfunction_generator_and_process_data=None,
malfunction_generator: "MalfunctionGenerator" = None,
remove_agents_at_target=True,
random_seed=None,
record_steps=False,
timetable_generator=ttg.timetable_generator,
acceleration_delta=1.0,
braking_delta=-1.0,
rewards: Rewards = None,
effects_generator: EffectsGenerator["RailEnv"] = None
):
"""
Environment init.
Parameters
----------
rail_generator : function
The rail_generator function is a function that takes the width,
height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
The rail_generator can pass a distance map in the hints or information for specific line_generators.
Implementations can be found in flatland/envs/rail_generators.py
line_generator : function
The line_generator function is a function that takes the grid, the number of agents and optional hints
and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
Implementations can be found in flatland/envs/line_generators.py
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
height : int
The height of the rail map. Potentially in the future,
a range of heights to sample from.
number_of_agents : int
Number of agents to spawn on the map. Potentially in the future,
a range of number of agents to sample from.
obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation
vectors for each agent.
remove_agents_at_target : bool
If remove_agents_at_target is set to true then the agents will be removed by placing to
RailEnv.DEPOT_POSITION when the agent has reached its target position.
random_seed : int or None
if None, then its ignored, else the random generators are seeded with this number to ensure
that stochastic operations are replicable across multiple operations
acceleration_delta : float
Determines how much speed is increased by MOVE_FORWARD action up to max_speed set by train's Line (sampled from `speed_ratios` by `LineGenerator`).
As speed is between 0.0 and 1.0, acceleration_delta=1.0 restores to previous constant speed behaviour
(i.e. MOVE_FORWARD always sets to max speed allowed for train).
braking_delta : float
Determines how much speed is decreased by STOP_MOVING action.
As speed is between 0.0 and 1.0, braking_delta=-1.0 restores to previous full stop behaviour.
rewards : Rewards
The rewards function to use. Defaults to standard settings of Flatland 3 behaviour.
effects_generator : Optional[EffectsGenerator["RailEnv"]]
The effects generator that can modify the env at the env of env reset, at the beginning of the env step and at the end of the env step.
"""
super().__init__()
if malfunction_generator_and_process_data is not None:
print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
elif malfunction_generator is not None:
self.malfunction_generator = malfunction_generator
# malfunction_process_data is not used
# self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.malfunction_process_data = self.malfunction_generator.get_process_data()
# replace default values here because we can't use default args values because of cyclic imports
else:
self.malfunction_generator = mal_gen.NoMalfunctionGen()
self.malfunction_process_data = self.malfunction_generator.get_process_data()
self.number_of_agents = number_of_agents
if rail_generator is None:
rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator
if line_generator is None:
line_generator = line_gen.sparse_line_generator()
self.line_generator: "LineGenerator" = line_generator
self.timetable_generator = timetable_generator
self.rail: Optional[RailGridTransitionMap] = None
self.width = width
self.height = height
self.remove_agents_at_target = remove_agents_at_target
self.obs_builder = obs_builder_object
self.obs_builder.set_env(self)
self._max_episode_steps: Optional[int] = None
self._elapsed_steps = 0
self.obs_dict = {}
self.rewards_dict = {}
self.dev_obs_dict = {}
self.dev_pred_dict = {}
self.agents: List[EnvAgent] = []
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.action_space = [5]
self._seed(seed=random_seed)
self.agent_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.record_steps = record_steps # whether to save timesteps
# save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
self.cur_episode = []
self.list_actions = [] # save actions in here
self.motion_check = ac.MotionCheck()
self.level_free_positions: Set[Vector2D] = set()
if rewards is None:
self.rewards = Rewards()
else:
self.rewards = rewards
self.acceleration_delta = acceleration_delta
self.braking_delta = braking_delta
self.effects_generator = effects_generator
self.temp_transition_data = {i: env_utils.AgentTransitionData(None, None, None, None, None, None, None, None) for i in range(self.get_num_agents())}
for i_agent in range(self.get_num_agents()):
self.temp_transition_data[i_agent].state_transition_signal = StateTransitionSignals()
def _seed(self, seed):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
self.random_seed = seed
# Keep track of all the seeds in order
if not hasattr(self, 'seed_history'):
self.seed_history = [seed]
if self.seed_history[-1] != seed:
self.seed_history.append(seed)
return [seed]
# no more agent_handles
[docs]
def get_agent_handles(self) -> List[int]:
return list(range(self.get_num_agents()))
[docs]
def get_num_agents(self) -> int:
return len(self.agents)
[docs]
def add_agent(self, agent):
""" Add static info for a single agent.
Returns the index of the new agent.
"""
self.agents.append(agent)
return len(self.agents) - 1
[docs]
def reset_agents(self):
""" Reset the agents to their starting positions
"""
for agent in self.agents:
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]
[docs]
@lru_cache()
@staticmethod
def action_required(agent_state, is_cell_entry):
"""
Check if an agent needs to provide an action
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return agent_state == TrainState.READY_TO_DEPART or \
(agent_state.is_on_map_state() and is_cell_entry)
[docs]
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, random_seed: int = None) -> Tuple[Dict, Dict]:
"""
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
The method resets the rail environment
Parameters
----------
regenerate_rail : bool, optional
regenerate the rails
regenerate_schedule : bool, optional
regenerate the schedule and the static agents
random_seed : int, optional
random seed for environment
Returns
-------
observation_dict: Dict
Dictionary with an observation for each agent
info_dict: Dict with agent specific information
"""
if random_seed is not None:
self._seed(random_seed)
optionals = {}
if regenerate_rail or self.rail is None:
if "__call__" in dir(self.rail_generator):
rail, optionals = self.rail_generator(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
elif "generate" in dir(self.rail_generator):
rail, optionals = self.rail_generator.generate(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
else:
raise ValueError("Could not invoke __call__ or generate on rail_generator")
self.rail = rail
self.height, self.width = self.rail.grid.shape
# Do a new set_env call on the obs_builder to ensure
# that obs_builder specific instantiations are made according to the
# specifications of the current environment : like width, height, etc
self.obs_builder.set_env(self)
if optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
if optionals and 'level_free_positions' in optionals:
self.level_free_positions = optionals['level_free_positions']
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
self.num_resets, self.np_random)
self.agents = EnvAgent.from_line(line)
# Reset distance map - basically initializing
self.distance_map.reset(self.agents, self.rail)
# NEW : Timetable Generation
timetable = self.timetable_generator(self.agents, self.distance_map,
agents_hints, self.np_random)
self._max_episode_steps = timetable.max_episode_steps
EnvAgent.apply_timetable(self.agents, timetable)
else:
self.distance_map.reset(self.agents, self.rail)
# Reset agents to initial states
self.reset_agents()
self.num_resets += 1
self._elapsed_steps = 0
# Agent positions map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
self._update_agent_positions_map(ignore_old_positions=False)
if self.effects_generator is not None:
self.effects_generator.on_episode_start(self)
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
# Empty the episode store of agent positions
self.cur_episode = []
self.temp_transition_data = {i: env_utils.AgentTransitionData(None, None, None, None, None, None, None, None) for i in range(self.get_num_agents())}
for i_agent in range(self.get_num_agents()):
self.temp_transition_data[i_agent].state_transition_signal = StateTransitionSignals()
info_dict = self.get_info_dict()
# Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def _update_agent_positions_map(self, ignore_old_positions=True):
""" Update the agent_positions array for agents that changed positions """
for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position:
if agent.position is not None:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
[docs]
def clear_rewards_dict(self):
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
[docs]
def get_info_dict(self):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
"""
info_dict = {
# TODO https://github.com/flatland-association/flatland-rl/issues/149 revise action required
'action_required': {i: RailEnv.action_required(agent.state, agent.speed_counter.is_cell_entry)
for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
[docs]
def end_of_episode_update(self, have_all_agents_ended):
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if have_all_agents_ended or \
((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
for i_agent, agent in enumerate(self.agents):
reward = self.rewards.end_of_episode_reward(agent, self.distance_map, self._elapsed_steps)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
[docs]
def handle_done_state(self, agent):
""" Any updates to agent to be made in Done state """
if agent.state == TrainState.DONE and agent.arrival_time is None:
agent.arrival_time = self._elapsed_steps
self.dones[agent.handle] = True
if self.remove_agents_at_target:
agent.position = None
[docs]
def step(self, action_dict: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
self._elapsed_steps += 1
# Not allowed to step further once done
if self.dones["__all__"]:
raise Exception("Episode is done, cannot call step()")
self.clear_rewards_dict()
self.motion_check = ac.MotionCheck() # reset the motion check
if self.effects_generator is not None:
self.effects_generator.on_episode_step_start(self)
for agent in self.agents:
i_agent = agent.handle
agent.old_position = agent.position
agent.old_direction = agent.direction
# Generate malfunction
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
# Get action for the agent
raw_action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
if current_position is None: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
_, new_direction_independent, new_position_independent, _, preprocessed_action = self.rail.check_action_on_agent(
RailEnvActions.from_value(raw_action),
current_position,
current_direction
)
# get desired new_position and new_direction
stop_action_given = preprocessed_action == RailEnvActions.STOP_MOVING
in_malfunction = agent.malfunction_handler.in_malfunction
movement_action_given = RailEnvActions.is_moving_action(preprocessed_action)
earliest_departure_reached = agent.earliest_departure <= self._elapsed_steps
new_speed = agent.speed_counter.speed
state = agent.state
agent_max_speed = agent.speed_counter.max_speed
# TODO revise design: should we instead of correcting LEFT/RIGHT to FORWARD instead preprocess to DO_NOTHING. Caveat: DO_NOTHING would be undefined for symmetric switches!
if (state == TrainState.STOPPED or state == TrainState.MALFUNCTION) and movement_action_given:
# start moving
new_speed += self.acceleration_delta
elif preprocessed_action == RailEnvActions.MOVE_FORWARD and raw_action == RailEnvActions.MOVE_FORWARD:
# accelerate, but not if left/right corrected to forward
new_speed += self.acceleration_delta
elif stop_action_given:
# decelerate
new_speed += self.braking_delta
new_speed = max(0.0, min(agent_max_speed, new_speed))
if state == TrainState.READY_TO_DEPART and movement_action_given:
new_position = agent.initial_position
new_direction = agent.initial_direction
elif state == TrainState.MALFUNCTION_OFF_MAP and not in_malfunction and earliest_departure_reached and (
movement_action_given or stop_action_given):
# TODO revise design: weirdly, MALFUNCTION_OFF_MAP does not go via READY_TO_DEPART, but STOP_MOVING and MOVE_* adds to map if possible
new_position = agent.initial_position
new_direction = agent.initial_direction
elif state.is_on_map_state():
new_position, new_direction = agent.position, agent.direction
# transition to next cell: at end of cell and next state potentially MOVING
if (agent.speed_counter.is_cell_exit(new_speed)
and
TrainStateMachine.can_get_moving_independent(state, in_malfunction, movement_action_given, new_speed, stop_action_given)
):
new_position, new_direction = new_position_independent, new_direction_independent
assert agent.position is not None
else:
assert state.is_off_map_state() or state == TrainState.DONE
new_position = None
new_direction = None
if new_position is not None:
valid_position_direction = any(self.rail.get_transitions(*new_position, new_direction))
if not valid_position_direction:
warnings.warn(f"{(new_position, new_direction)} not valid on the grid."
f" Coming from {(agent.position, agent.direction)} with raw action {raw_action} and preprocessed action {preprocessed_action}. {RailEnvTransitionsEnum(self.rail.get_full_transitions(*agent.position)).name}")
assert valid_position_direction
# only conflict if the level-free cell is traversed through the same axis (horizontally (0 north or 2 south), or vertically (1 east or 3 west)
new_position_level_free = new_position
if new_position in self.level_free_positions:
new_position_level_free = (new_position, new_direction % 2)
agent_position_level_free = agent.position
if agent.position in self.level_free_positions:
agent_position_level_free = (agent.position, agent.direction % 2)
# Malfunction starts when in_malfunction is set to true (inverse of malfunction_counter_complete)
self.temp_transition_data[i_agent].state_transition_signal.in_malfunction = agent.malfunction_handler.in_malfunction
# Earliest departure reached - Train is allowed to move now
self.temp_transition_data[i_agent].state_transition_signal.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
# Stop action given
self.temp_transition_data[i_agent].state_transition_signal.stop_action_given = stop_action_given
# Movement action given
self.temp_transition_data[i_agent].state_transition_signal.movement_action_given = movement_action_given
# Target reached - we only know after state and positions update - see handle_done_state below
self.temp_transition_data[i_agent].state_transition_signal.target_reached = None # we only know after motion check
# Movement allowed if inside cell or at end of cell and no conflict with other trains - we only know after motion check!
self.temp_transition_data[i_agent].state_transition_signal.movement_allowed = None # we only know after motion check
# New desired speed zero?
self.temp_transition_data[i_agent].state_transition_signal.new_speed_zero = new_speed == 0.0
self.temp_transition_data[i_agent].speed = agent.speed_counter.speed
self.temp_transition_data[i_agent].agent_position_level_free = agent_position_level_free
self.temp_transition_data[i_agent].new_position = new_position
self.temp_transition_data[i_agent].new_direction = new_direction
self.temp_transition_data[i_agent].new_speed = new_speed
self.temp_transition_data[i_agent].new_position_level_free = new_position_level_free
self.temp_transition_data[i_agent].preprocessed_action = preprocessed_action
# self.temp_transition_data[i_agent].state_transition_signal = state_transition_signals
self.motion_check.add_agent(i_agent, agent_position_level_free, new_position_level_free)
# Find conflicts between trains trying to occupy same cell
self.motion_check.find_conflicts()
have_all_agents_ended = True
for agent in self.agents:
i_agent = agent.handle
# Fetch the saved transition data
agent_transition_data = self.temp_transition_data[i_agent]
# motion_check is False if agent wants to stay in the cell
motion_check = self.motion_check.check_motion(i_agent, agent_transition_data.agent_position_level_free)
# Movement allowed if inside cell or at end of cell and no conflict with other trains
movement_allowed = (agent.state.is_on_map_state() and not agent.speed_counter.is_cell_exit(agent_transition_data.new_speed)) or motion_check
agent_transition_data.state_transition_signal.movement_allowed = movement_allowed
# state machine step
agent.state_machine.set_transition_signals(agent_transition_data.state_transition_signal)
agent.state_machine.step()
# position and speed_counter update
if agent.state == TrainState.MOVING:
# only position update while MOVING and motion_check OK
agent.position = agent_transition_data.new_position
agent.direction = agent_transition_data.new_direction
# N.B. no movement in first time step after READY_TO_DEPART or MALFUNCTION_OFF_MAP!
if not (agent.state_machine.previous_state == TrainState.READY_TO_DEPART or
agent.state_machine.previous_state == TrainState.MALFUNCTION_OFF_MAP):
agent.speed_counter.step(speed=agent_transition_data.new_speed)
agent.state_machine.update_if_reached(agent.position, agent.target)
elif agent.state_machine.previous_state == TrainState.MALFUNCTION_OFF_MAP and agent.state == TrainState.STOPPED:
agent.position = agent.initial_position
agent.direction = agent.initial_direction
# TODO revise design: condition could be generalized to not MOVING if we would enforce MALFUNCTION_OFF_MAP to go to READY_TO_DEPART first.
if agent.state.is_on_map_state() and agent.state != TrainState.MOVING:
agent.speed_counter.step(speed=0)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards
self.rewards_dict[i_agent] += self.rewards.step_reward(agent, agent_transition_data, self.distance_map, self._elapsed_steps)
# update malfunction counter
agent.malfunction_handler.update_counter()
# Off map or on map state and position should match
if not self._fast_state_position_sync_check(agent.state, agent.position, self.remove_agents_at_target):
agent.state_machine.state_position_sync_check(agent.position, agent.handle, self.remove_agents_at_target)
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map()
self._verify_mutually_exclusive_resource_allocation()
if self.record_steps:
self.record_timestep(action_dict)
if self.effects_generator is not None:
self.effects_generator.on_episode_step_end(self)
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
@lru_cache()
def _fast_state_position_sync_check(self, state, position, remove_agents_at_target):
""" Check for whether on map and off map states are matching with position being None """
if TrainState.is_on_map_state(state) and position is None:
return False
elif TrainState.is_off_map_state(state) and position is not None:
return False
elif state == TrainState.DONE and remove_agents_at_target and position is not None:
return False
return True
def _verify_mutually_exclusive_resource_allocation(self):
resources = [agent.position if agent.position not in self.level_free_positions else (*agent.position, agent.direction % 2) for agent in self.agents if
agent.position is not None]
if len(resources) != len(set(resources)):
msgs = f"Found two agents occupying same resource (cell or level-free cell) in step {self._elapsed_steps}: {resources}\n"
msgs += f"- motion check: {list(self.motion_check.stopped)}"
warnings.warn(msgs)
counts = {resource: resources.count(resource) for resource in set(resources)}
dup_positions = [pos for pos, count in counts.items() if count > 1]
for dup in dup_positions:
for agent in self.agents:
if agent.position == dup:
msg = (f"\n================== BAD AGENT ==================================\n\n\n\n\n"
f"- agent:\t{agent} \n"
f"- state_machine:\t{agent.state_machine}\n"
f"- speed_counter:\t{agent.speed_counter}\n"
f"- breakpoint:\tself._elapsed_steps == {self._elapsed_steps} and agent.handle == {agent.handle}\n"
f"- motion check:\t{list(self.motion_check.stopped)}\n\n\n"
f"- agents:\t{self.agents}")
warnings.warn(msg)
msgs += msg
assert len(resources) == len(set(resources)), msgs
# TODO extract to callbacks instead!
[docs]
def record_timestep(self, dActions):
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if agent.position is None:
pos = (0, 0)
else:
pos = (int(agent.position[0]), int(agent.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_handler.malfunction_down_counter,
agent.state.value,
int(agent.position in self.motion_check.deadlocked),
])
self.cur_episode.append(list_agents_state)
self.list_actions.append(dActions)
def _get_observations(self):
"""
Utility which returns the dictionary of observations for an agent with respect to environment
"""
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
[docs]
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def _exp_distirbution_synced(self, rate: float) -> float:
"""
Generates sample from exponential distribution
We need this to guarantee synchronicity between different instances with the same seed.
:param rate:
:return:
"""
u = self.np_random.rand()
x = - np.log(1 - u) * rate
return x
def _is_agent_ok(self, agent: EnvAgent) -> bool:
"""
Checks if an agent is ok, meaning it can move and is not malfunctioning.
Parameters
----------
agent
Returns
-------
True if agent is ok, False otherwise
"""
return agent.malfunction_handler.in_malfunction
[docs]
def save(self, filename):
print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
[docs]
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
Provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
[docs]
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
[docs]
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
[docs]
def close(self):
"""
Closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:", e)
self.renderer = None