-
Notifications
You must be signed in to change notification settings - Fork 265
Keep GptOss working, Fix Qwen35, Fix Gemma3n decorator stripping #525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1136,6 +1136,14 @@ def create_standalone_class( | |||||
|
|
||||||
| # Strip decorators from class source if present | ||||||
| # This fixes issues with classes like Qwen3NextExperts which have decorators that cause compilation failures | ||||||
| STRIP_DECORATORS = { | ||||||
| "use_experts_implementation", | ||||||
| "use_kernel_forward_from_hub", | ||||||
| "use_kernelized_func", | ||||||
| "auto_docstring", | ||||||
| # add more here if needed | ||||||
| } | ||||||
|
|
||||||
| if full_class.lstrip().startswith("@"): | ||||||
| start = re.search(r"^class ", full_class, flags=re.MULTILINE) | ||||||
| if start: | ||||||
|
|
@@ -1148,30 +1156,64 @@ def create_standalone_class( | |||||
| lines = preamble.split('\n') | ||||||
| new_lines = [] | ||||||
|
|
||||||
| # Capture decorator head, including dotted paths: @pkg.decorator(...) | ||||||
| decorator_head_re = re.compile( | ||||||
| r"^\s*@\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)\b" | ||||||
| ) | ||||||
|
|
||||||
| skipping = False | ||||||
| paren_depth = 0 | ||||||
| skip_base_name = None | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is true, can be removed |
||||||
|
|
||||||
| for line in lines: | ||||||
| if skipping: | ||||||
| # Continue skipping decorator args until balanced | ||||||
| paren_depth += line.count("(") - line.count(")") | ||||||
| if paren_depth <= 0: | ||||||
| skipping = False | ||||||
| paren_depth = 0 | ||||||
| skip_base_name = None | ||||||
| continue | ||||||
|
|
||||||
| stripped = line.strip() | ||||||
| if stripped.startswith("@"): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this when we check regex?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's technically cheaper to do this check vs for all |
||||||
| if ( | ||||||
| "use_experts_implementation" in stripped | ||||||
| or "use_kernel_forward_from_hub" in stripped | ||||||
| or "use_kernelized_func" in stripped | ||||||
| or stripped.startswith("@auto_docstring") | ||||||
| ): | ||||||
| decorator_name = stripped.split("(")[0].lstrip("@") | ||||||
| logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}") | ||||||
| continue # Strip it | ||||||
| else: | ||||||
| logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.") | ||||||
| new_lines.append(line) # Keep it | ||||||
| m = decorator_head_re.match(line) | ||||||
| if not m: | ||||||
| logger.warning( | ||||||
| f"Unsloth: Warning: Unparseable decorator {stripped} found for {module}." | ||||||
| ) | ||||||
| new_lines.append(line) | ||||||
| continue | ||||||
|
|
||||||
| decorator_full = m.group(1) # e.g. "foo.auto_docstring" | ||||||
| decorator_base = decorator_full.split(".")[-1] # e.g. "auto_docstring" | ||||||
|
|
||||||
| if decorator_base in STRIP_DECORATORS: | ||||||
| logger.info( | ||||||
| f"Unsloth: stripped {decorator_full} decorator from {module}" | ||||||
| ) | ||||||
|
|
||||||
| # If decorator has args and spans multiple lines, skip until parens close | ||||||
| paren_depth = line.count("(") - line.count(")") | ||||||
| if paren_depth > 0: | ||||||
| skipping = True | ||||||
|
Comment on lines
+1197
to
+1199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new multiline-strip logic uses raw Useful? React with 👍 / 👎.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a legitimate edge case, but also one I don't think likely to occur in practice. |
||||||
| skip_base_name = decorator_base | ||||||
| continue # Strip this decorator line | ||||||
|
|
||||||
| # Unknown decorator -> keep it but warn | ||||||
| logger.warning( | ||||||
| f"Unsloth: Warning: Unknown decorator {stripped} found for {module}." | ||||||
| ) | ||||||
| new_lines.append(line) | ||||||
| else: | ||||||
| new_lines.append(line) | ||||||
|
|
||||||
| full_class = '\n'.join(new_lines) + class_def | ||||||
| full_class = "\n".join(new_lines) + class_def | ||||||
|
|
||||||
| # Check if forward was replaced by a temporary patch (renamed function) | ||||||
| # In this case, keep the patched source as-is and replace the class forward body. | ||||||
| patched_forward_info = None | ||||||
| if "@torch.compiler.disable" in forward_source: | ||||||
| if 'gptossexperts' != module.lower(): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be a contradiction between the PR description and this line of code. The PR description states: "We update the patch to specifically bypass gptoss as a temporary workaround." This implies that the special bypass logic (which prevents recompilation of a patched function) should apply only to However, the condition Should this condition be
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no the original code is correct. we want to skip this block for gptoss to work.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we explore doing it in the previous function itself? To avoid any further confusion later
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We certainly can but this PR is intended to be a quicker fix for qwen3.5 and gemma3n. I think this snippet originated during the moe release so I'm missing some context about what is safe and not safe to do. |
||||||
| func_match = re.search(r"def\s+(\w+)\s*\(", forward_source) | ||||||
| if func_match and func_match.group(1) != "forward": | ||||||
| # Find original forward in class to replace it | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this works for now, but if a decorator has args, do we capture that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this regex is only there to capture the decorator name. The overall loop logic handles multiline multiarg decorators.