Evaluation#

We use the terms from arc42 for the different views.

Building Block View Trajectory Generation and Evaluation#

This is a conceptual view and reflects the target state of the implementation. In Flatland implementation, we currently do not yet distinguish between configuration and state, they go together in RailEnvPersister and a trajectory currently consists of full snapshots.

        classDiagram
    class Runner {
        Trajectory: +generate_trajectory_from_policy(Policy policy, ObservationBuilder obs_builder, int snapshot_interval)$ Trajectory
    }
    
    class Evaluator {
        Trajectory: +evaluate(Trajectory trajectory)
    }
    
    class Trajectory {
        Trajectory: +Path data_dir
        Trajectory: +UUID ep_id
        Trajectory: +run(int from_step, int to_step=-1)
    }



    

    class EnvSnapshot {
        EnvSnapshot: +Path data_dir
        Trajectory: +UUID ep_id
    }

    class EnvConfiguration
    EnvConfiguration: +int max_episode_steps
    EnvConfiguration: +int height
    EnvConfiguration: +int width
    EnvConfiguration: +Rewards reward_function
    EnvConfiguration: +MalGen
    EnvConfiguration: +RailGen etc. reset

    class EnvState {
        EnvState: +Grid rail
    }

    
        class EnvConfiguration

        class EnvState

        class EnvSnapshot
        
        class EnvActions

        class EnvRewards
    

    

    EnvSnapshot --> "1" EnvConfiguration
    EnvSnapshot --> "1" EnvState
    Trajectory --> "1" EnvConfiguration
    Trajectory --> "1..*" EnvState
    Trajectory --> "1..*" EnvActions
    Trajectory --> "1..*" EnvRewards

    class Policy
    Policy: act(int handle, Observation observation)

    class ObservationBuilder
    ObservationBuilder: get()
    ObservationBuilder: get_many()

    class Submission
    Submission --> "1" Policy
    Submission --> ObservationBuilder
    

Remarks:

  • Trajectory needs not start at step 0

  • Trajectory needs not contain state for every step - however, when starting the trajectory from an intermediate step, the snapshot must exist.

Runtime View Trajectory Generation#

        flowchart TD
    subgraph PolicyRunner.create_from_policy
        start((" ")) -->|data_dir| D0
        D0(RailEnvPersister.load_new) -->|env| E{env done?}
        E -->|no:<br/>observations| G{Agent loop:<br/> more agents?}
        G --->|observation| G1(policy.act)
        G1 -->|action| G
        G -->|no:<br/> actions| F3(env.step)
        F3 -->|observations,rewards,info| E
        E -->|yes:<br/> rewards| H((("&nbsp;")))
    end

    style Policy fill: #ffe, stroke: #333, stroke-width: 1px, color: black
    style G1 fill: #ffe, stroke: #333, stroke-width: 1px, color: black
    style Env fill: #fcc, stroke: #333, stroke-width: 1px, color: black
    style F3 fill: #fcc, stroke: #333, stroke-width: 1px, color: black
    subgraph legend
        Env(Environment)
        Policy(Policy)
        Trajectory(Trajectory)
    end

    PolicyRunner.create_from_policy~~~legend
    

Trajectory Generation and Evaluation#

Create a trajectory from a random policy and inspect the output

import tempfile
from pathlib import Path
from typing import Any, Optional

from flatland.env_generation.env_generator import env_generator
from flatland.core.policy import Policy
from flatland.trajectories.policy_runner import PolicyRunner
from flatland.utils.seeding import np_random, random_state_to_hashablestate
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator, evaluate_trajectory
from flatland.trajectories.trajectories import Trajectory
class RandomPolicy(Policy):
    def __init__(self, action_size: int = 5, seed=42):
        super(RandomPolicy, self).__init__()
        self.action_size = action_size
        self.np_random, _ = np_random(seed=seed)

    def act(self, handle: int, observation: Any, **kwargs):
        return self.np_random.choice(self.action_size)
