Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
20993f3
handle failing env step gracefully
kamil-kaczmarek Aug 13, 2025
ca8d68d
lint
kamil-kaczmarek Aug 13, 2025
d08aaed
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Aug 13, 2025
11dd9a3
Fix bool assignment for cases where terminated and truncated are both…
kamil-kaczmarek Aug 14, 2025
f77c13d
typos
kamil-kaczmarek Aug 14, 2025
4b9b8df
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Aug 14, 2025
c57b2f7
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Aug 15, 2025
b2ce4b0
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Aug 20, 2025
ffc9f90
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 4, 2025
b7361e9
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 10, 2025
4924d5f
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 11, 2025
13c05c8
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 12, 2025
9399dd7
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 16, 2025
f996cfb
test fix
kamil-kaczmarek Sep 17, 2025
46fa378
lint
kamil-kaczmarek Sep 17, 2025
8085e5b
drop oldAPIStack param
kamil-kaczmarek Sep 17, 2025
48a4062
drop default options
kamil-kaczmarek Sep 17, 2025
92fa743
lint and imports
kamil-kaczmarek Sep 17, 2025
bc50ba0
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 17, 2025
2d75129
typos, lint
kamil-kaczmarek Sep 17, 2025
4dad4dd
PEP8-like clean up
kamil-kaczmarek Sep 17, 2025
ae2ef40
better params, replace print statement with logger
kamil-kaczmarek Sep 17, 2025
5820d51
lint
kamil-kaczmarek Sep 17, 2025
96ebcab
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 17, 2025
4c6b05c
Merge branch 'master' into kk/fix-failing-env-step
kamil-kaczmarek Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions rllib/env/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,21 @@ def _try_env_reset(
raise e

def _try_env_step(self, actions):
"""Tries stepping the env and - if an error orrurs - handles it gracefully."""
"""Tries stepping the env and - if an error occurs - handles it gracefully."""
try:
with self.metrics.log_time(ENV_STEP_TIMER):
results = self.env.step(actions)
return results
except Exception as e:
self.metrics.log_value(NUM_ENV_STEP_FAILURES_LIFETIME, 1, reduce="sum")

# @OldAPIStack (config.restart_failed_sub_environments)
if self.config.restart_failed_sub_environments:
if not isinstance(e, StepFailedRecreateEnvError):
logger.exception("Stepping the env resulted in an error!")
logger.exception(
"Stepping the env resulted in an error! The original error "
f"is: {e}"
)
# Recreate the env.
self.make_env()
# And return that the stepping failed. The caller will then handle
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def step(self, action_dict):
# an additional episode_done bool that covers cases where all agents are
# either terminated or truncated, but not all are truncated and not all are
# terminated. We can then get rid of the aweful `__all__` special keys!
terminated["__all__"] = len(self.terminateds) + len(self.truncateds) == len(
terminated["__all__"] = len(self.terminateds | self.truncateds) == len(
self.envs
)
truncated["__all__"] = len(self.truncateds) == len(self.envs)
Expand Down
29 changes: 17 additions & 12 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def sample(

Args:
num_timesteps: The number of timesteps to sample during this call.
Note that only one of `num_timetseps` or `num_episodes` may be provided.
Note that only one of `num_timesteps` or `num_episodes` may be provided.
num_episodes: The number of episodes to sample during this call.
Note that only one of `num_timetseps` or `num_episodes` may be provided.
Note that only one of `num_timesteps` or `num_episodes` may be provided.
explore: If True, will use the RLModule's `forward_exploration()`
method to compute actions. If False, will use the RLModule's
`forward_inference()` method. If None (default), will use the `explore`
Expand All @@ -183,8 +183,11 @@ def sample(
f"{self} doesn't have an env! Can't call `sample()` on it."
)

assert not (num_timesteps is not None and num_episodes is not None)

assert not (num_timesteps is not None and num_episodes is not None), (
"Provide "
"either `num_timesteps` or `num_episodes`. Both provided here:"
f"{num_timesteps=}, {num_episodes=}"
)
# Log time between `sample()` requests.
if self._time_after_sampling is not None:
self.metrics.log_value(
Expand Down Expand Up @@ -214,23 +217,22 @@ def sample(
* self.num_envs
)

# Sample n timesteps.
# Sample "num_timesteps" timesteps.
if num_timesteps is not None:
samples = self._sample(
num_timesteps=num_timesteps,
explore=explore,
random_actions=random_actions,
force_reset=force_reset,
)
# Sample m episodes.
# Sample "num_episodes" episodes.
elif num_episodes is not None:
samples = self._sample(
num_episodes=num_episodes,
explore=explore,
random_actions=random_actions,
)
# For complete episodes mode, sample as long as the number of timesteps
# done is smaller than the `train_batch_size`.
# For batch_mode="complete_episodes" (env_runners configuration), continue sampling as long as the number of timesteps done is smaller than the `train_batch_size`.
else:
samples = self._sample(
num_episodes=self.num_envs,
Expand Down Expand Up @@ -346,21 +348,24 @@ def _sample(
metrics_prefix_key=(MODULE_TO_ENV_CONNECTOR,),
)
# In case all environments had been terminated `to_module` will be
# empty and no actions are needed b/c we reset all environemnts.
# empty and no actions are needed b/c we reset all environments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why this happens now that the to_module is None. Can we debug another round and see where this happens? Then check the autoreset and the connector run (I know this is complex).

What should happen is: env resets automatically; init obs goes through the to_module connector pipeline and produces to_module which can in turn passed through the module and the to_env pipeline to produce an action.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me investigate this more. This first happened last Thursday in the release tests. Will look into the code diff.

else:
to_env = {}
shared_data["vector_env_episodes_map"] = {}

# Extract the (vectorized) actions (to be sent to the env) from the
# module/connector output. Note that these actions are fully ready (e.g.
# already unsquashed/clipped) to be sent to the environment) and might not
# already unsquashed/clipped) to be sent to the environment and might not
# be identical to the actions produced by the RLModule/distribution, which
# are the ones stored permanently in the episode objects.
actions = to_env.pop(Columns.ACTIONS, [{} for _ in episodes])
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
# Try stepping the environment.
results = self._try_env_step(actions_for_env)
if results == ENV_STEP_FAILURE:
logging.warning(
f"RLlib {self.__class__.__name__}: Environment step failed. Will force reset env(s) in this EnvRunner."
)
return self._sample(
num_timesteps=num_timesteps,
num_episodes=num_episodes,
Expand All @@ -372,7 +377,7 @@ def _sample(

call_on_episode_start = set()
# Store the data from the last environment step into the
# episodes for all sub-envrironments.
# episodes for all sub-environments.
for env_index in range(self.num_envs):
extra_model_outputs = defaultdict(dict)
# `to_env` returns a dictionary with column keys and
Expand Down Expand Up @@ -710,7 +715,7 @@ def set_state(self, state: StateDict) -> None:
# update.
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)

# Only update the weigths, if this is the first synchronization or
# Only update the weights, if this is the first synchronization or
# if the weights of this `EnvRunner` lacks behind the actual ones.
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
rl_module_state = state[COMPONENT_RL_MODULE]
Expand Down
26 changes: 13 additions & 13 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def sample(

Args:
num_timesteps: The number of timesteps to sample during this call.
Note that only one of `num_timetseps` or `num_episodes` may be provided.
Note that only one of `num_timesteps` or `num_episodes` may be provided.
num_episodes: The number of episodes to sample during this call.
Note that only one of `num_timetseps` or `num_episodes` may be provided.
Note that only one of `num_timesteps` or `num_episodes` may be provided.
explore: If True, will use the RLModule's `forward_exploration()`
method to compute actions. If False, will use the RLModule's
`forward_inference()` method. If None (default), will use the `explore`
Expand Down Expand Up @@ -328,7 +328,7 @@ def _sample(

# Extract the (vectorized) actions (to be sent to the env) from the
# module/connector output. Note that these actions are fully ready (e.g.
# already unsquashed/clipped) to be sent to the environment) and might not
# already unsquashed/clipped) to be sent to the environment and might not
# be identical to the actions produced by the RLModule/distribution, which
# are the ones stored permanently in the episode objects.
actions = to_env.pop(Columns.ACTIONS)
Expand Down Expand Up @@ -362,7 +362,7 @@ def _sample(

# Call `add_env_step()` method on episode.
else:
# Only increase ts when we actually stepped (not reset'd as a reset
# Only increase ts when we actually stepped (not reset as a reset
# does not count as a timestep).
ts += 1
episodes[env_index].add_env_step(
Expand All @@ -375,7 +375,7 @@ def _sample(
extra_model_outputs=extra_model_output,
)

# Env-to-module connector pass (cache results as we will do the RLModule
# Env-to-module connector pass cache results as we will do the RLModule
# forward pass only in the next `while`-iteration.
if self.module is not None:
self._cached_to_module = self._env_to_module(
Expand Down Expand Up @@ -442,7 +442,7 @@ def _sample(
]

for eps in self._episodes:
# Just started Episodes do not have to be returned. There is no data
# Just started episodes do not have to be returned. There is no data
# in them anyway.
if eps.t == 0:
continue
Expand Down Expand Up @@ -554,8 +554,8 @@ def set_state(self, state: StateDict) -> None:
# update.
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)

# Only update the weigths, if this is the first synchronization or
# if the weights of this `EnvRunner` lacks behind the actual ones.
# Only update the weights, if this is the first synchronization or
# if the weights of this `EnvRunner` lag behind the actual ones.
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
rl_module_state = state[COMPONENT_RL_MODULE]
if isinstance(rl_module_state, ray.ObjectRef):
Expand Down Expand Up @@ -609,13 +609,13 @@ def get_checkpointable_components(self):
def assert_healthy(self):
"""Checks that self.__init__() has been completed properly.

Ensures that the instances has a `MultiRLModule` and an
Ensures that the instance has a `MultiRLModule` and an
environment defined.

Raises:
AssertionError: If the EnvRunner Actor has NOT been properly initialized.
"""
# Make sure, we have built our gym.vector.Env and RLModule properly.
# Make sure we have built our gym.vector.Env and RLModule properly.
assert self.env and hasattr(self, "module")

@override(EnvRunner)
Expand All @@ -626,8 +626,8 @@ def make_env(self) -> None:
`self.config.env_config`) and then call this method to create new environments
with the updated configuration.
"""
# If an env already exists, try closing it first (to allow it to properly
# cleanup).
# If an env already exists, try closing it first
# to allow it to properly clean up.
if self.env is not None:
try:
self.env.close()
Expand Down Expand Up @@ -854,7 +854,7 @@ def _log_episode_metrics(self, length, ret, sec):
# Log general episode metrics.
# Use the configured window, but factor in the parallelism of the EnvRunners.
# As a result, we only log the last `window / num_env_runners` steps here,
# b/c everything gets parallel-merged in the Algorithm process.
# because everything gets parallel-merged in the Algorithm process.
win = max(
1,
int(
Expand Down
Loading