diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5200bfefd2..ee9bdda26a 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -539,6 +539,193 @@ def wrapped(self, *args, **kwargs): trainer_cls._generate_and_score_completions = wrapped +_UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER = ( + "__UNSLOTH_SUPPORTS_RETURN_HIDDEN_STATES__" +) +_UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR = "_unsloth_grpo_hidden_states_forward_wrapped" +_UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR = "_unsloth_grpo_hidden_states_warning_issued" + + +def _grpo_hidden_states_wrap_target(model): + if model is None: + return None + get_base_model = getattr(model, "get_base_model", None) + if callable(get_base_model): + base_model = get_base_model() + if base_model is not None and base_model is not model: + return base_model + for attr in ("base_model", "model"): + child = getattr(model, attr, None) + if child is not None and child is not model and hasattr(child, "forward"): + return child + return model + + +def _model_supports_unsloth_return_hidden_states(model): + target_model = _grpo_hidden_states_wrap_target(model) + for candidate in (model, target_model): + if candidate is None: + continue + if getattr(candidate, _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False): + return True + if getattr( + type(candidate), _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False + ): + return True + return False + + +def _drop_forward_kwargs_consumed_positionally(forward_signature, args, kwargs): + if len(args) == 0 or len(kwargs) == 0: + return kwargs + + consumed_names = [] + for parameter in forward_signature.parameters.values(): + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + break + if parameter.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + consumed_names.append(parameter.name) + if len(consumed_names) >= len(args): + break + + if len(consumed_names) == 0: + return kwargs + + kwargs = dict(kwargs) + for name in consumed_names: + kwargs.pop(name, None) + return kwargs + + +def _get_num_logits_to_keep(forward_signature, args, kwargs): + try: + bound = forward_signature.bind_partial(*args, **kwargs) + arguments = bound.arguments + num_logits_to_keep = arguments.get("num_logits_to_keep", 0) or 0 + logits_to_keep = arguments.get("logits_to_keep", 0) or 0 + for parameter in forward_signature.parameters.values(): + if parameter.kind != inspect.Parameter.VAR_KEYWORD: + continue + extra_kwargs = arguments.get(parameter.name, {}) + num_logits_to_keep = max( + num_logits_to_keep, + extra_kwargs.get("num_logits_to_keep", 0) or 0, + ) + logits_to_keep = max( + logits_to_keep, + extra_kwargs.get("logits_to_keep", 0) or 0, + ) + break + return max(num_logits_to_keep, logits_to_keep) + except TypeError: + logger.debug( + "Unsloth: Could not bind forward arguments for GRPO hidden-state fallback.", + exc_info = True, + ) + + num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0 + logits_to_keep = kwargs.get("logits_to_keep", 0) or 0 + return max(num_logits_to_keep, logits_to_keep) + + +def _warn_grpo_hidden_states_fallback_once(model, message): + if getattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR, False): + return + setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WARNING_ATTR, True) + logger.warning(message) + + +def _replace_outputs_logits(outputs, hidden_states): + if hasattr(outputs, "logits"): + outputs.logits = hidden_states + return outputs + if isinstance(outputs, dict): + outputs["logits"] = hidden_states + return outputs + if isinstance(outputs, tuple) and len(outputs) != 0: + return (hidden_states,) + tuple(outputs[1:]) + raise TypeError( + f"Unsupported output type for GRPO hidden-state fallback: {type(outputs)}" + ) + + +def _install_grpo_hidden_states_forward_wrapper(model): + if model is None or getattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, False): + return False + if _model_supports_unsloth_return_hidden_states(model): + return False + + target_model = _grpo_hidden_states_wrap_target(model) + if getattr(target_model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, False): + setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True) + return False + + original_forward = target_model.forward + forward_signature = inspect.signature(original_forward) + model_name = type(target_model).__name__ + + def wrapped_forward(*args, **kwargs): + if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") != "1": + return original_forward(*args, **kwargs) + + forward_kwargs = _drop_forward_kwargs_consumed_positionally( + forward_signature, args, kwargs + ) + num_logits_to_keep = _get_num_logits_to_keep( + forward_signature, args, forward_kwargs + ) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + try: + outputs = original_forward(*args, **forward_kwargs) + except TypeError as error: + if "output_hidden_states" not in str(error) and "return_dict" not in str( + error + ): + raise + _warn_grpo_hidden_states_fallback_once( + target_model, + f"Unsloth: GRPO fallback could not request hidden states for unsupported model {model_name}; using logits directly.", + ) + return original_forward(*args, **kwargs) + + hidden_states = getattr(outputs, "hidden_states", None) + if hidden_states is None or len(hidden_states) == 0: + _warn_grpo_hidden_states_fallback_once( + target_model, + f"Unsloth: GRPO fallback did not receive hidden states for unsupported model {model_name}; using logits directly.", + ) + return outputs + + hidden_states = hidden_states[-1] + if num_logits_to_keep != 0: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + return _replace_outputs_logits(outputs, hidden_states) + + wrapped_forward._unsloth_grpo_hidden_states_forward_wrapped = True + target_model.forward = wrapped_forward + setattr(target_model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True) + setattr(model, _UNSLOTH_GRPO_HIDDEN_STATES_WRAPPED_ATTR, True) + return True + + +def _wrap_grpo_hidden_states_fallback(trainer_cls): + original_init = trainer_cls.__init__ + if getattr(original_init, "_unsloth_grpo_hidden_states_init_wrapped", False): + return + + def wrapped_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + _install_grpo_hidden_states_forward_wrapper(getattr(self, "model", None)) + _install_grpo_hidden_states_forward_wrapper(getattr(self, "ref_model", None)) + + wrapped_init._unsloth_grpo_hidden_states_init_wrapped = True + trainer_cls.__init__ = wrapped_init + + def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Defensive wrapper: matches patch_trl_rl_trainers()'s try/except so # direct callers don't see exceptions from the impl on TRL versions @@ -1627,6 +1814,14 @@ def _patch_trl_rl_trainers_impl(trainer_file = "grpo_trainer"): logger.info( f"Unsloth: Could not wrap _generate_and_score_completions for {RLTrainer_name}: {e}" ) + try: + _wrap_grpo_hidden_states_fallback( + getattr(created_module, f"Unsloth{RLTrainer_name}") + ) + except Exception as e: + logger.info( + f"Unsloth: Could not wrap GRPO hidden-state fallback for {RLTrainer_name}: {e}" + ) def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports):