import inspect
from flatland.env_generation.env_generator import env_generator
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.ml.observations.flatten_tree_observation_for_rail_env import FlattenedNormalizedTreeObsForRailEnv
from flatland.ml.pettingzoo.wrappers import PettingzooFlatland
from flatland.ml.pettingzoo.examples.flatland_pettingzoo_stable_baselines import train_flatland_pettingzoo_supersuit, eval_flatland_pettingzoo
PettingZoo#
PettingZoo (https://www.pettingzoo.ml/) is a collection of multi-agent environments for reinforcement learning. We build a pettingzoo interface for flatland.
Background#
PettingZoo is a popular multi-agent environment library (https://arxiv.org/abs/2009.14471) that aims to be the gym standard for Multi-Agent Reinforcement Learning. We list the below advantages that make it suitable for use with flatland
Works with both rllib (https://docs.ray.io/en/latest/rllib.html) and stable baselines 3 (https://stable-baselines3.readthedocs.io/) using wrappers from Super Suit.
Clean API (https://www.pettingzoo.ml/api) with additional facilities/api for parallel, saving observation, recording using gym monitor, processing, normalising observations
Scikit-learn inspired api e.g.
act = model.predict(obs, deterministic=True)[0]
Parallel learning using literally 2 lines of code to use with stable baselines 3
env = ss.pettingzoo_env_to_vec_env_v0(env)
env = ss.concat_vec_envs_v0(env, 8, num_cpus=4, base_class=’stable_baselines3’)
Tested and supports various multi-agent environments with many agents comparable to flatland. e.g. https://www.pettingzoo.ml/magent
Clean interface means we can custom add an experimenting tool like wandb and have full flexibility to save information we want
PettingZoo Demo#
Uses Stable-Baselines3 to train agents to play the Flatland environment using SuperSuit vector envs.
For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
Based on Farama-Foundation/PettingZoo
Inspect Training Code#
%pycat inspect.getsource(train_flatland_pettingzoo_supersuit)
def train_flatland_pettingzoo_supersuit(
env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
# Train a single model to play as each agent in a cooperative Parallel environment
env = env_fn.parallel_env(**env_kwargs)
env.reset(seed=seed)
print(f"Starting training on {str(env.metadata['name'])}.")
env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
model = PPO(
MlpPolicy,
env,
verbose=3,
learning_rate=1e-3,
batch_size=256,
)
model.learn(total_timesteps=steps)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
print("Model has been saved.")
print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")
env.close()
Inspect Eval Code#
%pycat inspect.getsource(eval_flatland_pettingzoo)
def eval_flatland_pettingzoo(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
# Evaluate a trained agent vs a random agent
env: ParallelEnv = env_fn.parallel_env(render_mode=render_mode, **env_kwargs)
print(
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
)
try:
latest_policy = max(
glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
)
except ValueError:
print("Policy not found.")
exit(0)
model = PPO.load(latest_policy)
rewards = {agent: 0 for agent in env.possible_agents}
for i in range(num_games):
obs, _ = env.reset(seed=i)
done = False
while not done:
act = {a: int(model.predict(obs[a], deterministic=True)[0]) for a in env.agents}
obs, rew, terminations, truncations, infos = env.step(act)
for a in env.agents:
rewards[a] += rew[a]
done = all(terminations.values())
env.close()
avg_reward = sum(rewards.values()) / len(rewards.values())
print("Rewards: ", rewards)
print(f"Avg reward: {avg_reward}")
return avg_reward
Train a model#
raw_env, _, _ = env_generator(obs_builder_object=FlattenedNormalizedTreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50)))
env_fn = PettingzooFlatland(raw_env)
env_kwargs = {}
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
train_flatland_pettingzoo_supersuit(env_fn, steps=196_608, seed=0, **env_kwargs)
Starting training on flatland_pettingzoo.
Using cpu device
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
-------------------------------
| time/ | |
| fps | 1467 |
| iterations | 1 |
| time_elapsed | 78 |
| total_timesteps | 114688 |
-------------------------------
-----------------------------------------
| time/ | |
| fps | 1223 |
| iterations | 2 |
| time_elapsed | 187 |
| total_timesteps | 229376 |
| train/ | |
| approx_kl | 0.008171785 |
| clip_fraction | 0.0342 |
| clip_range | 0.2 |
| entropy_loss | -1.6 |
| explained_variance | -0.00243 |
| learning_rate | 0.001 |
| loss | 447 |
| n_updates | 10 |
| policy_gradient_loss | -0.000753 |
| value_loss | 957 |
-----------------------------------------
Model has been saved.
Finished training on flatland_pettingzoo.
Evaluate 10 games (average reward should be positive but can vary significantly)#
eval_flatland_pettingzoo(env_fn, num_games=10, render_mode=None, **env_kwargs)
Starting evaluation on flatland_pettingzoo (num_games=10, render_mode=None)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
Rewards: {0: -1306, 1: -683, 2: -1485, 3: -955, 4: -1498, 5: -782, 6: -1194}
Avg reward: -1129.0
-1129.0
Watch 2 games#
eval_flatland_pettingzoo(env_fn, num_games=2, render_mode="human", **env_kwargs)
Starting evaluation on flatland_pettingzoo (num_games=2, render_mode=human)
Rewards: {0: -213, 1: -76, 2: -517, 3: -143, 4: -584, 5: -139, 6: -94}
Avg reward: -252.28571428571428
-252.28571428571428