[GRPO] Try returning hidden statex for GRPO#5142
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fallback mechanism for GRPO to retrieve hidden states from models that do not natively support it by wrapping the forward method. Feedback suggests optimizing performance by pre-computing function signatures during installation and using them to robustly handle positional arguments. It was also recommended to log exceptions during argument binding to aid debugging.
|
|
||
| def _get_num_logits_to_keep(kwargs): | ||
| num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0 | ||
| logits_to_keep = kwargs.get("logits_to_keep", 0) or 0 |
There was a problem hiding this comment.
The _get_num_logits_to_keep function currently only checks kwargs. If logits_to_keep or num_logits_to_keep are passed as positional arguments, they will be missed. It's better to use the model's signature to bind the arguments and extract the values robustly. Per repository rules, ensure exceptions are logged rather than silently ignored.
| def _get_num_logits_to_keep(kwargs): | |
| num_logits_to_keep = kwargs.get("num_logits_to_keep", 0) or 0 | |
| logits_to_keep = kwargs.get("logits_to_keep", 0) or 0 | |
| def _get_num_logits_to_keep(sig, args, kwargs): | |
| try: | |
| bound = sig.bind_partial(*args, **kwargs) | |
| return max(bound.arguments.get("num_logits_to_keep", 0) or 0, | |
| bound.arguments.get("logits_to_keep", 0) or 0) | |
| except Exception as e: | |
| import logging | |
| logging.debug(f"Error binding signature: {e}") | |
| return max(kwargs.get("num_logits_to_keep", 0) or 0, | |
| kwargs.get("logits_to_keep", 0) or 0) |
References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
| ) | ||
| return outputs | ||
|
|
||
| hidden_states = hidden_states[-1] |
| return True | ||
| return False | ||
|
|
||
|
|
There was a problem hiding this comment.
To improve performance, _drop_forward_kwargs_consumed_positionally should accept a pre-computed signature instead of calling inspect.signature() on every forward pass. This avoids redundant introspection overhead during training.
| def _drop_forward_kwargs_consumed_positionally(sig, args, kwargs): |
References
- To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.
| if len(args) == 0 or len(kwargs) == 0: | ||
| return kwargs | ||
|
|
||
| consumed_names = [] |
| return False | ||
|
|
There was a problem hiding this comment.
Pre-compute the signature once during installation to avoid the overhead of inspect.signature() on every forward pass.
| return False | |
| original_forward = target_model.forward | |
| sig = inspect.signature(original_forward) | |
| model_name = type(target_model).__name__ |
References
- To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.
| return original_forward(*args, **kwargs) | ||
|
|
||
| forward_kwargs = _drop_forward_kwargs_consumed_positionally( |
520de3b to
6be5ab1
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cf6d545e1d
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| get_base_model = getattr(model, "get_base_model", None) | ||
| if callable(get_base_model): | ||
| base_model = get_base_model() | ||
| if base_model is not None and base_model is not model: | ||
| return base_model |
There was a problem hiding this comment.
Wrap the LM forward instead of the base decoder
Choose the top-level causal LM module here instead of get_base_model/base_model: for standard *ForCausalLM models this helper resolves to the decoder block, so the wrapper runs before the LM head and cannot prevent full-vocab logits materialization. In GRPO (UNSLOTH_RETURN_HIDDEN_STATES=1), this means we still pay for full logits in the parent forward while also forcing hidden-state collection, which can significantly increase memory and still miss the intended optimization.
Useful? React with 👍 / 👎.
| if getattr(candidate, _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False): | ||
| return True | ||
| if getattr( | ||
| type(candidate), _UNSLOTH_RETURN_HIDDEN_STATES_SUPPORT_MARKER, False | ||
| ): |
There was a problem hiding this comment.
Detect native hidden-state support before installing wrapper
This support check relies only on a custom marker attribute, but no model in this repo sets __UNSLOTH_SUPPORTS_RETURN_HIDDEN_STATES__, so the fallback wrapper is effectively installed for all GRPO models. For models that already implement UNSLOTH_RETURN_HIDDEN_STATES, the new wrapper still forces output_hidden_states=True, which unnecessarily materializes all layer hidden states and can cause avoidable VRAM spikes/OOM during training.
Useful? React with 👍 / 👎.
Conflict resolution for .github/workflows/release-desktop.yml. main moved forward with PR #5394 (Chore(deps): bump the actions group across 1 directory with 4 updates) which bumped action SHAs on the build job's `actions/checkout` line, colliding with the harden-runner audit step that this PR inserts above the checkout. Resolution: - Keep the `step-security/harden-runner@<sha> # v2.19.1` audit step at the head of the build job (this PR's contribution). - Accept main's newer `actions/checkout@de0fac2e4500...` SHA (was `34e114876b0b...`). No functional change beyond the action SHA bump: harden-runner still runs in audit mode (logs egress, never blocks), and actions/checkout v6.0.2 is the dependabot-shipped upgrade from v6.0.x. Auto-merged cleanly: - .github/workflows/security-audit.yml - .github/workflows/studio-tauri-smoke.yml plus eight non-workflow files from main (studio backend / tests / unsloth GRPO changes from #5142, #5197, #5346, etc.). None touch this PR's surface area. Verified: pytest tests/security -> 34 passed in 2.71s; every .github/workflows/*.yml parses cleanly under PyYAML (24 files).
unslothai/unsloth-zoo#602 fixed an important issue where when some models return logits we were failing with shape mismatch. This is because for GRPO we generally expect hidden states to be returned with our wrappers (for most models ofc) adn lm_head is applied chunk wise to avoid materialising full large logits which are much larger than hidden states (4K vs 256K ish for eg)
This is an effort to make more models return hidden states for efficiency reasons. Orthogonal to the above mentioned PR :)
Ref: unslothai/unsloth-zoo#609