diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index f7b886b39dea..8da7d8eef330 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1226,12 +1226,13 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): symm_output = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device ) - result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, routing_bias=correction_bias, hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn), + hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( + *hs_scale_linear.shape[:-1], -1 + ), gemm1_weights=self.gemm1_weights_fp4_shuffled.data, gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view( torch.float8_e4m3fn diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 589f0c0edfbb..9deede2b8a6a 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -442,7 +442,9 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( routing_logits=router_logits, routing_bias=correction_bias, hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(), + hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( + *hs_scale_linear.shape[:-1], -1 + ), gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( torch.float8_e4m3fn diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 01f909a3093f..5898a078dbba 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -331,7 +331,9 @@ def apply_weights( False, # is_sf_swizzled_layout ) hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2) - hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) + hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape( + *hs_sf_bytes.shape[:-1], -1 + ) correction_bias = ( None