Policy Evaluation and Trajectories#

We use the terms from arc42 for the different views.

Building Block View Trajectory Generation and Evaluation#

This is a conceptual view and reflects the target state of the implementation. In Flatland implementation, we currently do not yet distinguish between configuration and state, they go together in RailEnvPersister and a trajectory currently consists of full snapshots.

        classDiagram
    class Runner {
        Trajectory: +generate_trajectory_from_policy(Policy policy, ObservationBuilder obs_builder, int snapshot_interval)$ Trajectory
    }
    
    class Evaluator {
        Trajectory: +evaluate(Trajectory trajectory)
    }
    
    class Trajectory {
        Trajectory: +Path data_dir
        Trajectory: +UUID ep_id
        Trajectory: +run(int from_step, int to_step=-1)
    }



    

    class EnvSnapshot {
        EnvSnapshot: +Path data_dir
        Trajectory: +UUID ep_id
    }

    class EnvConfiguration
    EnvConfiguration: +int max_episode_steps
    EnvConfiguration: +int height
    EnvConfiguration: +int width
    EnvConfiguration: +Rewards reward_function
    EnvConfiguration: +MalGen
    EnvConfiguration: +RailGen etc. reset

    class EnvState {
        EnvState: +Grid rail
    }

    
        class EnvConfiguration

        class EnvState

        class EnvSnapshot
        
        class EnvActions

        class EnvRewards
    

    

    EnvSnapshot --> "1" EnvConfiguration
    EnvSnapshot --> "1" EnvState
    Trajectory --> "1" EnvConfiguration
    Trajectory --> "1..*" EnvState
    Trajectory --> "1..*" EnvActions
    Trajectory --> "1..*" EnvRewards

    class Policy
    Policy: act(int handle, Observation observation)

    class ObservationBuilder
    ObservationBuilder: get()
    ObservationBuilder: get_many()

    class Submission
    Submission --> "1" Policy
    Submission --> ObservationBuilder
    

Remarks:

  • Trajectory needs not start at step 0

  • Trajectory needs not contain state for every step - however, when starting the trajectory from an intermediate step, the snapshot must exist.

Runtime View Trajectory Generation#

        flowchart TD
    subgraph PolicyRunner.create_from_policy
        start((" ")) -->|data_dir| D0
        D0(RailEnvPersister.load_new) -->|env| E{env done?}
        E -->|no:<br/>observations| G{Agent loop:<br/> more agents?}
        G --->|observation| G1(policy.act)
        G1 -->|action| G
        G -->|no:<br/> actions| F3(env.step)
        F3 -->|observations,rewards,info| E
        E -->|yes:<br/> rewards| H((("&nbsp;")))
    end

    style Policy fill: #ffe, stroke: #333, stroke-width: 1px, color: black
    style G1 fill: #ffe, stroke: #333, stroke-width: 1px, color: black
    style Env fill: #fcc, stroke: #333, stroke-width: 1px, color: black
    style F3 fill: #fcc, stroke: #333, stroke-width: 1px, color: black
    subgraph legend
        Env(Environment)
        Policy(Policy)
        Trajectory(Trajectory)
    end

    PolicyRunner.create_from_policy~~~legend
    

Trajectory Generation and Evaluation#

Create a trajectory from a random policy and inspect the output

import tempfile
from pathlib import Path
from typing import Any, Optional

from flatland.env_generation.env_generator import env_generator
from flatland.core.policy import Policy
from flatland.trajectories.policy_runner import PolicyRunner
from flatland.utils.seeding import np_random, random_state_to_hashablestate
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator, evaluate_trajectory
from flatland.trajectories.trajectories import Trajectory
class RandomPolicy(Policy):
    def __init__(self, action_size: int = 5, seed=42):
        super(RandomPolicy, self).__init__()
        self.action_size = action_size
        self.np_random, _ = np_random(seed=seed)

    def act(self, observation: Any, **kwargs):
        return self.np_random.choice(self.action_size)
