import importlib
from pathlib import Path
from typing import Optional
import click
import tqdm
from flatland.callbacks.callbacks import FlatlandCallbacks, make_multi_callbacks
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.policy import Policy
from flatland.env_generation.env_generator import env_generator
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.rail_env import RailEnv
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator
from flatland.trajectories.trajectories import Trajectory, SERIALISED_STATE_SUBDIR
[docs]
class PolicyRunner:
[docs]
@staticmethod
def create_from_policy(
policy: Policy,
data_dir: Path,
env: RailEnv = None,
n_agents=7,
x_dim=30,
y_dim=30,
n_cities=2,
max_rail_pairs_in_city=4,
grid_mode=False,
max_rails_between_cities=2,
malfunction_duration_min=20,
malfunction_duration_max=50,
malfunction_interval=540,
speed_ratios=None,
seed=42,
obs_builder: Optional[ObservationBuilder] = None,
snapshot_interval: int = 1,
ep_id: str = None,
callbacks: FlatlandCallbacks = None,
tqdm_kwargs: dict = None,
start_step: int = 0,
end_step: int = None,
fork_from_trajectory: "Trajectory" = None,
) -> "Trajectory":
"""
Creates trajectory by running submission (policy and obs builder).
Always backs up the actions and positions for steps executed in the tsvs.
Can start from existing trajectory.
Parameters
----------
policy : Policy
the submission's policy
data_dir : Path
the path to write the trajectory to
env: RailEnv
directly inject env, skip env generation
n_agents: int
number of agents
x_dim: int
number of columns
y_dim: int
number of rows
n_cities: int
Max number of cities to build. The generator tries to achieve this numbers given all the parameters. Goes into `sparse_rail_generator`.
max_rail_pairs_in_city: int
Number of parallel tracks in the city. This represents the number of tracks in the train stations. Goes into `sparse_rail_generator`.
grid_mode: bool
How to distribute the cities in the path, either equally in a grid or random. Goes into `sparse_rail_generator`.
max_rails_between_cities: int
Max number of rails connecting to a city. This is only the number of connection points at city boarder.
malfunction_duration_min: int
Minimal duration of malfunction. Goes into `ParamMalfunctionGen`.
malfunction_duration_max: int
Max duration of malfunction. Goes into `ParamMalfunctionGen`.
malfunction_interval: int
Inverse of rate of malfunction occurrence. Goes into `ParamMalfunctionGen`.
speed_ratios: Dict[float, float]
Speed ratios of all agents. They are probabilities of all different speeds and have to add up to 1. Goes into `sparse_line_generator`. Defaults to `{1.0: 0.25, 0.5: 0.25, 0.33: 0.25, 0.25: 0.25}`.
seed: int
Initiate random seed generators. Goes into `reset`.
obs_builder: Optional[ObservationBuilder]
Defaults to `TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50))`
snapshot_interval : int
interval to write pkl snapshots
ep_id: str
episode ID to store data under. If not provided, generate one.
callbacks: FlatlandCallbacks
callbacks to run during trajectory creation
tqdm_kwargs: dict
additional kwargs for tqdm
start_step : int
start evaluation from intermediate step incl. (requires snapshot to be present); take actions from start_step and first step executed is start_step + 1. Defaults to 0 with first elapsed step 1.
end_step : int
stop evaluation at intermediate step excl. Capped by env's max_episode_steps
fork_from_trajectory : Trajectory
copy data from this trajectory up to start step and run policy from there on
Returns
-------
Trajectory
"""
if ep_id is not None:
trajectory = Trajectory(data_dir=data_dir, ep_id=ep_id)
else:
trajectory = Trajectory(data_dir=data_dir)
trajectory.load()
# ensure to start with new empty df to avoid inconsistencies:
assert len(trajectory.trains_positions) == 0
assert len(trajectory.actions) == 0
assert len(trajectory.trains_arrived) == 0
assert len(trajectory.trains_rewards_dones_infos) == 0
if fork_from_trajectory is not None:
env = fork_from_trajectory.restore_episode(start_step=start_step)
fork_from_trajectory.load(episode_only=True)
# will run action start_step into step start_step+1
trajectory.actions = fork_from_trajectory.actions[fork_from_trajectory.actions["env_time"] < start_step]
trajectory.trains_positions = fork_from_trajectory.trains_positions[fork_from_trajectory.trains_positions["env_time"] <= start_step]
trajectory.trains_arrived = fork_from_trajectory.trains_arrived[fork_from_trajectory.trains_arrived["env_time"] <= start_step]
trajectory.trains_rewards_dones_infos = fork_from_trajectory.trains_rewards_dones_infos[
fork_from_trajectory.trains_rewards_dones_infos["env_time"] <= start_step]
trajectory.actions["episode_id"] = trajectory.ep_id
trajectory.trains_positions["episode_id"] = trajectory.ep_id
trajectory.trains_arrived["episode_id"] = trajectory.ep_id
trajectory.trains_rewards_dones_infos["episode_id"] = trajectory.ep_id
trajectory.persist()
if env is None:
env = fork_from_trajectory.restore_episode()
(trajectory.data_dir / SERIALISED_STATE_SUBDIR).mkdir(parents=True)
RailEnvPersister.save(env, trajectory.data_dir / SERIALISED_STATE_SUBDIR / f"{trajectory.ep_id}.pkl")
env = TrajectoryEvaluator(trajectory=trajectory, callbacks=callbacks).evaluate(end_step=start_step)
trajectory.load()
# TODO bad code smell - private method - check num resets?
observations = env._get_observations()
elif env is not None:
# TODO bad code smell - private method - check num resets?
observations = env._get_observations()
else:
env, observations, _ = env_generator(
n_agents=n_agents,
x_dim=x_dim,
y_dim=y_dim,
n_cities=n_cities,
max_rail_pairs_in_city=max_rail_pairs_in_city,
grid_mode=grid_mode,
max_rails_between_cities=max_rails_between_cities,
malfunction_duration_min=malfunction_duration_min,
malfunction_duration_max=malfunction_duration_max,
malfunction_interval=malfunction_interval,
speed_ratios=speed_ratios,
seed=seed,
obs_builder_object=obs_builder)
assert start_step == env._elapsed_steps, f"Expected env at {start_step}, found {env._elapsed_steps}."
if tqdm_kwargs is None:
tqdm_kwargs = {}
(data_dir / SERIALISED_STATE_SUBDIR).mkdir(parents=True, exist_ok=True)
RailEnvPersister.save(env, str(data_dir / SERIALISED_STATE_SUBDIR / f"{trajectory.ep_id}.pkl"))
if snapshot_interval > 0:
from flatland.trajectories.trajectory_snapshot_callbacks import TrajectorySnapshotCallbacks
if callbacks is None:
callbacks = TrajectorySnapshotCallbacks(trajectory, snapshot_interval=snapshot_interval, data_dir_override=data_dir)
else:
callbacks = make_multi_callbacks(callbacks,
TrajectorySnapshotCallbacks(trajectory, snapshot_interval=snapshot_interval, data_dir_override=data_dir))
trajectory.outputs_dir.mkdir(exist_ok=True)
n_agents = env.get_num_agents()
assert len(env.agents) == n_agents
env_time = start_step
if end_step is None:
end_step = env._max_episode_steps
env_time_range = range(start_step, end_step)
if callbacks is not None and start_step == 0:
callbacks.on_episode_start(env=env, data_dir=trajectory.outputs_dir)
for env_time in tqdm.tqdm(env_time_range, **tqdm_kwargs):
assert env_time == env._elapsed_steps
action_dict = policy.act_many(env.get_agent_handles(), observations)
for handle, action in action_dict.items():
trajectory.action_collect(env_time=env_time, agent_id=handle, action=action)
observations, rewards, dones, infos = env.step(action_dict)
for agent_id in range(n_agents):
agent = env.agents[agent_id]
actual_position = (agent.position, agent.direction)
trajectory.position_collect(env_time=env_time + 1, agent_id=agent_id, position=actual_position)
trajectory.rewards_dones_infos_collect(env_time=env_time + 1, agent_id=agent_id, reward=rewards.get(agent_id, 0.0),
info={k: v[agent_id] for k, v in infos.items()},
done=dones[agent_id])
done = dones['__all__']
if callbacks is not None:
callbacks.on_episode_step(env=env, data_dir=trajectory.outputs_dir)
if done:
if callbacks is not None:
callbacks.on_episode_end(env=env, data_dir=trajectory.outputs_dir)
break
actual_success_rate = sum([agent.state == 6 for agent in env.agents]) / n_agents
if done:
trajectory.arrived_collect(env_time, actual_success_rate)
trajectory.persist()
return trajectory
@click.command()
@click.option('--data-dir',
type=click.Path(exists=True, path_type=Path),
help="Path to folder containing Flatland episode",
required=True
)
@click.option('--policy-pkg',
type=str,
help="Policy's fully qualified package name.",
required=True
)
@click.option('--policy-cls',
type=str,
help="Policy class name.",
required=True
)
@click.option('--obs-builder-pkg',
type=str,
help="Defaults to `TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50))`",
required=False,
default=None
)
@click.option('--obs-builder-cls',
type=str,
help="Defaults to `TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50))`",
required=False,
default=None
)
@click.option('--n_agents',
type=int,
help="Number of agents.",
required=False,
default=7)
@click.option('--x_dim',
type=int,
help="Number of columns.",
required=False,
default=30)
@click.option('--y_dim',
type=int,
help="Number of rows.",
required=False,
default=30)
@click.option('--n_cities',
type=int,
help="Max number of cities to build. The generator tries to achieve this numbers given all the parameters. Goes into `sparse_rail_generator`. ",
required=False,
default=2)
@click.option('--max_rail_pairs_in_city',
type=int,
help="Number of parallel tracks in the city. This represents the number of tracks in the train stations. Goes into `sparse_rail_generator`.",
required=False,
default=4)
@click.option('--grid_mode',
type=bool,
help="How to distribute the cities in the path, either equally in a grid or random. Goes into `sparse_rail_generator`.",
required=False,
default=False)
@click.option('--max_rails_between_cities',
type=int,
help="Max number of rails connecting to a city. This is only the number of connection points at city boarder.",
required=False,
default=2)
@click.option('--malfunction_duration_min',
type=int,
help="Minimal duration of malfunction. Goes into `ParamMalfunctionGen`.",
required=False,
default=20)
@click.option('--malfunction_duration_max',
type=int,
help="Max duration of malfunction. Goes into `ParamMalfunctionGen`.",
required=False,
default=50)
@click.option('--malfunction_interval',
type=int,
help="Inverse of rate of malfunction occurrence. Goes into `ParamMalfunctionGen`.",
required=False,
default=540)
@click.option('--speed_ratios',
multiple=True,
nargs=2,
type=click.Tuple(types=[float, float]),
help="Speed ratios of all agents. They are probabilities of all different speeds and have to add up to 1. Goes into `sparse_line_generator`. Defaults to `{1.0: 0.25, 0.5: 0.25, 0.33: 0.25, 0.25: 0.25}`.",
required=False,
default=None)
@click.option('--seed',
type=int,
help="Initiate random seed generators. Goes into `reset`.",
required=False, default=42)
@click.option('--snapshot-interval',
type=int,
help="Interval to right snapshots. Use 0 to switch off, 1 for every step, ....",
required=False,
default=1)
@click.option('--ep-id',
type=str,
help="Set the episode ID used - if not set, a UUID will be sampled.",
required=False)
@click.option('--env-path',
type=click.Path(exists=True, path_type=Path),
help="Path to existing RailEnv to start trajectory from",
required=False
)
@click.option('--start-step',
type=int,
help="Path to existing RailEnv to start trajectory from",
required=False, default=0
)
@click.option('--end-step',
type=int,
help="Path to existing RailEnv to start trajectory from",
required=False, default=None
)
@click.option('--fork-data-dir',
type=click.Path(exists=True, path_type=Path),
help="Path to existing RailEnv to start trajectory from",
required=False, default=None
)
@click.option('--fork-ep-id',
type=int,
help="Path to existing RailEnv to start trajectory from",
required=False, default=None
)
def generate_trajectory_from_policy(
data_dir: Path,
policy_pkg: str, policy_cls: str,
obs_builder_pkg: str, obs_builder_cls: str,
n_agents=7,
x_dim=30,
y_dim=30,
n_cities=2,
max_rail_pairs_in_city=4,
grid_mode=False,
max_rails_between_cities=2,
malfunction_duration_min=20,
malfunction_duration_max=50,
malfunction_interval=540,
speed_ratios=None,
seed: int = 42,
snapshot_interval: int = 1,
ep_id: str = None,
env_path: Path = None,
start_step: int = 0,
end_step: int = None,
fork_data_dir: Path = None,
fork_ep_id: str = None,
):
module = importlib.import_module(policy_pkg)
policy_cls = getattr(module, policy_cls)
obs_builder = None
if obs_builder_pkg is not None and obs_builder_cls is not None:
module = importlib.import_module(obs_builder_pkg)
obs_builder_cls = getattr(module, obs_builder_cls)
obs_builder = obs_builder_cls()
env = None
if env_path is not None:
env, _ = RailEnvPersister.load_new(str(env_path))
fork_from_trajectory = None
if fork_data_dir is not None and fork_ep_id is not None:
fork_from_trajectory = Trajectory(data_dir=fork_data_dir, ep_id=fork_ep_id)
PolicyRunner.create_from_policy(
policy=policy_cls(),
data_dir=data_dir,
n_agents=n_agents,
x_dim=x_dim,
y_dim=y_dim,
n_cities=n_cities,
max_rail_pairs_in_city=max_rail_pairs_in_city,
grid_mode=grid_mode,
max_rails_between_cities=max_rails_between_cities,
malfunction_duration_min=malfunction_duration_min,
malfunction_duration_max=malfunction_duration_max,
malfunction_interval=malfunction_interval,
speed_ratios=dict(speed_ratios) if len(speed_ratios) > 0 else None,
seed=seed,
obs_builder=obs_builder,
snapshot_interval=snapshot_interval,
ep_id=ep_id,
env=env,
start_step=start_step,
end_step=end_step,
fork_from_trajectory=fork_from_trajectory,
)