import inspect
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN

from flatland.ml.ray.examples.flatland_inference_with_random_policy import add_flatland_inference_with_random_policy_args, rollout
from flatland.ml.ray.examples.flatland_training_with_parameter_sharing import train, add_flatland_training_with_parameter_sharing_args, \
    register_flatland_ray_cli_observation_builders

RLlib#

RLlib (https://arxiv.org/abs/1712.09381) is an open source library for reinforcement learning (RL), offering support for production-level, highly scalable, and fault-tolerant RL workloads, while maintaining simple and unified APIs for a large variety of industry applications

Register observation builds in rllib input registry#

These are the registered keys you can use for the --obs-builder param below. Use regiser_input to register your own.

%pycat inspect.getsource(register_flatland_ray_cli_observation_builders)
def register_flatland_ray_cli_observation_builders():
    register_input("DummyObservationBuilderGym", lambda: DummyObservationBuilderGym()),
    register_input("GlobalObsForRailEnvGym", lambda: GlobalObsForRailEnvGym()),
    register_input("FlattenedNormalizedTreeObsForRailEnv_max_depth_3_50",
                   lambda: FlattenedNormalizedTreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv(max_depth=50)))
register_flatland_ray_cli_observation_builders()

Rllib Training#

parser = add_flatland_training_with_parameter_sharing_args()

Inspect Training cli Options#

!python -m flatland.ml.ray.examples.flatland_training_with_parameter_sharing --help
usage: flatland_training_with_parameter_sharing.py [-h] [--algo ALGO]
                                                   [--enable-new-api-stack]
                                                   [--framework {tf,tf2,torch}]
                                                   [--env ENV]
                                                   [--num-env-runners NUM_ENV_RUNNERS]
                                                   [--num-envs-per-env-runner NUM_ENVS_PER_ENV_RUNNER]
                                                   [--num-agents NUM_AGENTS]
                                                   [--evaluation-num-env-runners EVALUATION_NUM_ENV_RUNNERS]
                                                   [--evaluation-interval EVALUATION_INTERVAL]
                                                   [--evaluation-duration EVALUATION_DURATION]
                                                   [--evaluation-duration-unit {episodes,timesteps}]
                                                   [--evaluation-parallel-to-training]
                                                   [--output OUTPUT]
                                                   [--log-level {INFO,DEBUG,WARN,ERROR}]
                                                   [--no-tune]
                                                   [--num-samples NUM_SAMPLES]
                                                   [--max-concurrent-trials MAX_CONCURRENT_TRIALS]
                                                   [--verbose VERBOSE]
                                                   [--checkpoint-freq CHECKPOINT_FREQ]
                                                   [--checkpoint-at-end]
                                                   [--wandb-key WANDB_KEY]
                                                   [--wandb-project WANDB_PROJECT]
                                                   [--wandb-run-name WANDB_RUN_NAME]
                                                   [--stop-reward STOP_REWARD]
                                                   [--stop-iters STOP_ITERS]
                                                   [--stop-timesteps STOP_TIMESTEPS]
                                                   [--as-test]
                                                   [--as-release-test]
                                                   [--num-learners NUM_LEARNERS]
                                                   [--num-gpus-per-learner NUM_GPUS_PER_LEARNER]
                                                   [--num-aggregator-actors-per-learner NUM_AGGREGATOR_ACTORS_PER_LEARNER]
                                                   [--num-cpus NUM_CPUS]
                                                   [--local-mode]
                                                   [--num-gpus NUM_GPUS]
                                                   [--train-batch-size-per-learner TRAIN_BATCH_SIZE_PER_LEARNER]
                                                   [--obs-builder OBS_BUILDER]
                                                   [--ray-address RAY_ADDRESS]
                                                   [--env_var [KEY=VALUE ...]]

