-
Notifications
You must be signed in to change notification settings - Fork 266
tests: CPU regression detectors for the MoE merge / save path (#5410) #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,172 @@ | ||||||||
| # Unsloth Zoo - Utilities for Unsloth | ||||||||
| # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. | ||||||||
| # | ||||||||
| # This program is free software: you can redistribute it and/or modify | ||||||||
| # it under the terms of the GNU Affero General Public License as published by | ||||||||
| # the Free Software Foundation, either version 3 of the License, or | ||||||||
| # (at your option) any later version. | ||||||||
| # | ||||||||
| # This program is distributed in the hope that it will be useful, | ||||||||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||||||
| # GNU Affero General Public License for more details. | ||||||||
| # | ||||||||
| # You should have received a copy of the GNU Affero General Public License | ||||||||
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | ||||||||
|
|
||||||||
| """CPU end-to-end regression for per-expert MoE merge (#5410).""" | ||||||||
|
|
||||||||
| from __future__ import annotations | ||||||||
|
|
||||||||
| from dataclasses import dataclass | ||||||||
|
|
||||||||
| import pytest | ||||||||
| import torch | ||||||||
|
|
||||||||
| from unsloth_zoo.saving_utils import ( | ||||||||
| LoraStats, | ||||||||
| _MOE_MERGE_STATE, | ||||||||
| _detect_moe_lora_layout, | ||||||||
| _merge_moe_down_proj_expert, | ||||||||
| _merge_moe_gate_expert, | ||||||||
| _merge_moe_up_expert, | ||||||||
| _reset_moe_merge_state, | ||||||||
| _resolve_num_experts_from_lora_stats, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| SEED = 5410 | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class _InnerMoE: | ||||||||
| num_experts: int | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class _OuterParamWrapper: | ||||||||
| base_layer: object | ||||||||
|
|
||||||||
|
|
||||||||
| def _build_synthetic_layer(num_experts, rank_per, hidden, intermediate, layout, alpha, dtype=torch.float32): | ||||||||
| torch.manual_seed(SEED) | ||||||||
| TR = num_experts * rank_per | ||||||||
| fused_gate_up = torch.randn(num_experts, 2 * intermediate, hidden, dtype=dtype) | ||||||||
| fused_down = torch.randn(num_experts, hidden, intermediate, dtype=dtype) | ||||||||
| if layout == "swapped": | ||||||||
| A_gu = torch.randn(TR, 2 * intermediate, dtype=dtype) * 0.05 | ||||||||
| B_gu = torch.randn(hidden, TR, dtype=dtype) * 0.05 | ||||||||
| A_dn = torch.randn(TR, hidden, dtype=dtype) * 0.05 | ||||||||
| B_dn = torch.randn(intermediate, TR, dtype=dtype) * 0.05 | ||||||||
| elif layout == "standard": | ||||||||
| A_gu = torch.randn(TR, hidden, dtype=dtype) * 0.05 | ||||||||
| B_gu = torch.randn(2 * intermediate, TR, dtype=dtype) * 0.05 | ||||||||
| A_dn = torch.randn(TR, intermediate, dtype=dtype) * 0.05 | ||||||||
| B_dn = torch.randn(hidden, TR, dtype=dtype) * 0.05 | ||||||||
| else: | ||||||||
| raise ValueError(layout) | ||||||||
| return fused_gate_up, fused_down, A_gu, B_gu, A_dn, B_dn | ||||||||
|
|
||||||||
|
|
||||||||
| def _analytic_gate_up_delta(A, B, alpha, expert_idx, num_experts, role, layout, I, H): | ||||||||
| r = A.shape[0] // num_experts | ||||||||
| s, e = expert_idx * r, (expert_idx + 1) * r | ||||||||
| a = A[s:e].to(torch.float64); b = B[:, s:e].to(torch.float64) | ||||||||
| if layout == "swapped": | ||||||||
| half = a[:, :I] if role == "gate" else a[:, I:] | ||||||||
| return alpha * (b @ half).T | ||||||||
| half = b[:I, :] if role == "gate" else b[I:, :] | ||||||||
| return alpha * (half @ a) | ||||||||
|
|
||||||||
|
|
||||||||
| def _analytic_down_delta(A, B, alpha, expert_idx, num_experts, layout): | ||||||||
| r = A.shape[0] // num_experts | ||||||||
| s, e = expert_idx * r, (expert_idx + 1) * r | ||||||||
| a = A[s:e].to(torch.float64); b = B[:, s:e].to(torch.float64) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid using semicolons to place multiple statements on a single line, as per PEP 8 guidelines. a = A[s:e].to(torch.float64)
b = B[:, s:e].to(torch.float64)References
|
||||||||
| if layout == "swapped": | ||||||||
| return alpha * (b @ a).T | ||||||||
| return alpha * (b @ a) | ||||||||
|
|
||||||||
|
|
||||||||
| @pytest.mark.parametrize("layout", ["swapped", "standard"]) | ||||||||
| def test_per_layer_merge_round_trip(layout): | ||||||||
| num_layers, num_experts, rank_per = 2, 4, 4 | ||||||||
| hidden, intermediate = 12, 8 | ||||||||
| alpha = 8.0 | ||||||||
| dtype = torch.float32 | ||||||||
|
|
||||||||
| _reset_moe_merge_state() | ||||||||
| total_expected_apply = 0 | ||||||||
| max_err = 0.0 | ||||||||
| for layer in range(num_layers): | ||||||||
| fused_gu, fused_dn, A_gu, B_gu, A_dn, B_dn = _build_synthetic_layer( | ||||||||
| num_experts, rank_per, hidden, intermediate, layout, alpha, dtype | ||||||||
| ) | ||||||||
| stats_gu = LoraStats(module=_InnerMoE(num_experts), lora_A=A_gu, lora_B=B_gu, alpha=alpha) | ||||||||
| stats_dn = LoraStats(module=_InnerMoE(num_experts), lora_A=A_dn, lora_B=B_dn, alpha=alpha) | ||||||||
|
|
||||||||
| for ei in range(num_experts): | ||||||||
| gate_disk = fused_gu[ei, :intermediate, :].clone() | ||||||||
| up_disk = fused_gu[ei, intermediate:, :].clone() | ||||||||
| down_disk = fused_dn[ei].clone() | ||||||||
|
|
||||||||
| gate_out = _merge_moe_gate_expert(gate_disk, stats_gu, ei, num_experts, dtype) | ||||||||
| up_out = _merge_moe_up_expert (up_disk, stats_gu, ei, num_experts, dtype) | ||||||||
| down_out = _merge_moe_down_proj_expert(down_disk, stats_dn, ei, num_experts, dtype) | ||||||||
|
|
||||||||
| gate_ref = (fused_gu[ei, :intermediate, :].to(torch.float64) | ||||||||
| + _analytic_gate_up_delta(A_gu, B_gu, alpha, ei, num_experts, "gate", layout, intermediate, hidden)).to(dtype) | ||||||||
| up_ref = (fused_gu[ei, intermediate:, :].to(torch.float64) | ||||||||
| + _analytic_gate_up_delta(A_gu, B_gu, alpha, ei, num_experts, "up", layout, intermediate, hidden)).to(dtype) | ||||||||
| down_ref = (fused_dn[ei].to(torch.float64) | ||||||||
| + _analytic_down_delta(A_dn, B_dn, alpha, ei, num_experts, layout)).to(dtype) | ||||||||
|
|
||||||||
| for out, ref in ((gate_out, gate_ref), (up_out, up_ref), (down_out, down_ref)): | ||||||||
| err = (out.cpu() - ref.cpu()).abs().max().item() | ||||||||
| if err > max_err: max_err = err | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using the max() function is more idiomatic and readable than a conditional assignment for updating a maximum value. Additionally, compound one-liner if statements are discouraged by PEP 8.
Suggested change
References
|
||||||||
| total_expected_apply += 1 | ||||||||
|
|
||||||||
| assert max_err < 1e-4, f"merge delta error too large: {max_err:.2e}" | ||||||||
| assert _MOE_MERGE_STATE["applied"] == total_expected_apply | ||||||||
| assert _MOE_MERGE_STATE["attempted"] == total_expected_apply | ||||||||
| assert _MOE_MERGE_STATE["fallback"] == 0 | ||||||||
| assert _MOE_MERGE_STATE["first_error"] is None | ||||||||
| _reset_moe_merge_state() | ||||||||
|
|
||||||||
|
|
||||||||
| def test_unrecognised_layout_records_fallback_and_first_error(): | ||||||||
| _reset_moe_merge_state() | ||||||||
| num_experts, rank_per, intermediate, hidden = 4, 4, 8, 12 | ||||||||
| TR = num_experts * rank_per | ||||||||
| W = torch.randn(intermediate, hidden) | ||||||||
| A = torch.randn(TR, hidden + 7); B = torch.randn(hidden, TR) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid using semicolons to place multiple statements on a single line, as per PEP 8 guidelines.
Suggested change
References
|
||||||||
| stats = LoraStats(module=_InnerMoE(num_experts), lora_A=A, lora_B=B, alpha=1.0) | ||||||||
| out = _merge_moe_gate_expert(W.clone(), stats, 0, num_experts, torch.float32) | ||||||||
| assert torch.equal(out.cpu(), W) | ||||||||
| assert _MOE_MERGE_STATE["fallback"] >= 1 | ||||||||
| err = _MOE_MERGE_STATE["first_error"] | ||||||||
| assert err is not None and err["role"] == "gate" | ||||||||
| assert err["lora_A_shape"] == (TR, hidden + 7) | ||||||||
| _reset_moe_merge_state() | ||||||||
|
|
||||||||
|
|
||||||||
| def test_resolver_walks_outer_wrapper_chain(): | ||||||||
| """Walks past outer ParamWrapper (.module=None) to inner num_experts.""" | ||||||||
| outer = _OuterParamWrapper(base_layer=_InnerMoE(num_experts=128)) | ||||||||
| stats = LoraStats(module=outer, lora_A=None, lora_B=None, alpha=0.0) | ||||||||
| assert _resolve_num_experts_from_lora_stats(stats, fallback=-1) == 128 | ||||||||
|
|
||||||||
|
|
||||||||
| def test_resolver_terminates_on_self_cycle(): | ||||||||
| class SelfCycle: pass | ||||||||
| sc = SelfCycle(); sc.base_layer = sc | ||||||||
| stats = LoraStats(module=sc, lora_A=None, lora_B=None, alpha=0.0) | ||||||||
| assert _resolve_num_experts_from_lora_stats(stats, fallback=42) == 42 | ||||||||
|
|
||||||||
|
|
||||||||
| def test_detector_is_stable_against_non_divisor_num_experts(): | ||||||||
| num_experts, rank_per, intermediate, hidden = 128, 4, 8, 12 | ||||||||
| TR = num_experts * rank_per | ||||||||
| A = torch.empty(TR, hidden); B = torch.empty(2 * intermediate, TR) | ||||||||
| layout, _ = _detect_moe_lora_layout(A, B, num_experts=17, out_dim=2*intermediate, in_dim=hidden) | ||||||||
| assert layout == "unknown" | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # Unsloth Zoo - Utilities for Unsloth | ||
| # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. | ||
| # | ||
| # This program is free software: you can redistribute it and/or modify | ||
| # it under the terms of the GNU Affero General Public License as published by | ||
| # the Free Software Foundation, either version 3 of the License, or | ||
| # (at your option) any later version. | ||
| # | ||
| # This program is distributed in the hope that it will be useful, | ||
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
| # GNU Affero General Public License for more details. | ||
| # | ||
| # You should have received a copy of the GNU Affero General Public License | ||
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
|
|
||
| """PEFT 3D-ParamWrapper layout drift canary (#5410). Fires if PEFT | ||
| introduces a third layout. CPU only.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| peft = pytest.importorskip("peft") | ||
| from peft import LoraConfig, get_peft_model | ||
|
|
||
|
|
||
| # 3D MoE fused parameter (num_experts, 2*intermediate, hidden). | ||
| NUM_EXPERTS = 4 | ||
| INTERMEDIATE = 8 | ||
| HIDDEN = 12 | ||
| TWO_INTER = 2 * INTERMEDIATE | ||
| PER_EXPERT_R = 4 | ||
| TOTAL_RANK = NUM_EXPERTS * PER_EXPERT_R | ||
|
|
||
|
|
||
| class _ToyMoE(nn.Module): | ||
| num_experts = NUM_EXPERTS | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.gate_up_proj = nn.Parameter(torch.randn(NUM_EXPERTS, TWO_INTER, HIDDEN)) | ||
|
|
||
| def forward(self, x): | ||
| return torch.einsum("bh,eih->bei", x, self.gate_up_proj) | ||
|
|
||
|
|
||
| def _peft_supports_target_parameters() -> bool: | ||
| try: | ||
| LoraConfig(r=1, target_parameters=["dummy"]) | ||
| return True | ||
| except TypeError: | ||
| return False | ||
| except Exception: | ||
| return True | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not _peft_supports_target_parameters(), | ||
| reason="PEFT < 0.18 lacks target_parameters") | ||
| def test_paramwrapper_lora_shape_is_one_of_two_known_layouts(): | ||
| torch.manual_seed(0) | ||
| base = _ToyMoE() | ||
|
|
||
| cfg_kwargs = dict(r=PER_EXPERT_R, lora_alpha=PER_EXPERT_R * 2, lora_dropout=0.0, bias="none") | ||
| try: | ||
| cfg = LoraConfig(target_parameters=["gate_up_proj"], **cfg_kwargs) | ||
| except TypeError: | ||
| pytest.skip("Installed PEFT does not accept target_parameters yet") | ||
|
|
||
| try: | ||
| peft_model = get_peft_model(base, cfg) | ||
| except Exception as e: | ||
| pytest.skip(f"PEFT failed to wrap fused 3D param on this build: {e}") | ||
|
|
||
| lora_A = lora_B = None | ||
| for name, p in peft_model.named_parameters(): | ||
| if name.endswith("lora_A.default") or name.endswith("lora_A.default.weight"): | ||
| lora_A = p | ||
| elif name.endswith("lora_B.default") or name.endswith("lora_B.default.weight"): | ||
| lora_B = p | ||
|
|
||
| assert lora_A is not None and lora_B is not None, ( | ||
| f"lora_A / lora_B not found in named_parameters: " | ||
| f"{[n for n, _ in peft_model.named_parameters()]}" | ||
| ) | ||
|
|
||
| A_shape, B_shape = tuple(lora_A.shape), tuple(lora_B.shape) | ||
| swapped = ((TOTAL_RANK, TWO_INTER), (HIDDEN, TOTAL_RANK)) | ||
| standard = ((TOTAL_RANK, HIDDEN), (TWO_INTER, TOTAL_RANK)) | ||
| observed = (A_shape, B_shape) | ||
| layout = "swapped" if observed == swapped else "standard" if observed == standard else "unknown" | ||
|
|
||
| assert layout != "unknown", ( | ||
| f"PEFT layout drift: peft={peft.__version__} A={A_shape} B={B_shape}; " | ||
| f"expected swapped={swapped} or standard={standard}. Update " | ||
| f"_detect_moe_lora_layout + merge math (#5410)." | ||
| ) | ||
| assert A_shape[0] // NUM_EXPERTS == PER_EXPERT_R | ||
|
|
||
|
|
||
| def test_zoo_detector_classifies_both_known_layouts(): | ||
| from unsloth_zoo.saving_utils import _detect_moe_lora_layout | ||
| A = torch.empty(TOTAL_RANK, TWO_INTER); B = torch.empty(HIDDEN, TOTAL_RANK) | ||
| assert _detect_moe_lora_layout(A, B, NUM_EXPERTS, TWO_INTER, HIDDEN) == ("swapped", PER_EXPERT_R) | ||
| A = torch.empty(TOTAL_RANK, HIDDEN); B = torch.empty(TWO_INTER, TOTAL_RANK) | ||
| assert _detect_moe_lora_layout(A, B, NUM_EXPERTS, TWO_INTER, HIDDEN) == ("standard", PER_EXPERT_R) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using semicolons to place multiple statements on a single line. Following PEP 8 guidelines by splitting these into separate lines improves code readability and maintainability.
References