Skip to content

Commit 29bc824

Browse files
[RLlib] Fix failing env step in MultiAgentEnvRunner. (#55567)
## Why are these changes needed? Fix failing release test: `learning_tests_multi_agent_cartpole_appo_multi_gpu`. ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [x] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Kamil Kaczmarek <[email protected]> Signed-off-by: Kamil Kaczmarek <[email protected]>
1 parent 9539786 commit 29bc824

File tree

6 files changed

+115
-99
lines changed

6 files changed

+115
-99
lines changed

rllib/env/env_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,21 @@ def _try_env_reset(
228228
raise e
229229

230230
def _try_env_step(self, actions):
231-
"""Tries stepping the env and - if an error orrurs - handles it gracefully."""
231+
"""Tries stepping the env and - if an error occurs - handles it gracefully."""
232232
try:
233233
with self.metrics.log_time(ENV_STEP_TIMER):
234234
results = self.env.step(actions)
235235
return results
236236
except Exception as e:
237237
self.metrics.log_value(NUM_ENV_STEP_FAILURES_LIFETIME, 1, reduce="sum")
238238

239+
# @OldAPIStack (config.restart_failed_sub_environments)
239240
if self.config.restart_failed_sub_environments:
240241
if not isinstance(e, StepFailedRecreateEnvError):
241-
logger.exception("Stepping the env resulted in an error!")
242+
logger.exception(
243+
"Stepping the env resulted in an error! The original error "
244+
f"is: {e}"
245+
)
242246
# Recreate the env.
243247
self.make_env()
244248
# And return that the stepping failed. The caller will then handle

rllib/env/multi_agent_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def step(self, action_dict):
444444
# an additional episode_done bool that covers cases where all agents are
445445
# either terminated or truncated, but not all are truncated and not all are
446446
# terminated. We can then get rid of the aweful `__all__` special keys!
447-
terminated["__all__"] = len(self.terminateds) + len(self.truncateds) == len(
447+
terminated["__all__"] = len(self.terminateds | self.truncateds) == len(
448448
self.envs
449449
)
450450
truncated["__all__"] = len(self.truncateds) == len(self.envs)

rllib/env/multi_agent_env_runner.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def sample(
158158
159159
Args:
160160
num_timesteps: The number of timesteps to sample during this call.
161-
Note that only one of `num_timetseps` or `num_episodes` may be provided.
161+
Note that only one of `num_timesteps` or `num_episodes` may be provided.
162162
num_episodes: The number of episodes to sample during this call.
163-
Note that only one of `num_timetseps` or `num_episodes` may be provided.
163+
Note that only one of `num_timesteps` or `num_episodes` may be provided.
164164
explore: If True, will use the RLModule's `forward_exploration()`
165165
method to compute actions. If False, will use the RLModule's
166166
`forward_inference()` method. If None (default), will use the `explore`
@@ -183,8 +183,11 @@ def sample(
183183
f"{self} doesn't have an env! Can't call `sample()` on it."
184184
)
185185

186-
assert not (num_timesteps is not None and num_episodes is not None)
187-
186+
assert not (num_timesteps is not None and num_episodes is not None), (
187+
"Provide "
188+
"either `num_timesteps` or `num_episodes`. Both provided here:"
189+
f"{num_timesteps=}, {num_episodes=}"
190+
)
188191
# Log time between `sample()` requests.
189192
if self._time_after_sampling is not None:
190193
self.metrics.log_value(
@@ -214,23 +217,22 @@ def sample(
214217
* self.num_envs
215218
)
216219

217-
# Sample n timesteps.
220+
# Sample "num_timesteps" timesteps.
218221
if num_timesteps is not None:
219222
samples = self._sample(
220223
num_timesteps=num_timesteps,
221224
explore=explore,
222225
random_actions=random_actions,
223226
force_reset=force_reset,
224227
)
225-
# Sample m episodes.
228+
# Sample "num_episodes" episodes.
226229
elif num_episodes is not None:
227230
samples = self._sample(
228231
num_episodes=num_episodes,
229232
explore=explore,
230233
random_actions=random_actions,
231234
)
232-
# For complete episodes mode, sample as long as the number of timesteps
233-
# done is smaller than the `train_batch_size`.
235+
# For batch_mode="complete_episodes" (env_runners configuration), continue sampling as long as the number of timesteps done is smaller than the `train_batch_size`.
234236
else:
235237
samples = self._sample(
236238
num_episodes=self.num_envs,
@@ -346,21 +348,24 @@ def _sample(
346348
metrics_prefix_key=(MODULE_TO_ENV_CONNECTOR,),
347349
)
348350
# In case all environments had been terminated `to_module` will be
349-
# empty and no actions are needed b/c we reset all environemnts.
351+
# empty and no actions are needed b/c we reset all environments.
350352
else:
351353
to_env = {}
352354
shared_data["vector_env_episodes_map"] = {}
353355

354356
# Extract the (vectorized) actions (to be sent to the env) from the
355357
# module/connector output. Note that these actions are fully ready (e.g.
356-
# already unsquashed/clipped) to be sent to the environment) and might not
358+
# already unsquashed/clipped) to be sent to the environment and might not
357359
# be identical to the actions produced by the RLModule/distribution, which
358360
# are the ones stored permanently in the episode objects.
359361
actions = to_env.pop(Columns.ACTIONS, [{} for _ in episodes])
360362
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
361363
# Try stepping the environment.
362364
results = self._try_env_step(actions_for_env)
363365
if results == ENV_STEP_FAILURE:
366+
logging.warning(
367+
f"RLlib {self.__class__.__name__}: Environment step failed. Will force reset env(s) in this EnvRunner."
368+
)
364369
return self._sample(
365370
num_timesteps=num_timesteps,
366371
num_episodes=num_episodes,
@@ -372,7 +377,7 @@ def _sample(
372377

373378
call_on_episode_start = set()
374379
# Store the data from the last environment step into the
375-
# episodes for all sub-envrironments.
380+
# episodes for all sub-environments.
376381
for env_index in range(self.num_envs):
377382
extra_model_outputs = defaultdict(dict)
378383
# `to_env` returns a dictionary with column keys and
@@ -710,7 +715,7 @@ def set_state(self, state: StateDict) -> None:
710715
# update.
711716
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
712717

713-
# Only update the weigths, if this is the first synchronization or
718+
# Only update the weights, if this is the first synchronization or
714719
# if the weights of this `EnvRunner` lacks behind the actual ones.
715720
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
716721
rl_module_state = state[COMPONENT_RL_MODULE]

rllib/env/single_agent_env_runner.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def sample(
157157
158158
Args:
159159
num_timesteps: The number of timesteps to sample during this call.
160-
Note that only one of `num_timetseps` or `num_episodes` may be provided.
160+
Note that only one of `num_timesteps` or `num_episodes` may be provided.
161161
num_episodes: The number of episodes to sample during this call.
162-
Note that only one of `num_timetseps` or `num_episodes` may be provided.
162+
Note that only one of `num_timesteps` or `num_episodes` may be provided.
163163
explore: If True, will use the RLModule's `forward_exploration()`
164164
method to compute actions. If False, will use the RLModule's
165165
`forward_inference()` method. If None (default), will use the `explore`
@@ -328,7 +328,7 @@ def _sample(
328328

329329
# Extract the (vectorized) actions (to be sent to the env) from the
330330
# module/connector output. Note that these actions are fully ready (e.g.
331-
# already unsquashed/clipped) to be sent to the environment) and might not
331+
# already unsquashed/clipped) to be sent to the environment and might not
332332
# be identical to the actions produced by the RLModule/distribution, which
333333
# are the ones stored permanently in the episode objects.
334334
actions = to_env.pop(Columns.ACTIONS)
@@ -362,7 +362,7 @@ def _sample(
362362

363363
# Call `add_env_step()` method on episode.
364364
else:
365-
# Only increase ts when we actually stepped (not reset'd as a reset
365+
# Only increase ts when we actually stepped (not reset as a reset
366366
# does not count as a timestep).
367367
ts += 1
368368
episodes[env_index].add_env_step(
@@ -375,7 +375,7 @@ def _sample(
375375
extra_model_outputs=extra_model_output,
376376
)
377377

378-
# Env-to-module connector pass (cache results as we will do the RLModule
378+
# Env-to-module connector pass cache results as we will do the RLModule
379379
# forward pass only in the next `while`-iteration.
380380
if self.module is not None:
381381
self._cached_to_module = self._env_to_module(
@@ -442,7 +442,7 @@ def _sample(
442442
]
443443

444444
for eps in self._episodes:
445-
# Just started Episodes do not have to be returned. There is no data
445+
# Just started episodes do not have to be returned. There is no data
446446
# in them anyway.
447447
if eps.t == 0:
448448
continue
@@ -554,8 +554,8 @@ def set_state(self, state: StateDict) -> None:
554554
# update.
555555
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
556556

557-
# Only update the weigths, if this is the first synchronization or
558-
# if the weights of this `EnvRunner` lacks behind the actual ones.
557+
# Only update the weights, if this is the first synchronization or
558+
# if the weights of this `EnvRunner` lag behind the actual ones.
559559
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
560560
rl_module_state = state[COMPONENT_RL_MODULE]
561561
if isinstance(rl_module_state, ray.ObjectRef):
@@ -609,13 +609,13 @@ def get_checkpointable_components(self):
609609
def assert_healthy(self):
610610
"""Checks that self.__init__() has been completed properly.
611611
612-
Ensures that the instances has a `MultiRLModule` and an
612+
Ensures that the instance has a `MultiRLModule` and an
613613
environment defined.
614614
615615
Raises:
616616
AssertionError: If the EnvRunner Actor has NOT been properly initialized.
617617
"""
618-
# Make sure, we have built our gym.vector.Env and RLModule properly.
618+
# Make sure we have built our gym.vector.Env and RLModule properly.
619619
assert self.env and hasattr(self, "module")
620620

621621
@override(EnvRunner)
@@ -626,8 +626,8 @@ def make_env(self) -> None:
626626
`self.config.env_config`) and then call this method to create new environments
627627
with the updated configuration.
628628
"""
629-
# If an env already exists, try closing it first (to allow it to properly
630-
# cleanup).
629+
# If an env already exists, try closing it first
630+
# to allow it to properly clean up.
631631
if self.env is not None:
632632
try:
633633
self.env.close()
@@ -854,7 +854,7 @@ def _log_episode_metrics(self, length, ret, sec):
854854
# Log general episode metrics.
855855
# Use the configured window, but factor in the parallelism of the EnvRunners.
856856
# As a result, we only log the last `window / num_env_runners` steps here,
857-
# b/c everything gets parallel-merged in the Algorithm process.
857+
# because everything gets parallel-merged in the Algorithm process.
858858
win = max(
859859
1,
860860
int(

0 commit comments

Comments
 (0)