Source code for flatland.envs.step_utils.speed_counter
from functools import lru_cache
import numpy as np
from flatland.envs.step_utils.states import TrainState
@lru_cache()
def _calc_max_count(speed):
return int(1.0 / speed) - 1
[docs]
class SpeedCounter:
def __init__(self, speed):
self._speed = speed
self.counter = None
self.reset_counter()
[docs]
def update_counter(self, state, old_position):
# Can't start counting when adding train to the map
if state == TrainState.MOVING and old_position is not None:
self.counter += 1
self.counter = self.counter % (self.max_count + 1)
def __repr__(self):
return f"speed: {self.speed} \
max_count: {self.max_count} \
is_cell_entry: {self.is_cell_entry} \
is_cell_exit: {self.is_cell_exit} \
counter: {self.counter}"
[docs]
def reset_counter(self):
self.counter = 0
@property
def is_cell_entry(self):
return self.counter == 0
@property
def is_cell_exit(self):
return self.counter == self.max_count
@property
def speed(self):
return self._speed
@property
def max_count(self):
return _calc_max_count(self._speed)
[docs]
def to_dict(self):
return {"speed": self._speed,
"counter": self.counter}
[docs]
def from_dict(self, load_dict):
self._speed = load_dict['speed']
self.counter = load_dict['counter']
def __eq__(self, other):
return self._speed == other._speed and self.counter == other.counter