diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d73410026..cc5e7b6e6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1136,6 +1136,14 @@ def create_standalone_class( # Strip decorators from class source if present # This fixes issues with classes like Qwen3NextExperts which have decorators that cause compilation failures + STRIP_DECORATORS = { + "use_experts_implementation", + "use_kernel_forward_from_hub", + "use_kernelized_func", + "auto_docstring", + # add more here if needed + } + if full_class.lstrip().startswith("@"): start = re.search(r"^class ", full_class, flags=re.MULTILINE) if start: @@ -1148,30 +1156,64 @@ def create_standalone_class( lines = preamble.split('\n') new_lines = [] + # Capture decorator head, including dotted paths: @pkg.decorator(...) + decorator_head_re = re.compile( + r"^\s*@\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)\b" + ) + + skipping = False + paren_depth = 0 + skip_base_name = None + for line in lines: + if skipping: + # Continue skipping decorator args until balanced + paren_depth += line.count("(") - line.count(")") + if paren_depth <= 0: + skipping = False + paren_depth = 0 + skip_base_name = None + continue + stripped = line.strip() if stripped.startswith("@"): - if ( - "use_experts_implementation" in stripped - or "use_kernel_forward_from_hub" in stripped - or "use_kernelized_func" in stripped - or stripped.startswith("@auto_docstring") - ): - decorator_name = stripped.split("(")[0].lstrip("@") - logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}") - continue # Strip it - else: - logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.") - new_lines.append(line) # Keep it + m = decorator_head_re.match(line) + if not m: + logger.warning( + f"Unsloth: Warning: Unparseable decorator {stripped} found for {module}." + ) + new_lines.append(line) + continue + + decorator_full = m.group(1) # e.g. "foo.auto_docstring" + decorator_base = decorator_full.split(".")[-1] # e.g. "auto_docstring" + + if decorator_base in STRIP_DECORATORS: + logger.info( + f"Unsloth: stripped {decorator_full} decorator from {module}" + ) + + # If decorator has args and spans multiple lines, skip until parens close + paren_depth = line.count("(") - line.count(")") + if paren_depth > 0: + skipping = True + skip_base_name = decorator_base + continue # Strip this decorator line + + # Unknown decorator -> keep it but warn + logger.warning( + f"Unsloth: Warning: Unknown decorator {stripped} found for {module}." + ) + new_lines.append(line) else: new_lines.append(line) - full_class = '\n'.join(new_lines) + class_def + full_class = "\n".join(new_lines) + class_def # Check if forward was replaced by a temporary patch (renamed function) # In this case, keep the patched source as-is and replace the class forward body. patched_forward_info = None - if "@torch.compiler.disable" in forward_source: + if 'gptossexperts' != module.lower(): func_match = re.search(r"def\s+(\w+)\s*\(", forward_source) if func_match and func_match.group(1) != "forward": # Find original forward in class to replace it