with tempfile.TemporaryDirectory() as tmpdirname:
    data_dir = Path(tmpdirname)
    trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
    # np_random in loaded episode is same as if it comes directly from env_generator incl. reset()!
    env = trajectory.restore_episode()
    gen, _, _ = env_generator()
    assert random_state_to_hashablestate(env.np_random) == random_state_to_hashablestate(gen.np_random)

    
    # inspect output
    for p in sorted(data_dir.rglob("**/*")):
        print(p)

    # inspect the actions taken by the policy
    print(trajectory._read_actions())

    # verify steps 5 to 15 - we can start at 5 as there is a snapshot for step 5.
    TrajectoryEvaluator(trajectory).evaluate(start_step=15,end_step=25, tqdm_kwargs={"disable": True})
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:335: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:231: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 3
      1 with tempfile.TemporaryDirectory() as tmpdirname:
      2     data_dir = Path(tmpdirname)
----> 3     trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
      4     # np_random in loaded episode is same as if it comes directly from env_generator incl. reset()!
      5     env = trajectory.restore_episode()

File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/flatland/trajectories/policy_runner.py:200, in PolicyRunner.create_from_policy(policy, data_dir, env, n_agents, x_dim, y_dim, n_cities, max_rail_pairs_in_city, grid_mode, max_rails_between_cities, malfunction_duration_min, malfunction_duration_max, malfunction_interval, speed_ratios, seed, obs_builder, snapshot_interval, ep_id, callbacks, tqdm_kwargs, start_step, end_step, fork_from_trajectory)
    197 for env_time in tqdm.tqdm(env_time_range, **tqdm_kwargs):
    198     assert env_time == env._elapsed_steps
--> 200     action_dict = policy.act_many(env.get_agent_handles(), observations=list(observations.values()))
    201     for handle, action in action_dict.items():
    202         trajectory.action_collect(env_time=env_time, agent_id=handle, action=action)

File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/flatland/core/policy.py:55, in Policy.act_many(self, handles, observations, **kwargs)
     36 def act_many(self, handles: List[int], observations: List[T_obs], **kwargs) -> Dict[int, T_act]:
     37     """
     38     Get action_dict for all agents. Default implementation calls `act()` for each handle in the list.
     39 
   (...)
     53         the action dict
     54     """
---> 55     return {handle: self.act(observations[handle]) for handle in handles}

File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/flatland/core/policy.py:55, in <dictcomp>(.0)
     36 def act_many(self, handles: List[int], observations: List[T_obs], **kwargs) -> Dict[int, T_act]:
     37     """
     38     Get action_dict for all agents. Default implementation calls `act()` for each handle in the list.
     39 
   (...)
     53         the action dict
     54     """
---> 55     return {handle: self.act(observations[handle]) for handle in handles}

TypeError: RandomPolicy.act() missing 1 required positional argument: 'observation'

List of Policy implementations#

  • tests.trajectories.test_trajectories.RandomPolicy

  • flatland_baselines.deadlock_avoidance_heuristic.policy.deadlock_avoidance_policy.DeadLockAvoidancePolicy (see flatland-baselines)

  • flatland/ml/ray/wrappers.ray_policy_wrapper and flatland/ml/ray/wrappers.ray_checkpoint_policy_wrapper for wrapping RLlib RLModules.

Flatland Callbacks#

Flatland callbacks can be used for custom metrics and custom postprocessing.

import inspect
from flatland.envs.rail_env import RailEnv
from flatland.callbacks.callbacks import FlatlandCallbacks, make_multi_callbacks
lines, _ = inspect.getsourcelines(FlatlandCallbacks)
print("".join(lines))
class DummyCallbacks(FlatlandCallbacks):
    def on_episode_step(
        self,
        *,
        env: Optional[RailEnv] = None,
        **kwargs,
    ) -> None:
        if (env._elapsed_steps - 1) % 10 == 0:
            print(f"step{env._elapsed_steps - 1}")
with tempfile.TemporaryDirectory() as tmpdirname:
    data_dir = Path(tmpdirname)
    trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
    TrajectoryEvaluator(trajectory, callbacks=make_multi_callbacks(DummyCallbacks())).evaluate(tqdm_kwargs={"disable": True})

List of FlatlandCallbacks#

  • flatland.callbacks.generate_movie_callbacks.GenerateMovieCallbacks

  • flatland.integrations.interactiveai/interactiveai.FlatlandInteractiveAICallbacks

  • flatland.trajectories.trajectory_snapshot_callbacks.TrajectorySnapshotCallbacks