with tempfile.TemporaryDirectory() as tmpdirname:
    data_dir = Path(tmpdirname)
    env, _, _ = env_generator(seed=44)
    trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), env=env, data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
    # np_random in loaded episode is same as if it comes directly from env_generator incl. reset()!
    env = trajectory.load_env()
    # we need to seed explictly to have the same env re-generated!
    gen, _, _ = env_generator(seed=44)
    assert random_state_to_hashablestate(env.np_random) == random_state_to_hashablestate(gen.np_random)

    
    # inspect output
    for p in sorted(data_dir.rglob("**/*")):
        print(p)

    # inspect the actions taken by the policy
    print(trajectory._read_actions())

    # verify steps 5 to 15 - we can start at 5 as there is a snapshot for step 5.
    TrajectoryEvaluator(trajectory).evaluate(start_step=15,end_step=25, tqdm_kwargs={"disable": True})
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:344: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:238: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
/tmp/tmpz2t1gs0y/event_logs
/tmp/tmpz2t1gs0y/event_logs/ActionEvents.discrete_action.tsv
/tmp/tmpz2t1gs0y/event_logs/TrainMovementEvents.trains_arrived.tsv
/tmp/tmpz2t1gs0y/event_logs/TrainMovementEvents.trains_positions.tsv
/tmp/tmpz2t1gs0y/event_logs/TrainMovementEvents.trains_rewards_dones_infos.tsv
/tmp/tmpz2t1gs0y/outputs
/tmp/tmpz2t1gs0y/serialised_state
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0000.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0015.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0030.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0045.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0060.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0075.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0090.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0105.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0120.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0135.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0150.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0165.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0180.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0195.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0210.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0225.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0240.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0255.pkl
/tmp/tmpz2t1gs0y/serialised_state/42c44911-bce9-47a2-92ad-0ec6db0d2e43_step0270.pkl
                                episode_id  env_time  agent_id  \
0     42c44911-bce9-47a2-92ad-0ec6db0d2e43         0         0   
1     42c44911-bce9-47a2-92ad-0ec6db0d2e43         0         1   
2     42c44911-bce9-47a2-92ad-0ec6db0d2e43         0         2   
3     42c44911-bce9-47a2-92ad-0ec6db0d2e43         0         3   
4     42c44911-bce9-47a2-92ad-0ec6db0d2e43         0         4   
...                                    ...       ...       ...   
1983  42c44911-bce9-47a2-92ad-0ec6db0d2e43       283         2   
1984  42c44911-bce9-47a2-92ad-0ec6db0d2e43       283         3   
1985  42c44911-bce9-47a2-92ad-0ec6db0d2e43       283         4   
1986  42c44911-bce9-47a2-92ad-0ec6db0d2e43       283         5   
1987  42c44911-bce9-47a2-92ad-0ec6db0d2e43       283         6   

                           action  
0       RailEnvActions.MOVE_RIGHT  
1        RailEnvActions.MOVE_LEFT  
2       RailEnvActions.MOVE_RIGHT  
3     RailEnvActions.MOVE_FORWARD  
4     RailEnvActions.MOVE_FORWARD  
...                           ...  
1983     RailEnvActions.MOVE_LEFT  
1984    RailEnvActions.DO_NOTHING  
1985   RailEnvActions.STOP_MOVING  
1986     RailEnvActions.MOVE_LEFT  
1987  RailEnvActions.MOVE_FORWARD  

[1988 rows x 4 columns]
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:81: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  self.trains_arrived = pd.concat([self.trains_arrived, self._collected_trains_arrived_to_df()])
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:82: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  self.trains_rewards_dones_infos = pd.concat([self.trains_rewards_dones_infos, self._collected_trains_rewards_dones_infos_to_df()])
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:344: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/envs/rail_generators.py:238: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")

List of Policy implementations#

  • tests.trajectories.test_trajectories.RandomPolicy

  • flatland.envs.rail_env_policies.ShortestPathPolicy

  • flatland_baselines.deadlock_avoidance_heuristic.policy.deadlock_avoidance_policy.DeadLockAvoidancePolicy (see flatland-baselines)

  • flatland/ml/ray/wrappers.ray_policy_wrapper and flatland/ml/ray/wrappers.ray_checkpoint_policy_wrapper for wrapping RLlib RLModules.

