Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,193 @@ def wrapped(self, *args, **kwargs):
trainer_cls._generate_and_score_completions = wrapped


_UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER = (
"__UNSLOTH_SUPPORTS_RETURN_HIDDEN_STATES__"
)
_UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR = "_unsloth_grpo_hidden_states_forward_wrapped"
_UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR = "_unsloth_grpo_hidden_states_warning_issued"


def _grpo_hidden_states_wrap_target(model):
if model is None:
return None
get_base_model = getattr(model, "get_base_model", None)
if callable(get_base_model):
base_model = get_base_model()
if base_model is not None and base_model is not model:
return base_model
Comment on lines +552 to +556
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Wrap the LM forward instead of the base decoder

Choose the top-level causal LM module here instead of get_base_model/base_model: for standard *ForCausalLM models this helper resolves to the decoder block, so the wrapper runs before the LM head and cannot prevent full-vocab logits materialization. In GRPO (UNSLOTH_RETURN_HIDDEN_STATES=1), this means we still pay for full logits in the parent forward while also forcing hidden-state collection, which can significantly increase memory and still miss the intended optimization.

Useful? React with 👍 / 👎.

for attr in ("base_model", "model"):
child = getattr(model, attr, None)
if child is not None and child is not model and hasattr(child, "forward"):
return child
return model


def _model_supports_unsloth_return_hidden_states(model):
target_model = _grpo_hidden_states_wrap_target(model)
for candidate in (model, target_model):
if candidate is None:
continue
if getattr(candidate, _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False):
return True
if getattr(
type(candidate), _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False
):
Comment on lines +569 to +573
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Detect native hidden-state support before installing wrapper

This support check relies only on a custom marker attribute, but no model in this repo sets __UNSLOTH_SUPPORTS_RETURN_HIDDEN_STATES__, so the fallback wrapper is effectively installed for all GRPO models. For models that already implement UNSLOTH_RETURN_HIDDEN_STATES, the new wrapper still forces output_hidden_states=True, which unnecessarily materializes all layer hidden states and can cause avoidable VRAM spikes/OOM during training.

Useful? React with 👍 / 👎.

return True
return False


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve performance, _drop_forward_kwargs_consumed_positionally should accept a pre-computed signature instead of calling inspect.signature() on every forward pass. This avoids redundant introspection overhead during training.

Suggested change
def _drop_forward_kwargs_consumed_positionally(sig, args, kwargs):
References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

def _drop_forward_kwargs_consumed_positionally(forward_signature, args, kwargs):
if len(args) == 0 or len(kwargs) == 0:
return kwargs

consumed_names = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the passed signature instead of re-computing it.

Suggested change
consumed_names = []
for parameter in sig.parameters.values():

for parameter in forward_signature.parameters.values():
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
break
if parameter.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
consumed_names.append(parameter.name)
if len(consumed_names) >= len(args):
break

if len(consumed_names) == 0:
return kwargs

kwargs = dict(kwargs)
for name in consumed_names:
kwargs.pop(name, None)
return kwargs


def _get_num_logits_to_keep(forward_signature, args, kwargs):
try:
bound = forward_signature.bind_partial(*args, **kwargs)
arguments = bound.arguments
num_logits_to_keep = arguments.get("num_logits_to_keep", 0) or 0
logits_to_keep = arguments.get("logits_to_keep", 0) or 0
for parameter in forward_signature.parameters.values():
if parameter.kind != inspect.Parameter.VAR_KEYWORD:
continue
extra_kwargs = arguments.get(parameter.name, {})
num_logits_to_keep = max(
num_logits_to_keep,
extra_kwargs.get("num_logits_to_keep", 0) or 0,
)
logits_to_keep = max(
logits_to_keep,
extra_kwargs.get("logits_to_keep", 0) or 0,
)
break
return max(num_logits_to_keep, logits_to_keep)
except TypeError:
logger.debug(
"Unsloth: Could not bind forward arguments for GRPO hidden-state fallback.",
exc_info = True,
)

num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0
logits_to_keep = kwargs.get("logits_to_keep", 0) or 0
Comment on lines +602 to +630
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _get_num_logits_to_keep function currently only checks kwargs. If logits_to_keep or num_logits_to_keep are passed as positional arguments, they will be missed. It's better to use the model's signature to bind the arguments and extract the values robustly. Per repository rules, ensure exceptions are logged rather than silently ignored.

Suggested change
def _get_num_logits_to_keep(kwargs):
num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0
logits_to_keep = kwargs.get("logits_to_keep", 0) or 0
def _get_num_logits_to_keep(sig, args, kwargs):
try:
bound = sig.bind_partial(*args, **kwargs)
return max(bound.arguments.get("num_logits_to_keep", 0) or 0,
bound.arguments.get("logits_to_keep", 0) or 0)
except Exception as e:
import logging
logging.debug(f"Error binding signature: {e}")
return max(kwargs.get("num_logits_to_keep", 0) or 0,
kwargs.get("logits_to_keep", 0) or 0)
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