options:
  -h, --help            show this help message and exit
  --algo ALGO           The RLlib-registered algorithm to use.
  --enable-new-api-stack
                        Whether to use the `enable_rl_module_and_learner`
                        config setting.
  --framework {tf,tf2,torch}
                        The DL framework specifier.
  --env ENV             The gym.Env identifier to run the experiment with.
  --num-env-runners NUM_ENV_RUNNERS
                        The number of (remote) EnvRunners to use for the
                        experiment.
  --num-envs-per-env-runner NUM_ENVS_PER_ENV_RUNNER
                        The number of (vectorized) environments per EnvRunner.
                        Note that this is identical to the batch size for
                        (inference) action computations.
  --num-agents NUM_AGENTS
                        If 0 (default), will run as single-agent. If > 0, will
                        run as multi-agent with the environment simply cloned
                        n times and each agent acting independently at every
                        single timestep. The overall reward for this
                        experiment is then the sum over all individual agents'
                        rewards.
  --evaluation-num-env-runners EVALUATION_NUM_ENV_RUNNERS
                        The number of evaluation (remote) EnvRunners to use
                        for the experiment.
  --evaluation-interval EVALUATION_INTERVAL
                        Every how many iterations to run one round of
                        evaluation. Use 0 (default) to disable evaluation.
  --evaluation-duration EVALUATION_DURATION
                        The number of evaluation units to run each evaluation
                        round. Use `--evaluation-duration-unit` to count
                        either in 'episodes' or 'timesteps'. If 'auto', will
                        run as many as possible during train pass
                        (`--evaluation-parallel-to-training` must be set
                        then).
  --evaluation-duration-unit {episodes,timesteps}
                        The evaluation duration unit to count by. One of
                        'episodes' or 'timesteps'. This unit will be run
                        `--evaluation-duration` times in each evaluation
                        round. If `--evaluation-duration=auto`, this setting
                        does not matter.
  --evaluation-parallel-to-training
                        Whether to run evaluation parallel to training. This
                        might help speed up your overall iteration time. Be
                        aware that when using this option, your reported
                        evaluation results are referring to one iteration
                        before the current one.
  --output OUTPUT       The output directory to write trajectories to, which
                        are collected by the algo's EnvRunners.
  --log-level {INFO,DEBUG,WARN,ERROR}
                        The log-level to be used by the RLlib logger.
  --no-tune             Whether to NOT use tune.Tuner(), but rather a simple
                        for-loop calling `algo.train()` repeatedly until one
                        of the stop criteria is met.
  --num-samples NUM_SAMPLES
                        How many (tune.Tuner.fit()) experiments to execute -
                        if possible in parallel.
  --max-concurrent-trials MAX_CONCURRENT_TRIALS
                        How many (tune.Tuner) trials to run concurrently.
  --verbose VERBOSE     The verbosity level for the `tune.Tuner()` running the
                        experiment.
  --checkpoint-freq CHECKPOINT_FREQ
                        The frequency (in training iterations) with which to
                        create checkpoints. Note that if --wandb-key is
                        provided, all checkpoints will automatically be
                        uploaded to WandB.
  --checkpoint-at-end   Whether to create a checkpoint at the very end of the
                        experiment. Note that if --wandb-key is provided, all
                        checkpoints will automatically be uploaded to WandB.
  --wandb-key WANDB_KEY
                        The WandB API key to use for uploading results.
  --wandb-project WANDB_PROJECT
                        The WandB project name to use.
  --wandb-run-name WANDB_RUN_NAME
                        The WandB run name to use.
  --stop-reward STOP_REWARD
                        Reward at which the script should stop training.
  --stop-iters STOP_ITERS
                        The number of iterations to train.
  --stop-timesteps STOP_TIMESTEPS
                        The number of (environment sampling) timesteps to
                        train.
  --as-test             Whether this script should be run as a test. If set,
                        --stop-reward must be achieved within --stop-timesteps
                        AND --stop-iters, otherwise this script will throw an
                        exception at the end.
  --as-release-test     Whether this script should be run as a release test.
                        If set, all that applies to the --as-test option is
                        true, plus, a short JSON summary will be written into
                        a results file whose location is given by the ENV
                        variable `TEST_OUTPUT_JSON`.
  --num-learners NUM_LEARNERS
                        The number of Learners to use. If `None`, use the
                        algorithm's default value.
  --num-gpus-per-learner NUM_GPUS_PER_LEARNER
                        The number of GPUs per Learner to use. If `None` and
                        there are enough GPUs for all required Learners
                        (--num-learners), use a value of 1, otherwise 0.
  --num-aggregator-actors-per-learner NUM_AGGREGATOR_ACTORS_PER_LEARNER
                        The number of Aggregator actors to use per Learner. If
                        `None`, use the algorithm's default value.
  --num-cpus NUM_CPUS
  --local-mode          Init Ray in local mode for easier debugging.
  --num-gpus NUM_GPUS   The number of GPUs to use (only on the old API stack).
  --train-batch-size-per-learner TRAIN_BATCH_SIZE_PER_LEARNER
                        See https://docs.ray.io/en/latest/rllib/package_ref/do
                        c/ray.rllib.algorithms.algorithm_config.AlgorithmConfi
                        g.training.html#ray.rllib.algorithms.algorithm_config.
                        AlgorithmConfig.training
  --obs-builder OBS_BUILDER
  --ray-address RAY_ADDRESS
                        The address of the ray cluster to connect to in the
                        form ray://<head_node_ip_address>:10001. Leave empty
                        to start a new cluster. Passed to
                        ray.init(address=...). See
                        https://docs.ray.io/en/latest/ray-
                        core/api/doc/ray.init.html
  --env_var [KEY=VALUE ...], -e [KEY=VALUE ...]
                        Set ray runtime environment variables like -e
                        RAY_DEBUG=1, passed to ray.init(runtime_env={env_vars:
                        {...}}), see https://docs.ray.io/en/latest/ray-
                        core/handling-dependencies.html#api-reference

Inspect Training cli Code#