Flatland Callbacks#

Flatland callbacks can be used for custom metrics and custom postprocessing.

import inspect
from flatland.envs.rail_env import RailEnv
from flatland.callbacks.callbacks import FlatlandCallbacks, make_multi_callbacks
lines, _ = inspect.getsourcelines(FlatlandCallbacks)
print("".join(lines))
class FlatlandCallbacks(Generic[EnvType]):
    """
    Abstract base class for Flatland callbacks similar to rllib, see https://github.com/ray-project/ray/blob/master/rllib/callbacks/callbacks.py.

    These callbacks can be used for custom metrics and custom postprocessing.

    By default, all of these callbacks are no-ops.
    """

    def on_episode_start(
        self,
        *,
        env: Optional[EnvType] = None,
        data_dir: Path = None,
        **kwargs,
    ) -> None:
        """Callback run right after an Episode has been started.

        This method gets called after `env.reset()`.

        Parameters
        ---------
            env : Environment
                the env
            data_dir : Path
                trajectory data dir
            kwargs:
                Forward compatibility placeholder.
        """
        pass

    def on_episode_step(
        self,
        *,
        env: Optional[EnvType] = None,
        data_dir: Path = None,
        **kwargs,
    ) -> None:
        """Called on each episode step (after the action(s) has/have been logged).

        This callback is also called after the final step of an episode,
        meaning when terminated/truncated are returned as True
        from the `env.step()` call.

        The exact time of the call of this callback is after `env.step([action])` and
        also after the results of this step (observation, reward, terminated, truncated,
        infos) have been logged to the given `episode` object.

        Parameters
        ---------
            env : Environment
                the env
            data_dir : Path
                trajectory data dir
            kwargs:
                Forward compatibility placeholder.
        """
        pass

    def on_episode_end(
        self,
        *,
        env: Optional[EnvType] = None,
        data_dir: Path = None,
        **kwargs,
    ) -> None:
        """Called when an episode is done (after terminated/truncated have been logged).

        The exact time of the call of this callback is after `env.step([action])`

        Parameters
        ---------
            env : Environment
                the env
            data_dir : Path
                trajectory data dir
            kwargs:
                Forward compatibility placeholder.
        """
        pass
class DummyCallbacks(FlatlandCallbacks):
    def on_episode_step(
        self,
        *,
        env: Optional[RailEnv] = None,
        **kwargs,
    ) -> None:
        if (env._elapsed_steps - 1) % 10 == 0:
            print(f"step{env._elapsed_steps - 1}")
with tempfile.TemporaryDirectory() as tmpdirname:
    data_dir = Path(tmpdirname)
    trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), env=env_generator()[0], data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
    TrajectoryEvaluator(trajectory, callbacks=make_multi_callbacks(DummyCallbacks())).evaluate(tqdm_kwargs={"disable": True})
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:81: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  self.trains_arrived = pd.concat([self.trains_arrived, self._collected_trains_arrived_to_df()])
/opt/hostedtoolcache/Python/3.10.19/x64/lib/python3.10/site-packages/flatland/trajectories/trajectories.py:82: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  self.trains_rewards_dones_infos = pd.concat([self.trains_rewards_dones_infos, self._collected_trains_rewards_dones_infos_to_df()])
step0
step10
step20
step30
step40
step50
step60
step70
step80
step90
step100
step110
step120
step130
step140
step150
step160
step170
step180
step190
step200
step210
step220
step230
step240
step250
step260
step270
0.0% trains arrived. Expected 0.0%. 273 elapsed steps.

List of FlatlandCallbacks#

  • flatland.callbacks.generate_movie_callbacks.GenerateMovieCallbacks

  • flatland.integrations.interactiveai/interactiveai.FlatlandInteractiveAICallbacks

  • flatland.trajectories.trajectory_snapshot_callbacks.TrajectorySnapshotCallbacks