Source code for flatland.core.policy
from abc import ABC
from typing import List, Dict, TypeVar, Generic
from flatland.core.env import Environment
T_env = TypeVar('T_env', bound=Environment)
T_obs = TypeVar('T_obs', covariant=True)
T_act = TypeVar('T_act', covariant=True)
[docs]
class Policy(ABC, Generic[T_env, T_obs, T_act]):
"""
Abstract base class for Flatland policies. Used for evaluation.
Loosely corresponding to https://github.com/ray-project/ray/blob/master/rllib/core/rl_module/rl_module.py, but much simpler.
"""
[docs]
def act(self, observation: List[T_obs], **kwargs) -> T_act:
"""
Get action for agent. Called by `act_many()` for each agent.
Parameters
----------
observation: Any
the agent's observation
kwargs
forward compatibility placeholder
Returns
-------
Any
the action dict
"""
raise NotImplementedError()
[docs]
def act_many(self, handles: List[int], observations: List[T_obs], **kwargs) -> Dict[int, T_act]:
"""
Get action_dict for all agents. Default implementation calls `act()` for each handle in the list.
Override if you need to initialize before / cleanup after calling `act()` for individual agents.
Parameters
----------
handles: List[int]
the agents' handles
observations: List[Any]
the agents' observations
kwargs
forward compatibility placeholder
Returns
-------
Dict[int, Any]
the action dict
"""
return {handle: self.act(observations[handle]) for handle in handles}