Source code for flatland.envs.step_utils.speed_counter
from decimal import Decimal
from fractions import Fraction
from typing import Optional
import numpy as np
SEGMENT_LENGTH: Fraction = Fraction(1)
def _pseudo_fractional(v: Optional[float], atol=1.e-2) -> Optional[Fraction]:
"""
Convert float to fractional with special consideration of inverses of integers.
E.g. with tolerance `atol=1.e-2`, `float(0.33)` is converted to `Fraction(1,3)`.
Parameters
----------
v : Optional[float]
d the float to be converted to fractional; if the float is the inverse of an integer by tolerance, then the corresponding fraction is returned
atol : float
the tolerance to determine inverse of integers
Returns
-------
Fraction
"""
if v is None:
return None
elif isinstance(v, Fraction):
return v
elif isinstance(v, Decimal):
return Fraction.from_decimal(v)
elif isinstance(v, int):
return Fraction(v)
elif isinstance(v, float):
if np.isclose(v % 1, 0.0):
return Fraction(0, 1) + int(v // 1)
elif np.isclose(1 / round(1 / (v % 1)), v % 1, atol=atol):
return Fraction(1, round(1 / (v % 1))) + int(v // 1)
elif v < 0 and np.isclose(1 / round(1 / ((-v) % 1)), (-v) % 1, atol=atol):
return - Fraction(1, round(1 / ((-v) % 1))) + int((-v) // 1)
elif np.isclose(float(Decimal(str(v))), v):
return Fraction.from_decimal(Decimal(str(v)))
else:
return Fraction.from_float(v)
raise ValueError(f"Cannot convert {v} to Fraction.")
[docs]
class SpeedCounter:
def __init__(self, speed: float, max_speed: float = None):
self._speed: Fraction = _pseudo_fractional(speed)
self._distance: Fraction = Fraction(0)
self._is_cell_entry = True
self._max_speed: Fraction
if max_speed is not None:
self._max_speed = _pseudo_fractional(max_speed)
else:
# old constant speed behaviour
self._max_speed = self._speed
assert self._max_speed <= 1.0
assert self._speed <= self._max_speed
assert self._speed >= 0.0
self.reset()
[docs]
def step(self, speed: Fraction = None):
"""
Step the speed counter.
Parameters
----------
speed : Fraction
Set new speed effective immediately.
"""
if speed is not None:
self._speed = max(min(_pseudo_fractional(speed), self._max_speed), Fraction(0))
assert isinstance(self._speed, Fraction)
assert self._speed >= 0.0
assert self.speed <= 1.0
self._distance += self._speed
# If trains cannot move to the next cell, they are in state stopped, so it's safe to apply modulo to reflect the distance travelled in the new cell!
while self.distance >= SEGMENT_LENGTH:
self._distance = self._distance - SEGMENT_LENGTH
if self._distance < self._speed:
self._is_cell_entry = True
else:
self._is_cell_entry = False
def __repr__(self):
return f"speed: {self.speed} \
max_speed: {self.max_speed} \
distance: {self.distance} \
is_cell_entry: {self.is_cell_entry}"
[docs]
def reset(self):
self._distance = 0
self._is_cell_entry = True
@property
def is_cell_entry(self):
"""
Have just entered the cell in the previous step?
"""
return self._is_cell_entry
[docs]
def is_cell_exit(self, speed: Fraction):
"""
With the given speed, do we exit cell at next time step?
"""
speed = max(min(speed, self._max_speed), Fraction(0))
return self._distance + speed >= SEGMENT_LENGTH
@property
def speed(self) -> Fraction:
return self._speed
@property
def max_speed(self) -> Fraction:
return self._max_speed
@property
def distance(self) -> Fraction:
"""
Distance travelled in current cell.
"""
return self._distance
def __getstate__(self):
return {
"speed": self._speed,
"max_speed": self._max_speed,
"distance": self._distance,
"is_cell_entry": self._is_cell_entry,
}
def __setstate__(self, load_dict):
if "_speed" in load_dict:
# backwards compatibility
self._speed = _pseudo_fractional(load_dict['_speed'])
else:
self._speed = _pseudo_fractional(load_dict["speed"])
if "counter" in load_dict:
# old pickles have constant speed
self._distance = _pseudo_fractional(load_dict['counter'] * self._speed)
self._is_cell_entry = load_dict['counter'] == 0
else:
self._distance = _pseudo_fractional(load_dict['distance'])
if "is_cell_entry" in load_dict:
self._is_cell_entry = load_dict['is_cell_entry']
if "max_speed" in load_dict:
self._max_speed = _pseudo_fractional(load_dict["max_speed"])
else:
# old pickles have constant speed
self._max_speed = _pseudo_fractional(self._speed)
def __eq__(self, other):
if not isinstance(other, SpeedCounter):
return False
return self._speed == other._speed and self._distance == other._distance and self._max_speed == other._max_speed