Source code for flatland.envs.observations_perturbed

from typing import Optional, List, Dict

import numpy as np
from numpy.random import RandomState

from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder, AgentHandle
from flatland.envs.malfunction_generators import ParamMalfunctionGen, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv, Node
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler


[docs] def perturbation_tree_observation_builder_wrapper( builder: TreeObsForRailEnv, np_random: RandomState, perturbation_rate: float = None, min_duration: float = None, max_duration: float = None, blank=-np.inf, ) -> ObservationBuilder[Node]: """ Make some trains blind for some time according to Poisson process. Parameters ---------- builder : TreeObsForRailEnv the wrapped observation builder np_random : RandomState perturbation_rate : int Poisson process with given rate. min_duration : int If perturbed, duration uniformly in [min_duration,max_duration]. max_duration : int If perturbed, duration uniformly in [min_duration,max_duration]. blank : float value to insert for perturbed trains. Returns ------- Observations with some trains not seeing anything.. """ class _PerturbedTreeObsForRailEnv(ObservationBuilder[Node]): def __init__(self, builder: TreeObsForRailEnv): super().__init__() self._malfunction_rate = perturbation_rate if perturbation_rate is not None else 0 self._min_duration = min_duration if min_duration is not None else 1 self._max_duration = max_duration if max_duration is not None else 1 self._builder = builder self._np_random = np_random self._malfunction_handlers: Dict[AgentHandle, MalfunctionHandler] = {} self._malfunction_generator = ParamMalfunctionGen( MalfunctionParameters(malfunction_rate=self._malfunction_rate, min_duration=self._min_duration, max_duration=self._max_duration) ) self._blank = blank def set_env(self, env: Environment): super().set_env(env) self._builder.set_env(env) def reset(self): self._builder.reset() self._malfunction_handlers = { i: MalfunctionHandler() for i in self.env.get_agent_handles() } def get_many(self, handles: Optional[List[AgentHandle]] = None) -> Dict[AgentHandle, Node]: obs = self._builder.get_many(handles) for handle in self.env.get_agent_handles(): self._malfunction_handlers[handle].generate_malfunction(self._malfunction_generator, self._np_random) if self._malfunction_handlers[handle].in_malfunction: # agent invisible for all others obs[handle] = Node(dist_own_target_encountered=self._blank, dist_other_target_encountered=self._blank, dist_other_agent_encountered=self._blank, dist_potential_conflict=self._blank, dist_unusable_switch=self._blank, dist_to_next_branch=self._blank, dist_min_to_target=self._blank, num_agents_same_direction=self._blank, num_agents_opposite_direction=self._blank, num_agents_malfunctioning=self._blank, speed_min_fractional=self._blank, num_agents_ready_to_depart=self._blank, childs={}) return obs return _PerturbedTreeObsForRailEnv(builder)