%pycat inspect.getsource(train)
def train(args: Optional[argparse.Namespace] = None, init_args=None) -> Union[ResultDict, tune.result_grid.ResultGrid]:
    if args is None:
        parser = add_flatland_training_with_parameter_sharing_args()
        args = parser.parse_args()
    assert args.num_agents > 0, "Must set --num-agents > 0 when running this script!"
    assert (
        args.enable_new_api_stack
    ), "Must set --enable-new-api-stack when running this script!"
    assert (
        args.obs_builder
    ), "Must set --obs-builder <obs builder ID> when running this script!"

    setup_func()
    if init_args is None:
        env_vars = set()
        if args.env_var is not None:
            env_vars = args.env_var
        init_args = {
            # https://docs.ray.io/en/latest/ray-core/handling-dependencies.html#runtime-environments
            "runtime_env": {
                "env_vars": dict(map(lambda s: s.split('='), env_vars)),
                # https://docs.ray.io/en/latest/ray-observability/user-guides/configure-logging.html
                "worker_process_setup_hook": "flatland.ml.ray.examples.flatland_training_with_parameter_sharing.setup_func"
            },
            "ignore_reinit_error": True,
        }
        if args.ray_address is not None:
            init_args['address'] = args.ray_address

    # https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
    ray.init(
        **init_args,
    )
    env_name = "flatland_env"
    register_env(env_name, lambda _: ray_env_generator(n_agents=args.num_agents, obs_builder_object=registry_get_input(args.obs_builder)()))

    # TODO could be extracted to cli - keep it low key as illustration only
    additional_training_config = {}
    if args.algo == "DQN":
        additional_training_config = {"replay_buffer_config": {
            "type": "MultiAgentEpisodeReplayBuffer",
        }}
    base_config = (
        # N.B. the warning `passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64`
        #   comes from ray.tune.registry._register_all() -->  import ray.rllib.algorithms.dreamerv3 as dreamerv3!
        get_trainable_cls(args.algo)
        .get_default_config()
        .environment("flatland_env")
        .multi_agent(
            policies={"p0"},
            # All agents map to the exact same policy.
            policy_mapping_fn=(lambda aid, *args, **kwargs: "p0"),
        )
        .training(
            model={
                "vf_share_layers": True,
            },
            train_batch_size=args.train_batch_size_per_learner,
            **additional_training_config
        )
        .rl_module(
            rl_module_spec=MultiRLModuleSpec(
                rl_module_specs={"p0": RLModuleSpec()},
            )
        )
    )
    res = run_rllib_example_script_experiment(base_config, args)

    if res.num_errors > 0:
        raise AssertionError(f"{res.errors}")
    return res

Run Training with PPO for one iteration with reduced batch size and checkpointing#

!echo $PWD
/Users/che/workspaces/flatland-rl-2/notebooks
algo = "PPO"
obid = "FlattenedNormalizedTreeObsForRailEnv_max_depth_3_50"
# in order to get the results, we call `train()` directly from python
results = train(parser.parse_args(
    ["--num-agents", "2", "--obs-builder", obid, "--algo", algo, "--stop-iters", "1", "--train-batch-size-per-learner", "200", "--checkpoint-freq", "1"]))