return max(num_logits_to_keep, logits_to_keep)


def _warn_grpo_hidden_states_fallback_once(model, message):
if getattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR, False):
return
setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR, True)
logger.warning(message)


def _replace_outputs_logits(outputs, hidden_states):
if hasattr(outputs, "logits"):
outputs.logits = hidden_states
return outputs
if isinstance(outputs, dict):
outputs["logits"] = hidden_states
return outputs
if isinstance(outputs, tuple) and len(outputs) != 0:
return (hidden_states,) + tuple(outputs[1:])
raise TypeError(
f"Unsupported output type for GRPO hidden-state fallback: {type(outputs)}"
)


def _install_grpo_hidden_states_forward_wrapper(model):
if model is None or getattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, False):
return False
if _model_supports_unsloth_return_hidden_states(model):
return False

target_model = _grpo_hidden_states_wrap_target(model)
if getattr(target_model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, False):
setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True)
return False

Comment on lines +664 to +665
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-compute the signature once during installation to avoid the overhead of inspect.signature() on every forward pass.

Suggested change
return False
original_forward = target_model.forward
sig = inspect.signature(original_forward)
model_name = type(target_model).__name__
References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

original_forward = target_model.forward
forward_signature = inspect.signature(original_forward)
model_name = type(target_model).__name__

def wrapped_forward(*args, **kwargs):
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") != "1":
return original_forward(*args, **kwargs)

forward_kwargs = _drop_forward_kwargs_consumed_positionally(
Comment on lines +672 to +674
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass the pre-computed signature to the helper function.

Suggested change
return original_forward(*args, **kwargs)
forward_kwargs = _drop_forward_kwargs_consumed_positionally(
forward_kwargs = _drop_forward_kwargs_consumed_positionally(
sig, args, kwargs
)

forward_signature, args, kwargs
)
num_logits_to_keep = _get_num_logits_to_keep(
forward_signature, args, forward_kwargs
)
forward_kwargs["output_hidden_states"] = True
forward_kwargs["return_dict"] = True
try:
outputs = original_forward(*args, **forward_kwargs)
except TypeError as error:
if "output_hidden_states" not in str(error) and "return_dict" not in str(
error
):
raise
_warn_grpo_hidden_states_fallback_once(
target_model,
f"Unsloth: GRPO fallback could not request hidden states for unsupported model {model_name}; using logits directly.",
)
return original_forward(*args, **kwargs)

hidden_states = getattr(outputs, "hidden_states", None)
if hidden_states is None or len(hidden_states) == 0:
_warn_grpo_hidden_states_fallback_once(
target_model,
f"Unsloth: GRPO fallback did not receive hidden states for unsupported model {model_name}; using logits directly.",
)
return outputs

hidden_states = hidden_states[-1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use the pre-computed signature and original arguments to correctly extract num_logits_to_keep, accounting for positional arguments.

Suggested change
hidden_states = hidden_states[-1]
num_logits_to_keep = _get_num_logits_to_keep(sig, args, kwargs)

if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return _replace_outputs_logits(outputs, hidden_states)

wrapped_forward._unsloth_grpo_hidden_states_forward_wrapped = True
target_model.forward = wrapped_forward
setattr(target_model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True)
setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True)
return True


def _wrap_grpo_hidden_states_fallback(trainer_cls):
original_init = trainer_cls.__init__
if getattr(original_init, "_unsloth_grpo_hidden_states_init_wrapped", False):
return

def wrapped_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
_install_grpo_hidden_states_forward_wrapper(getattr(self, "model", None))
_install_grpo_hidden_states_forward_wrapper(getattr(self, "ref_model", None))

wrapped_init._unsloth_grpo_hidden_states_init_wrapped = True
trainer_cls.__init__ = wrapped_init


def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
# Defensive wrapper: matches patch_trl_rl_trainers()'s try/except so
# direct callers don't see exceptions from the impl on TRL versions
Expand Down Expand Up @@ -1627,6 +1814,14 @@ def _patch_trl_rl_trainers_impl(trainer_file = "grpo_trainer"):
logger.info(
f"Unsloth: Could not wrap _generate_and_score_completions for {RLTrainer_name}: {e}"
)
try:
_wrap_grpo_hidden_states_fallback(
getattr(created_module, f"Unsloth{RLTrainer_name}")
)
except Exception as e:
logger.info(
f"Unsloth: Could not wrap GRPO hidden-state fallback for {RLTrainer_name}: {e}"
)


def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports):
Expand Down
Loading