Source code for flatland.evaluators.trajectory_analysis

from pathlib import Path
from typing import Tuple

import click
import pandas as pd
from pandas import DataFrame

from flatland.envs.rail_env import RailEnv
from flatland.trajectories.trajectories import Trajectory


[docs] def data_frame_for_trajectories(root_data_dir: Path) -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]: all_actions = [] all_trains_positions = [] all_trains_arrived = [] all_trains_rewards_dones_infos = [] env_stats = [] agent_stats = [] data_dirs = sorted([serialised_state.parent for serialised_state in (root_data_dir.resolve().glob("**/serialised_state"))]) print(data_dirs) for data_dir in data_dirs: snapshots = [snapshot for snapshot in (data_dir / "serialised_state").glob("*.pkl") if "step" not in snapshot.name] # must be data dir with single episode data assert len(snapshots) == 1, snapshots ep_id = snapshots[0].stem trajectory = Trajectory.load_existing(data_dir=data_dir, ep_id=ep_id) env: RailEnv = trajectory.load_env() all_actions.append(trajectory.actions) all_trains_positions.append(trajectory.trains_positions) all_trains_arrived.append(trajectory.trains_arrived) trajectory.trains_rewards_dones_infos["action_required"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["action_required"]) trajectory.trains_rewards_dones_infos["malfunction"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["malfunction"]) trajectory.trains_rewards_dones_infos["speed"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["speed"]) trajectory.trains_rewards_dones_infos["state"] = trajectory.trains_rewards_dones_infos["info"].map(lambda d: d["state"]) all_trains_rewards_dones_infos.append(trajectory.trains_rewards_dones_infos) env_stats.append(pd.DataFrame.from_records([{ "episode_id": ep_id, "max_episode_steps": env._max_episode_steps, "num_agents": len(env.agents), "x_dim": env.width, "y_dim": env.height, # TODO https://github.com/flatland-association/flatland-rl/issues/242 rail/line/timetable/malfunction generator not serializable currently. # max_rail_pairs_in_city=4, # grid_mode=False, # max_rails_between_cities=2, "malfunction_process_data": env.malfunction_process_data, # malfunction_duration_min=20, # malfunction_duration_max=50, # malfunction_interval=540, # speed_ratios=None, # line_length=2, # TODO https://github.com/flatland-association/flatland-rl/issues/7 standardization of obs builder interface and serialization "obs_builder": type(env.obs_builder), "acceleration_delta": env.acceleration_delta, "braking_delta": env.braking_delta, # TODO https://github.com/flatland-association/flatland-rl/issues/242 rewards not serializable currently. "rewards": type(env.rewards), }])) agent_stats.append(pd.DataFrame.from_records([{ "episode_id": ep_id, "agent_id": agent.handle, "earliest_departure": agent.earliest_departure, "latest_arrival": agent.latest_arrival, "num_waypoints": len(agent.waypoints), } for agent in env.agents])) all_actions = pd.concat(all_actions) all_trains_positions = pd.concat(all_trains_positions) all_trains_arrived = pd.concat(all_trains_arrived) all_trains_rewards_dones_infos = pd.concat(all_trains_rewards_dones_infos) env_stats = pd.concat(env_stats) agent_stats = pd.concat(agent_stats) print(all_trains_arrived) return all_actions, all_trains_positions, all_trains_arrived, all_trains_rewards_dones_infos, env_stats, agent_stats
[docs] def persist_data_frame_for_trajectories(agent_stats, all_actions, all_trains_arrived, all_trains_positions, all_trains_rewards_dones_infos, env_stats, output_dir): output_dir.mkdir(exist_ok=True, parents=True) assert len(list(output_dir.glob("*"))) == 0 all_actions.to_csv(output_dir / "all_actions.csv", index=False) all_trains_positions.to_csv(output_dir / "all_trains_positions.csv", index=False) all_trains_arrived.to_csv(output_dir / "all_trains_arrived.csv", index=False) all_trains_rewards_dones_infos.to_csv(output_dir / "all_trains_rewards_dones_infos.csv", index=False) env_stats.to_csv(output_dir / "env_stats.csv", index=False) agent_stats.to_csv(output_dir / "agent_stats.csv", index=False)
@click.command() @click.option( '--root-data-dir', type=click.Path(exists=True, path_type=Path), help="Path to existing trjajectories. Defaults to current directory.", default=Path("."), ) @click.option( '--output-dir', type=click.Path(file_okay=False, path_type=Path), help="Path store data frames to. Must be empty.", required=False, default=None ) def cli(root_data_dir: Path, output_dir: Path): all_actions, all_trains_positions, all_trains_arrived, all_trains_rewards_dones_infos, env_stats, agent_stats = data_frame_for_trajectories( root_data_dir=root_data_dir) if output_dir is not None: persist_data_frame_for_trajectories(agent_stats, all_actions, all_trains_arrived, all_trains_positions, all_trains_rewards_dones_infos, env_stats, output_dir)