Policy Evaluation and Trajectories
==================================

We use the terms from [arc42](https://docs.arc42.org/section-7/) for the different views.

Building Block View Trajectory Generation and Evaluation
--------------------------------------------------------

This is a conceptual view and reflects the target state of the implementation. In Flatland implementation, we currently do not yet distinguish between configuration and state, they go together in `RailEnvPersister` and a trajectory currently consists of full snapshots.

```mermaid
classDiagram
    class Runner {
        Trajectory: +generate_trajectory_from_policy(Policy policy, ObservationBuilder obs_builder, int snapshot_interval)$ Trajectory
    }
    
    class Evaluator {
        Trajectory: +evaluate(Trajectory trajectory)
    }
    
    class Trajectory {
        Trajectory: +Path data_dir
        Trajectory: +UUID ep_id
        Trajectory: +run(int from_step, int to_step=-1)
    }



    

    class EnvSnapshot {
        EnvSnapshot: +Path data_dir
        Trajectory: +UUID ep_id
    }

    class EnvConfiguration
    EnvConfiguration: +int max_episode_steps
    EnvConfiguration: +int height
    EnvConfiguration: +int width
    EnvConfiguration: +Rewards reward_function
    EnvConfiguration: +MalGen
    EnvConfiguration: +RailGen etc. reset

    class EnvState {
        EnvState: +Grid rail
    }

    
        class EnvConfiguration

        class EnvState

        class EnvSnapshot
        
        class EnvActions

        class EnvRewards
    

    

    EnvSnapshot --> "1" EnvConfiguration
    EnvSnapshot --> "1" EnvState
    Trajectory --> "1" EnvConfiguration
    Trajectory --> "1..*" EnvState
    Trajectory --> "1..*" EnvActions
    Trajectory --> "1..*" EnvRewards

    class Policy
    Policy: act(int handle, Observation observation)

    class ObservationBuilder
    ObservationBuilder: get()
    ObservationBuilder: get_many()

    class Submission
    Submission --> "1" Policy
    Submission --> ObservationBuilder
```

Remarks:

* Trajectory needs not start at step 0
* Trajectory needs not contain state for every step - however, when starting the trajectory from an intermediate step, the snapshot must exist.

Runtime View Trajectory Generation
--------------------------------

```mermaid
flowchart TD
    subgraph PolicyRunner.create_from_policy
        start(("&nbsp;")) -->|data_dir| D0
        D0(RailEnvPersister.load_new) -->|env| E{env done?}
        E -->|no:<br/>observations| G{Agent loop:<br/> more agents?}
        G --->|observation| G1(policy.act)
        G1 -->|action| G
        G -->|no:<br/> actions| F3(env.step)
        F3 -->|observations,rewards,info| E
        E -->|yes:<br/> rewards| H((("&nbsp;")))
    end

    style Policy fill: #ffe, stroke: #333, stroke-width: 1px, color: black
    style G1 fill: #ffe, stroke: #333, stroke-width: 1px, color: black
    style Env fill: #fcc, stroke: #333, stroke-width: 1px, color: black
    style F3 fill: #fcc, stroke: #333, stroke-width: 1px, color: black
    subgraph legend
        Env(Environment)
        Policy(Policy)
        Trajectory(Trajectory)
    end

    PolicyRunner.create_from_policy~~~legend
```

Trajectory Generation and Evaluation
------------------------------------

Create a trajectory from a random policy and inspect the output

In [None]:
import tempfile
from pathlib import Path
from typing import Any, Optional

from flatland.env_generation.env_generator import env_generator
from flatland.core.policy import Policy
from flatland.trajectories.policy_runner import PolicyRunner
from flatland.utils.seeding import np_random, random_state_to_hashablestate
from flatland.evaluators.trajectory_evaluator import TrajectoryEvaluator, evaluate_trajectory
from flatland.trajectories.trajectories import Trajectory

In [None]:
class RandomPolicy(Policy):
    def __init__(self, action_size: int = 5, seed=42):
        super(RandomPolicy, self).__init__()
        self.action_size = action_size
        self.np_random, _ = np_random(seed=seed)

    def act(self, observation: Any, **kwargs):
        return self.np_random.choice(self.action_size)

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    data_dir = Path(tmpdirname)
    env, _, _ = env_generator()
    trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), env=env, data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
    # np_random in loaded episode is same as if it comes directly from env_generator incl. reset()!
    env = trajectory.restore_episode()
    gen, _, _ = env_generator()
    assert random_state_to_hashablestate(env.np_random) == random_state_to_hashablestate(gen.np_random)

    
    # inspect output
    for p in sorted(data_dir.rglob("**/*")):
        print(p)

    # inspect the actions taken by the policy
    print(trajectory._read_actions())

    # verify steps 5 to 15 - we can start at 5 as there is a snapshot for step 5.
    TrajectoryEvaluator(trajectory).evaluate(start_step=15,end_step=25, tqdm_kwargs={"disable": True})

### List of `Policy` implementations
* `tests.trajectories.test_trajectories.RandomPolicy`
* `flatland.envs.rail_env_policies.ShortestPathPolicy`
* `flatland_baselines.deadlock_avoidance_heuristic.policy.deadlock_avoidance_policy.DeadLockAvoidancePolicy` (see [flatland-baselines](https://github.com/flatland-association/flatland-baselines/blob/main/flatland_baselines/deadlock_avoidance_heuristic/policy/deadlock_avoidance_policy.py))
* `flatland/ml/ray/wrappers.ray_policy_wrapper` and `flatland/ml/ray/wrappers.ray_checkpoint_policy_wrapper` for wrapping RLlib RLModules.

Flatland Callbacks
------------------
Flatland callbacks can be used for custom metrics and custom postprocessing.

In [None]:
import inspect
from flatland.envs.rail_env import RailEnv
from flatland.callbacks.callbacks import FlatlandCallbacks, make_multi_callbacks

In [None]:
lines, _ = inspect.getsourcelines(FlatlandCallbacks)
print("".join(lines))

In [None]:
class DummyCallbacks(FlatlandCallbacks):
    def on_episode_step(
        self,
        *,
        env: Optional[RailEnv] = None,
        **kwargs,
    ) -> None:
        if (env._elapsed_steps - 1) % 10 == 0:
            print(f"step{env._elapsed_steps - 1}")

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    data_dir = Path(tmpdirname)
    trajectory = PolicyRunner.create_from_policy(policy=RandomPolicy(), env=env_generator()[0], data_dir=data_dir, snapshot_interval=15, tqdm_kwargs={"disable": True})
    TrajectoryEvaluator(trajectory, callbacks=make_multi_callbacks(DummyCallbacks())).evaluate(tqdm_kwargs={"disable": True})

### List of FlatlandCallbacks
* `flatland.callbacks.generate_movie_callbacks.GenerateMovieCallbacks`
* `flatland.integrations.interactiveai/interactiveai.FlatlandInteractiveAICallbacks`
* `flatland.trajectories.trajectory_snapshot_callbacks.TrajectorySnapshotCallbacks`