diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 349ddd15f3ba..f86fcf413b10 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -915,86 +915,6 @@ def process_weights_after_loading(self, layer: Module) -> None: if _is_hip: self.process_weights_hip_scale_padding(layer) - - # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled - if get_moe_runner_backend().is_flashinfer_trtllm(): - from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a - - # Note: No need to swap W13 halves, they are already in the correct order: [Gate, Up] - num_experts, two_n, hidden = layer.w13_weight.shape - - # 2) Reorder rows for fused gated activation (W13) - w13_interleaved = [ - reorder_rows_for_gated_act_gemm(layer.w13_weight[i]) - for i in range(num_experts) - ] - w13_interleaved = torch.stack(w13_interleaved).reshape( - num_experts, two_n, hidden - ) - - # 3) Shuffle weights for transposed MMA output (both W13, W2) - epilogue_tile_m = 128 - w13_shuffled = [ - shuffle_matrix_a( - w13_interleaved[i].view(torch.uint8), epilogue_tile_m - ) - for i in range(num_experts) - ] - w2_shuffled = [ - shuffle_matrix_a( - layer.w2_weight[i].view(torch.uint8), epilogue_tile_m - ) - for i in range(num_experts) - ] - - layer.w13_weight = Parameter( - torch.stack(w13_shuffled).view(torch.float8_e4m3fn), - requires_grad=False, - ) - layer.w2_weight = Parameter( - torch.stack(w2_shuffled).view(torch.float8_e4m3fn), - requires_grad=False, - ) - - # Precompute and register per-expert output scaling factors for FI MoE - # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction - assert ( - hasattr(layer, "w13_input_scale") - and layer.w13_input_scale is not None - ) - assert ( - hasattr(layer, "w2_input_scale") - and layer.w2_input_scale is not None - ) - assert ( - hasattr(layer, "w13_weight_scale") - and layer.w13_weight_scale is not None - ) - assert ( - hasattr(layer, "w2_weight_scale") - and layer.w2_weight_scale is not None - ) - - input_scale = layer.w13_input_scale.to(torch.float32) - activation_scale = layer.w2_input_scale.to(torch.float32) - w13_weight_scale = layer.w13_weight_scale.to(torch.float32) - w2_weight_scale = layer.w2_weight_scale.to(torch.float32) - - output1_scales_scalar = ( - w13_weight_scale * input_scale * (1.0 / activation_scale) - ) - output1_scales_gate_scalar = w13_weight_scale * input_scale - output2_scales_scalar = activation_scale * w2_weight_scale - - layer.output1_scales_scalar = Parameter( - output1_scales_scalar, requires_grad=False - ) - layer.output1_scales_gate_scalar = Parameter( - output1_scales_gate_scalar, requires_grad=False - ) - layer.output2_scales_scalar = Parameter( - output2_scales_scalar, requires_grad=False - ) return def process_weights_hip_int4(self, layer: Module): @@ -1298,10 +1218,7 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer.fused_moe import ( - trtllm_fp8_block_scale_moe, - trtllm_fp8_per_tensor_scale_moe, - ) + from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.utils import RoutingMethodType @@ -1312,15 +1229,9 @@ def apply_with_router_logits( assert ( activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" - - if self.block_quant: - a_q, a_sf = per_token_group_quant_fp8( - x, self.quant_config.weight_block_size[1] - ) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - else: - a_q, _ = scaled_fp8_quant(x, layer.w13_input_scale) + a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() correction_bias = ( None @@ -1328,79 +1239,43 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - routing_method_type = getattr( - layer, "routing_method_type", RoutingMethodType.DeepSeekV3 - ) + routing_method_type = getattr(layer, "routing_method_type") with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): - if self.block_quant: - # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. - # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 - # so we put the whole function under the ``use_symmetric_memory`` context manager. - # If the bug is fixed, we can only put the output tensor allocation under the context manager. - return trtllm_fp8_block_scale_moe( - routing_logits=( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor - if routed_scaling_factor is not None - else 1.0 - ), - tile_tokens_dim=None, - routing_method_type=routing_method_type, - use_shuffled_weight=False, - ) - else: - routing_bias_cast = ( - None - if correction_bias is None - else correction_bias.to(torch.bfloat16) - ) - - return trtllm_fp8_per_tensor_scale_moe( - routing_logits=router_logits.to(torch.bfloat16), - routing_bias=routing_bias_cast, - hidden_states=a_q, - gemm1_weights=layer.w13_weight, - output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - gemm2_weights=layer.w2_weight, - output2_scales_scalar=layer.output2_scales_scalar, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor - if routed_scaling_factor is not None - else 1.0 - ), - use_routing_scales_on_input=False, - routing_method_type=routing_method_type, - ) + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + return trtllm_fp8_block_scale_moe( + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + tile_tokens_dim=None, + routing_method_type=routing_method_type, + use_shuffled_weight=False, + ) def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c6de3f4cb0ac..70914c9d3fd2 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1014,9 +1014,3 @@ def validate_fp8_block_shape( f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}." ) - - -def expert_weight_is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index ba31c12c6fdb..6b106379eec0 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1565,27 +1565,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_weight_scale_2 = layer.w13_weight_scale_2[:] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - def _slice_scale(w): - assert w.shape == (layer.num_experts,) - assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts - return w[ - layer.moe_ep_rank - * layer.num_local_experts : (layer.moe_ep_rank + 1) - * layer.num_local_experts - ] - # Calculate input scales based on strategy if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe: - w13_input_scale = ( - layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) - ) - w2_input_scale = ( - layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) - ) - - if layer.moe_ep_size > 1: - w13_input_scale = _slice_scale(w13_input_scale) - w2_input_scale = _slice_scale(w2_input_scale) + w13_input_scale = layer.w13_input_scale.max().to(torch.float32) + w2_input_scale = layer.w2_input_scale.max().to(torch.float32) elif self.enable_flashinfer_cutedsl_moe: # All-expert-one-input-scale is mathematically different from default per-expert-input-scale # Thus we allow users to switch the flag to do thorough testing @@ -1602,6 +1585,15 @@ def _slice_scale(w): w2_input_scale = layer.w2_input_scale + def _slice_scale(w): + assert w.shape == (layer.num_experts,) + assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts + return w[ + layer.moe_ep_rank + * layer.num_local_experts : (layer.moe_ep_rank + 1) + * layer.num_local_experts + ] + w13_input_scale = _slice_scale(w13_input_scale) w2_input_scale = _slice_scale(w2_input_scale) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b3b2fa0d7e1b..9dc9fed76524 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -649,9 +649,7 @@ def __init__( layer_id=self.layer_id, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, - routing_method_type=getattr( - config, "routing_method_type", RoutingMethodType.DeepSeekV3 - ), + routing_method_type=RoutingMethodType.DeepSeekV3, prefix=add_prefix("experts", prefix), ) @@ -3349,7 +3347,6 @@ def forward( class DeepseekV2ForCausalLM(nn.Module): # for quark model load packed_modules_mapping = {} - model_cls = DeepseekV2Model def __init__( self, @@ -3376,7 +3373,7 @@ def __init__( self.quant_config = quant_config self.determine_num_fused_shared_experts() self.use_nsa = is_deepseek_nsa(config) - self.model = self.model_cls( + self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) if self.pp_group.is_last_rank: