Policy Evaluation and Trajectories#
We use the terms from arc42 for the different views.
Building Block View Trajectory Generation and Evaluation#
This is a conceptual view and reflects the target state of the implementation. In Flatland implementation, we currently do not yet distinguish between configuration and state, they go together in RailEnvPersister and a trajectory currently consists of full snapshots.
classDiagram
class Runner {
Trajectory: +generate_trajectory_from_policy(Policy policy, ObservationBuilder obs_builder, int snapshot_interval)$ Trajectory
}
class Evaluator {
Trajectory: +evaluate(Trajectory trajectory)
}
class Trajectory {
Trajectory: +Path data_dir
Trajectory: +UUID ep_id
Trajectory: +run(int from_step, int to_step=-1)
}
class EnvSnapshot {
EnvSnapshot: +Path data_dir
Trajectory: +UUID ep_id
}
class EnvConfiguration
EnvConfiguration: +int max_episode_steps
EnvConfiguration: +int height
EnvConfiguration: +int width
EnvConfiguration: +Rewards reward_function
EnvConfiguration: +MalGen
EnvConfiguration: +RailGen etc. reset
class EnvState {
EnvState: +Grid rail
}
class EnvConfiguration
class EnvState
class EnvSnapshot
class EnvActions
class EnvRewards
EnvSnapshot --> "1" EnvConfiguration
EnvSnapshot --> "1" EnvState
Trajectory --> "1" EnvConfiguration
Trajectory --> "1..*" EnvState
Trajectory --> "1..*" EnvActions
Trajectory --> "1..*" EnvRewards
class Policy
Policy: act(int handle, Observation observation)
class ObservationBuilder
ObservationBuilder: get()
ObservationBuilder: get_many()
class Submission
Submission --> "1" Policy
Submission --> ObservationBuilder
Remarks:
Trajectory needs not start at step 0
Trajectory needs not contain state for every step - however, when starting the trajectory from an intermediate step, the snapshot must exist.
Runtime View Trajectory Generation#
flowchart TD
subgraph PolicyRunner.create_from_policy
start((" ")) -->|data_dir| D0
D0(RailEnvPersister.load_new) -->|env| E{env done?}
E -->|no:<br/>observations| G{Agent loop:<br/> more agents?}
G --->|observation| G1(policy.act)
G1 -->|action| G
G -->|no:<br/> actions| F3(env.step)
F3 -->|observations,rewards,info| E
E -->|yes:<br/> rewards| H(((" ")))
end
style Policy fill: #ffe, stroke: #333, stroke-width: 1px, color: black
style G1 fill: #ffe, stroke: #333, stroke-width: 1px, color: black
style Env fill: #fcc, stroke: #333, stroke-width: 1px, color: black
style F3 fill: #fcc, stroke: #333, stroke-width: 1px, color: black
subgraph legend
Env(Environment)
Policy(Policy)
Trajectory(Trajectory)
end
PolicyRunner.create_from_policy~~~legend
Trajectory Generation and Evaluation#
Create a trajectory from a random policy and inspect the output
import tempfile
from pathlib import Path
from typing import Any, Optional
from flatland.env_generation.env_generator import env_generator
from flatland.core.policy import Policy
from flatland.trajectories.policy_runner import PolicyRunner
from flatland.utils.seeding import np_random, random_state_to_hashablestate
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator, evaluate_trajectory
from flatland.trajectories.trajectories import Trajectory
class RandomPolicy(Policy):
def __init__(self, action_size: int = 5, seed=42):
super(RandomPolicy, self).__init__()
self.action_size = action_size
self.np_random, _ = np_random(seed=seed)
def act(self, observation: Any, **kwargs):
return self.np_random.choice(self.action_size)
with tempfile.TemporaryDirectory() as tmpdirname:
data_dir = Path(tmpdirname)
env, _, _ = env_generator(seed=44)
trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), env=env, data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
# np_random in loaded episode is same as if it comes directly from env_generator incl. reset()!
env = trajectory.load_env()
# we need to seed explictly to have the same env re-generated!
gen, _, _ = env_generator(seed=44)
assert random_state_to_hashablestate(env.np_random) == random_state_to_hashablestate(gen.np_random)
# inspect output
for p in sorted(data_dir.rglob("**/*")):
print(p)
# inspect the actions taken by the policy
print(trajectory._read_actions())
# verify steps 5 to 15 - we can start at 5 as there is a snapshot for step 5.
TrajectoryEvaluator(trajectory).evaluate(start_step=15,end_step=25, tqdm_kwargs={"disable": True})
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:344: UserWarning: Could not set all required cities! Created 1/2
warnings.warn(city_warning)
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:238: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
/tmp/tmpz2t1gs0y/event_logs
/tmp/tmpz2t1gs0y/event_logs/ActionEvents.discrete_action.tsv
/tmp/tmpz2t1gs0y/event_logs/TrainMovementEvents.trains_arrived.tsv
/tmp/tmpz2t1gs0y/event_logs/TrainMovementEvents.trains_positions.tsv
/tmp/tmpz2t1gs0y/event_logs/TrainMovementEvents.trains_rewards_dones_infos.tsv
/tmp/tmpz2t1gs0y/outputs
/tmp/tmpz2t1gs0y/serialised_state
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0000.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0015.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0030.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0045.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0060.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0075.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0090.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0105.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0120.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0135.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0150.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0165.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0180.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0195.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0210.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0225.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0240.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0255.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0270.pkl
episode_id env_time agent_id \
0 42c44911-bce9-47a2-92ad-0ec6db0d2e43 0 0
1 42c44911-bce9-47a2-92ad-0ec6db0d2e43 0 1
2 42c44911-bce9-47a2-92ad-0ec6db0d2e43 0 2
3 42c44911-bce9-47a2-92ad-0ec6db0d2e43 0 3
4 42c44911-bce9-47a2-92ad-0ec6db0d2e43 0 4
... ... ... ...
1983 42c44911-bce9-47a2-92ad-0ec6db0d2e43 283 2
1984 42c44911-bce9-47a2-92ad-0ec6db0d2e43 283 3
1985 42c44911-bce9-47a2-92ad-0ec6db0d2e43 283 4
1986 42c44911-bce9-47a2-92ad-0ec6db0d2e43 283 5
1987 42c44911-bce9-47a2-92ad-0ec6db0d2e43 283 6
action
0 RailEnvActions.MOVE_RIGHT
1 RailEnvActions.MOVE_LEFT
2 RailEnvActions.MOVE_RIGHT
3 RailEnvActions.MOVE_FORWARD
4 RailEnvActions.MOVE_FORWARD
... ...
1983 RailEnvActions.MOVE_LEFT
1984 RailEnvActions.DO_NOTHING
1985 RailEnvActions.STOP_MOVING
1986 RailEnvActions.MOVE_LEFT
1987 RailEnvActions.MOVE_FORWARD
[1988 rows x 4 columns]
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:81: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.trains_arrived = pd.concat([self.trains_arrived, self._collected_trains_arrived_to_df()])
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:82: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.trains_rewards_dones_infos = pd.concat([self.trains_rewards_dones_infos, self._collected_trains_rewards_dones_infos_to_df()])
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:344: UserWarning: Could not set all required cities! Created 1/2
warnings.warn(city_warning)
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:238: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
List of Policy implementations#
tests.trajectories.test_trajectories.RandomPolicyflatland.envs.rail_env_policies.ShortestPathPolicyflatland_baselines.deadlock_avoidance_heuristic.policy.deadlock_avoidance_policy.DeadLockAvoidancePolicy(see flatland-baselines)flatland/ml/ray/wrappers.ray_policy_wrapperandflatland/ml/ray/wrappers.ray_checkpoint_policy_wrapperfor wrapping RLlib RLModules.
Flatland Callbacks#
Flatland callbacks can be used for custom metrics and custom postprocessing.
import inspect
from flatland.envs.rail_env import RailEnv
from flatland.callbacks.callbacks import FlatlandCallbacks, make_multi_callbacks
lines, _ = inspect.getsourcelines(FlatlandCallbacks)
print("".join(lines))
class FlatlandCallbacks(Generic[EnvType]):
"""
Abstract base class for Flatland callbacks similar to rllib, see https://github.com/ray-project/ray/blob/master/rllib/callbacks/callbacks.py.
These callbacks can be used for custom metrics and custom postprocessing.
By default, all of these callbacks are no-ops.
"""
def on_episode_start(
self,
*,
env: Optional[EnvType] = None,
data_dir: Path = None,
**kwargs,
) -> None:
"""Callback run right after an Episode has been started.
This method gets called after `env.reset()`.
Parameters
---------
env : Environment
the env
data_dir : Path
trajectory data dir
kwargs:
Forward compatibility placeholder.
"""
pass
def on_episode_step(
self,
*,
env: Optional[EnvType] = None,
data_dir: Path = None,
**kwargs,
) -> None:
"""Called on each episode step (after the action(s) has/have been logged).
This callback is also called after the final step of an episode,
meaning when terminated/truncated are returned as True
from the `env.step()` call.
The exact time of the call of this callback is after `env.step([action])` and
also after the results of this step (observation, reward, terminated, truncated,
infos) have been logged to the given `episode` object.
Parameters
---------
env : Environment
the env
data_dir : Path
trajectory data dir
kwargs:
Forward compatibility placeholder.
"""
pass
def on_episode_end(
self,
*,
env: Optional[EnvType] = None,
data_dir: Path = None,
**kwargs,
) -> None:
"""Called when an episode is done (after terminated/truncated have been logged).
The exact time of the call of this callback is after `env.step([action])`
Parameters
---------
env : Environment
the env
data_dir : Path
trajectory data dir
kwargs:
Forward compatibility placeholder.
"""
pass
class DummyCallbacks(FlatlandCallbacks):
def on_episode_step(
self,
*,
env: Optional[RailEnv] = None,
**kwargs,
) -> None:
if (env._elapsed_steps - 1) % 10 == 0:
print(f"step{env._elapsed_steps - 1}")
with tempfile.TemporaryDirectory() as tmpdirname:
data_dir = Path(tmpdirname)
trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), env=env_generator()[0], data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
TrajectoryEvaluator(trajectory, callbacks=make_multi_callbacks(DummyCallbacks())).evaluate(tqdm_kwargs={"disable": True})
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:81: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.trains_arrived = pd.concat([self.trains_arrived, self._collected_trains_arrived_to_df()])
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:82: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.trains_rewards_dones_infos = pd.concat([self.trains_rewards_dones_infos, self._collected_trains_rewards_dones_infos_to_df()])
step0
step10
step20
step30
step40
step50
step60
step70
step80
step90
step100
step110
step120
step130
step140
step150
step160
step170
step180
step190
step200
step210
step220
step230
step240
step250
step260
step270
0.0% trains arrived. Expected 0.0%. 273 elapsed steps.
List of FlatlandCallbacks#
flatland.callbacks.generate_movie_callbacks.GenerateMovieCallbacksflatland.integrations.interactiveai/interactiveai.FlatlandInteractiveAICallbacksflatland.trajectories.trajectory_snapshot_callbacks.TrajectorySnapshotCallbacks