Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4452,6 +4452,12 @@ def get_rl_module_spec(
# If module_config_dict is not defined, set to our generic one.
if rl_module_spec.model_config is None:
rl_module_spec.model_config = self.model_config
# Otherwise we combine the two dictionaries where settings from the
# `RLModuleSpec` have higher priority.
else:
rl_module_spec.model_config = (
self.model_config | rl_module_spec._get_model_config()
)

if inference_only is not None:
rl_module_spec.inference_only = inference_only
Expand Down
14 changes: 13 additions & 1 deletion rllib/core/rl_module/torch/torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,19 @@ def set_state(self, state: StateDict) -> None:
# these keys (strict=False). This is most likely due to `state` coming from
# an `inference_only=False` RLModule, while `self` is an `inference_only=True`
# RLModule.
self.load_state_dict(convert_to_torch_tensor(state), strict=False)
missing_keys, unexpected_keys = self.load_state_dict(
convert_to_torch_tensor(state), strict=False
)

# For inference_only modules, missing_keys should always be empty.
# If there are missing keys, it means the target module expects parameters
# that don't exist in the source, indicating an architecture mismatch.
if self.inference_only and missing_keys:
raise ValueError(
"Architecture mismatch detected when loading state into inference_only module! "
f"Missing parameters (not found in source state): {list(missing_keys)} "
"This usually indicates the learner and env-runner have different architectures."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please be a little more precise here.
-> What does having a different architecture mean here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically this Error should give a good clue to the user about what they are doing wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, it should probably reference the layer names being difference.

)

@OverrideToImplementCustomLogic
@override(RLModule)
Expand Down