Skip to content

Commit 94937a1

Browse files
authored
[RLlib] Cleanup examples folder #14: Add example script for policy (RLModule) inference on new API stack. (#45831)
1 parent e48dba9 commit 94937a1

File tree

2 files changed

+139
-100
lines changed

2 files changed

+139
-100
lines changed

rllib/BUILD

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,24 +2495,13 @@ py_test(
24952495

24962496
# subdirectory: inference/
24972497
# ....................................
2498-
#@OldAPIStack
2499-
py_test(
2500-
name = "examples/inference/policy_inference_after_training_tf",
2501-
main = "examples/inference/policy_inference_after_training.py",
2502-
tags = ["team:rllib", "exclusive", "examples"],
2503-
size = "medium",
2504-
srcs = ["examples/inference/policy_inference_after_training.py"],
2505-
args = ["--stop-iters=3", "--framework=tf"]
2506-
)
2507-
2508-
#@OldAPIStack
25092498
py_test(
2510-
name = "examples/inference/policy_inference_after_training_torch",
2499+
name = "examples/inference/policy_inference_after_training",
25112500
main = "examples/inference/policy_inference_after_training.py",
25122501
tags = ["team:rllib", "exclusive", "examples"],
25132502
size = "medium",
25142503
srcs = ["examples/inference/policy_inference_after_training.py"],
2515-
args = ["--stop-iters=3", "--framework=torch"]
2504+
args = ["--enable-new-api-stack", "--stop-reward=100.0"]
25162505
)
25172506

25182507
#@OldAPIStack
Lines changed: 137 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,107 @@
1-
"""
2-
Example showing how you can use your trained policy for inference
3-
(computing actions) in an environment.
1+
"""Example on how to compute actions in production on an already trained policy.
2+
3+
This example uses the simplest setup possible: An RLModule (policy net) recovered
4+
from a checkpoint and a manual env-loop (CartPole-v1). No ConnectorV2s or EnvRunners are
5+
used in this example.
6+
7+
This example shows ..
8+
- .. how to use an already existing checkpoint to extract a single-agent RLModule
9+
from (our policy network).
10+
- .. how to setup this recovered policy net for action computations (with or without
11+
using exploration).
12+
- .. have the policy run through a very simple gymnasium based env-loop, w/o using
13+
RLlib's ConnectorV2s or EnvRunners.
14+
15+
16+
How to run this script
17+
----------------------
18+
`python [script file name].py --enable-new-api-stack --stop-reward=200.0`
19+
20+
Use the `--explore-during-inference` option to switch on exploratory behavior
21+
during inference. Normally, you should not explore during inference, though,
22+
unless your environment has a stochastic optimal solution.
23+
Use the `--num-episodes-during-inference=[int]` option to set the number of
24+
episodes to run through during the inference phase using the restored RLModule.
25+
26+
For debugging, use the following additional command line options
27+
`--no-tune --num-env-runners=0`
28+
which should allow you to set breakpoints anywhere in the RLlib code and
29+
have the execution stop there for inspection and debugging.
30+
31+
Note that the shown GPU settings in this script also work in case you are not
32+
running via tune, but instead are using the `--no-tune` command line option.
33+
34+
For logging to your WandB account, use:
35+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
36+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
37+
38+
You can visualize experiment results in ~/ray_results using TensorBoard.
439
5-
Includes options for LSTM-based models (--use-lstm), attention-net models
6-
(--use-attention), and plain (non-recurrent) models.
40+
41+
Results to expect
42+
-----------------
43+
44+
For the training step - depending on your `--stop-reward` setting, you should see
45+
something similar to this:
46+
47+
Number of trials: 1/1 (1 TERMINATED)
48+
+-----------------------------+------------+-----------------+--------+
49+
| Trial name | status | loc | iter |
50+
| | | | |
51+
|-----------------------------+------------+-----------------+--------+
52+
| PPO_CartPole-v1_6660c_00000 | TERMINATED | 127.0.0.1:43566 | 8 |
53+
+-----------------------------+------------+-----------------+--------+
54+
+------------------+------------------------+------------------------+
55+
| total time (s) | num_env_steps_sample | num_env_steps_traine |
56+
| | d_lifetime | d_lifetime |
57+
+------------------+------------------------+------------------------+
58+
| 21.0283 | 32000 | 32000 |
59+
+------------------+------------------------+------------------------+
60+
61+
Then, after restoring the RLModule for the inference phase, your output should
62+
look similar to:
63+
64+
Training completed. Restoring new RLModule for action inference.
65+
Episode done: Total reward = 500.0
66+
Episode done: Total reward = 500.0
67+
Episode done: Total reward = 500.0
68+
Episode done: Total reward = 500.0
69+
Episode done: Total reward = 500.0
70+
Episode done: Total reward = 500.0
71+
Episode done: Total reward = 500.0
72+
Episode done: Total reward = 500.0
73+
Episode done: Total reward = 500.0
74+
Episode done: Total reward = 500.0
75+
Done performing action inference through 10 Episodes
776
"""
8-
import argparse
977
import gymnasium as gym
78+
import numpy as np
1079
import os
1180

12-
import ray
13-
from ray import air, tune
14-
from ray.air.constants import TRAINING_ITERATION
15-
from ray.rllib.algorithms.algorithm import Algorithm
81+
from ray.rllib.core import DEFAULT_MODULE_ID
82+
from ray.rllib.core.columns import Columns
83+
from ray.rllib.core.rl_module.rl_module import RLModule
84+
from ray.rllib.utils.framework import try_import_torch
85+
from ray.rllib.utils.numpy import convert_to_numpy, softmax
1686
from ray.rllib.utils.metrics import (
1787
ENV_RUNNER_RESULTS,
1888
EPISODE_RETURN_MEAN,
19-
NUM_ENV_STEPS_SAMPLED_LIFETIME,
89+
)
90+
from ray.rllib.utils.test_utils import (
91+
add_rllib_example_script_args,
92+
run_rllib_example_script_experiment,
2093
)
2194
from ray.tune.registry import get_trainable_cls
2295

23-
parser = argparse.ArgumentParser()
24-
parser.add_argument(
25-
"--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
26-
)
27-
parser.add_argument("--num-cpus", type=int, default=0)
28-
parser.add_argument(
29-
"--framework",
30-
choices=["tf", "tf2", "torch"],
31-
default="torch",
32-
help="The DL framework specifier.",
33-
)
34-
parser.add_argument(
35-
"--stop-iters",
36-
type=int,
37-
default=200,
38-
help="Number of iterations to train before we do inference.",
39-
)
40-
parser.add_argument(
41-
"--stop-timesteps",
42-
type=int,
43-
default=100000,
44-
help="Number of timesteps to train before we do inference.",
45-
)
46-
parser.add_argument(
47-
"--stop-reward",
48-
type=float,
49-
default=150.0,
50-
help="Reward at which we stop training before we do inference.",
96+
torch, _ = try_import_torch()
97+
98+
parser = add_rllib_example_script_args(default_reward=200.0)
99+
parser.set_defaults(
100+
# Make sure that - by default - we produce checkpoints during training.
101+
checkpoint_freq=1,
102+
checkpoint_at_end=True,
103+
# Use CartPole-v1 by default.
104+
env="CartPole-v1",
51105
)
52106
parser.add_argument(
53107
"--explore-during-inference",
@@ -59,74 +113,70 @@
59113
"--num-episodes-during-inference",
60114
type=int,
61115
default=10,
62-
help="Number of episodes to do inference over after training.",
116+
help="Number of episodes to do inference over (after restoring from a checkpoint).",
63117
)
64118

119+
65120
if __name__ == "__main__":
66121
args = parser.parse_args()
67122

68-
ray.init(num_cpus=args.num_cpus or None)
123+
assert (
124+
args.enable_new_api_stack
125+
), "Must set --enable-new-api-stack when running this script!"
69126

70-
config = (
71-
get_trainable_cls(args.run)
72-
.get_default_config()
73-
.environment("FrozenLake-v1")
74-
# Run with tracing enabled for tf2?
75-
.framework(args.framework)
76-
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
77-
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
78-
)
79-
80-
stop = {
81-
TRAINING_ITERATION: args.stop_iters,
82-
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
83-
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
84-
}
127+
base_config = get_trainable_cls(args.algo).get_default_config()
85128

86129
print("Training policy until desired reward/timesteps/iterations. ...")
87-
tuner = tune.Tuner(
88-
args.run,
89-
param_space=config.to_dict(),
90-
run_config=air.RunConfig(
91-
stop=stop,
92-
verbose=2,
93-
checkpoint_config=air.CheckpointConfig(
94-
checkpoint_frequency=1, checkpoint_at_end=True
95-
),
96-
),
97-
)
98-
results = tuner.fit()
130+
results = run_rllib_example_script_experiment(base_config, args)
99131

100-
print("Training completed. Restoring new Algorithm for action inference.")
132+
print("Training completed. Restoring new RLModule for action inference.")
101133
# Get the last checkpoint from the above training run.
102-
checkpoint = results.get_best_result().checkpoint
134+
best_result = results.get_best_result(
135+
metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max"
136+
)
103137
# Create new Algorithm and restore its state from the last checkpoint.
104-
algo = Algorithm.from_checkpoint(checkpoint)
138+
rl_module = RLModule.from_checkpoint(
139+
os.path.join(
140+
best_result.checkpoint.path,
141+
"learner",
142+
"module_state",
143+
DEFAULT_MODULE_ID,
144+
)
145+
)
105146

106147
# Create the env to do inference in.
107-
env = gym.make("FrozenLake-v1")
148+
env = gym.make(args.env)
108149
obs, info = env.reset()
109150

110151
num_episodes = 0
111-
episode_reward = 0.0
152+
episode_return = 0.0
112153

113154
while num_episodes < args.num_episodes_during_inference:
114-
# Compute an action (`a`).
115-
a = algo.compute_single_action(
116-
observation=obs,
117-
explore=args.explore_during_inference,
118-
policy_id="default_policy", # <- default value
119-
)
155+
# Compute an action using a B=1 observation "batch".
156+
input_dict = {Columns.OBS: torch.from_numpy(obs).unsqueeze(0)}
157+
# No exploration.
158+
if not args.explore_during_inference:
159+
rl_module_out = rl_module.forward_inference(input_dict)
160+
# Using exploration.
161+
else:
162+
rl_module_out = rl_module.forward_exploration(input_dict)
163+
164+
# For discrete action spaces used here, normally, an RLModule "only"
165+
# produces action logits, from which we then have to sample.
166+
# However, you can also write custom RLModules that output actions
167+
# directly, performing the sampling step already inside their
168+
# `forward_...()` methods.
169+
logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS])
170+
# Perform the sampling step in numpy for simplicity.
171+
action = np.random.choice(env.action_space.n, p=softmax(logits[0]))
120172
# Send the computed action `a` to the env.
121-
obs, reward, done, truncated, _ = env.step(a)
122-
episode_reward += reward
173+
obs, reward, terminated, truncated, _ = env.step(action)
174+
episode_return += reward
123175
# Is the episode `done`? -> Reset.
124-
if done:
125-
print(f"Episode done: Total reward = {episode_reward}")
176+
if terminated or truncated:
177+
print(f"Episode done: Total reward = {episode_return}")
126178
obs, info = env.reset()
127179
num_episodes += 1
128-
episode_reward = 0.0
129-
130-
algo.stop()
180+
episode_return = 0.0
131181

132-
ray.shutdown()
182+
print(f"Done performing action inference through {num_episodes} Episodes")

0 commit comments

Comments
 (0)