import inspect
from flatland.env_generation.env_generator import env_generator
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.ml.observations.flatten_tree_observation_for_rail_env import FlattenedNormalizedTreeObsForRailEnv
from flatland.ml.pettingzoo.wrappers import PettingzooFlatland
from flatland.ml.pettingzoo.examples.flatland_pettingzoo_stable_baselines import train_flatland_pettingzoo_supersuit, eval_flatland_pettingzoo

PettingZoo#

PettingZoo (https://www.pettingzoo.ml/) is a collection of multi-agent environments for reinforcement learning. We build a pettingzoo interface for flatland.

Background#

PettingZoo is a popular multi-agent environment library (https://arxiv.org/abs/2009.14471) that aims to be the gym standard for Multi-Agent Reinforcement Learning. We list the below advantages that make it suitable for use with flatland

act = model.predict(obs, deterministic=True)[0] 
  • Parallel learning using literally 2 lines of code to use with stable baselines 3

env = ss.pettingzoo_env_to_vec_env_v0(env)
env = ss.concat_vec_envs_v0(env, 8, num_cpus=4, base_class=stable_baselines3)
  • Tested and supports various multi-agent environments with many agents comparable to flatland. e.g. https://www.pettingzoo.ml/magent

  • Clean interface means we can custom add an experimenting tool like wandb and have full flexibility to save information we want

PettingZoo Demo#

Uses Stable-Baselines3 to train agents to play the Flatland environment using SuperSuit vector envs.

For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

Based on Farama-Foundation/PettingZoo

Inspect Training Code#

%pycat inspect.getsource(train_flatland_pettingzoo_supersuit)
def train_flatland_pettingzoo_supersuit(
    env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs
):
    # Train a single model to play as each agent in a cooperative Parallel environment
    env = env_fn.parallel_env(**env_kwargs)

    env.reset(seed=seed)

    print(f"Starting training on {str(env.metadata['name'])}.")

    env = ss.pettingzoo_env_to_vec_env_v1(env)
    env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")

    model = PPO(
        MlpPolicy,
        env,
        verbose=3,
        learning_rate=1e-3,
        batch_size=256,
    )

    model.learn(total_timesteps=steps)

    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")

    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")

    env.close()

Inspect Eval Code#

%pycat inspect.getsource(eval_flatland_pettingzoo)
def eval_flatland_pettingzoo(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwargs):
    # Evaluate a trained agent vs a random agent
    env: ParallelEnv = env_fn.parallel_env(render_mode=render_mode, **env_kwargs)

    print(
        f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)

    model = PPO.load(latest_policy)

    rewards = {agent: 0 for agent in env.possible_agents}

    for i in range(num_games):
        obs, _ = env.reset(seed=i)

        done = False
        while not done:
            act = {a: int(model.predict(obs[a], deterministic=True)[0]) for a in env.agents}
            obs, rew, terminations, truncations, infos = env.step(act)
            for a in env.agents:
                rewards[a] += rew[a]
            done = all(terminations.values())

    env.close()

    avg_reward = sum(rewards.values()) / len(rewards.values())
    print("Rewards: ", rewards)
    print(f"Avg reward: {avg_reward}")
    return avg_reward

Train a model#

raw_env, _, _ = env_generator(obs_builder_object=FlattenedNormalizedTreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50)))
env_fn = PettingzooFlatland(raw_env)
env_kwargs = {}
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
train_flatland_pettingzoo_supersuit(env_fn, steps=196_608, seed=0, **env_kwargs)
Starting training on flatland_pettingzoo.
Using cpu device
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
-------------------------------
| time/              |        |
|    fps             | 1467   |
|    iterations      | 1      |
|    time_elapsed    | 78     |
|    total_timesteps | 114688 |
-------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1223        |
|    iterations           | 2           |
|    time_elapsed         | 187         |
|    total_timesteps      | 229376      |
| train/                  |             |
|    approx_kl            | 0.008171785 |
|    clip_fraction        | 0.0342      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.6        |
|    explained_variance   | -0.00243    |
|    learning_rate        | 0.001       |
|    loss                 | 447         |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.000753   |
|    value_loss           | 957         |
-----------------------------------------
Model has been saved.
Finished training on flatland_pettingzoo.

Evaluate 10 games (average reward should be positive but can vary significantly)#

eval_flatland_pettingzoo(env_fn, num_games=10, render_mode=None, **env_kwargs)
Starting evaluation on flatland_pettingzoo (num_games=10, render_mode=None)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
Rewards:  {0: -1306, 1: -683, 2: -1485, 3: -955, 4: -1498, 5: -782, 6: -1194}
Avg reward: -1129.0
-1129.0

Watch 2 games#

eval_flatland_pettingzoo(env_fn, num_games=2, render_mode="human", **env_kwargs)
Starting evaluation on flatland_pettingzoo (num_games=2, render_mode=human)
Rewards:  {0: -213, 1: -76, 2: -517, 3: -143, 4: -584, 5: -139, 6: -94}
Avg reward: -252.28571428571428
-252.28571428571428