Source code for flatland.envs.timetable_generators
"""Timetable generators: Railway Undertaking (RU) / Eisenbahnverkehrsunternehmen (EVU)."""
import pickle
import warnings
from pathlib import Path
from typing import List, Union
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs import persistence
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.grid.distance_map import DistanceMap
from flatland.envs.timetable_utils import Timetable
[docs]
def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap,
agents_hints: dict, np_random: RandomState = None) -> Timetable:
"""
Calculates earliest departure and latest arrival times for the agents
This is the new addition in Flatland 3
Also calculates the max episodes steps based on the density of the timetable
inputs:
agents - List of all the agents rail_env.agents
distance_map - Distance map of positions to targets of each agent in each direction
agent_hints - Uses the number of cities
np_random - RNG state for seeding
returns:
Timetable with the latest_arrivals, earliest_departures and max_episdode_steps
"""
# max_episode_steps calculation
if agents_hints:
city_positions = agents_hints['city_positions']
num_cities = len(city_positions)
else:
num_cities = 2
timedelay_factor = 4
alpha = 2
num_agents = len(agents)
max_episode_steps = int(timedelay_factor * alpha * \
(distance_map.rail.width + distance_map.rail.height + (num_agents / num_cities)))
# Multipliers
old_max_episode_steps_multiplier = 3.0
new_max_episode_steps_multiplier = 1.5
travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier
assert new_max_episode_steps_multiplier > travel_buffer_multiplier
end_buffer_multiplier = 0.05
mean_shortest_path_multiplier = 0.2
if len(agents[0].waypoints) > 1:
# distance for intermediates parts and sum up
line_length = len(agents[0].waypoints) - 1
fake_agents = []
for i in range(line_length):
for a in agents:
waypoints = a.waypoints
fake_agents.append(EnvAgent(
handle=i * num_agents + a.handle,
initial_configuration=(waypoints[i][0].position, waypoints[i][0].direction),
# N.B. routing flexibility is ignored by this timetable generator
current_configuration=(waypoints[i][0].position, waypoints[i][0].direction),
old_configuration=(None, None),
targets={(waypoints[i + 1][0].position, d) for d in Grid4TransitionsEnum},
))
distance_map_with_intermediates = DistanceMap(fake_agents, distance_map.env_height, distance_map.env_width)
distance_map_with_intermediates.reset(fake_agents, distance_map.rail)
shortest_paths = distance_map_with_intermediates.get_shortest_paths()
shortest_path_segment_lengths = [[] for _ in range(num_agents)]
for k, v in shortest_paths.items():
shortest_path_segment_lengths[k % num_agents].append(len_handle_none(v))
shortest_paths_lengths = [sum(l) for l in shortest_path_segment_lengths]
else:
shortest_paths = distance_map.get_shortest_paths()
shortest_paths_lengths = [len_handle_none(v) for k, v in shortest_paths.items()]
shortest_path_segment_lengths = [[l] for l in shortest_paths_lengths]
# Find mean_shortest_path_time
agent_speeds = [agent.speed_counter.speed for agent in agents]
agent_shortest_path_times = np.array(shortest_paths_lengths) / np.array(agent_speeds)
mean_shortest_path_time = np.mean(agent_shortest_path_times)
# Deciding on a suitable max_episode_steps
longest_speed_normalized_time = np.max(agent_shortest_path_times)
mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier
max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay)
max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier)
max_episode_steps = min(max_episode_steps_new, max_episode_steps_old)
end_buffer = int(max_episode_steps * end_buffer_multiplier)
latest_arrival_max = max_episode_steps - end_buffer
earliest_departures = []
latest_arrivals = []
for agent in agents:
agent_shortest_path_time = agent_shortest_path_times[agent.handle]
agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay))
departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1)
earliest_departure = np_random.randint(0, departure_window_max)
latest_arrival = earliest_departure + agent_travel_time_max
agent.earliest_departure = earliest_departure
agent.latest_arrival = latest_arrival
ed = earliest_departure
eds = [earliest_departure]
for l in shortest_path_segment_lengths[agent.handle]:
ed += l
eds.append(ed)
la = latest_arrival
las = [latest_arrival]
for l in reversed(shortest_path_segment_lengths[agent.handle]):
la -= l
las.insert(0, la)
eds[-1] = None
las[0] = None
earliest_departures.append(eds)
latest_arrivals.append(las)
return Timetable(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals,
max_episode_steps=max_episode_steps)
[docs]
def ttgen_flatland2(agents: List[EnvAgent], distance_map: DistanceMap,
agents_hints: dict, np_random: RandomState = None) -> Timetable:
n_max_steps = 1000
return Timetable(
earliest_departures=[[0]] * len(agents),
latest_arrivals=[[n_max_steps]] * len(agents),
max_episode_steps=n_max_steps)
[docs]
class FileTimetableGenerator:
def __init__(self, filename: Path, load_from_package: bool = None):
self.filename = filename
self.load_from_package = load_from_package
[docs]
def generate(self, *args, **kwargs) -> Timetable:
if self.load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(self.load_from_package, self.filename)
else:
with open(self.filename, "rb") as file_in:
load_data = file_in.read()
return pickle.loads(load_data)
[docs]
@staticmethod
def save(filename: Path, tt: Timetable):
with open(filename, "wb") as file_out:
file_out.write(pickle.dumps(tt))
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
[docs]
@staticmethod
def wrap(timetable_generator: timetable_generator, tt_pkl: Path) -> timetable_generator:
def _wrap(*args, **kwargs):
tt = timetable_generator(*args, **kwargs)
FileTimetableGenerator.save(tt_pkl, tt)
return tt
return _wrap
[docs]
def timetable_from_file(filename: Union[str, Path] = None, load_from_package=None, env_dict=None) -> timetable_generator:
"""
Utility to load timetable generator from persisted env - uses env_dict if populated, otherwise tries to load from file / package.
Parameters
----------
filename : Union[str, Path]
pickle file with persisted env, defaults to `None`.
load_from_package : str
package, defaults to `None`.
env_dict: dict
env_dict, defaults to `None`.
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(agents: List[EnvAgent], distance_map: DistanceMap, agents_hints: dict, np_random: RandomState = None) -> Timetable:
_env_dict = env_dict
if _env_dict is None:
_env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
agents = _env_dict["agents"]
max_episode_steps = _env_dict.get("max_episode_steps", 0)
if max_episode_steps == 0:
warnings.warn("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100
earliest_departures = [[a.earliest_departure] for a in agents]
latest_arrivals = [[a.latest_arrival] for a in agents]
return Timetable(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals, max_episode_steps=max_episode_steps)
return generator