Source code for flatland.trajectories.policy_grid_runner

from pathlib import Path

import pandas as pd

from flatland.trajectories.policy_runner import generate_trajectory_from_policy


[docs] def generate_trajectories_from_metadata( metadata_csv: Path, data_dir: Path, policy_pkg: str, policy_cls: str, obs_builder_pkg: str, obs_builder_cls: str): metadata = pd.read_csv(metadata_csv) for k, v in metadata.iterrows(): try: test_folder = data_dir / v["test_id"] / v["env_id"] test_folder.mkdir(parents=True, exist_ok=True) generate_trajectory_from_policy( ["--data-dir", test_folder, "--policy-pkg", policy_pkg, "--policy-cls", policy_cls, "--obs-builder-pkg", obs_builder_pkg, "--obs-builder-cls", obs_builder_cls, "--n_agents", v["n_agents"], "--x_dim", v["x_dim"], "--y_dim", v["y_dim"], "--n_cities", v["n_cities"], "--max_rail_pairs_in_city", v["max_rail_pairs_in_city"], "--grid_mode", v["grid_mode"], "--max_rails_between_cities", v["max_rails_between_cities"], "--malfunction_duration_min", v["malfunction_duration_min"], "--malfunction_duration_max", v["malfunction_duration_max"], "--malfunction_interval", v["malfunction_interval"], "--speed_ratios", "1.0", "0.25", "--speed_ratios", "0.5", "0.25", "--speed_ratios", "0.33", "0.25", "--speed_ratios", "0.25", "0.25", "--seed", v["seed"], "--snapshot-interval", 0, "--ep-id", v["test_id"] + "_" + v["env_id"] ]) except SystemExit as exc: assert exc.code == 0
if __name__ == '__main__': metadata_csv = Path("./episodes/trajectories/malfunction_deadlock_avoidance_heuristics/metadata.csv").resolve() data_dir = Path("./episodes/trajectories/malfunction_deadlock_avoidance_heuristics").resolve() generate_trajectories_from_metadata( metadata_csv=metadata_csv, data_dir=data_dir, policy_pkg="flatland_baselines.deadlock_avoidance_heuristic.policy.deadlock_avoidance_policy", policy_cls="DeadLockAvoidancePolicy", obs_builder_pkg="flatland_baselines.deadlock_avoidance_heuristic.observation.full_env_observation", obs_builder_cls="FullEnvObservation" )