import pickle
from pathlib import Path
from typing import Tuple, Dict, Optional, Union
import msgpack
import msgpack_numpy
import numpy as np
from numpy.random import RandomState
from flatland.envs.malfunction_effects_generators import MalfunctionEffectsGenerator
msgpack_numpy.patch()
from flatland.envs.step_utils.states import StateTransitionSignals
from flatland.envs.rail_grid_transition_map import RailGridTransitionMap
from flatland.envs import rail_env
from flatland.envs.step_utils import env_utils
from flatland.utils.seeding import random_state_to_hashablestate
from flatland.core.env_observation_builder import DummyObservationBuilder, ObservationBuilder
from flatland.envs.agent_utils import EnvAgent, load_env_agent
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
from flatland.envs import timetable_generators as tt_gen
[docs]
class RailEnvPersister(object):
[docs]
@classmethod
def save(cls, env, filename: Union[str, Path], save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
env_dict = cls.get_full_state(env)
if isinstance(filename, Path):
filename = str(filename)
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
if len(oDistMap) > 0:
env_dict["distance_map"] = oDistMap
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
data = msgpack.packb(env_dict)
elif filename.endswith("pkl"):
data = pickle.dumps(env_dict)
# pickle.dump(env_dict, file_out)
file_out.write(data)
[docs]
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
# Add additional info to dict_env before saving
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
dict_env["shape"] = (env.width, env.height)
dict_env["max_episode_steps"] = env._max_episode_steps
with open(filename, "wb") as file_out:
if filename.endswith(".mpk"):
file_out.write(msgpack.packb(dict_env))
elif filename.endswith(".pkl"):
pickle.dump(dict_env, file_out)
[docs]
@classmethod
def load(cls, env: "RailEnv",
filename: Union[str, Path] = None,
env_dict=None,
load_from_package: Optional[str] = None,
obs_builder: Optional[ObservationBuilder["RailEnv"]] = None):
"""
Load environment with distance map from a file into existing env.
Parameters:
-------
env: RailEnv
filename: Union[str, Path]
load_from_package: Optional[str]
defaults to `None`.
obs_builder : ObservationBuilder[RailEnv]
defaults to `None`. If `None`, then a `DummyObservationBuilder` is installed.
"""
if env_dict is None:
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
cls.set_full_state(env, env_dict)
if obs_builder is None:
obs_builder = DummyObservationBuilder()
env.obs_builder = obs_builder
env.obs_builder.set_env(env)
env.obs_builder.reset()
env.rail_generator = rail_gen.rail_from_file(env_dict=env_dict)
env.line_generator = line_gen.line_from_file(env_dict=env_dict)
env.timetable_generator = tt_gen.timetable_from_file(env_dict=env_dict)
env.malfunction_generator = mal_gen.FileMalfunctionGen(env_dict=env_dict)
# TODO generic effects generator serialization
env.effects_generator = MalfunctionEffectsGenerator(env.malfunction_generator)
[docs]
@classmethod
def load_new(cls,
filename: Union[str, Path],
load_from_package=None,
obs_builder: Optional[ObservationBuilder["RailEnv"]] = None
) -> Tuple["RailEnv", Dict]:
"""
Load environment with distance map from a file into new env.
Parameters:
-------
filename: Union[str, Path]
load_from_package: Optional[str]
defaults to `None`.
obs_builder : ObservationBuilder[RailEnv]
defaults to `None`.
"""
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
llGrid = env_dict["grid"]
height = len(llGrid)
width = len(llGrid[0])
if obs_builder is None:
obs_builder = DummyObservationBuilder()
env = rail_env.RailEnv(
width=width, height=height,
rail_generator=rail_gen.rail_from_file(env_dict=env_dict),
line_generator=line_gen.line_from_file(env_dict=env_dict),
timetable_generator=tt_gen.timetable_from_file(env_dict=env_dict),
malfunction_generator=mal_gen.FileMalfunctionGen(env_dict=env_dict),
obs_builder_object=obs_builder,
record_steps=True)
cls.set_full_state(env, env_dict)
return env, env_dict
[docs]
@classmethod
def load_env_dict(cls, filename: Union[str, Path], load_from_package=None) -> Dict:
if isinstance(filename, Path):
filename = str(filename)
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
elif filename.endswith("pkl"):
try:
env_dict = pickle.loads(load_data)
except ValueError:
print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
return env_dict
[docs]
@classmethod
def load_resource(cls, package, resource):
"""
Load environment (with distance map?) from a binary
"""
return cls.load_new(resource, load_from_package=package)
[docs]
@classmethod
def set_full_state(cls, env, env_dict):
"""
Sets environment state from env_dict
Parameters
-------
env_dict: dict
"""
env.rail = RailGridTransitionMap(1, 1) # dummy
grid = np.array(env_dict["grid"])
# Replace the agents tuple with EnvAgent objects
if "agents_static" in env_dict:
env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
# remove the legacy key
del env_dict["agents_static"]
elif "agents" in env_dict:
env_dict["agents"] = [load_env_agent(d) for d in env_dict["agents"]]
# Initialise the env with the frozen agents in the file
env.agents = env_dict.get("agents", [])
# For consistency, set number_of_agents, which is the number which will be generated on reset
env.number_of_agents = env.get_num_agents()
env.height, env.width = grid.shape
# use new rail object instance for lru cache scoping and garbage collection to work properly
env.rail = RailGridTransitionMap(height=env.height, width=env.width)
env.rail.grid = grid
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
max_episode_steps = env_dict.get('max_episode_steps', None)
if max_episode_steps is not None:
env._max_episode_steps = max_episode_steps
_elapsed_steps = env_dict.get("elapsed_steps", None)
if _elapsed_steps is not None:
env._elapsed_steps = _elapsed_steps
env.distance_map.distance_map = env_dict.get('distance_map', None)
env.distance_map.reset(env.agents, env.rail)
env.distance_map._compute(env.agents, env.rail)
random_seed = env.random_seed = env_dict.get("random_seed", None)
if random_seed is not None:
env.random_seed = random_seed
seed_history = env_dict.get("seed_history", None)
if seed_history is not None:
env.seed_history = seed_history
# it's not sufficient to store random_seed, as seeding from random_seed is done
# at start of reset (before rail/line/timetable (re-)generation,
# hence np_random depends on rail/line/timetable generation
np_random_state = env_dict.get("np_random_state", None)
if np_random_state is not None:
env.np_random.set_state(np_random_state)
dev_pred_dict_ = env_dict.get("dev_pred_dict", None)
if dev_pred_dict_ is not None:
env.dev_pred_dict = dev_pred_dict_
dev_obs_dict_ = env_dict.get("dev_obs_dict", None)
if dev_pred_dict_ is not None:
env.dev_obs_dict = dev_obs_dict_
malfunction_cached_rand = env_dict.get("malfunction_cached_rand", None)
malfunction_rand_idx = env_dict.get("malfunction_rand_idx", None)
# backwards compatibility
if malfunction_cached_rand is not None:
env.malfunction_generator._cached_rand = malfunction_cached_rand
if malfunction_rand_idx is not None:
env.malfunction_generator._rand_idx = malfunction_rand_idx
malfunction_cached_random_state = env_dict.get("malfunction_cached_random_state", None)
if malfunction_cached_random_state is not None:
env.malfunction_generator._cached_random_state = malfunction_cached_random_state
np_random = RandomState()
np_random.set_state(malfunction_cached_random_state)
env.malfunction_generator.generate_rand_numbers(np_random)
env.temp_transition_data = {i: env_utils.AgentTransitionData(None, None, None, None, None, None, None, None) for i in range(env.get_num_agents())}
for i_agent in range(env.get_num_agents()):
env.temp_transition_data[i_agent].state_transition_signal = StateTransitionSignals()
dones = env_dict.get("dones", None)
if dones is not None:
env.dones = dones
# TODO bad code smell - agent_position initialized in reset() only.
env.agent_positions = np.zeros((env.height, env.width), dtype=int) - 1
[docs]
@classmethod
def get_full_state(cls, env):
"""
Returns state of environment in dict object, ready for serialization
"""
grid_data = env.rail.grid.tolist()
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
malfunction_data: mal_gen.MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
"grid": grid_data,
"agents": agent_data,
"malfunction": malfunction_data,
"malfunction_cached_random_state": env.malfunction_generator._cached_random_state if hasattr(env.malfunction_generator,
'_cached_random_state') else None,
"malfunction_rand_idx": env.malfunction_generator._rand_idx if hasattr(env.malfunction_generator, '_rand_idx') else None,
"max_episode_steps": env._max_episode_steps,
"elapsed_steps": env._elapsed_steps,
"random_seed": env.random_seed,
"seed_history": env.seed_history,
"np_random_state": random_state_to_hashablestate(env.np_random),
"dev_pred_dict": env.dev_pred_dict,
"dev_obs_dict": env.dev_obs_dict,
"dones": env.dones
}
return msg_data_dict
################################################################################################
# deprecated methods moved from RailEnv. Most likely broken.
[docs]
def deprecated_get_full_state_msg(self) -> msgpack.Packer:
"""
Returns state of environment in msgpack object
"""
msg_data_dict = self.get_full_state_dict()
return msgpack.packb(msg_data_dict, use_bin_type=True)
[docs]
def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_agent() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
[docs]
def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
# I think these calls do nothing - they create packed data and it is discarded
# msgpack.packb(grid_data, use_bin_type=True)
# msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data: mal_gen.MalfunctionProcessData = self.malfunction_process_data
# msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data = {
"grid": grid_data,
"agents": agent_data,
"distance_map": distance_map_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
[docs]
def deprecated_set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
[docs]
def deprecated_set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)