-
Notifications
You must be signed in to change notification settings - Fork 7k
[RLlib] Fix failing env step in MultiAgentEnvRunner.
#55567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
20993f3
ca8d68d
d08aaed
11dd9a3
f77c13d
4b9b8df
c57b2f7
b2ce4b0
ffc9f90
b7361e9
4924d5f
13c05c8
9399dd7
f996cfb
46fa378
8085e5b
48a4062
92fa743
bc50ba0
2d75129
4dad4dd
ae2ef40
5820d51
96ebcab
4c6b05c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,9 +158,9 @@ def sample( | |
|
|
||
| Args: | ||
| num_timesteps: The number of timesteps to sample during this call. | ||
| Note that only one of `num_timetseps` or `num_episodes` may be provided. | ||
| Note that only one of `num_timesteps` or `num_episodes` may be provided. | ||
| num_episodes: The number of episodes to sample during this call. | ||
| Note that only one of `num_timetseps` or `num_episodes` may be provided. | ||
| Note that only one of `num_timesteps` or `num_episodes` may be provided. | ||
| explore: If True, will use the RLModule's `forward_exploration()` | ||
| method to compute actions. If False, will use the RLModule's | ||
| `forward_inference()` method. If None (default), will use the `explore` | ||
|
|
@@ -183,8 +183,11 @@ def sample( | |
| f"{self} doesn't have an env! Can't call `sample()` on it." | ||
| ) | ||
|
|
||
| assert not (num_timesteps is not None and num_episodes is not None) | ||
|
|
||
| assert not (num_timesteps is not None and num_episodes is not None), ( | ||
| "Provide " | ||
kamil-kaczmarek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "either `num_timesteps` or `num_episodes`. Both provided here:" | ||
| f"{num_timesteps=}, {num_episodes=}" | ||
| ) | ||
| # Log time between `sample()` requests. | ||
| if self._time_after_sampling is not None: | ||
| self.metrics.log_value( | ||
|
|
@@ -214,23 +217,22 @@ def sample( | |
| * self.num_envs | ||
| ) | ||
|
|
||
| # Sample n timesteps. | ||
| # Sample "num_timesteps" timesteps. | ||
| if num_timesteps is not None: | ||
| samples = self._sample( | ||
| num_timesteps=num_timesteps, | ||
| explore=explore, | ||
| random_actions=random_actions, | ||
| force_reset=force_reset, | ||
| ) | ||
| # Sample m episodes. | ||
| # Sample "num_episodes" episodes. | ||
| elif num_episodes is not None: | ||
| samples = self._sample( | ||
| num_episodes=num_episodes, | ||
| explore=explore, | ||
| random_actions=random_actions, | ||
| ) | ||
| # For complete episodes mode, sample as long as the number of timesteps | ||
| # done is smaller than the `train_batch_size`. | ||
| # For batch_mode="complete_episodes" (env_runners configuration), continue sampling as long as the number of timesteps done is smaller than the `train_batch_size`. | ||
| else: | ||
| samples = self._sample( | ||
| num_episodes=self.num_envs, | ||
|
|
@@ -346,21 +348,24 @@ def _sample( | |
| metrics_prefix_key=(MODULE_TO_ENV_CONNECTOR,), | ||
| ) | ||
| # In case all environments had been terminated `to_module` will be | ||
| # empty and no actions are needed b/c we reset all environemnts. | ||
| # empty and no actions are needed b/c we reset all environments. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why this happens now that the What should happen is: env resets automatically; init obs goes through the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me investigate this more. This first happened last Thursday in the release tests. Will look into the code diff. |
||
| else: | ||
| to_env = {} | ||
| shared_data["vector_env_episodes_map"] = {} | ||
|
|
||
| # Extract the (vectorized) actions (to be sent to the env) from the | ||
| # module/connector output. Note that these actions are fully ready (e.g. | ||
| # already unsquashed/clipped) to be sent to the environment) and might not | ||
| # already unsquashed/clipped) to be sent to the environment and might not | ||
| # be identical to the actions produced by the RLModule/distribution, which | ||
| # are the ones stored permanently in the episode objects. | ||
| actions = to_env.pop(Columns.ACTIONS, [{} for _ in episodes]) | ||
| actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions) | ||
| # Try stepping the environment. | ||
| results = self._try_env_step(actions_for_env) | ||
| if results == ENV_STEP_FAILURE: | ||
| logging.warning( | ||
kamil-kaczmarek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f"RLlib {self.__class__.__name__}: Environment step failed. Will force reset env(s) in this EnvRunner." | ||
kamil-kaczmarek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| return self._sample( | ||
| num_timesteps=num_timesteps, | ||
| num_episodes=num_episodes, | ||
|
|
@@ -372,7 +377,7 @@ def _sample( | |
|
|
||
| call_on_episode_start = set() | ||
| # Store the data from the last environment step into the | ||
| # episodes for all sub-envrironments. | ||
| # episodes for all sub-environments. | ||
| for env_index in range(self.num_envs): | ||
| extra_model_outputs = defaultdict(dict) | ||
| # `to_env` returns a dictionary with column keys and | ||
|
|
@@ -710,7 +715,7 @@ def set_state(self, state: StateDict) -> None: | |
| # update. | ||
| weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0) | ||
|
|
||
| # Only update the weigths, if this is the first synchronization or | ||
| # Only update the weights, if this is the first synchronization or | ||
| # if the weights of this `EnvRunner` lacks behind the actual ones. | ||
| if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no: | ||
| rl_module_state = state[COMPONENT_RL_MODULE] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.