Source code for flatland.trajectories.trajectory_snapshot_callbacks
from pathlib import Path
from typing import Optional
from flatland.callbacks.callbacks import FlatlandCallbacks
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.rail_env import RailEnv
from flatland.trajectories.trajectories import SERIALISED_STATE_SUBDIR
from flatland.trajectories.trajectories import Trajectory
[docs]
class TrajectorySnapshotCallbacks(FlatlandCallbacks):
"""
FlatlandCallbacks to write env snapshots at the configured interval.
Parameters
----------
trajectory: Trajectory
the trajectory
data_dir_override : Path
use this override instead of the `data_dir` passed in the callback.
snapshot_interval : int
interval to write pkl snapshots to serialised_state subdirectory of the data_dir or data_dir_override. 1 means at every step. 0 means never.
"""
def __init__(self, trajectory: Trajectory, data_dir_override: Path = None, snapshot_interval: int = None):
self.trajectory = trajectory
self.snapshot_interval = snapshot_interval
self.data_dir_override = data_dir_override
[docs]
def on_episode_start(
self,
*,
env: Optional[RailEnv] = None,
data_dir: Path = None,
**kwargs,
) -> None:
if self.data_dir_override is not None:
data_dir = self.data_dir_override
if self.snapshot_interval > 0:
(data_dir / SERIALISED_STATE_SUBDIR).mkdir(exist_ok=True)
RailEnvPersister.save(env, str(data_dir / SERIALISED_STATE_SUBDIR / f"{self.trajectory.ep_id}_step{env._elapsed_steps:04d}.pkl"))
[docs]
def on_episode_step(
self,
*,
env: Optional[RailEnv] = None,
data_dir: Path = None,
**kwargs,
) -> None:
if self.data_dir_override is not None:
data_dir = self.data_dir_override
elapsed_steps = env._elapsed_steps
if self.snapshot_interval > 0 and elapsed_steps % self.snapshot_interval == 0:
(data_dir / SERIALISED_STATE_SUBDIR).mkdir(exist_ok=True)
RailEnvPersister.save(env, str(data_dir / SERIALISED_STATE_SUBDIR / f"{self.trajectory.ep_id}_step{elapsed_steps :04d}.pkl"))