-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix gpt temporary patch for grpo to happen after compile #4180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,8 +254,21 @@ def prefer_flex_attn_if_supported(model_class, config): | |
| return None | ||
|
|
||
|
|
||
| for temporary_patch in TEMPORARY_PATCHES: | ||
| temporary_patch() | ||
| def _run_temporary_patches(phase): | ||
| import inspect | ||
|
|
||
| for temporary_patch in TEMPORARY_PATCHES: | ||
| try: | ||
| sig = inspect.signature(temporary_patch) | ||
| if "phase" in sig.parameters: | ||
| temporary_patch(phase = phase) | ||
| else: | ||
| temporary_patch() | ||
| except (ValueError, TypeError): | ||
| temporary_patch() | ||
|
Comment on lines
+267
to
+268
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Catching Useful? React with 👍 / 👎. |
||
|
|
||
|
|
||
| _run_temporary_patches("init") | ||
|
|
||
| # ============================================= | ||
| # Disable some warnings which can get annoying | ||
|
|
@@ -2095,8 +2108,7 @@ def unsloth_compile_transformers( | |
|
|
||
| # Run patches BEFORE compiler so class replacements (e.g. GptOssTopKRouter, | ||
| # GptOssExperts) are in place before the compiler caches references to them. | ||
| for temporary_patch in TEMPORARY_PATCHES: | ||
| temporary_patch() | ||
| _run_temporary_patches("pre_compile") | ||
|
|
||
| for model_type in model_types: | ||
| _unsloth_compile_transformers( | ||
|
|
@@ -2128,8 +2140,7 @@ def unsloth_compile_transformers( | |
| supports_sdpa = supports_sdpa, | ||
| ) | ||
| # Redo patches which override compiler | ||
| for temporary_patch in TEMPORARY_PATCHES: | ||
| temporary_patch() | ||
| _run_temporary_patches("post_compile") | ||
| return model_types, supports_sdpa[0] | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
except (ValueError, TypeError)block is overly broad and can maskTypeErrorexceptions that signal a genuine issue with a patch function, such as incorrect arguments. By catchingTypeErrorand then unconditionally callingtemporary_patch(), you might suppress the original error and cause a new, potentially more confusing one. It's better to only catchValueError, whichinspect.signatureraises for callables it can't inspect (like some C built-ins). ATypeErrorshould generally be allowed to propagate to indicate a problem with the patch or its invocation, leading to faster bug detection.