Source code for flatland.trajectories.trajectory_observation_callbacks
import pickle
from pathlib import Path
from typing import Optional, Literal
import msgpack
from flatland.callbacks.callbacks import FlatlandCallbacks
from flatland.envs.rail_env import RailEnv
from flatland.trajectories.trajectories import SERIALISED_STATE_SUBDIR
from flatland.trajectories.trajectories import Trajectory
[docs]
class TrajectoryObservationCallbacks(FlatlandCallbacks):
"""
FlatlandCallbacks to write observations.
Parameters
----------
trajectory: Trajectory
the trajectory
data_dir_override : Path
use this override instead of the `data_dir` passed in the callback.
"""
def __init__(self, trajectory: Trajectory, data_dir_override: Path = None, format: Literal["pkl", "mpk"] = "pkl"):
self.trajectory = trajectory
self.data_dir_override = data_dir_override
self.format = format
def _dump(self, data_dir: Path, env: RailEnv):
if self.format == "pkl":
data = pickle.dumps(env._get_observations())
elif self.format == "mpk":
data = msgpack.packb(env._get_observations())
else:
raise ValueError("Format must be \"mpk\" (msgpack) or \"pkl\" for pickle")
(data_dir / SERIALISED_STATE_SUBDIR).mkdir(exist_ok=True, parents=True)
with (data_dir / SERIALISED_STATE_SUBDIR / f"{self.trajectory.ep_id}_obs{env._elapsed_steps:04d}.{self.format}").open("wb") as f:
f.write(data)
[docs]
def on_episode_start(
self,
*,
env: Optional[RailEnv] = None,
data_dir: Path = None,
**kwargs,
) -> None:
if self.data_dir_override is not None:
data_dir = self.data_dir_override
self._dump(data_dir, env)
[docs]
def on_episode_step(
self,
*,
env: Optional[RailEnv] = None,
data_dir: Path = None,
**kwargs,
) -> None:
if self.data_dir_override is not None:
data_dir = self.data_dir_override
self._dump(data_dir, env)