Source code for flatland.callbacks.callbacks
from pathlib import Path
from typing import Optional, List
from flatland.core.env import Environment
[docs]
class FlatlandCallbacks:
"""
Abstract base class for Flatland callbacks similar to rllib, see https://github.com/ray-project/ray/blob/master/rllib/callbacks/callbacks.py.
These callbacks can be used for custom metrics and custom postprocessing.
By default, all of these callbacks are no-ops.
"""
[docs]
def on_episode_start(
self,
*,
env: Optional[Environment] = None,
data_dir: Path = None,
**kwargs,
) -> None:
"""Callback run right after an Episode has been started.
This method gets called after `env.reset()`.
Parameters
---------
env : Environment
the env
data_dir : Path
trajectory data dir
kwargs:
Forward compatibility placeholder.
"""
pass
[docs]
def on_episode_step(
self,
*,
env: Optional[Environment] = None,
data_dir: Path = None,
**kwargs,
) -> None:
"""Called on each episode step (after the action(s) has/have been logged).
This callback is also called after the final step of an episode,
meaning when terminated/truncated are returned as True
from the `env.step()` call.
The exact time of the call of this callback is after `env.step([action])` and
also after the results of this step (observation, reward, terminated, truncated,
infos) have been logged to the given `episode` object.
Parameters
---------
env : Environment
the env
data_dir : Path
trajectory data dir
kwargs:
Forward compatibility placeholder.
"""
pass
[docs]
def on_episode_end(
self,
*,
env: Optional[Environment] = None,
data_dir: Path = None,
**kwargs,
) -> None:
"""Called when an episode is done (after terminated/truncated have been logged).
The exact time of the call of this callback is after `env.step([action])`
Parameters
---------
env : Environment
the env
data_dir : Path
trajectory data dir
kwargs:
Forward compatibility placeholder.
"""
pass
# https://github.com/ray-project/ray/blob/3b94e5ff0038798a6955cde37459a0d30aa718c4/rllib/callbacks/utils.py#L41
[docs]
def make_multi_callbacks(*_callback_list: FlatlandCallbacks):
class _MultiFlatlandCallbacks(FlatlandCallbacks):
IS_CALLBACK_CONTAINER = True
def __init__(self, callback_list: List[FlatlandCallbacks]):
self._callback_list = callback_list
def on_episode_start(self, **kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_start(**kwargs)
def on_episode_step(self, **kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_step(**kwargs)
def on_episode_end(self, **kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_end(**kwargs)
return _MultiFlatlandCallbacks(_callback_list)