-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[GRPO] Try returning hidden statex for GRPO #5142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6be5ab1
f03a341
3d1b4e8
bdcb08e
cf6d545
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -539,6 +539,193 @@ def wrapped(self, *args, **kwargs): | |||||||||||||||||||||||||||||
| trainer_cls._generate_and_score_completions = wrapped | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER = ( | ||||||||||||||||||||||||||||||
| "__UNSLOTH_SUPPORTS_RETURN_HIDDEN_STATES__" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR = "_unsloth_grpo_hidden_states_forward_wrapped" | ||||||||||||||||||||||||||||||
| _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR = "_unsloth_grpo_hidden_states_warning_issued" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _grpo_hidden_states_wrap_target(model): | ||||||||||||||||||||||||||||||
| if model is None: | ||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||
| get_base_model = getattr(model, "get_base_model", None) | ||||||||||||||||||||||||||||||
| if callable(get_base_model): | ||||||||||||||||||||||||||||||
| base_model = get_base_model() | ||||||||||||||||||||||||||||||
| if base_model is not None and base_model is not model: | ||||||||||||||||||||||||||||||
| return base_model | ||||||||||||||||||||||||||||||
| for attr in ("base_model", "model"): | ||||||||||||||||||||||||||||||
| child = getattr(model, attr, None) | ||||||||||||||||||||||||||||||
| if child is not None and child is not model and hasattr(child, "forward"): | ||||||||||||||||||||||||||||||
| return child | ||||||||||||||||||||||||||||||
| return model | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _model_supports_unsloth_return_hidden_states(model): | ||||||||||||||||||||||||||||||
| target_model = _grpo_hidden_states_wrap_target(model) | ||||||||||||||||||||||||||||||
| for candidate in (model, target_model): | ||||||||||||||||||||||||||||||
| if candidate is None: | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| if getattr(candidate, _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False): | ||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||
| if getattr( | ||||||||||||||||||||||||||||||
| type(candidate), _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
|
Comment on lines
+569
to
+573
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This support check relies only on a custom marker attribute, but no model in this repo sets Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve performance,
Suggested change
References
|
||||||||||||||||||||||||||||||
| def _drop_forward_kwargs_consumed_positionally(forward_signature, args, kwargs): | ||||||||||||||||||||||||||||||
| if len(args) == 0 or len(kwargs) == 0: | ||||||||||||||||||||||||||||||
| return kwargs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| consumed_names = [] | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||
| for parameter in forward_signature.parameters.values(): | ||||||||||||||||||||||||||||||
| if parameter.kind == inspect.Parameter.VAR_POSITIONAL: | ||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||
| if parameter.kind in ( | ||||||||||||||||||||||||||||||
| inspect.Parameter.POSITIONAL_ONLY, | ||||||||||||||||||||||||||||||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| consumed_names.append(parameter.name) | ||||||||||||||||||||||||||||||
| if len(consumed_names) >= len(args): | ||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if len(consumed_names) == 0: | ||||||||||||||||||||||||||||||
| return kwargs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| kwargs = dict(kwargs) | ||||||||||||||||||||||||||||||
| for name in consumed_names: | ||||||||||||||||||||||||||||||
| kwargs.pop(name, None) | ||||||||||||||||||||||||||||||
| return kwargs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _get_num_logits_to_keep(forward_signature, args, kwargs): | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| bound = forward_signature.bind_partial(*args, **kwargs) | ||||||||||||||||||||||||||||||
| arguments = bound.arguments | ||||||||||||||||||||||||||||||
| num_logits_to_keep = arguments.get("num_logits_to_keep", 0) or 0 | ||||||||||||||||||||||||||||||
| logits_to_keep = arguments.get("logits_to_keep", 0) or 0 | ||||||||||||||||||||||||||||||
| for parameter in forward_signature.parameters.values(): | ||||||||||||||||||||||||||||||
| if parameter.kind != inspect.Parameter.VAR_KEYWORD: | ||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||
| extra_kwargs = arguments.get(parameter.name, {}) | ||||||||||||||||||||||||||||||
| num_logits_to_keep = max( | ||||||||||||||||||||||||||||||
| num_logits_to_keep, | ||||||||||||||||||||||||||||||
| extra_kwargs.get("num_logits_to_keep", 0) or 0, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| logits_to_keep = max( | ||||||||||||||||||||||||||||||
| logits_to_keep, | ||||||||||||||||||||||||||||||
| extra_kwargs.get("logits_to_keep", 0) or 0, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||
| return max(num_logits_to_keep, logits_to_keep) | ||||||||||||||||||||||||||||||
| except TypeError: | ||||||||||||||||||||||||||||||
| logger.debug( | ||||||||||||||||||||||||||||||
| "Unsloth: Could not bind forward arguments for GRPO hidden-state fallback.", | ||||||||||||||||||||||||||||||
| exc_info = True, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0 | ||||||||||||||||||||||||||||||
| logits_to_keep = kwargs.get("logits_to_keep", 0) or 0 | ||||||||||||||||||||||||||||||
|
Comment on lines
+602
to
+630
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
References
|
||||||||||||||||||||||||||||||
| return max(num_logits_to_keep, logits_to_keep) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _warn_grpo_hidden_states_fallback_once(model, message): | ||||||||||||||||||||||||||||||
| if getattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR, False): | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR, True) | ||||||||||||||||||||||||||||||
| logger.warning(message) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _replace_outputs_logits(outputs, hidden_states): | ||||||||||||||||||||||||||||||
| if hasattr(outputs, "logits"): | ||||||||||||||||||||||||||||||
| outputs.logits = hidden_states | ||||||||||||||||||||||||||||||
| return outputs | ||||||||||||||||||||||||||||||
| if isinstance(outputs, dict): | ||||||||||||||||||||||||||||||
| outputs["logits"] = hidden_states | ||||||||||||||||||||||||||||||
| return outputs | ||||||||||||||||||||||||||||||
| if isinstance(outputs, tuple) and len(outputs) != 0: | ||||||||||||||||||||||||||||||
| return (hidden_states,) + tuple(outputs[1:]) | ||||||||||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||||||||||
| f"Unsupported output type for GRPO hidden-state fallback: {type(outputs)}" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _install_grpo_hidden_states_forward_wrapper(model): | ||||||||||||||||||||||||||||||
| if model is None or getattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, False): | ||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||
| if _model_supports_unsloth_return_hidden_states(model): | ||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| target_model = _grpo_hidden_states_wrap_target(model) | ||||||||||||||||||||||||||||||
| if getattr(target_model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, False): | ||||||||||||||||||||||||||||||
| setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True) | ||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
Comment on lines
+664
to
+665
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pre-compute the signature once during installation to avoid the overhead of
Suggested change
References
|
||||||||||||||||||||||||||||||
| original_forward = target_model.forward | ||||||||||||||||||||||||||||||
| forward_signature = inspect.signature(original_forward) | ||||||||||||||||||||||||||||||
| model_name = type(target_model).__name__ | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def wrapped_forward(*args, **kwargs): | ||||||||||||||||||||||||||||||
| if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") != "1": | ||||||||||||||||||||||||||||||
| return original_forward(*args, **kwargs) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| forward_kwargs = _drop_forward_kwargs_consumed_positionally( | ||||||||||||||||||||||||||||||
|
Comment on lines
+672
to
+674
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||
| forward_signature, args, kwargs | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| num_logits_to_keep = _get_num_logits_to_keep( | ||||||||||||||||||||||||||||||
| forward_signature, args, forward_kwargs | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| forward_kwargs["output_hidden_states"] = True | ||||||||||||||||||||||||||||||
| forward_kwargs["return_dict"] = True | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| outputs = original_forward(*args, **forward_kwargs) | ||||||||||||||||||||||||||||||
| except TypeError as error: | ||||||||||||||||||||||||||||||
| if "output_hidden_states" not in str(error) and "return_dict" not in str( | ||||||||||||||||||||||||||||||
| error | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||
| _warn_grpo_hidden_states_fallback_once( | ||||||||||||||||||||||||||||||
| target_model, | ||||||||||||||||||||||||||||||
| f"Unsloth: GRPO fallback could not request hidden states for unsupported model {model_name}; using logits directly.", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return original_forward(*args, **kwargs) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| hidden_states = getattr(outputs, "hidden_states", None) | ||||||||||||||||||||||||||||||
| if hidden_states is None or len(hidden_states) == 0: | ||||||||||||||||||||||||||||||
| _warn_grpo_hidden_states_fallback_once( | ||||||||||||||||||||||||||||||
| target_model, | ||||||||||||||||||||||||||||||
| f"Unsloth: GRPO fallback did not receive hidden states for unsupported model {model_name}; using logits directly.", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return outputs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| hidden_states = hidden_states[-1] | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||
| if num_logits_to_keep != 0: | ||||||||||||||||||||||||||||||
| hidden_states = hidden_states[:, -num_logits_to_keep:, :] | ||||||||||||||||||||||||||||||
| return _replace_outputs_logits(outputs, hidden_states) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| wrapped_forward._unsloth_grpo_hidden_states_forward_wrapped = True | ||||||||||||||||||||||||||||||
| target_model.forward = wrapped_forward | ||||||||||||||||||||||||||||||
| setattr(target_model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True) | ||||||||||||||||||||||||||||||
| setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True) | ||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _wrap_grpo_hidden_states_fallback(trainer_cls): | ||||||||||||||||||||||||||||||
| original_init = trainer_cls.__init__ | ||||||||||||||||||||||||||||||
| if getattr(original_init, "_unsloth_grpo_hidden_states_init_wrapped", False): | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def wrapped_init(self, *args, **kwargs): | ||||||||||||||||||||||||||||||
| original_init(self, *args, **kwargs) | ||||||||||||||||||||||||||||||
| _install_grpo_hidden_states_forward_wrapper(getattr(self, "model", None)) | ||||||||||||||||||||||||||||||
| _install_grpo_hidden_states_forward_wrapper(getattr(self, "ref_model", None)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| wrapped_init._unsloth_grpo_hidden_states_init_wrapped = True | ||||||||||||||||||||||||||||||
| trainer_cls.__init__ = wrapped_init | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): | ||||||||||||||||||||||||||||||
| # Defensive wrapper: matches patch_trl_rl_trainers()'s try/except so | ||||||||||||||||||||||||||||||
| # direct callers don't see exceptions from the impl on TRL versions | ||||||||||||||||||||||||||||||
|
|
@@ -1627,6 +1814,14 @@ def _patch_trl_rl_trainers_impl(trainer_file = "grpo_trainer"): | |||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||
| f"Unsloth: Could not wrap _generate_and_score_completions for {RLTrainer_name}: {e}" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| _wrap_grpo_hidden_states_fallback( | ||||||||||||||||||||||||||||||
| getattr(created_module, f"Unsloth{RLTrainer_name}") | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||
| f"Unsloth: Could not wrap GRPO hidden-state fallback for {RLTrainer_name}: {e}" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choose the top-level causal LM module here instead of
get_base_model/base_model: for standard*ForCausalLMmodels this helper resolves to the decoder block, so the wrapper runs before the LM head and cannot prevent full-vocab logits materialization. In GRPO (UNSLOTH_RETURN_HIDDEN_STATES=1), this means we still pay for full logits in the parent forward while also forcing hidden-state collection, which can significantly increase memory and still miss the intended optimization.Useful? React with 👍 / 👎.