flatland.trajectories.trajectories module#
- class flatland.trajectories.trajectories.Trajectory(data_dir: Path, ep_id: str = NOTHING)[source]#
Bases:
object
Encapsulates episode data (actions, positions etc.) for one or multiple episodes for further analysis/evaluation.
Aka. Episode Aka. Recording
In contrast to rllib (ray-project/ray), we use a tabular approach (tsv-backed) instead of `dict`s.
Directory structure: - event_logs
ActionEvents.discrete_action – holds set of action to be replayed for the related episodes. TrainMovementEvents.trains_arrived – holds success rate for the related episodes. TrainMovementEvents.trains_positions – holds the positions for the related episodes.
- serialised_state
<ep_id>.pkl – Holds the pickled environment version for the episode.
- action_collect(df: DataFrame, env_time: int, agent_id: int, action: RailEnvActions)[source]#
- action_lookup(actions_df: DataFrame, env_time: int, agent_id: int) RailEnvActions [source]#
Method used to retrieve the stored action (if available). Defaults to 2 = MOVE_FORWARD.
Parameters#
- actions_df: pd.DataFrame
Data frame from ActionEvents.discrete_action.tsv
- env_time: int
action going into step env_time
- agent_id: int
agent ID
Returns#
- RailEnvActions
The action to step the env.
- static create_from_policy(policy: Policy, data_dir: Path, env: RailEnv | None = None, n_agents=7, x_dim=30, y_dim=30, n_cities=2, max_rail_pairs_in_city=4, grid_mode=False, max_rails_between_cities=2, malfunction_duration_min=20, malfunction_duration_max=50, malfunction_interval=540, speed_ratios=None, seed=42, obs_builder: ObservationBuilder | None = None, snapshot_interval: int = 1, ep_id: str | None = None, callbacks: FlatlandCallbacks | None = None) Trajectory [source]#
Creates trajectory by running submission (policy and obs builder).
Parameters#
- policyPolicy
the submission’s policy
- data_dirPath
the path to write the trajectory to
- env: RailEnv
directly inject env, skip env generation
- n_agents: int
number of agents
- x_dim: int
number of columns
- y_dim: int
number of rows
- n_cities: int
Max number of cities to build. The generator tries to achieve this numbers given all the parameters. Goes into sparse_rail_generator.
- max_rail_pairs_in_city: int
Number of parallel tracks in the city. This represents the number of tracks in the train stations. Goes into sparse_rail_generator.
- grid_mode: bool
How to distribute the cities in the path, either equally in a grid or random. Goes into sparse_rail_generator.
- max_rails_between_cities: int
Max number of rails connecting to a city. This is only the number of connection points at city boarder.
- malfunction_duration_min: int
Minimal duration of malfunction. Goes into ParamMalfunctionGen.
- malfunction_duration_max: int
Max duration of malfunction. Goes into ParamMalfunctionGen.
- malfunction_interval: int
Inverse of rate of malfunction occurrence. Goes into ParamMalfunctionGen.
- speed_ratios: Dict[float, float]
Speed ratios of all agents. They are probabilities of all different speeds and have to add up to 1. Goes into sparse_line_generator. Defaults to {1.0: 0.25, 0.5: 0.25, 0.33: 0.25, 0.25: 0.25}.
- seed: int
Initiate random seed generators. Goes into reset.
- obs_builder: Optional[ObservationBuilder]
Defaults to TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50))
- snapshot_intervalint
interval to write pkl snapshots
- ep_id: str
episode ID to store data under. If not provided, generate one.
- callbacks: FlatlandCallbacks
callbacks to run during trajectory creation
Returns#
Trajectory
- property outputs_dir: Path#
- position_collect(df: DataFrame, env_time: int, agent_id: int, position: Tuple[Tuple[int, int], int])[source]#
- position_lookup(df: DataFrame, env_time: int, agent_id: int) Tuple[Tuple[int, int], int] [source]#
Method used to retrieve the stored position (if available).
Parameters#
- df: pd.DataFrame
Data frame from ActionEvents.discrete_action.tsv
- env_time: int
position before (!) step env_time
- agent_id: int
agent ID
Returns#
- Tuple[Tuple[int, int], int]
The position in the format ((row, column), direction).
- read_trains_positions() DataFrame [source]#
Returns pd df with all trains’ positions for all episodes.
- restore_episode(start_step: int | None = None) RailEnv [source]#
Restore an episode.
Parameters#
- start_stepOptional[int]
start from snapshot (if it exists)
Returns#
- RailEnv
the episode
- trains_arrived_lookup(movements_df: DataFrame) Series [source]#
Method used to retrieve the trains arrived for the episode.
Parameters#
- movements_df: pd.DataFrame
Data frame from event_logs/TrainMovementEvents.trains_arrived.tsv
Returns#
- pd.Series
The trains arrived data.