Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
90098a5
Add Mistral Large 3 support.
dcampora Nov 17, 2025
0cef9be
Merge branch 'main' into ml3_support
zhyncs Dec 1, 2025
b92e641
Apply suggestions from code review
dcampora Dec 1, 2025
4bfe175
Commit suggestions from gemini.
dcampora Dec 1, 2025
94f923f
Formatting.
dcampora Dec 1, 2025
4fe81ec
Add back routed scaling factor for nvfp4.
dcampora Dec 1, 2025
6942b91
Merge branch 'main' into ml3_support
JustinTong0323 Dec 1, 2025
efc8ffa
Add back codepath for deepseek_v32 as it was.
dcampora Dec 2, 2025
896aa66
Use str(model).lower().
dcampora Dec 2, 2025
f1fa66f
Adapt router manager formatting.
dcampora Dec 2, 2025
1a3b88d
Adding fp8 loader for attn.
dcampora Dec 2, 2025
0353f52
Use preexisting dispatch fn.
dcampora Dec 2, 2025
69186b3
fp8 block moe
elvischenv Dec 2, 2025
7aac216
Fixing issues.
dcampora Dec 2, 2025
ee68d59
Fix.
dcampora Dec 2, 2025
bc4ed51
Possible fix.
dcampora Dec 2, 2025
74cda59
Test.
dcampora Dec 2, 2025
7f0c021
Revert "Test."
dcampora Dec 2, 2025
15d6201
fix quant config
elvischenv Dec 2, 2025
f4ad69c
Add fix.
dcampora Dec 2, 2025
0a48bd0
Fix.
dcampora Dec 2, 2025
90f589a
fix to run
elvischenv Dec 2, 2025
c7be4b3
tuple simplification
Linda-Stadter Dec 2, 2025
5100fa9
Merge pull request #1 from dcampora/dcampora/block_fp8_loader
dcampora Dec 2, 2025
f0f5064
Merge branch 'main' into ml3_support
ishandhanani Dec 2, 2025
50fdab2
Merge branch 'main' into ml3_support
JustinTong0323 Dec 2, 2025
1ef9948
fix lint
elvischenv Dec 3, 2025
b9d5f64
fix dsv3 accuracy
elvischenv Dec 3, 2025
dc500a0
fix ckpt loading
elvischenv Dec 3, 2025
77319c4
fix no module named deep_gemm error
Linda-Stadter Dec 3, 2025
623b33a
fix triton moe
elvischenv Dec 3, 2025
73abf6b
clean up (#2)
elvischenv Dec 3, 2025
d7a7b6c
add init llama4 scale
elvischenv Dec 1, 2025
1b1c84b
clean up eagle
elvischenv Dec 3, 2025
fcc2c6d
Merge pull request #3 from dcampora/elvis/clean-up-eagle
dcampora Dec 3, 2025
5d077e1
Merge pull request #4 from dcampora/elvis/fix-llama4-scale
dcampora Dec 3, 2025
5d52732
fix Llama4 scaling None
elvischenv Dec 4, 2025
13ccef9
Merge remote-tracking branch 'upstream/main' into ml3_support
elvischenv Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion benchmark/kernels/fused_moe_triton/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,17 @@ def get_model_config(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"MistralLarge3ForCausalLM",
]:
E = (config.n_routed_experts // ep_size) + (
0
if disable_shared_experts_fusion
or architecture not in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]
or architecture
not in [
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"MistralLarge3ForCausalLM",
]
else 1
)
topk = config.num_experts_per_tok + (
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLMNextN",
"MistralLarge3ForCausalLM",
"PixtralForConditionalGeneration",
]
and getattr(config, "index_topk", None) is not None
)
Expand Down Expand Up @@ -334,6 +336,8 @@ def _derive_model_shapes(self):
or "LongcatFlashForCausalLM" in self.hf_config.architectures
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
or "DotsVLMForCausalLM" in self.hf_config.architectures
or "MistralLarge3ForCausalLM" in self.hf_config.architectures
or "PixtralForConditionalGeneration" in self.hf_config.architectures
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
Expand Down Expand Up @@ -939,6 +943,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"NemotronH_Nano_VL_V2",
"PixtralForConditionalGeneration",
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ def forward_decode(
k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
llama_4_scaling: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we add this logic here, we will have to do the same thing for a ton of attn backends and sync them. thus one way maybe just put logic in the models folder (deepseek_v2.py or mistral py etc), since looks like it is just scaling the q tensor before entering the core attention logic

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Mistral's implementation in vLLM, the scaling is applied between RoPE and attention, so it implements inside the mla layer: https://github.com/vllm-project/vllm/pull/29757/files#diff-6ffcb4f51daf85df32c7d35433c3393f1602663960f677ae61f55af1ed3ab524

In SGLang, for trtllm mla backend, the RoPE is fused into the attention backend. This is quite different from vLLM. The scaling is needed to be passed into the backend to apply correctly.

) -> torch.Tensor:
"""Run forward for decode using TRTLLM MLA kernel."""
merge_query = q_rope is not None
Expand Down Expand Up @@ -843,6 +844,11 @@ def forward_decode(
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)

# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
query = query.to(self.q_data_type) * llama_4_scaling
query = query.to(self.data_type)

# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
if query.dim() == 3:
query = query.unsqueeze(1)
Expand Down Expand Up @@ -903,6 +909,7 @@ def forward_extend(
k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
llama_4_scaling: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if (
Expand Down Expand Up @@ -955,6 +962,10 @@ def forward_extend(

q = q.view(-1, layer.tp_q_head_num, layer.head_dim)

# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling

if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
self.packed_modules_mapping = packed_modules_mapping or {}
# FP8 config for linear layers, compressed tensor currently does not support block fp8, this is used for ktransformers
self.linear_fp8_config = linear_fp8_config

def get_linear_method(self) -> CompressedTensorsLinearMethod:
Expand Down Expand Up @@ -142,6 +141,15 @@ def get_quant_method(
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
return None

@property
def weight_block_size(self) -> Optional[List[int]]:
"""Get the weight block size from the quantization config."""
if "Linear" in self.target_scheme_map:
weights_config = self.target_scheme_map["Linear"].get("weights")
if weights_config and hasattr(weights_config, "block_structure"):
return weights_config.block_structure
return None

@classmethod
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
ignore: List[str] = cast(List[str], config.get("ignore", []))
Expand Down Expand Up @@ -306,7 +314,9 @@ def _is_dynamic_token_w8a8(
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic

def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
def _is_fp8_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False
Expand All @@ -318,15 +328,16 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = weight_quant.strategy in [
is_tensor_or_channel_or_block_weight = weight_quant.strategy in [
QuantizationStrategy.TENSOR,
QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK,
]
if not (
is_floating_point
and is_symmetric_weight
and is_static_weight
and is_per_tensor_or_channel_weight
and is_tensor_or_channel_or_block_weight
):
return False

Expand Down Expand Up @@ -406,7 +417,7 @@ def _get_scheme_from_parts(
)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
weight_quant=weight_quant,
is_static_input_scheme=(
input_quant and not input_quant.dynamic
),
Expand Down Expand Up @@ -608,6 +619,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):

def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
self.quant_config = quantization_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
Expand Down Expand Up @@ -81,8 +82,8 @@ def get_moe_method(

weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):

if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
Expand All @@ -102,7 +103,28 @@ def __init__(self, quant_config: CompressedTensorsConfig):
"input_activations"
)

per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN
)
if not (per_tensor or per_channel):
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
self.weight_block_size = self.weight_quant.block_structure
assert self.weight_quant.dynamic is not None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.weight_quant.dynamic is false, so it can be None?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mostly from vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py#L667-L681
Given compressed tensors are also published by them, would you confirm with them?

Copy link
Copy Markdown

@Wangzheee Wangzheee Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
self.weight_block_size = None
self.block_quant = self.weight_block_size is not None

self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)

def create_weights(
self,
Expand All @@ -117,6 +139,32 @@ def create_weights(

params_dtype = torch.float8_e4m3fn

if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.weight_block_size[0],
self.weight_block_size[1],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -169,6 +217,26 @@ def create_weights(
requires_grad=False,
)
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value
else:
raise ValueError(
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
Expand Down Expand Up @@ -343,6 +411,18 @@ def apply(
a2_scale=layer.w2_input_scale,
)
return StandardCombineInput(hidden_states=output)
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size,
)
return self.runner.run(dispatch_output, quant_info)
else:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
Expand Down
Loading
Loading