diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index a03761484e..b0f1009301 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -41,6 +41,10 @@ "glm4v_moe": "Glm4vMoeTextMoE", # sigmoid -> topk routing (no group selection) "minimax_m2": "MiniMaxM2SparseMoeBlock", + # Non-GLU MoE (no gate_proj, experts have up_proj + down_proj only) + "nemotron_h": "NemotronHMoE", + # Models below need custom routing (not yet implemented): + # "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group) # softmax->topk, e_score_correction_bias between softmax and topk "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, group_limited_greedy, different attr names (num_group) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index c6c01e255a..c416e733bd 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -196,12 +196,14 @@ def _unwrap_experts_lora(experts_module): if num_experts is None: # Fallback: infer from parameter shape gup = getattr(base_experts, "gate_up_proj", None) + if gup is None: + gup = getattr(base_experts, "up_proj", None) if gup is not None: num_experts = gup.shape[0] - # Extract gate_up_proj LoRA (needs A<->B swap due to transposition) + # Extract gate_up_proj (or up_proj for non-GLU) LoRA gup_lora = None - gup_wrapper = wrappers.get("gate_up_proj") + gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj") if gup_wrapper is not None: lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) if lora_A is not None: @@ -489,6 +491,21 @@ def forward(self: nn.Module, layer_input: torch.Tensor): # ==================================================================== experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts) + # ==================================================================== + # Detect GLU vs non-GLU expert architecture + # ==================================================================== + # GLU models (Qwen, Mixtral, etc.): gate_up_proj [E, 2*I, H] + # Non-GLU models (Nemotron-H, etc.): up_proj [E, I, H] + has_glu = hasattr(experts, "gate_up_proj") + up_proj_name = "gate_up_proj" if has_glu else "up_proj" + + # ==================================================================== + # Optional latent projection before experts (e.g. Nemotron-H) + # ==================================================================== + fc1_latent = getattr(self, "fc1_latent_proj", None) + if fc1_latent is not None: + hidden_states_flat = fc1_latent(hidden_states_flat) + # ==================================================================== # Selective expert weight dequantization # ==================================================================== @@ -498,7 +515,7 @@ def forward(self: nn.Module, layer_input: torch.Tensor): use_selective = ( getattr(self, "_use_selective_dequant", False) and hasattr(experts, "parametrizations") - and "gate_up_proj" in experts.parametrizations + and up_proj_name in experts.parametrizations ) if use_selective: @@ -517,11 +534,11 @@ def forward(self: nn.Module, layer_input: torch.Tensor): num_experts, ) # Dequantize only active experts' weights - gate_up_W = selective_expert_weights( + up_W = selective_expert_weights( experts, - "gate_up_proj", + up_proj_name, active_experts, - ).transpose(2, 1) # [num_active, hidden, 2*inter] + ).transpose(2, 1) # Remap LoRA weights to match compact expert indices if gup_lora is not None: @@ -538,18 +555,18 @@ def forward(self: nn.Module, layer_input: torch.Tensor): sei_gup = remapped_expert_idxs eo_gup = compact_offsets else: - gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] + up_W = getattr(experts, up_proj_name).transpose(2, 1) sei_gup = sorted_expert_idxs eo_gup = expert_offsets # ==================================================================== - # Gate + Up projection + # Up projection (GLU: gate+up fused, non-GLU: up only) # ==================================================================== if gup_lora is not None: gup_A, gup_B, gup_scaling = gup_lora - gup = parallel_linear_lora( + up_out = parallel_linear_lora( hidden_states_flat, - gate_up_W, + up_W, top_k, sei_gup, sorted_scattered_idxs, @@ -563,9 +580,9 @@ def forward(self: nn.Module, layer_input: torch.Tensor): use_fused_gather=True, ) else: - gup = parallel_linear( + up_out = parallel_linear( hidden_states_flat, - gate_up_W, + up_W, top_k, sei_gup, sorted_scattered_idxs, @@ -574,8 +591,18 @@ def forward(self: nn.Module, layer_input: torch.Tensor): grouped_out=True, ) - gates, h = gup.chunk(2, dim=-1) - h = experts.act_fn(gates) * h + # GLU: split into gate and up, apply act_fn(gate) * up + # Non-GLU: apply act_fn directly + if has_glu: + gates, h = up_out.chunk(2, dim=-1) + h = experts.act_fn(gates) * h + else: + h = experts.act_fn(up_out) + + # Some activations (e.g. relu2) upcast to fp32 internally. + # Cast back to weight dtype for the down projection Triton kernel. + if h.dtype != experts.down_proj.dtype: + h = h.to(experts.down_proj.dtype) # ==================================================================== # Down projection @@ -635,6 +662,13 @@ def forward(self: nn.Module, layer_input: torch.Tensor): gates=routing_weights, ) + # ==================================================================== + # Optional latent projection after experts (e.g. Nemotron-H) + # ==================================================================== + fc2_latent = getattr(self, "fc2_latent_proj", None) + if fc2_latent is not None: + expert_output = fc2_latent(expert_output) + # ==================================================================== # Combine with shared expert and reshape # ==================================================================== diff --git a/src/axolotl/integrations/kernels/sonicmoe/patch.py b/src/axolotl/integrations/kernels/sonicmoe/patch.py index a3b96f12a1..a783ff54c1 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/patch.py +++ b/src/axolotl/integrations/kernels/sonicmoe/patch.py @@ -39,6 +39,8 @@ def patch_sonicmoe(model_type: str, torch_compile: bool = False): torch_compile: If True, wrap routing functions with torch.compile for kernel fusion (fuses softmax+topk+renorm into fewer launches). """ + from sonicmoe.enums import is_glu + from .routing import get_model_moe_config from .weight_converter import register_sonicmoe_weight_converter @@ -49,7 +51,11 @@ def patch_sonicmoe(model_type: str, torch_compile: bool = False): for moe_cls in resolve_moe_block_classes(model_type): _patch_forward(moe_cls, routing_fn, activation, router_attr) - register_sonicmoe_weight_converter(model_type) + + # Weight interleaving only applies to GLU models (gate_up_proj). + # Non-GLU models have a plain up_proj that needs no conversion. + if is_glu(activation): + register_sonicmoe_weight_converter(model_type) def _try_compile_routing(routing_fn): @@ -98,6 +104,9 @@ def _patch_forward(moe_cls, routing_fn, activation, router_attr): def _make_general_forward(moe_cls, routing_fn, activation): """Create forward using routing_fn + moe_general_routing_inputs.""" + from sonicmoe.enums import is_glu + + glu_activation = is_glu(activation) def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: from sonicmoe import moe_general_routing_inputs @@ -105,7 +114,7 @@ def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) - # Shared expert (computed early, matching original model ordering) + # Shared expert shared_expert_output = _compute_shared_expert(self, hidden_states_flat) # Routing @@ -113,28 +122,42 @@ def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states_flat, self ) - # Permute weights to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) + # Optional latent projection before experts (e.g. Nemotron-H) + expert_input = hidden_states_flat + fc1_latent = getattr(self, "fc1_latent_proj", None) + if fc1_latent is not None: + expert_input = fc1_latent(expert_input) + + # Permute weights to SonicMoE layout. + # GLU models: gate_up_proj [E, 2*I, H] -> [2*I, H, E] + # Non-GLU: up_proj [E, I, H] -> [I, H, E] + if glu_activation: + up_weight = self.experts.gate_up_proj.permute(1, 2, 0) + else: + up_weight = self.experts.up_proj.permute(1, 2, 0) down_weight = self.experts.down_proj.permute(1, 2, 0) - E = gate_up_weight.shape[-1] + E = up_weight.shape[-1] output, _ = moe_general_routing_inputs( - hidden_states_flat, + expert_input, router_scores, token_indices, expert_indices, - gate_up_weight, - None, # b1 (no gate/up bias) + up_weight, + None, # b1 (no bias) down_weight, - None, # b2 (no down bias) + None, # b2 (no bias) E, torch.cuda.current_stream().cuda_stream, activation, False, # is_inference_mode ) + # Optional latent projection after experts (e.g. Nemotron-H) + fc2_latent = getattr(self, "fc2_latent_proj", None) + if fc2_latent is not None: + output = fc2_latent(output) + # Add shared expert contribution if present if shared_expert_output is not None: if hasattr(self, "shared_expert_gate"): @@ -151,6 +174,9 @@ def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def _make_fused_forward(moe_cls, activation, router_attr): """Create forward using moe_TC_softmax_topk_layer (topk -> softmax).""" + from sonicmoe.enums import is_glu + + glu_activation = is_glu(activation) def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: from sonicmoe import moe_TC_softmax_topk_layer @@ -158,30 +184,44 @@ def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) - # Shared expert (computed early, matching original model ordering) + # Shared expert shared_expert_output = _compute_shared_expert(self, hidden_states_flat) router = getattr(self, router_attr) - # Permute weights to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) + # Optional latent projection before experts (e.g. Nemotron-H) + expert_input = hidden_states_flat + fc1_latent = getattr(self, "fc1_latent_proj", None) + if fc1_latent is not None: + expert_input = fc1_latent(expert_input) + + # Permute weights to SonicMoE layout. + # GLU models: gate_up_proj [E, 2*I, H] -> [2*I, H, E] + # Non-GLU: up_proj [E, I, H] -> [I, H, E] + if glu_activation: + up_weight = self.experts.gate_up_proj.permute(1, 2, 0) + else: + up_weight = self.experts.up_proj.permute(1, 2, 0) down_weight = self.experts.down_proj.permute(1, 2, 0) output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer( - hidden_states_flat, + expert_input, router.weight, - gate_up_weight, - None, # b1 (no gate/up bias) + up_weight, + None, # b1 (no bias) down_weight, - None, # b2 (no down bias) + None, # b2 (no bias) router.top_k, torch.cuda.current_stream().cuda_stream, activation, False, # is_inference_mode ) + # Optional latent projection after experts (e.g. Nemotron-H) + fc2_latent = getattr(self, "fc2_latent_proj", None) + if fc2_latent is not None: + output = fc2_latent(output) + # Add shared expert contribution if present if shared_expert_output is not None: if hasattr(self, "shared_expert_gate"): diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/sonicmoe/routing.py index 09bffc7421..7a8cdf31b8 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/sonicmoe/routing.py @@ -59,6 +59,13 @@ def get_model_moe_config(model_type: str): "minimax_m2", ): return sigmoid_topk_routing, ActivationType.SWIGLU, "gate" + # Non-GLU MoE (no gate_proj, experts use up_proj + down_proj only) + elif model_type in ("nemotron_h",): + return sigmoid_topk_routing, ActivationType.RELU_SQ, "gate" + # elif model_type in ("deepseek_v2",): + # # Softmax→topk with group_limited_greedy. Different attr names: num_group + # # (not n_group), gate is nn.Linear (not a router class). + # return ..., ActivationType.SWIGLU, "gate" elif model_type in ("ernie4_5_moe",): return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate" elif model_type in ("hunyuan_v1_moe",): diff --git a/tests/e2e/integrations/test_scattermoe_lora_kernels.py b/tests/e2e/integrations/test_scattermoe_lora_kernels.py index c204a15035..9b5208e636 100644 --- a/tests/e2e/integrations/test_scattermoe_lora_kernels.py +++ b/tests/e2e/integrations/test_scattermoe_lora_kernels.py @@ -13,6 +13,7 @@ 4. Various configurations: top-k, grouped_in/out, with/without bias 5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference 6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2) +7. Non-GLU MoE forward (Nemotron-H style: up_proj + relu2 + down_proj) Test strategy: - Reference implementation uses pure PyTorch ops (no Triton) @@ -1858,3 +1859,286 @@ def test_shared_expert_with_gate(self): hidden = torch.randn(1, T, H, device="cuda") output = HFScatterMoEGatedMLP.forward(moe_block, hidden) assert output.shape == (1, T, H) + + +# ============================================================================= +# Test: Non-GLU MoE forward (Nemotron-H style) +# ============================================================================= + + +def _reference_non_glu_moe_forward( + hidden_states, + up_proj, + down_proj, + act_fn, + routing_weights, + selected_experts, + num_experts, +): + """Pure PyTorch reference for a non-GLU MoE forward pass. + + Non-GLU experts have up_proj [E, I, H] and down_proj [E, H, I] with + a direct activation (no gate/up split). Forward per expert: + output = down_proj(act_fn(up_proj(x))) + + Args: + hidden_states: [T, H] + up_proj: [E, I, H] + down_proj: [E, H, I] + act_fn: activation function (e.g. relu2) + routing_weights: [T, K] routing weights + selected_experts: [T, K] expert indices + num_experts: int + + Returns: + output: [T, H] + """ + T, H = hidden_states.shape + K = selected_experts.shape[1] + output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype) + + for t in range(T): + for j in range(K): + e = selected_experts[t, j].item() + w = routing_weights[t, j].item() + + # up projection: [I] + up_out = hidden_states[t] @ up_proj[e].T + + # activation (no gating, just direct activation) + h = act_fn(up_out) + + # Cast back to weight dtype (relu2 upcasts to fp32) + if h.dtype != down_proj.dtype: + h = h.to(down_proj.dtype) + + # down projection: [H] + out = h @ down_proj[e].T + + output[t] += w * out + + return output + + +class _ReLU2(torch.nn.Module): + """ReLU squared activation (relu2): relu(x)^2.""" + + def forward(self, x): + return torch.relu(x).square() + + +def _make_mock_non_glu_moe_block(T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1): + """Create a mock non-GLU MoE block (Nemotron-H style) for GPU testing. + + Non-GLU: experts have up_proj [E, I, H] and down_proj [E, H, I], + no gate_up_proj. Uses relu2 activation. + """ + up_proj = torch.randn(E, FF, H, device="cuda") * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda") * 0.02 + act_fn = _ReLU2() + + experts = SimpleNamespace( + up_proj=up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + e_score_correction_bias=torch.zeros(E, device="cuda"), + ) + moe_block = SimpleNamespace( + gate=gate, + experts=experts, + top_k=K, + n_routed_experts=E, + n_group=n_group, + topk_group=topk_group, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + + return moe_block, T, H, FF, E, K + + +@pytest.mark.slow +class TestHFScatterMoENonGLU: + """Test HFScatterMoEGatedMLP forward with non-GLU experts (Nemotron-H style).""" + + def test_forward_matches_reference(self): + """Non-GLU forward pass with sigmoid routing matches reference.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + _sigmoid_topk_route, + ) + + moe_block, T, H, FF, E, K = _make_mock_non_glu_moe_block( + T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1 + ) + + hidden = torch.randn(1, T, H, device="cuda") + + # Get routing for reference + gate = moe_block.gate + hidden_flat = hidden.view(-1, H) + routing_weights, selected_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden_flat, gate.weight, None + ) + + # Reference output (non-GLU) + ref_output = _reference_non_glu_moe_forward( + hidden_flat, + moe_block.experts.up_proj, + moe_block.experts.down_proj, + moe_block.experts.act_fn, + routing_weights, + selected_experts, + E, + ) + + # Kernel output + kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + kernel_output_flat = kernel_output.view(-1, H) + + torch.testing.assert_close( + kernel_output_flat.float(), + ref_output.float(), + atol=5e-2, + rtol=5e-2, + ) + + def test_forward_with_softmax_routing(self): + """Non-GLU forward with softmax routing (hypothetical non-GLU softmax model).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + _softmax_topk_route, + ) + + T, H, FF, E, K = 16, 64, 32, 4, 2 + up_proj = torch.randn(E, FF, H, device="cuda") * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda") * 0.02 + act_fn = torch.nn.ReLU() + + experts = SimpleNamespace( + up_proj=up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + moe_block = SimpleNamespace(gate=gate, experts=experts) + + hidden = torch.randn(1, T, H, device="cuda") + hidden_flat = hidden.view(-1, H) + + routing_weights, selected_experts, _, _ = _softmax_topk_route( + moe_block, gate, hidden_flat, gate.weight, None + ) + + ref_output = _reference_non_glu_moe_forward( + hidden_flat, + up_proj, + down_proj, + act_fn, + routing_weights, + selected_experts, + E, + ) + + kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + kernel_output_flat = kernel_output.view(-1, H) + + torch.testing.assert_close( + kernel_output_flat.float(), + ref_output.float(), + atol=5e-2, + rtol=5e-2, + ) + + def test_relu2_dtype_cast(self): + """relu2 activation upcasts to fp32; verify output is still correct in bf16.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + ) + + T, H, FF, E, K = 8, 64, 32, 4, 2 + up_proj = torch.randn(E, FF, H, device="cuda", dtype=torch.bfloat16) * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda", dtype=torch.bfloat16) * 0.02 + act_fn = _ReLU2() + + experts = SimpleNamespace( + up_proj=up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda", dtype=torch.bfloat16) * 0.1, + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + moe_block = SimpleNamespace(gate=gate, experts=experts) + + hidden = torch.randn(1, T, H, device="cuda", dtype=torch.bfloat16) + + # Should not raise despite relu2 upcasting to fp32 internally + output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + assert output.shape == (1, T, H) + assert torch.isfinite(output).all() + + def test_forward_with_latent_projections(self): + """Non-GLU forward with fc1_latent_proj / fc2_latent_proj (Nemotron-H).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + ) + + T, H, FF, E, K = 8, 64, 32, 4, 2 + LATENT = 48 # intermediate latent dim + + up_proj = torch.randn(E, FF, LATENT, device="cuda") * 0.02 + down_proj = torch.randn(E, LATENT, FF, device="cuda") * 0.02 + act_fn = _ReLU2() + + experts = SimpleNamespace( + up_proj=up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + + # Latent projections: H -> LATENT before experts, LATENT -> H after + fc1_latent = torch.nn.Linear(H, LATENT, bias=False).cuda() + fc2_latent = torch.nn.Linear(LATENT, H, bias=False).cuda() + + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + moe_block = SimpleNamespace( + gate=gate, + experts=experts, + fc1_latent_proj=fc1_latent, + fc2_latent_proj=fc2_latent, + ) + + hidden = torch.randn(1, T, H, device="cuda") + + output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + assert output.shape == (1, T, H) + assert torch.isfinite(output).all() + + def test_non_glu_no_gate_up_proj_attribute(self): + """Verify non-GLU block does NOT have gate_up_proj on experts.""" + moe_block, T, H, FF, E, K = _make_mock_non_glu_moe_block() + assert not hasattr(moe_block.experts, "gate_up_proj") + assert hasattr(moe_block.experts, "up_proj") + assert hasattr(moe_block.experts, "down_proj") diff --git a/tests/integrations/test_scattermoe_lora.py b/tests/integrations/test_scattermoe_lora.py index bd50d06feb..de2880d1c3 100644 --- a/tests/integrations/test_scattermoe_lora.py +++ b/tests/integrations/test_scattermoe_lora.py @@ -13,6 +13,7 @@ - HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract - Routing strategy detection and sigmoid routing - Generic shared expert handling +- Non-GLU expert detection and activation path (Nemotron-H) """ from types import SimpleNamespace @@ -709,3 +710,54 @@ def test_no_shared_expert(self): moe_block = SimpleNamespace() result = _compute_shared_expert(moe_block, torch.randn(4, 8)) assert result is None + + +# ============================================================================ +# 9. Non-GLU expert detection and activation path +# ============================================================================ + + +class TestNonGLUExpertDetection: + """Test GLU vs non-GLU expert architecture detection in HFScatterMoEGatedMLP.""" + + def test_glu_detected_when_gate_up_proj_exists(self): + """Experts with gate_up_proj should be detected as GLU.""" + experts = SimpleNamespace( + gate_up_proj=torch.randn(4, 64, 16), + down_proj=torch.randn(4, 16, 32), + act_fn=torch.nn.SiLU(), + ) + assert hasattr(experts, "gate_up_proj") + assert not hasattr(experts, "up_proj") + + def test_non_glu_detected_when_only_up_proj(self): + """Experts with up_proj (no gate_up_proj) should be detected as non-GLU.""" + experts = SimpleNamespace( + up_proj=torch.randn(4, 32, 16), + down_proj=torch.randn(4, 16, 32), + act_fn=torch.nn.ReLU(), + ) + has_glu = hasattr(experts, "gate_up_proj") + up_proj_name = "gate_up_proj" if has_glu else "up_proj" + assert not has_glu + assert up_proj_name == "up_proj" + + def test_unwrap_experts_lora_fallback_to_up_proj(self): + """_unwrap_experts_lora falls back to up_proj when gate_up_proj is absent.""" + _skip_without_triton() + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _unwrap_experts_lora, + ) + + # Non-GLU experts: only up_proj, no gate_up_proj + experts = SimpleNamespace( + up_proj=torch.randn(4, 32, 16), + down_proj=torch.randn(4, 16, 32), + act_fn=torch.nn.ReLU(), + num_experts=4, + ) + base_experts, gup_lora, down_lora = _unwrap_experts_lora(experts) + # Should return the experts unchanged with no LoRA + assert base_experts is experts + assert gup_lora is None + assert down_lora is None diff --git a/tests/integrations/test_sonicmoe.py b/tests/integrations/test_sonicmoe.py index 7d26d9d936..2006b69e2f 100644 --- a/tests/integrations/test_sonicmoe.py +++ b/tests/integrations/test_sonicmoe.py @@ -786,3 +786,73 @@ def test_uses_gate_wg_weight(self): scores2, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) assert not torch.equal(scores1, scores2) + + +# ============================================================================ +# Non-GLU model config (Nemotron-H) +# ============================================================================ + + +class TestNemotronHMoEConfig: + """Test routing config for Nemotron-H (non-GLU, relu2 activation).""" + + def test_get_model_moe_config_returns_relu_sq(self): + """nemotron_h should map to RELU_SQ activation, not SWIGLU.""" + sonicmoe_enums = pytest.importorskip("sonicmoe.enums") + ActivationType = sonicmoe_enums.ActivationType + is_glu = sonicmoe_enums.is_glu + + from axolotl.integrations.kernels.sonicmoe.routing import get_model_moe_config + + routing_fn, activation, router_attr = get_model_moe_config("nemotron_h") + + assert activation == ActivationType.RELU_SQ + assert not is_glu(activation) + assert router_attr == "gate" + assert routing_fn is sigmoid_topk_routing + + def test_non_glu_skips_weight_converter(self): + """patch_sonicmoe should NOT register weight converter for non-GLU models. + + Non-GLU models have plain up_proj, so there's nothing to interleave. + This test verifies the is_glu() guard in patch_sonicmoe(). + """ + sonicmoe_enums = pytest.importorskip("sonicmoe.enums") + ActivationType = sonicmoe_enums.ActivationType + is_glu = sonicmoe_enums.is_glu + + # RELU_SQ is non-GLU + assert not is_glu(ActivationType.RELU_SQ) + # SWIGLU is GLU + assert is_glu(ActivationType.SWIGLU) + + def test_nemotron_h_in_constants(self): + """nemotron_h should be in the SPARSE_MOE_BLOCK mapping.""" + from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK + + assert "nemotron_h" in SPARSE_MOE_BLOCK + assert SPARSE_MOE_BLOCK["nemotron_h"] == "NemotronHMoE" + + def test_nemotron_h_routing_shapes(self): + """Verify sigmoid_topk_routing works with Nemotron-H-style block.""" + T, H, E, K = 8, 16, 8, 2 + gate = SimpleNamespace( + weight=torch.randn(E, H), + e_score_correction_bias=torch.zeros(E), + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + n_routed_experts=E, + n_group=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E)