Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,14 @@ def create_standalone_class(

# Strip decorators from class source if present
# This fixes issues with classes like Qwen3NextExperts which have decorators that cause compilation failures
STRIP_DECORATORS = {
"use_experts_implementation",
"use_kernel_forward_from_hub",
"use_kernelized_func",
"auto_docstring",
# add more here if needed
}

if full_class.lstrip().startswith("@"):
start = re.search(r"^class ", full_class, flags=re.MULTILINE)
if start:
Expand All @@ -1148,30 +1156,64 @@ def create_standalone_class(
lines = preamble.split('\n')
new_lines = []

# Capture decorator head, including dotted paths: @pkg.decorator(...)
decorator_head_re = re.compile(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this works for now, but if a decorator has args, do we capture that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this regex is only there to capture the decorator name. The overall loop logic handles multiline multiarg decorators.

r"^\s*@\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)\b"
)

skipping = False
paren_depth = 0
skip_base_name = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable skip_base_name is initialized here, and reassigned on lines 1175 and 1200, but its value is never read. This appears to be dead code and can be removed to improve clarity.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is true, can be removed


for line in lines:
if skipping:
# Continue skipping decorator args until balanced
paren_depth += line.count("(") - line.count(")")
if paren_depth <= 0:
skipping = False
paren_depth = 0
skip_base_name = None
continue

stripped = line.strip()
if stripped.startswith("@"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this when we check regex?
Our regex seems to capture @ already

Copy link
Copy Markdown
Collaborator Author

@mmathew23 mmathew23 Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically cheaper to do this check vs

m = decorator_head_re.match(line)
if m:

for all

if (
"use_experts_implementation" in stripped
or "use_kernel_forward_from_hub" in stripped
or "use_kernelized_func" in stripped
or stripped.startswith("@auto_docstring")
):
decorator_name = stripped.split("(")[0].lstrip("@")
logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}")
continue # Strip it
else:
logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.")
new_lines.append(line) # Keep it
m = decorator_head_re.match(line)
if not m:
logger.warning(
f"Unsloth: Warning: Unparseable decorator {stripped} found for {module}."
)
new_lines.append(line)
continue

decorator_full = m.group(1) # e.g. "foo.auto_docstring"
decorator_base = decorator_full.split(".")[-1] # e.g. "auto_docstring"

if decorator_base in STRIP_DECORATORS:
logger.info(
f"Unsloth: stripped {decorator_full} decorator from {module}"
)

# If decorator has args and spans multiple lines, skip until parens close
paren_depth = line.count("(") - line.count(")")
if paren_depth > 0:
skipping = True
Comment on lines +1197 to +1199
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid counting literal parentheses as decorator continuation

The new multiline-strip logic uses raw line.count("(") - line.count(")") to decide whether to keep skipping lines, which also counts parentheses inside string literals and comments. For a valid one-line decorator like @auto_docstring("("), paren_depth becomes positive and the loop starts skipping subsequent lines (including class ...), corrupting full_class and causing generated source to fail at import/runtime. This is a regression from the previous behavior for single-line decorators and will surface whenever stripped decorators contain unmatched parentheses characters in their arguments.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a legitimate edge case, but also one I don't think likely to occur in practice.

skip_base_name = decorator_base
continue # Strip this decorator line

# Unknown decorator -> keep it but warn
logger.warning(
f"Unsloth: Warning: Unknown decorator {stripped} found for {module}."
)
new_lines.append(line)
else:
new_lines.append(line)

full_class = '\n'.join(new_lines) + class_def
full_class = "\n".join(new_lines) + class_def

# Check if forward was replaced by a temporary patch (renamed function)
# In this case, keep the patched source as-is and replace the class forward body.
patched_forward_info = None
if "@torch.compiler.disable" in forward_source:
if 'gptossexperts' != module.lower():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a contradiction between the PR description and this line of code.

The PR description states: "We update the patch to specifically bypass gptoss as a temporary workaround." This implies that the special bypass logic (which prevents recompilation of a patched function) should apply only to gptossexperts.

However, the condition if 'gptossexperts' != module.lower(): applies this bypass logic to all modules except gptossexperts. This seems to be the opposite of the intended behavior.

Should this condition be if 'gptossexperts' == module.lower(): to match the description and correctly apply the bypass only to gptossexperts?

Suggested change
if 'gptossexperts' != module.lower():
if 'gptossexperts' == module.lower():

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no the original code is correct. we want to skip this block for gptoss to work.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we explore doing it in the previous function itself? To avoid any further confusion later

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We certainly can but this PR is intended to be a quicker fix for qwen3.5 and gemma3n. I think this snippet originated during the moe release so I'm missing some context about what is safe and not safe to do.

func_match = re.search(r"def\s+(\w+)\s*\(", forward_source)
if func_match and func_match.group(1) != "forward":
# Find original forward in class to replace it
Expand Down