diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 029cc9eedde9..e11f208455b1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -932,52 +932,45 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: will_use_deepgemm = self.is_deepgemm_moe_runner_backend_enabled() if self.is_fp4_expert: - if get_moe_runner_backend().is_marlin(): - layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) - layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) - elif not get_moe_runner_backend().is_flashinfer_mxfp4(): - raise NotImplementedError( - "DeepSeekV4 FP4 experts now require a native FP4 MoE backend. " - "Use `--moe-runner-backend marlin` on Hopper or " - "`--moe-runner-backend flashinfer_mxfp4` when available." + # FP4 experts support three MoE backends: + # - marlin (Hopper w4a16): only needs int8 view + # - flashinfer_mxfp4: only needs int8 view + # - deepgemm/auto (Blackwell): int8 view + mega_moe or scale conversion + layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) + + if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): + from sglang.srt.models.deepseek_v4 import ( + build_mega_moe_experts_weights, ) - else: - layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) - layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) + build_mega_moe_experts_weights(layer) + return - if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): - from sglang.srt.models.deepseek_v4 import ( - build_mega_moe_experts_weights, + if ( + envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and will_use_deepgemm + ): + from deep_gemm import transform_sf_into_required_layout + + for scale_param, weight_param in [ + (layer.w13_weight_scale_inv, layer.w13_weight), + (layer.w2_weight_scale_inv, layer.w2_weight), + ]: + num_experts, n, _ = scale_param.data.shape + k = weight_param.shape[2] * 2 + scale_param.data = transform_sf_into_required_layout( + scale_param.data, + mn=n, + k=k, + recipe=(1, 32), + num_groups=num_experts, + disable_ue8m0_cast=False, ) - - build_mega_moe_experts_weights(layer) - return - - if ( - envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get() - and envs.SGLANG_DSV4_MODE.get() == "2604" - and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - and will_use_deepgemm - ): - from deep_gemm import transform_sf_into_required_layout - - for scale_param, weight_param in [ - (layer.w13_weight_scale_inv, layer.w13_weight), - (layer.w2_weight_scale_inv, layer.w2_weight), - ]: - num_experts, n, _ = scale_param.data.shape - k = weight_param.shape[2] * 2 - scale_param.data = transform_sf_into_required_layout( - scale_param.data, - mn=n, - k=k, - recipe=(1, 32), - num_groups=num_experts, - disable_ue8m0_cast=False, - ) - layer.w13_weight_scale_inv.format_ue8m0 = True - layer.w2_weight_scale_inv.format_ue8m0 = True + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True if ( not self.is_fp4_expert