Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
FlashInferTrtllmFp8MoeQuantInfo,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsMoEScheme,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import all_close_1d, per_tensor_dequantize
from sglang.srt.layers.quantization.utils import (
all_close_1d,
per_tensor_dequantize,
swap_w13_to_w31,
)
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,6 +51,7 @@ class CompressedTensorsW8A8Fp8MoE(CompressedTensorsMoEScheme):
def __init__(self, weight_quant, input_quant):
self.weight_quant = weight_quant
self.input_quant = input_quant
self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm()

per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
Expand Down Expand Up @@ -305,11 +314,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No
)
torch.cuda.empty_cache()

if (
self.weight_quant.strategy == QuantizationStrategy.BLOCK
and self.use_flashinfer_trtllm
):
layer.w13_weight = torch.nn.Parameter(
swap_w13_to_w31(layer.w13_weight.data),
requires_grad=False,
)
layer.w13_weight_scale = torch.nn.Parameter(
swap_w13_to_w31(layer.w13_weight_scale.data),
requires_grad=False,
)

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
moe_runner_backend = MoeRunnerBackend.TRITON
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)

def apply_weights(
self,
Expand Down Expand Up @@ -358,16 +383,31 @@ def apply_weights(
)
return StandardCombineInput(hidden_states=output)
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size,
)
if self.use_flashinfer_trtllm:
quant_info = FlashInferTrtllmFp8MoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
global_num_experts=layer.num_experts,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
intermediate_size=layer.w2_weight.shape[2],
routing_method_type=layer.routing_method_type,
block_quant=self.block_quant,
weight_block_k=self.weight_block_size[1],
w13_weight_scale_inv=layer.w13_weight_scale,
w2_weight_scale_inv=layer.w2_weight_scale,
)
else:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size,
)
return self.runner.run(dispatch_output, quant_info)
else:
quant_info = TritonMoeQuantInfo(
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/layers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,12 @@ def swizzle_blockscale(scale: torch.Tensor):
)


def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return (
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)
)


def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down
7 changes: 2 additions & 5 deletions test/registered/8-gpu-models/test_mistral_large3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_mistral_large3_all_variants(self):
base_args = [
"--tp=8",
"--attention-backend=trtllm_mla",
"--moe-runner-backend=flashinfer_trtllm",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
"--chat-template=mistral",
Expand All @@ -58,10 +59,6 @@ def test_mistral_large3_all_variants(self):
"--speculative-num-draft-tokens=4",
"--kv-cache-dtype=auto",
]
# TODO: add this to base args when FP8 TRTLLM moe is supported
nvfp4_args = [
"--moe-runner-backend=flashinfer_trtllm",
]

variants = [
# Variant: "basic" - FP8 model + TP=8 + trtllm_mla backend
Expand All @@ -83,7 +80,7 @@ def test_mistral_large3_all_variants(self):
ModelLaunchSettings(
MISTRAL_LARGE3_NVFP4_MODEL_PATH,
tp_size=8,
extra_args=base_args + nvfp4_args,
extra_args=base_args,
variant="NVFP4",
),
]
Expand Down
Loading