|
1 | | -""" |
2 | | -Example showing how you can use your trained policy for inference |
3 | | -(computing actions) in an environment. |
| 1 | +"""Example on how to compute actions in production on an already trained policy. |
| 2 | +
|
| 3 | +This example uses the simplest setup possible: An RLModule (policy net) recovered |
| 4 | +from a checkpoint and a manual env-loop (CartPole-v1). No ConnectorV2s or EnvRunners are |
| 5 | +used in this example. |
| 6 | +
|
| 7 | +This example shows .. |
| 8 | + - .. how to use an already existing checkpoint to extract a single-agent RLModule |
| 9 | + from (our policy network). |
| 10 | + - .. how to setup this recovered policy net for action computations (with or without |
| 11 | + using exploration). |
| 12 | + - .. have the policy run through a very simple gymnasium based env-loop, w/o using |
| 13 | + RLlib's ConnectorV2s or EnvRunners. |
| 14 | +
|
| 15 | +
|
| 16 | +How to run this script |
| 17 | +---------------------- |
| 18 | +`python [script file name].py --enable-new-api-stack --stop-reward=200.0` |
| 19 | +
|
| 20 | +Use the `--explore-during-inference` option to switch on exploratory behavior |
| 21 | +during inference. Normally, you should not explore during inference, though, |
| 22 | +unless your environment has a stochastic optimal solution. |
| 23 | +Use the `--num-episodes-during-inference=[int]` option to set the number of |
| 24 | +episodes to run through during the inference phase using the restored RLModule. |
| 25 | +
|
| 26 | +For debugging, use the following additional command line options |
| 27 | +`--no-tune --num-env-runners=0` |
| 28 | +which should allow you to set breakpoints anywhere in the RLlib code and |
| 29 | +have the execution stop there for inspection and debugging. |
| 30 | +
|
| 31 | +Note that the shown GPU settings in this script also work in case you are not |
| 32 | +running via tune, but instead are using the `--no-tune` command line option. |
| 33 | +
|
| 34 | +For logging to your WandB account, use: |
| 35 | +`--wandb-key=[your WandB API key] --wandb-project=[some project name] |
| 36 | +--wandb-run-name=[optional: WandB run name (within the defined project)]` |
| 37 | +
|
| 38 | +You can visualize experiment results in ~/ray_results using TensorBoard. |
4 | 39 |
|
5 | | -Includes options for LSTM-based models (--use-lstm), attention-net models |
6 | | -(--use-attention), and plain (non-recurrent) models. |
| 40 | +
|
| 41 | +Results to expect |
| 42 | +----------------- |
| 43 | +
|
| 44 | +For the training step - depending on your `--stop-reward` setting, you should see |
| 45 | +something similar to this: |
| 46 | +
|
| 47 | +Number of trials: 1/1 (1 TERMINATED) |
| 48 | ++-----------------------------+------------+-----------------+--------+ |
| 49 | +| Trial name | status | loc | iter | |
| 50 | +| | | | | |
| 51 | +|-----------------------------+------------+-----------------+--------+ |
| 52 | +| PPO_CartPole-v1_6660c_00000 | TERMINATED | 127.0.0.1:43566 | 8 | |
| 53 | ++-----------------------------+------------+-----------------+--------+ |
| 54 | ++------------------+------------------------+------------------------+ |
| 55 | +| total time (s) | num_env_steps_sample | num_env_steps_traine | |
| 56 | +| | d_lifetime | d_lifetime | |
| 57 | ++------------------+------------------------+------------------------+ |
| 58 | +| 21.0283 | 32000 | 32000 | |
| 59 | ++------------------+------------------------+------------------------+ |
| 60 | +
|
| 61 | +Then, after restoring the RLModule for the inference phase, your output should |
| 62 | +look similar to: |
| 63 | +
|
| 64 | +Training completed. Restoring new RLModule for action inference. |
| 65 | +Episode done: Total reward = 500.0 |
| 66 | +Episode done: Total reward = 500.0 |
| 67 | +Episode done: Total reward = 500.0 |
| 68 | +Episode done: Total reward = 500.0 |
| 69 | +Episode done: Total reward = 500.0 |
| 70 | +Episode done: Total reward = 500.0 |
| 71 | +Episode done: Total reward = 500.0 |
| 72 | +Episode done: Total reward = 500.0 |
| 73 | +Episode done: Total reward = 500.0 |
| 74 | +Episode done: Total reward = 500.0 |
| 75 | +Done performing action inference through 10 Episodes |
7 | 76 | """ |
8 | | -import argparse |
9 | 77 | import gymnasium as gym |
| 78 | +import numpy as np |
10 | 79 | import os |
11 | 80 |
|
12 | | -import ray |
13 | | -from ray import air, tune |
14 | | -from ray.air.constants import TRAINING_ITERATION |
15 | | -from ray.rllib.algorithms.algorithm import Algorithm |
| 81 | +from ray.rllib.core import DEFAULT_MODULE_ID |
| 82 | +from ray.rllib.core.columns import Columns |
| 83 | +from ray.rllib.core.rl_module.rl_module import RLModule |
| 84 | +from ray.rllib.utils.framework import try_import_torch |
| 85 | +from ray.rllib.utils.numpy import convert_to_numpy, softmax |
16 | 86 | from ray.rllib.utils.metrics import ( |
17 | 87 | ENV_RUNNER_RESULTS, |
18 | 88 | EPISODE_RETURN_MEAN, |
19 | | - NUM_ENV_STEPS_SAMPLED_LIFETIME, |
| 89 | +) |
| 90 | +from ray.rllib.utils.test_utils import ( |
| 91 | + add_rllib_example_script_args, |
| 92 | + run_rllib_example_script_experiment, |
20 | 93 | ) |
21 | 94 | from ray.tune.registry import get_trainable_cls |
22 | 95 |
|
23 | | -parser = argparse.ArgumentParser() |
24 | | -parser.add_argument( |
25 | | - "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." |
26 | | -) |
27 | | -parser.add_argument("--num-cpus", type=int, default=0) |
28 | | -parser.add_argument( |
29 | | - "--framework", |
30 | | - choices=["tf", "tf2", "torch"], |
31 | | - default="torch", |
32 | | - help="The DL framework specifier.", |
33 | | -) |
34 | | -parser.add_argument( |
35 | | - "--stop-iters", |
36 | | - type=int, |
37 | | - default=200, |
38 | | - help="Number of iterations to train before we do inference.", |
39 | | -) |
40 | | -parser.add_argument( |
41 | | - "--stop-timesteps", |
42 | | - type=int, |
43 | | - default=100000, |
44 | | - help="Number of timesteps to train before we do inference.", |
45 | | -) |
46 | | -parser.add_argument( |
47 | | - "--stop-reward", |
48 | | - type=float, |
49 | | - default=150.0, |
50 | | - help="Reward at which we stop training before we do inference.", |
| 96 | +torch, _ = try_import_torch() |
| 97 | + |
| 98 | +parser = add_rllib_example_script_args(default_reward=200.0) |
| 99 | +parser.set_defaults( |
| 100 | + # Make sure that - by default - we produce checkpoints during training. |
| 101 | + checkpoint_freq=1, |
| 102 | + checkpoint_at_end=True, |
| 103 | + # Use CartPole-v1 by default. |
| 104 | + env="CartPole-v1", |
51 | 105 | ) |
52 | 106 | parser.add_argument( |
53 | 107 | "--explore-during-inference", |
|
59 | 113 | "--num-episodes-during-inference", |
60 | 114 | type=int, |
61 | 115 | default=10, |
62 | | - help="Number of episodes to do inference over after training.", |
| 116 | + help="Number of episodes to do inference over (after restoring from a checkpoint).", |
63 | 117 | ) |
64 | 118 |
|
| 119 | + |
65 | 120 | if __name__ == "__main__": |
66 | 121 | args = parser.parse_args() |
67 | 122 |
|
68 | | - ray.init(num_cpus=args.num_cpus or None) |
| 123 | + assert ( |
| 124 | + args.enable_new_api_stack |
| 125 | + ), "Must set --enable-new-api-stack when running this script!" |
69 | 126 |
|
70 | | - config = ( |
71 | | - get_trainable_cls(args.run) |
72 | | - .get_default_config() |
73 | | - .environment("FrozenLake-v1") |
74 | | - # Run with tracing enabled for tf2? |
75 | | - .framework(args.framework) |
76 | | - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. |
77 | | - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) |
78 | | - ) |
79 | | - |
80 | | - stop = { |
81 | | - TRAINING_ITERATION: args.stop_iters, |
82 | | - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, |
83 | | - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, |
84 | | - } |
| 127 | + base_config = get_trainable_cls(args.algo).get_default_config() |
85 | 128 |
|
86 | 129 | print("Training policy until desired reward/timesteps/iterations. ...") |
87 | | - tuner = tune.Tuner( |
88 | | - args.run, |
89 | | - param_space=config.to_dict(), |
90 | | - run_config=air.RunConfig( |
91 | | - stop=stop, |
92 | | - verbose=2, |
93 | | - checkpoint_config=air.CheckpointConfig( |
94 | | - checkpoint_frequency=1, checkpoint_at_end=True |
95 | | - ), |
96 | | - ), |
97 | | - ) |
98 | | - results = tuner.fit() |
| 130 | + results = run_rllib_example_script_experiment(base_config, args) |
99 | 131 |
|
100 | | - print("Training completed. Restoring new Algorithm for action inference.") |
| 132 | + print("Training completed. Restoring new RLModule for action inference.") |
101 | 133 | # Get the last checkpoint from the above training run. |
102 | | - checkpoint = results.get_best_result().checkpoint |
| 134 | + best_result = results.get_best_result( |
| 135 | + metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max" |
| 136 | + ) |
103 | 137 | # Create new Algorithm and restore its state from the last checkpoint. |
104 | | - algo = Algorithm.from_checkpoint(checkpoint) |
| 138 | + rl_module = RLModule.from_checkpoint( |
| 139 | + os.path.join( |
| 140 | + best_result.checkpoint.path, |
| 141 | + "learner", |
| 142 | + "module_state", |
| 143 | + DEFAULT_MODULE_ID, |
| 144 | + ) |
| 145 | + ) |
105 | 146 |
|
106 | 147 | # Create the env to do inference in. |
107 | | - env = gym.make("FrozenLake-v1") |
| 148 | + env = gym.make(args.env) |
108 | 149 | obs, info = env.reset() |
109 | 150 |
|
110 | 151 | num_episodes = 0 |
111 | | - episode_reward = 0.0 |
| 152 | + episode_return = 0.0 |
112 | 153 |
|
113 | 154 | while num_episodes < args.num_episodes_during_inference: |
114 | | - # Compute an action (`a`). |
115 | | - a = algo.compute_single_action( |
116 | | - observation=obs, |
117 | | - explore=args.explore_during_inference, |
118 | | - policy_id="default_policy", # <- default value |
119 | | - ) |
| 155 | + # Compute an action using a B=1 observation "batch". |
| 156 | + input_dict = {Columns.OBS: torch.from_numpy(obs).unsqueeze(0)} |
| 157 | + # No exploration. |
| 158 | + if not args.explore_during_inference: |
| 159 | + rl_module_out = rl_module.forward_inference(input_dict) |
| 160 | + # Using exploration. |
| 161 | + else: |
| 162 | + rl_module_out = rl_module.forward_exploration(input_dict) |
| 163 | + |
| 164 | + # For discrete action spaces used here, normally, an RLModule "only" |
| 165 | + # produces action logits, from which we then have to sample. |
| 166 | + # However, you can also write custom RLModules that output actions |
| 167 | + # directly, performing the sampling step already inside their |
| 168 | + # `forward_...()` methods. |
| 169 | + logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS]) |
| 170 | + # Perform the sampling step in numpy for simplicity. |
| 171 | + action = np.random.choice(env.action_space.n, p=softmax(logits[0])) |
120 | 172 | # Send the computed action `a` to the env. |
121 | | - obs, reward, done, truncated, _ = env.step(a) |
122 | | - episode_reward += reward |
| 173 | + obs, reward, terminated, truncated, _ = env.step(action) |
| 174 | + episode_return += reward |
123 | 175 | # Is the episode `done`? -> Reset. |
124 | | - if done: |
125 | | - print(f"Episode done: Total reward = {episode_reward}") |
| 176 | + if terminated or truncated: |
| 177 | + print(f"Episode done: Total reward = {episode_return}") |
126 | 178 | obs, info = env.reset() |
127 | 179 | num_episodes += 1 |
128 | | - episode_reward = 0.0 |
129 | | - |
130 | | - algo.stop() |
| 180 | + episode_return = 0.0 |
131 | 181 |
|
132 | | - ray.shutdown() |
| 182 | + print(f"Done performing action inference through {num_episodes} Episodes") |
0 commit comments