2025-03-21 16:51:26,759	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 
/Users/che/Miniconda3/miniconda3/envs/flatland/lib/python3.12/site-packages/gymnasium/spaces/box.py:235: UserWarning: WARN: Box low's precision lowered by casting to float32, current low.dtype=float64
  gym.logger.warn(
/Users/che/Miniconda3/miniconda3/envs/flatland/lib/python3.12/site-packages/gymnasium/spaces/box.py:305: UserWarning: WARN: Box high's precision lowered by casting to float32, current high.dtype=float64
  gym.logger.warn(
/Users/che/Miniconda3/miniconda3/envs/flatland/lib/python3.12/site-packages/gymnasium/utils/passive_env_checker.py:134: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
  logger.warn(
/Users/che/Miniconda3/miniconda3/envs/flatland/lib/python3.12/site-packages/gymnasium/utils/passive_env_checker.py:158: UserWarning: WARN: The obs returned by the `reset()` method is not within the observation space.
  logger.warn(f"{pre} is not within the observation space.")
2025-03-21 16:51:28,339	INFO worker.py:1672 -- Calling ray.init() again after it has already been called.
2025-03-21 16:51:28,384	WARNING algorithm_config.py:4726 -- You are running PPO on the new API stack! This is the new default behavior for this algorithm. If you don't want to use the new API stack, set `config.api_stack(enable_rl_module_and_learner=False,enable_env_runner_and_connector_v2=False)`. For a detailed migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html
2025-03-21 16:51:28,421	WARNING algorithm_config.py:4726 -- You are running PPO on the new API stack! This is the new default behavior for this algorithm. If you don't want to use the new API stack, set `config.api_stack(enable_rl_module_and_learner=False,enable_env_runner_and_connector_v2=False)`. For a detailed migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html
== Status ==
Current time: 2025-03-21 16:51:28 (running for 00:00:00.21)
Using FIFO scheduling algorithm.
Logical resource usage: 0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+


== Status ==
Current time: 2025-03-21 16:51:33 (running for 00:00:05.31)
Using FIFO scheduling algorithm.
Logical resource usage: 0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+
(raylet) [2025-03-21 16:51:36,766 E 85004 3531068] file_system_monitor.cc:116: /tmp/ray/session_2025-03-21_16-51-24_808186_84981 is over 95% full, available space: 86.4491 GB; capacity: 1858.19 GB. Object creation will fail if spilling is required.
== Status ==
Current time: 2025-03-21 16:51:38 (running for 00:00:10.35)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+
(PPO pid=85022) 2025-03-21 16:51:39,841	WARNING algorithm_config.py:4726 -- You are running PPO on the new API stack! This is the new default behavior for this algorithm. If you don't want to use the new API stack, set `config.api_stack(enable_rl_module_and_learner=False,enable_env_runner_and_connector_v2=False)`. For a detailed migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html
== Status ==
Current time: 2025-03-21 16:51:43 (running for 00:00:15.43)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 PENDING)
+------------------------------+----------+-------+
| Trial name                   | status   | loc   |
|------------------------------+----------+-------|
| PPO_flatland_env_59ef6_00000 | PENDING  |       |
+------------------------------+----------+-------+
(MultiAgentEnvRunner pid=85026) /Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
(MultiAgentEnvRunner pid=85026)   warnings.warn(city_warning)
(MultiAgentEnvRunner pid=85026) /Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
(MultiAgentEnvRunner pid=85026)   warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
(MultiAgentEnvRunner pid=85026) 2025-03-21 16:51:46,093	WARNING rl_module.py:419 -- Could not create a Catalog object for your RLModule! If you are not using the new API stack yet, make sure to switch it off in your config: `config.api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)`. All algos use the new stack by default. Ignore this message, if your RLModule does not use a Catalog to build its sub-components.
(MultiAgentEnvRunner pid=85026) 2025-03-21 16:51:46,093	WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
(MultiAgentEnvRunner pid=85026) agent_to_module_mapping.py: [Discrete(5), Discrete(5)]
(MultiAgentEnvRunner pid=85026) agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)]
(MultiAgentEnvRunner pid=85026) agent_to_module_mapping.py: [Discrete(5), Discrete(5)]
(MultiAgentEnvRunner pid=85026) agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)]
(MultiAgentEnvRunner pid=85026) agent_to_module_mapping.py: [Discrete(5), Discrete(5)]
(MultiAgentEnvRunner pid=85026) agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)]
(PPO pid=85022) 2025-03-21 16:51:46,491	WARNING algorithm_config.py:4726 -- You are running PPO on the new API stack! This is the new default behavior for this algorithm. If you don't want to use the new API stack, set `config.api_stack(enable_rl_module_and_learner=False,enable_env_runner_and_connector_v2=False)`. For a detailed migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html
(raylet) [2025-03-21 16:51:46,860 E 85004 3531068] file_system_monitor.cc:116: /tmp/ray/session_2025-03-21_16-51-24_808186_84981 is over 95% full, available space: 86.4488 GB; capacity: 1858.19 GB. Object creation will fail if spilling is required.
(PPO pid=85022) Install gputil for GPU system monitoring.
== Status ==
Current time: 2025-03-21 16:51:48 (running for 00:00:20.50)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+
| Trial name                   | status   | loc             |
|------------------------------+----------+-----------------|
| PPO_flatland_env_59ef6_00000 | RUNNING  | 127.0.0.1:85022 |
+------------------------------+----------+-----------------+
2025-03-21 16:51:50,270	WARNING trial.py:863 -- Stopping criterion 'env_runners/episode_return_mean' not found in result dict! Available keys are ['num_training_step_calls_per_iteration', 'num_env_steps_sampled_lifetime', 'num_env_steps_sampled_lifetime_throughput', 'done', 'training_iteration', 'trial_id', 'date', 'timestamp', 'time_this_iter_s', 'time_total_s', 'pid', 'hostname', 'node_ip', 'time_since_restore', 'iterations_since_restore', 'timers/training_iteration', 'timers/restore_env_runners', 'timers/training_step', 'timers/env_runner_sampling_timer', 'timers/learner_update_timer', 'timers/synch_weights', 'env_runners/num_env_steps_sampled', 'env_runners/num_env_steps_sampled_lifetime', 'env_runners/env_to_module_sum_episodes_length_in', 'env_runners/env_to_module_sum_episodes_length_out', 'env_runners/num_env_steps_sampled_lifetime_throughput', 'fault_tolerance/num_healthy_workers', 'fault_tolerance/num_remote_worker_restarts', 'env_runner_group/actor_manager_num_outstanding_async_reqs', 'config/placement_strategy', 'config/num_gpus', 'config/_fake_gpus', 'config/num_cpus_for_main_process', 'config/eager_tracing', 'config/eager_max_retraces', 'config/torch_compile_learner', 'config/torch_compile_learner_what_to_compile', 'config/torch_compile_learner_dynamo_backend', 'config/torch_compile_learner_dynamo_mode', 'config/torch_compile_worker', 'config/torch_compile_worker_dynamo_backend', 'config/torch_compile_worker_dynamo_mode', 'config/torch_skip_nan_gradients', 'config/env', 'config/observation_space', 'config/action_space', 'config/clip_rewards', 'config/normalize_actions', 'config/clip_actions', 'config/_is_atari', 'config/disable_env_checking', 'config/render_env', 'config/action_mask_key', 'config/env_runner_cls', 'config/num_env_runners', 'config/num_envs_per_env_runner', 'config/gym_env_vectorize_mode', 'config/num_cpus_per_env_runner', 'config/num_gpus_per_env_runner', 'config/validate_env_runners_after_construction', 'config/episodes_to_numpy', 'config/max_requests_in_flight_per_env_runner', 'config/sample_timeout_s', 'config/_env_to_module_connector', 'config/add_default_connectors_to_env_to_module_pipeline', 'config/_module_to_env_connector', 'config/add_default_connectors_to_module_to_env_pipeline', 'config/episode_lookback_horizon', 'config/rollout_fragment_length', 'config/batch_mode', 'config/compress_observations', 'config/remote_worker_envs', 'config/remote_env_batch_wait_ms', 'config/enable_tf1_exec_eagerly', 'config/sample_collector', 'config/preprocessor_pref', 'config/observation_filter', 'config/update_worker_filter_stats', 'config/use_worker_filter_stats', 'config/sampler_perf_stats_ema_coef', 'config/num_learners', 'config/num_gpus_per_learner', 'config/num_cpus_per_learner', 'config/num_aggregator_actors_per_learner', 'config/max_requests_in_flight_per_aggregator_actor', 'config/local_gpu_idx', 'config/max_requests_in_flight_per_learner', 'config/gamma', 'config/lr', 'config/grad_clip', 'config/grad_clip_by', 'config/_train_batch_size_per_learner', 'config/train_batch_size', 'config/num_epochs', 'config/minibatch_size', 'config/shuffle_batch_per_epoch', 'config/_learner_connector', 'config/add_default_connectors_to_learner_pipeline', 'config/_learner_class', 'config/callbacks_on_algorithm_init', 'config/callbacks_on_env_runners_recreated', 'config/callbacks_on_checkpoint_loaded', 'config/callbacks_on_environment_created', 'config/callbacks_on_episode_created', 'config/callbacks_on_episode_start', 'config/callbacks_on_episode_step', 'config/callbacks_on_episode_end', 'config/callbacks_on_evaluate_start', 'config/callbacks_on_evaluate_end', 'config/callbacks_on_sample_end', 'config/callbacks_on_train_result', 'config/explore', 'config/enable_rl_module_and_learner', 'config/enable_env_runner_and_connector_v2', 'config/count_steps_by', 'config/policy_map_capacity', 'config/policy_mapping_fn', 'config/policies_to_train', 'config/policy_states_are_swappable', 'config/observation_fn', 'config/offline_data_class', 'config/input_read_method', 'config/input_read_episodes', 'config/input_read_sample_batches', 'config/input_read_batch_size', 'config/input_filesystem', 'config/input_compress_columns', 'config/input_spaces_jsonable', 'config/materialize_data', 'config/materialize_mapped_data', 'config/prelearner_class', 'config/prelearner_buffer_class', 'config/prelearner_module_synch_period', 'config/dataset_num_iters_per_learner', 'config/actions_in_input_normalized', 'config/postprocess_inputs', 'config/shuffle_buffer_size', 'config/output', 'config/output_compress_columns', 'config/output_max_file_size', 'config/output_max_rows_per_file', 'config/output_write_remaining_data', 'config/output_write_method', 'config/output_filesystem', 'config/output_write_episodes', 'config/offline_sampling', 'config/evaluation_interval', 'config/evaluation_duration', 'config/evaluation_duration_unit', 'config/evaluation_sample_timeout_s', 'config/evaluation_parallel_to_training', 'config/evaluation_force_reset_envs_before_iteration', 'config/evaluation_config', 'config/ope_split_batch_by_episode', 'config/evaluation_num_env_runners', 'config/in_evaluation', 'config/sync_filters_on_rollout_workers_timeout_s', 'config/keep_per_episode_custom_metrics', 'config/metrics_episode_collection_timeout_s', 'config/metrics_num_episodes_for_smoothing', 'config/min_time_s_per_iteration', 'config/min_train_timesteps_per_iteration', 'config/min_sample_timesteps_per_iteration', 'config/log_gradients', 'config/export_native_model_files', 'config/checkpoint_trainable_policies_only', 'config/logger_creator', 'config/logger_config', 'config/log_level', 'config/log_sys_usage', 'config/fake_sampler', 'config/seed', 'config/restart_failed_env_runners', 'config/ignore_env_runner_failures', 'config/max_num_env_runner_restarts', 'config/delay_between_env_runner_restarts_s', 'config/restart_failed_sub_environments', 'config/num_consecutive_env_runner_failures_tolerance', 'config/env_runner_health_probe_timeout_s', 'config/env_runner_restore_timeout_s', 'config/_rl_module_spec', 'config/_validate_config', 'config/_use_msgpack_checkpoints', 'config/_torch_grad_scaler_class', 'config/_torch_lr_scheduler_classes', 'config/_tf_policy_handles_more_than_one_loss', 'config/_disable_preprocessor_api', 'config/_disable_action_flattening', 'config/_disable_initialize_loss_from_dummy_batch', 'config/_dont_auto_sync_env_runner_states', 'config/env_task_fn', 'config/enable_connectors', 'config/simple_optimizer', 'config/policy_map_cache', 'config/worker_cls', 'config/synchronize_filters', 'config/enable_async_evaluation', 'config/custom_async_evaluation_function', 'config/_enable_rl_module_api', 'config/auto_wrap_old_gym_envs', 'config/always_attach_evaluation_results', 'config/replay_sequence_length', 'config/_disable_execution_plan_api', 'config/use_critic', 'config/use_gae', 'config/use_kl_loss', 'config/kl_coeff', 'config/kl_target', 'config/vf_loss_coeff', 'config/entropy_coeff', 'config/clip_param', 'config/vf_clip_param', 'config/entropy_coeff_schedule', 'config/lr_schedule', 'config/sgd_minibatch_size', 'config/vf_share_layers', 'config/__stdout_file__', 'config/__stderr_file__', 'config/lambda', 'config/input', 'config/callbacks', 'config/create_env_on_driver', 'config/custom_eval_function', 'config/framework', 'perf/cpu_util_percent', 'perf/ram_util_percent', 'env_runners/num_agent_steps_sampled_lifetime/0', 'env_runners/num_agent_steps_sampled_lifetime/1', 'env_runners/num_module_steps_sampled/p0', 'env_runners/num_agent_steps_sampled/1', 'env_runners/num_agent_steps_sampled/0', 'env_runners/num_module_steps_sampled_lifetime/p0', 'learners/__all_modules__/num_non_trainable_parameters', 'learners/__all_modules__/learner_connector_sum_episodes_length_in', 'learners/__all_modules__/learner_connector_sum_episodes_length_out', 'learners/__all_modules__/num_env_steps_trained_lifetime', 'learners/__all_modules__/num_trainable_parameters', 'learners/__all_modules__/num_module_steps_trained', 'learners/__all_modules__/num_env_steps_trained', 'learners/__all_modules__/num_module_steps_trained_lifetime', 'learners/__all_modules__/num_env_steps_trained_lifetime_throughput', 'learners/p0/num_trainable_parameters', 'learners/p0/weights_seq_no', 'learners/p0/entropy', 'learners/p0/module_train_batch_size_mean', 'learners/p0/policy_loss', 'learners/p0/total_loss', 'learners/p0/vf_loss_unclipped', 'learners/p0/num_module_steps_trained', 'learners/p0/mean_kl_loss', 'learners/p0/num_module_steps_trained_lifetime', 'learners/p0/curr_kl_coeff', 'learners/p0/num_non_trainable_parameters', 'learners/p0/vf_explained_var', 'learners/p0/default_optimizer_learning_rate', 'learners/p0/gradients_default_optimizer_global_norm', 'learners/p0/curr_entropy_coeff', 'learners/p0/diff_num_grad_updates_vs_sampler_policy', 'learners/p0/vf_loss', 'config/tf_session_args/intra_op_parallelism_threads', 'config/tf_session_args/inter_op_parallelism_threads', 'config/tf_session_args/log_device_placement', 'config/tf_session_args/allow_soft_placement', 'config/local_tf_session_args/intra_op_parallelism_threads', 'config/local_tf_session_args/inter_op_parallelism_threads', 'config/model/fcnet_hiddens', 'config/model/fcnet_activation', 'config/model/fcnet_weights_initializer', 'config/model/fcnet_weights_initializer_config', 'config/model/fcnet_bias_initializer', 'config/model/fcnet_bias_initializer_config', 'config/model/conv_filters', 'config/model/conv_activation', 'config/model/conv_kernel_initializer', 'config/model/conv_kernel_initializer_config', 'config/model/conv_bias_initializer', 'config/model/conv_bias_initializer_config', 'config/model/conv_transpose_kernel_initializer', 'config/model/conv_transpose_kernel_initializer_config', 'config/model/conv_transpose_bias_initializer', 'config/model/conv_transpose_bias_initializer_config', 'config/model/post_fcnet_hiddens', 'config/model/post_fcnet_activation', 'config/model/post_fcnet_weights_initializer', 'config/model/post_fcnet_weights_initializer_config', 'config/model/post_fcnet_bias_initializer', 'config/model/post_fcnet_bias_initializer_config', 'config/model/free_log_std', 'config/model/log_std_clip_param', 'config/model/no_final_linear', 'config/model/vf_share_layers', 'config/model/use_lstm', 'config/model/max_seq_len', 'config/model/lstm_cell_size', 'config/model/lstm_use_prev_action', 'config/model/lstm_use_prev_reward', 'config/model/lstm_weights_initializer', 'config/model/lstm_weights_initializer_config', 'config/model/lstm_bias_initializer', 'config/model/lstm_bias_initializer_config', 'config/model/_time_major', 'config/model/use_attention', 'config/model/attention_num_transformer_units', 'config/model/attention_dim', 'config/model/attention_num_heads', 'config/model/attention_head_dim', 'config/model/attention_memory_inference', 'config/model/attention_memory_training', 'config/model/attention_position_wise_mlp_dim', 'config/model/attention_init_gru_gate_bias', 'config/model/attention_use_n_prev_actions', 'config/model/attention_use_n_prev_rewards', 'config/model/framestack', 'config/model/dim', 'config/model/grayscale', 'config/model/zero_mean', 'config/model/custom_model', 'config/model/custom_action_dist', 'config/model/custom_preprocessor', 'config/model/encoder_latent_dim', 'config/model/always_check_shapes', 'config/model/lstm_use_prev_action_reward', 'config/model/_use_default_native_models', 'config/model/_disable_preprocessor_api', 'config/model/_disable_action_flattening', 'config/_prior_exploration_config/type', 'config/policies/p0', 'env_runners/timers/connectors/AgentToModuleMapping', 'env_runners/timers/connectors/NumpyToTensor', 'env_runners/timers/connectors/UnBatchToIndividualItems', 'env_runners/timers/connectors/AddStatesFromEpisodesToBatch', 'env_runners/timers/connectors/TensorToNumpy', 'env_runners/timers/connectors/NormalizeAndClipActions', 'env_runners/timers/connectors/ModuleToAgentUnmapping', 'env_runners/timers/connectors/AddTimeDimToBatchAndZeroPad', 'env_runners/timers/connectors/ListifyDataForVectorEnv', 'env_runners/timers/connectors/AddObservationsFromEpisodesToBatch', 'env_runners/timers/connectors/GetActions', 'env_runners/timers/connectors/BatchIndividualItems', 'env_runners/timers/connectors/RemoveSingleTsTimeRankFromBatch', 'config/tf_session_args/gpu_options/allow_growth', 'config/tf_session_args/device_count/CPU', 'learners/__all_modules__/timers/connectors/BatchIndividualItems', 'learners/__all_modules__/timers/connectors/AddOneTsToEpisodesAndTruncate', 'learners/__all_modules__/timers/connectors/AgentToModuleMapping', 'learners/__all_modules__/timers/connectors/NumpyToTensor', 'learners/__all_modules__/timers/connectors/GeneralAdvantageEstimation', 'learners/__all_modules__/timers/connectors/AddStatesFromEpisodesToBatch', 'learners/__all_modules__/timers/connectors/AddTimeDimToBatchAndZeroPad', 'learners/__all_modules__/timers/connectors/AddObservationsFromEpisodesToBatch', 'learners/__all_modules__/timers/connectors/AddColumnsFromEpisodesToTrainBatch']. If 'env_runners/episode_return_mean' is never reported, the run will continue until training is finished.

