Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,9 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
quant_config is None
or quant_config.get_name() == "fp8"
or quant_config.get_name() == "modelopt_fp8"
or quant_config.get_name() == "compressed_tensors"
):
# FlashInferFusedMoE support bf16 and fp8
# FlashInferFusedMoE support bf16, fp8 and compressed_tensors
return FlashInferFusedMoE

if get_moe_runner_backend().is_flashinfer_cutlass():
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,6 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):

else:

# FP8 Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
dispatch_output=StandardDispatchOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_tensor_model_parallel_world_size, get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS,
Expand All @@ -29,10 +34,18 @@
from sglang.srt.layers.quantization.utils import (
all_close_1d,
per_tensor_dequantize,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
replace_parameter,
swizzle_blockscale,
)
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_hip,
next_power_of_2,
set_weight_attrs,
)

if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
Expand Down Expand Up @@ -115,6 +128,7 @@ def __init__(self, quant_config: CompressedTensorsConfig):
)
self.quant_config = quant_config
self.group_size = 16
self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm()

def create_weights(
self,
Expand All @@ -127,7 +141,6 @@ def create_weights(
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

layer.num_experts = num_experts
layer.params_dtype = params_dtype

w13_weight = torch.nn.Parameter(
Expand Down Expand Up @@ -240,6 +253,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)
delattr(layer, "w2_weight_packed")

if self.use_flashinfer_trtllm:
w, s = reorder_w1w3_to_w3w1(
layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2
)
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)

if not torch.allclose(
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
):
Expand All @@ -258,9 +278,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)

# w13
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
torch.float32
)
if self.use_flashinfer_trtllm:
w13_input_global_scale = (
layer.w13_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_local_experts)
)
else:
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
torch.float32
)
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False,
Expand All @@ -271,7 +298,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)

# w2
w2_input_global_scale = layer.w2_input_global_scale
if self.use_flashinfer_trtllm:
w2_input_global_scale = (
layer.w2_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_local_experts)
)
else:
w2_input_global_scale = layer.w2_input_global_scale

layer.g2_alphas = torch.nn.Parameter(
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
Expand All @@ -282,22 +316,66 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
(w2_input_global_scale), requires_grad=False
)

# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
# TensorRT-LLM specific processing
if self.use_flashinfer_trtllm:
# Prepare static weights for TRT-LLM kernel
(
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
logger.debug("Finished shuffling weights for TRT-LLM MOE")

layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)
layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)

layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
layer.w13_weight.device,
num_experts=layer.num_experts,
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2,
hidden_size=layer.w13_weight.shape[2] * 2,
)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = torch.nn.Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)

# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else:
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)

layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)

layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
layer.w13_weight.device,
num_experts=layer.num_experts,
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2,
hidden_size=layer.w13_weight.shape[2] * 2,
)

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
Expand Down Expand Up @@ -336,6 +414,100 @@ def apply(

return StandardCombineInput(hidden_states=output)

def apply_with_router_logits(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
assert self.use_flashinfer_trtllm

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe

from sglang.srt.layers.moe.utils import RoutingMethodType

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config

# Quantize input hidden states using fp4_quantize
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
x,
layer.w13_input_scale_quant,
self.group_size, # sf_vec_size
False, # use_ue8m0
False, # is_sf_swizzled_layout
)
hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2)
hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(x.dtype)
)

assert layer.routing_method_type is not None

# DeepSeekV3 style routing requires float32 router logits
if layer.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)

routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
routed_scaling_factor = (
routed_scaling_factor if routed_scaling_factor is not None else 1.0
)

with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
num_tokens = hs_fp4.shape[0]
hidden_size = (
hs_fp4.shape[-1] * 2
if hs_fp4.dtype == torch.uint8
else hs_fp4.shape[-1]
)
symm_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device
)

return trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=correction_bias,
hidden_states=hs_fp4,
hidden_states_scale=hs_scale,
gemm1_weights=layer.gemm1_weights_fp4_shuffled,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c,
output1_scale_gate_scalar=layer.g1_alphas,
output2_scale_scalar=layer.g2_alphas,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=routed_scaling_factor,
tile_tokens_dim=None,
routing_method_type=layer.routing_method_type,
do_finalize=True,
tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]),
output=symm_output,
)[0]


class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

Expand Down
Loading
Loading