Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions tests/test_upstream_pinned_symbols_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,15 @@ def test_moe_expert_merges_call_active_merge_device():
through _active_merge_device(). A regression to a hardcoded "cuda" or
DEVICE_TYPE_TORCH inside any one of them silently drops MPS/XPU
placement and was the exact 2564f39 bug class.

After unsloth-zoo#647 the gate / up wrappers delegate to a unified
helper ``_merge_moe_gate_or_up_expert``; the check follows that
delegation by inspecting the union of each entry-point's source and
the source of any sibling ``_merge_moe_*`` helper it explicitly
forwards to.
"""
import inspect
import re
import unsloth_zoo.saving_utils as su

targets = [
Expand All @@ -183,21 +190,42 @@ def test_moe_expert_merges_call_active_merge_device():
"_merge_moe_fused_gate_up_expert",
"_merge_moe_fused_down_proj_expert",
]
_helper_call_re = re.compile(r"\b(_merge_moe_[A-Za-z0-9_]+)\(")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The regex for detecting helper calls is slightly fragile as it doesn't account for optional whitespace between the function name and the opening parenthesis. While the current codebase follows a strict style, making the regex more robust ensures the test doesn't fail on valid Python syntax variations (e.g., _merge_moe_helper (...)). This minor improvement is preferred as it increases robustness without requiring a large-scale refactor.

Suggested change
_helper_call_re = re.compile(r"\b(_merge_moe_[A-Za-z0-9_]+)\(")
_helper_call_re = re.compile(r"\b(_merge_moe_[A-Za-z0-9_]+)\s*\(")
References
  1. The codebase follows a strict style regarding whitespace in function calls. (link)
  2. Fragile string-matching is acceptable for code patching only if it maintains consistency and avoids large-scale refactors.


def _effective_source(name: str, seen: set) -> str:
"""Return the entry-point's source plus the source of any
sibling ``_merge_moe_*`` helper it explicitly forwards to.
One-level follow is enough: zoo never chains wrapper -> wrapper
-> helper, and the implementations all live in saving_utils."""
if name in seen:
return ""
seen.add(name)
fn = getattr(su, name, None)
if fn is None:
return ""
src = inspect.getsource(fn)
callees = set(_helper_call_re.findall(src)) - {name}
for callee in callees:
src += "\n" + _effective_source(callee, seen)
return src

for name in targets:
fn = getattr(su, name, None)
assert fn is not None, (
f"{name} missing; the MoE-expert merge dispatch surface "
"shrank without notice — see commit 2564f39."
)
src = inspect.getsource(fn)
src = _effective_source(name, set())
assert "_active_merge_device(" in src, (

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep the active-device check on the entry point

Because _effective_source() concatenates every sibling helper before this assertion, any target can now satisfy the check via a callee even if the entry point itself starts doing its own device move with DEVICE_TYPE_TORCH or another non-helper device before delegating. That reopens the exact drift the test documents for non-thin helpers or wrappers that grow logic: the union still contains _active_merge_device( from _merge_moe_gate_or_up_expert, so the regression is missed unless it is a literal .to('cuda').

Useful? React with 👍 / 👎.

f"{name} no longer routes through _active_merge_device(). "
"That regresses 2564f39 + fd58aa1: hardcoded 'cuda' breaks "
"Intel XPU and Apple MPS LoRA merge."
f"{name} (and any sibling _merge_moe_* it delegates to) no "
"longer routes through _active_merge_device(). That regresses "
"2564f39 + fd58aa1: hardcoded 'cuda' breaks Intel XPU and "
"Apple MPS LoRA merge."
)
assert '.to("cuda"' not in src and ".to('cuda'" not in src, (
f"{name} hardcodes .to('cuda', ...) again — same regression "
"class as commit 2564f39."
f"{name} (or the helper it delegates to) hardcodes "
".to('cuda', ...) again — same regression class as commit "
"2564f39."
)


Expand Down
Loading