In [1]:
import inspect

In [2]:
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN

from flatland.ml.ray.examples.flatland_inference_with_random_policy import add_flatland_inference_with_random_policy_args, rollout
from flatland.ml.ray.examples.flatland_training_with_parameter_sharing import train, add_flatland_training_with_parameter_sharing_args, \
    register_flatland_ray_cli_observation_builders

# RLlib
RLlib (https://arxiv.org/abs/1712.09381) is an open source library for reinforcement learning (RL), offering support for production-level, highly scalable, and fault-tolerant RL workloads, while maintaining simple and unified APIs for a large variety of industry applications

### Register observation builds in rllib input registry

These are the registered keys you can use for the `--obs-builder` param below. Use `regiser_input` to register your own.

In [4]:
%pycat inspect.getsource(register_flatland_ray_cli_observation_builders)

[0;32mdef[0m [0mregister_flatland_ray_cli_observation_builders[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mregister_input[0m[0;34m([0m[0;34m"DummyObservationBuilderGym"[0m[0;34m,[0m [0;32mlambda[0m[0;34m:[0m [0mDummyObservationBuilderGym[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mregister_input[0m[0;34m([0m[0;34m"GlobalObsForRailEnvGym"[0m[0;34m,[0m [0;32mlambda[0m[0;34m:[0m [0mGlobalObsForRailEnvGym[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mregister_input[0m[0;34m([0m[0;34m"FlattenedNormalizedTreeObsForRailEnv_max_depth_3_50"[0m[0;34m,[0m[0;34m[0m
[0;34m[0m                   [0;32mlambda[0m[0;34m:[0m [0mFlattenedNormalizedTreeObsForRailEnv[0m[0;34m([0m[0mmax_depth[0m[0;34m=[0m[0;36m3[0m[0;34m,[0m [0mpredictor[0m[0;34m=[0m[0mShortestPathPredictorForRailEnv[0m[0;34m([0m[0mmax_depth[0m[0;34m=[0m[0;36m50[0m[0;34m)[0m

In [5]:
register_flatland_ray_cli_observation_builders()

## Rllib Training

In [6]:
parser = add_flatland_training_with_parameter_sharing_args()

#### Inspect Training cli Options

In [7]:
!python -m flatland.ml.ray.examples.flatland_training_with_parameter_sharing --help

usage: flatland_training_with_parameter_sharing.py [-h] [--algo ALGO]
                                                   [--enable-new-api-stack]
                                                   [--framework {tf,tf2,torch}]
                                                   [--env ENV]
                                                   [--num-env-runners NUM_ENV_RUNNERS]
                                                   [--num-envs-per-env-runner NUM_ENVS_PER_ENV_RUNNER]
                                                   [--num-agents NUM_AGENTS]
                                                   [--evaluation-num-env-runners EVALUATION_NUM_ENV_RUNNERS]
                                                   [--evaluation-interval EVALUATION_INTERVAL]
                                                   [--evaluation-duration EVALUATION_DURATION]
                                                   [--evaluation-duration-unit {episodes,timesteps}]
                                            

#### Inspect Training cli Code

In [8]:
%pycat inspect.getsource(train)

[0;32mdef[0m [0mtrain[0m[0;34m([0m[0margs[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0margparse[0m[0;34m.[0m[0mNamespace[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m [0minit_args[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m [0;34m->[0m [0mUnion[0m[0;34m[[0m[0mResultDict[0m[0;34m,[0m [0mtune[0m[0;34m.[0m[0mresult_grid[0m[0;34m.[0m[0mResultGrid[0m[0;34m][0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mif[0m [0margs[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mparser[0m [0;34m=[0m [0madd_flatland_training_with_parameter_sharing_args[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0margs[0m [0;34m=[0m [0mparser[0m[0;34m.[0m[0mparse_args[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0;32massert[0m [0margs[0m[0;34m.[0m[0mnum_agents[0m [0;34m>[0m [0;36m0[0m[0;34m,[0m [0;34m"Must set --num-agents > 0 when running this script!"[0m[0;34m[0m
[0;34m[0m    [0;32m

#### Run Training with PPO for one iteration with reduced batch size and checkpointing

In [9]:
!echo $PWD

/Users/che/workspaces/flatland-rl-2/notebooks


In [10]:
algo = "PPO"
obid = "FlattenedNormalizedTreeObsForRailEnv_max_depth_3_50"
# in order to get the results, we call `train()` directly from python
results = train(parser.parse_args(
    ["--num-agents", "2", "--obs-builder", obid, "--algo", algo, "--stop-iters", "1", "--train-batch-size-per-learner", "200", "--checkpoint-freq", "1"]))

2025-03-21 16:51:26,759	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
  gym.logger.warn(
  gym.logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
2025-03-21 16:51:28,339	INFO worker.py:1672 -- Calling ray.init() again after it has already been called.


== Status ==
Current time: 2025-03-21 16:51:28 (running for 00:00:00.21)
Using FIFO scheduling algorithm.
Logical resource usage: 0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+


== Status ==
Current time: 2025-03-21 16:51:33 (running for 00:00:05.31)
Using FIFO scheduling algorithm.
Logical resource usage: 0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc 

[33m(raylet)[0m [2025-03-21 16:51:36,766 E 85004 3531068] file_system_monitor.cc:116: /tmp/ray/session_2025-03-21_16-51-24_808186_84981 is over 95% full, available space: 86.4491 GB; capacity: 1858.19 GB. Object creation will fail if spilling is required.


== Status ==
Current time: 2025-03-21 16:51:38 (running for 00:00:10.35)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+






== Status ==
Current time: 2025-03-21 16:51:43 (running for 00:00:15.43)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+






[36m(MultiAgentEnvRunner pid=85026)[0m agent_to_module_mapping.py: [Discrete(5), Discrete(5)]
[36m(MultiAgentEnvRunner pid=85026)[0m agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)]
[36m(MultiAgentEnvRunner pid=85026)[0m agent_to_module_mapping.py: [Discrete(5), Discrete(5)]
[36m(MultiAgentEnvRunner pid=85026)[0m agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)]
[36m(MultiAgentEnvRunner pid=85026)[0m agent_to_module_mapping.py: [Discrete(5), Discrete(5)]
[36m(MultiAgentEnvRunner pid=85026)[0m agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)]


[33m(raylet)[0m [2025-03-21 16:51:46,860 E 85004 3531068] file_system_monitor.cc:116: /tmp/ray/session_2025-03-21_16-51-24_808186_84981 is over 95% full, available space: 86.4488 GB; capacity: 1858.19 GB. Object creation will fail if spilling is required.
[36m(PPO pid=85022)[0m Install gputil for GPU system monitoring.


== Status ==
Current time: 2025-03-21 16:51:48 (running for 00:00:20.50)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+
| Trial name                   | status   | loc             |
|------------------------------+----------+-----------------|
| PPO_flatland_env_59ef6_00000 | RUNNING  | 127.0.0.1:85022 |
+------------------------------+----------+-----------------+






Trial name,env_runner_group,env_runners,fault_tolerance,learners,num_env_steps_sampled_lifetime,num_env_steps_sampled_lifetime_throughput,num_training_step_calls_per_iteration,perf,timers
PPO_flatland_env_59ef6_00000,{'actor_manager_num_outstanding_async_reqs': 0},"{'num_agent_steps_sampled_lifetime': {'0': 200, '1': 200}, 'timers': {'connectors': {'AgentToModuleMapping': 7.898451490889527e-06, 'NumpyToTensor': 0.0006514624725359729, 'UnBatchToIndividualItems': 3.3355799842430625e-05, 'AddStatesFromEpisodesToBatch': 6.196388367066876e-06, 'TensorToNumpy': 0.0002189852724555016, 'NormalizeAndClipActions': 0.00017074410228305476, 'ModuleToAgentUnmapping': 6.448348645382307e-06, 'AddTimeDimToBatchAndZeroPad': 2.2233997937464932e-05, 'ListifyDataForVectorEnv': 8.073932804185945e-06, 'AddObservationsFromEpisodesToBatch': 3.9197153741186344e-05, 'GetActions': 0.020396071205466673, 'BatchIndividualItems': 5.00980064539732e-05, 'RemoveSingleTsTimeRankFromBatch': 2.0477231243917827e-06}}, 'num_module_steps_sampled': {'p0': 400}, 'num_env_steps_sampled': 200, 'num_agent_steps_sampled': {'1': 200, '0': 200}, 'num_module_steps_sampled_lifetime': {'p0': 400}, 'num_env_steps_sampled_lifetime': 200, 'env_to_module_sum_episodes_length_in': 37.23720178604972, 'env_to_module_sum_episodes_length_out': 37.23720178604972, 'num_env_steps_sampled_lifetime_throughput': nan}","{'num_healthy_workers': 2, 'num_remote_worker_restarts': 0}","{'__all_modules__': {'timers': {'connectors': {'BatchIndividualItems': 0.009110875020269305, 'AddOneTsToEpisodesAndTruncate': 0.0038188330072443932, 'AgentToModuleMapping': 0.000297874998068437, 'NumpyToTensor': 0.0008266670047305524, 'GeneralAdvantageEstimation': 0.017291625001234934, 'AddStatesFromEpisodesToBatch': 9.374984074383974e-06, 'AddTimeDimToBatchAndZeroPad': 3.024999750778079e-05, 'AddObservationsFromEpisodesToBatch': 8.495798101648688e-05, 'AddColumnsFromEpisodesToTrainBatch': 0.008209916995838284}}, 'num_non_trainable_parameters': 0, 'learner_connector_sum_episodes_length_in': 200, 'learner_connector_sum_episodes_length_out': 200, 'num_env_steps_trained_lifetime': 200, 'num_trainable_parameters': 655878, 'num_module_steps_trained': 404, 'num_env_steps_trained': 200, 'num_module_steps_trained_lifetime': 404, 'num_env_steps_trained_lifetime_throughput': nan}, 'p0': {'num_trainable_parameters': 655878, 'weights_seq_no': 1.0, 'entropy': 1.5917836427688599, 'module_train_batch_size_mean': 404, 'policy_loss': -0.15537874400615692, 'total_loss': -0.15024973452091217, 'vf_loss_unclipped': 1.6708769123852107e-07, 'num_module_steps_trained': 404, 'mean_kl_loss': 0.025644313544034958, 'num_module_steps_trained_lifetime': 404, 'curr_kl_coeff': 0.30000001192092896, 'num_non_trainable_parameters': 0, 'vf_explained_var': 0.9406133890151978, 'default_optimizer_learning_rate': 5e-05, 'gradients_default_optimizer_global_norm': 0.9358460307121277, 'curr_entropy_coeff': 0.0, 'diff_num_grad_updates_vs_sampler_policy': 0.0, 'vf_loss': 1.6708769123852107e-07}}",200,,1,"{'cpu_util_percent': 67.225, 'ram_util_percent': 96.525}","{'training_iteration': 2.536862332985038, 'restore_env_runners': 2.4582986952736974e-05, 'training_step': 2.5366032080200966, 'env_runner_sampling_timer': 0.5369193330116104, 'learner_update_timer': 1.9899427500204183, 'synch_weights': 0.009463166003115475}"


2025-03-21 16:51:50,353	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/che/ray_results/PPO_2025-03-21_16-51-28' in 0.0099s.
[36m(PPO(env=flatland_env; env-runners=2; learners=0; multi-agent=True) pid=85022)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/che/ray_results/PPO_2025-03-21_16-51-28/PPO_flatland_env_59ef6_00000_0_2025-03-21_16-51-28/checkpoint_000000)
2025-03-21 16:51:50,489	INFO tune.py:1041 -- Total run time: 22.13 seconds (21.97 seconds for the tuning loop).


== Status ==
Current time: 2025-03-21 16:51:50 (running for 00:00:21.98)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)
+------------------------------+------------+-----------------+--------+------------------+------+
| Trial name                   | status     | loc             |   iter |   total time (s) |   ts |
|------------------------------+------------+-----------------+--------+------------------+------|
| PPO_flatland_env_59ef6_00000 | TERMINATED | 127.0.0.1:85022 |      1 |          2.54778 |  200 |
+------------------------------+------------+-----------------+--------+------------------+------+






[36m(PPO pid=85022)[0m agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)][32m [repeated 12x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m


## Rollout from Checkpoint

In [11]:
parser = add_flatland_inference_with_random_policy_args()

#### Inspect Rollout cli Options

In [12]:
!python -m flatland.ml.ray.examples.flatland_inference_with_random_policy --help

usage: flatland_inference_with_random_policy.py [-h] [--num-agents NUM_AGENTS]
                                                [--obs-builder OBS_BUILDER]
                                                [--num-episodes-during-inference NUM_EPISODES_DURING_INFERENCE]
                                                [--policy-id POLICY_ID]
                                                [--cp CP]

options:
  -h, --help            show this help message and exit
  --num-agents NUM_AGENTS
  --obs-builder OBS_BUILDER
  --num-episodes-during-inference NUM_EPISODES_DURING_INFERENCE
                        Number of episodes to do inference over (after
                        restoring from a checkpoint).
  --policy-id POLICY_ID
  --cp CP


#### Inspect Rollout cli Code

In [13]:
%pycat inspect.getsource(rollout)

[0;32mdef[0m [0mrollout[0m[0;34m([0m[0margs[0m[0;34m:[0m [0mNamespace[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;31m# Create an env to do inference in.[0m[0;34m[0m
[0;34m[0m    [0menv[0m [0;34m=[0m [0mray_env_generator[0m[0;34m([0m[0mn_agents[0m[0;34m=[0m[0margs[0m[0;34m.[0m[0mnum_agents[0m[0;34m,[0m [0mobs_builder_object[0m[0;34m=[0m[0mregistry_get_input[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mobs_builder[0m[0;34m)[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0mobs[0m[0;34m,[0m [0m_[0m [0;34m=[0m [0menv[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0mnum_episodes[0m [0;34m=[0m [0;36m0[0m[0;34m[0m
[0;34m[0m    [0mepisode_return[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;32mif[0m [0margs[0m[0;34m.[0m[0mcp[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34

#### Rollout on best checkpoint from previous training

In [14]:
register_flatland_ray_cli_observation_builders()

In [15]:
best_result = results.get_best_result(
    metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max"
)

In [16]:
!python -m flatland.ml.ray.examples.flatland_inference_with_random_policy --num-agents 2 --obs-builder {obid} --cp {best_result.checkpoint.path} --policy-id p0  --num-episodes-during-inference 1

Episode done: Total reward = -230.0
Done performing action inference through 1 Episodes