Trial Progress

Trial name env_runner_group env_runners fault_tolerance learners num_env_steps_sampled_lifetime num_env_steps_sampled_lifetime_throughput num_training_step_calls_per_iterationperf timers
PPO_flatland_env_59ef6_00000{'actor_manager_num_outstanding_async_reqs': 0}{'num_agent_steps_sampled_lifetime': {'0': 200, '1': 200}, 'timers': {'connectors': {'AgentToModuleMapping': 7.898451490889527e-06, 'NumpyToTensor': 0.0006514624725359729, 'UnBatchToIndividualItems': 3.3355799842430625e-05, 'AddStatesFromEpisodesToBatch': 6.196388367066876e-06, 'TensorToNumpy': 0.0002189852724555016, 'NormalizeAndClipActions': 0.00017074410228305476, 'ModuleToAgentUnmapping': 6.448348645382307e-06, 'AddTimeDimToBatchAndZeroPad': 2.2233997937464932e-05, 'ListifyDataForVectorEnv': 8.073932804185945e-06, 'AddObservationsFromEpisodesToBatch': 3.9197153741186344e-05, 'GetActions': 0.020396071205466673, 'BatchIndividualItems': 5.00980064539732e-05, 'RemoveSingleTsTimeRankFromBatch': 2.0477231243917827e-06}}, 'num_module_steps_sampled': {'p0': 400}, 'num_env_steps_sampled': 200, 'num_agent_steps_sampled': {'1': 200, '0': 200}, 'num_module_steps_sampled_lifetime': {'p0': 400}, 'num_env_steps_sampled_lifetime': 200, 'env_to_module_sum_episodes_length_in': 37.23720178604972, 'env_to_module_sum_episodes_length_out': 37.23720178604972, 'num_env_steps_sampled_lifetime_throughput': nan}{'num_healthy_workers': 2, 'num_remote_worker_restarts': 0}{'__all_modules__': {'timers': {'connectors': {'BatchIndividualItems': 0.009110875020269305, 'AddOneTsToEpisodesAndTruncate': 0.0038188330072443932, 'AgentToModuleMapping': 0.000297874998068437, 'NumpyToTensor': 0.0008266670047305524, 'GeneralAdvantageEstimation': 0.017291625001234934, 'AddStatesFromEpisodesToBatch': 9.374984074383974e-06, 'AddTimeDimToBatchAndZeroPad': 3.024999750778079e-05, 'AddObservationsFromEpisodesToBatch': 8.495798101648688e-05, 'AddColumnsFromEpisodesToTrainBatch': 0.008209916995838284}}, 'num_non_trainable_parameters': 0, 'learner_connector_sum_episodes_length_in': 200, 'learner_connector_sum_episodes_length_out': 200, 'num_env_steps_trained_lifetime': 200, 'num_trainable_parameters': 655878, 'num_module_steps_trained': 404, 'num_env_steps_trained': 200, 'num_module_steps_trained_lifetime': 404, 'num_env_steps_trained_lifetime_throughput': nan}, 'p0': {'num_trainable_parameters': 655878, 'weights_seq_no': 1.0, 'entropy': 1.5917836427688599, 'module_train_batch_size_mean': 404, 'policy_loss': -0.15537874400615692, 'total_loss': -0.15024973452091217, 'vf_loss_unclipped': 1.6708769123852107e-07, 'num_module_steps_trained': 404, 'mean_kl_loss': 0.025644313544034958, 'num_module_steps_trained_lifetime': 404, 'curr_kl_coeff': 0.30000001192092896, 'num_non_trainable_parameters': 0, 'vf_explained_var': 0.9406133890151978, 'default_optimizer_learning_rate': 5e-05, 'gradients_default_optimizer_global_norm': 0.9358460307121277, 'curr_entropy_coeff': 0.0, 'diff_num_grad_updates_vs_sampler_policy': 0.0, 'vf_loss': 1.6708769123852107e-07}} 200 nan 1{'cpu_util_percent': 67.225, 'ram_util_percent': 96.525}{'training_iteration': 2.536862332985038, 'restore_env_runners': 2.4582986952736974e-05, 'training_step': 2.5366032080200966, 'env_runner_sampling_timer': 0.5369193330116104, 'learner_update_timer': 1.9899427500204183, 'synch_weights': 0.009463166003115475}
2025-03-21 16:51:50,353	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/che/ray_results/PPO_2025-03-21_16-51-28' in 0.0099s.
(PPO(env=flatland_env; env-runners=2; learners=0; multi-agent=True) pid=85022) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/che/ray_results/PPO_2025-03-21_16-51-28/PPO_flatland_env_59ef6_00000_0_2025-03-21_16-51-28/checkpoint_000000)
2025-03-21 16:51:50,489	INFO tune.py:1041 -- Total run time: 22.13 seconds (21.97 seconds for the tuning loop).
== Status ==
Current time: 2025-03-21 16:51:50 (running for 00:00:21.98)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-03-21_16-51-24_808186_84981/artifacts/2025-03-21_16-51-28/PPO_2025-03-21_16-51-28/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)
+------------------------------+------------+-----------------+--------+------------------+------+
| Trial name                   | status     | loc             |   iter |   total time (s) |   ts |
|------------------------------+------------+-----------------+--------+------------------+------|
| PPO_flatland_env_59ef6_00000 | TERMINATED | 127.0.0.1:85022 |      1 |          2.54778 |  200 |
+------------------------------+------------+-----------------+--------+------------------+------+
(MultiAgentEnvRunner pid=85027) /Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2 [repeated 4x across cluster]
(MultiAgentEnvRunner pid=85027)   warnings.warn(city_warning) [repeated 4x across cluster]
(MultiAgentEnvRunner pid=85027) /Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities. [repeated 4x across cluster]
(MultiAgentEnvRunner pid=85027)   warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.") [repeated 4x across cluster]
(PPO pid=85022) 2025-03-21 16:51:46,512	WARNING rl_module.py:419 -- Could not create a Catalog object for your RLModule! If you are not using the new API stack yet, make sure to switch it off in your config: `config.api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)`. All algos use the new stack by default. Ignore this message, if your RLModule does not use a Catalog to build its sub-components. [repeated 3x across cluster]
(PPO pid=85022) 2025-03-21 16:51:46,472	WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future! [repeated 2x across cluster]
(PPO pid=85022) agent_to_module_mapping.py: [Box(0.0, 2.0, (1020,), float64), Box(0.0, 2.0, (1020,), float64)] [repeated 12x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)

Rollout from Checkpoint#

parser = add_flatland_inference_with_random_policy_args()

Inspect Rollout cli Options#

!python -m flatland.ml.ray.examples.flatland_inference_with_random_policy --help
usage: flatland_inference_with_random_policy.py [-h] [--num-agents NUM_AGENTS]
                                                [--obs-builder OBS_BUILDER]
                                                [--num-episodes-during-inference NUM_EPISODES_DURING_INFERENCE]
                                                [--policy-id POLICY_ID]
                                                [--cp CP]

options:
  -h, --help            show this help message and exit
  --num-agents NUM_AGENTS
  --obs-builder OBS_BUILDER
  --num-episodes-during-inference NUM_EPISODES_DURING_INFERENCE
                        Number of episodes to do inference over (after
                        restoring from a checkpoint).
  --policy-id POLICY_ID
  --cp CP

Inspect Rollout cli Code#

%pycat inspect.getsource(rollout)
def rollout(args: Namespace):
    # Create an env to do inference in.
    env = ray_env_generator(n_agents=args.num_agents, obs_builder_object=registry_get_input(args.obs_builder)())
    obs, _ = env.reset()

    num_episodes = 0
    episode_return = 0.0

    if args.cp is not None:
        cp = os.path.join(
            args.cp,
            "learner_group",
            "learner",
            "rl_module",
            args.policy_id,
        )
        rl_module = RLModule.from_checkpoint(cp)
    else:
        rl_module = RandomRLModule(action_space=env.action_space)

    while num_episodes < args.num_episodes_during_inference:
        obss = np.stack(list(obs.values()))
        if args.cp is not None:
            rl_module_out = rl_module.forward_inference({"obs": torch.from_numpy(obss).unsqueeze(0).float()})
            if Columns.ACTIONS in rl_module_out:
                action_dict = dict(zip(env.agents, convert_to_numpy(rl_module_out[Columns.ACTIONS][0])))
            else:
                logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS])
                action_dict = {str(h): np.random.choice(len(RailEnvActions), p=softmax(l)) for h, l in enumerate(logits[0])}
        else:
            action_dict = rl_module.forward_inference({"obs": np.expand_dims(obs, 0)})
            action_dict = {h: a[0] for h, a in action_dict['actions'].items()}

        obs, rewards, terminateds, truncateds, _ = env.step(action_dict)
        for _, v in rewards.items():
            episode_return += v
        # Is the episode `done`? -> Reset.
        if terminateds["__all__"] or truncateds["__all__"]:
            print(f"Episode done: Total reward = {episode_return}")
            env.reset()
            num_episodes += 1
            episode_return = 0.0
    print(f"Done performing action inference through {num_episodes} Episodes")

Rollout on best checkpoint from previous training#

register_flatland_ray_cli_observation_builders()
best_result = results.get_best_result(
    metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max"
)
!python -m flatland.ml.ray.examples.flatland_inference_with_random_policy --num-agents 2 --obs-builder {obid} --cp {best_result.checkpoint.path} --policy-id p0  --num-episodes-during-inference 1
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:321: UserWarning: Could not set all required cities! Created 1/2
  warnings.warn(city_warning)
/Users/che/workspaces/flatland-rl-2/flatland/envs/rail_generators.py:217: UserWarning: [WARNING] Changing to Grid mode to place at least 2 cities.
  warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
2025-03-21 16:52:03,227	WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
Episode done: Total reward = -230.0
Done performing action inference through 1 Episodes