Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 36 additions & 161 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,86 +915,6 @@ def process_weights_after_loading(self, layer: Module) -> None:

if _is_hip:
self.process_weights_hip_scale_padding(layer)

# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if get_moe_runner_backend().is_flashinfer_trtllm():
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a

# Note: No need to swap W13 halves, they are already in the correct order: [Gate, Up]
num_experts, two_n, hidden = layer.w13_weight.shape

# 2) Reorder rows for fused gated activation (W13)
w13_interleaved = [
reorder_rows_for_gated_act_gemm(layer.w13_weight[i])
for i in range(num_experts)
]
w13_interleaved = torch.stack(w13_interleaved).reshape(
num_experts, two_n, hidden
)

# 3) Shuffle weights for transposed MMA output (both W13, W2)
epilogue_tile_m = 128
w13_shuffled = [
shuffle_matrix_a(
w13_interleaved[i].view(torch.uint8), epilogue_tile_m
)
for i in range(num_experts)
]
w2_shuffled = [
shuffle_matrix_a(
layer.w2_weight[i].view(torch.uint8), epilogue_tile_m
)
for i in range(num_experts)
]

layer.w13_weight = Parameter(
torch.stack(w13_shuffled).view(torch.float8_e4m3fn),
requires_grad=False,
)
layer.w2_weight = Parameter(
torch.stack(w2_shuffled).view(torch.float8_e4m3fn),
requires_grad=False,
)

# Precompute and register per-expert output scaling factors for FI MoE
# Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction
assert (
hasattr(layer, "w13_input_scale")
and layer.w13_input_scale is not None
)
assert (
hasattr(layer, "w2_input_scale")
and layer.w2_input_scale is not None
)
assert (
hasattr(layer, "w13_weight_scale")
and layer.w13_weight_scale is not None
)
assert (
hasattr(layer, "w2_weight_scale")
and layer.w2_weight_scale is not None
)

input_scale = layer.w13_input_scale.to(torch.float32)
activation_scale = layer.w2_input_scale.to(torch.float32)
w13_weight_scale = layer.w13_weight_scale.to(torch.float32)
w2_weight_scale = layer.w2_weight_scale.to(torch.float32)

output1_scales_scalar = (
w13_weight_scale * input_scale * (1.0 / activation_scale)
)
output1_scales_gate_scalar = w13_weight_scale * input_scale
output2_scales_scalar = activation_scale * w2_weight_scale

layer.output1_scales_scalar = Parameter(
output1_scales_scalar, requires_grad=False
)
layer.output1_scales_gate_scalar = Parameter(
output1_scales_gate_scalar, requires_grad=False
)
layer.output2_scales_scalar = Parameter(
output2_scales_scalar, requires_grad=False
)
return

def process_weights_hip_int4(self, layer: Module):
Expand Down Expand Up @@ -1298,10 +1218,7 @@ def apply_with_router_logits(
activation = self.moe_runner_config.activation
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor

from flashinfer.fused_moe import (
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe

from sglang.srt.layers.moe.topk import TopKOutputChecker
from sglang.srt.layers.moe.utils import RoutingMethodType
Expand All @@ -1312,95 +1229,53 @@ def apply_with_router_logits(
assert (
activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"

if self.block_quant:
a_q, a_sf = per_token_group_quant_fp8(
x, self.quant_config.weight_block_size[1]
)
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
else:
a_q, _ = scaled_fp8_quant(x, layer.w13_input_scale)
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(x.dtype)
)

routing_method_type = getattr(
layer, "routing_method_type", RoutingMethodType.DeepSeekV3
)
routing_method_type = getattr(layer, "routing_method_type")

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):

if self.block_quant:
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
return trtllm_fp8_block_scale_moe(
routing_logits=(
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale_inv,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor
if routed_scaling_factor is not None
else 1.0
),
tile_tokens_dim=None,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
else:
routing_bias_cast = (
None
if correction_bias is None
else correction_bias.to(torch.bfloat16)
)

return trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=routing_bias_cast,
hidden_states=a_q,
gemm1_weights=layer.w13_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
gemm2_weights=layer.w2_weight,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor
if routed_scaling_factor is not None
else 1.0
),
use_routing_scales_on_input=False,
routing_method_type=routing_method_type,
)
# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
return trtllm_fp8_block_scale_moe(
routing_logits=(
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale_inv,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
tile_tokens_dim=None,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)

def maybe_apply_hip_fused_experts(
self,
Expand Down
6 changes: 0 additions & 6 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,9 +1014,3 @@ def validate_fp8_block_shape(
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)


def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
30 changes: 11 additions & 19 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,27 +1565,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13_weight_scale_2 = layer.w13_weight_scale_2[:]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

def _slice_scale(w):
assert w.shape == (layer.num_experts,)
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
return w[
layer.moe_ep_rank
* layer.num_local_experts : (layer.moe_ep_rank + 1)
* layer.num_local_experts
]

# Calculate input scales based on strategy
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
w13_input_scale = (
layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
)
w2_input_scale = (
layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
)

if layer.moe_ep_size > 1:
w13_input_scale = _slice_scale(w13_input_scale)
w2_input_scale = _slice_scale(w2_input_scale)
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
elif self.enable_flashinfer_cutedsl_moe:
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
# Thus we allow users to switch the flag to do thorough testing
Expand All @@ -1602,6 +1585,15 @@ def _slice_scale(w):

w2_input_scale = layer.w2_input_scale

def _slice_scale(w):
assert w.shape == (layer.num_experts,)
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
return w[
layer.moe_ep_rank
* layer.num_local_experts : (layer.moe_ep_rank + 1)
* layer.num_local_experts
]

w13_input_scale = _slice_scale(w13_input_scale)
w2_input_scale = _slice_scale(w2_input_scale)

Expand Down
7 changes: 2 additions & 5 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,7 @@ def __init__(
layer_id=self.layer_id,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
routing_method_type=getattr(
config, "routing_method_type", RoutingMethodType.DeepSeekV3
),
routing_method_type=RoutingMethodType.DeepSeekV3,
prefix=add_prefix("experts", prefix),
)

Expand Down Expand Up @@ -3349,7 +3347,6 @@ def forward(
class DeepseekV2ForCausalLM(nn.Module):
# for quark model load
packed_modules_mapping = {}
model_cls = DeepseekV2Model

def __init__(
self,
Expand All @@ -3376,7 +3373,7 @@ def __init__(
self.quant_config = quant_config
self.determine_num_fused_shared_experts()
self.use_nsa = is_deepseek_nsa(config)
self.model = self.model_cls(
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if self.pp_group.is_last_rank:
Expand Down
Loading