From d6322e051ccd81c108170bb46a74d9631307e563 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 3 Dec 2025 23:44:45 -0800 Subject: [PATCH 01/16] Support eagle --- python/sglang/srt/configs/model_config.py | 12 +- .../layers/attention/trtllm_mla_backend.py | 3 +- python/sglang/srt/layers/quantization/fp8.py | 197 ++++++++++++++---- python/sglang/srt/models/deepseek_v2.py | 4 +- .../srt/models/mistral_large_3_eagle.py | 104 +++++++++ python/sglang/srt/server_args.py | 10 +- python/sglang/srt/utils/mistral_utils.py | 9 +- 7 files changed, 295 insertions(+), 44 deletions(-) create mode 100644 python/sglang/srt/models/mistral_large_3_eagle.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 87fc66b63c6f..6483487fd863 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -338,6 +338,7 @@ def _derive_model_shapes(self): or "DotsVLMForCausalLM" in self.hf_config.architectures or "MistralLarge3ForCausalLM" in self.hf_config.architectures or "PixtralForConditionalGeneration" in self.hf_config.architectures + or "MistralLarge3ForCausalLMEagle" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA @@ -702,7 +703,16 @@ def _verify_quantization(self) -> None: if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: - if ( + # Allow auto-detection of quantization from checkpoint for draft model + # even if it differs from main model's quantization + if self.is_draft_model: + logger.info( + f"Draft model quantization ({quant_method}) differs from " + f"main model quantization ({self.quantization}). " + f"Using draft model's detected quantization: {quant_method}" + ) + self.quantization = quant_method + elif ( self.quantization not in compatible_quantization_methods or quant_method not in compatible_quantization_methods[self.quantization] diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7afa0a09ed7c..b409cfdbf4ae 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -964,7 +964,8 @@ def forward_extend( # Apply llama 4 scaling if provided if llama_4_scaling is not None: - q *= llama_4_scaling + q = q.to(self.q_data_type) * llama_4_scaling + q = q.to(self.data_type) if ( forward_batch.forward_mode.is_target_verify() diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index d1a42236f689..dbebcab64df0 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -922,6 +922,86 @@ def process_weights_after_loading(self, layer: Module) -> None: if _is_hip: self.process_weights_hip_scale_padding(layer) + + # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled + if get_moe_runner_backend().is_flashinfer_trtllm(): + from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + + # Note: No need to swap W13 halves, they are already in the correct order: [Gate, Up] + num_experts, two_n, hidden = layer.w13_weight.shape + + # 2) Reorder rows for fused gated activation (W13) + w13_interleaved = [ + reorder_rows_for_gated_act_gemm(layer.w13_weight[i]) + for i in range(num_experts) + ] + w13_interleaved = torch.stack(w13_interleaved).reshape( + num_experts, two_n, hidden + ) + + # 3) Shuffle weights for transposed MMA output (both W13, W2) + epilogue_tile_m = 128 + w13_shuffled = [ + shuffle_matrix_a( + w13_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + for i in range(num_experts) + ] + w2_shuffled = [ + shuffle_matrix_a( + layer.w2_weight[i].view(torch.uint8), epilogue_tile_m + ) + for i in range(num_experts) + ] + + layer.w13_weight = Parameter( + torch.stack(w13_shuffled).view(torch.float8_e4m3fn), + requires_grad=False, + ) + layer.w2_weight = Parameter( + torch.stack(w2_shuffled).view(torch.float8_e4m3fn), + requires_grad=False, + ) + + # Precompute and register per-expert output scaling factors for FI MoE + # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction + assert ( + hasattr(layer, "w13_input_scale") + and layer.w13_input_scale is not None + ) + assert ( + hasattr(layer, "w2_input_scale") + and layer.w2_input_scale is not None + ) + assert ( + hasattr(layer, "w13_weight_scale") + and layer.w13_weight_scale is not None + ) + assert ( + hasattr(layer, "w2_weight_scale") + and layer.w2_weight_scale is not None + ) + + input_scale = layer.w13_input_scale.to(torch.float32) + activation_scale = layer.w2_input_scale.to(torch.float32) + w13_weight_scale = layer.w13_weight_scale.to(torch.float32) + w2_weight_scale = layer.w2_weight_scale.to(torch.float32) + + output1_scales_scalar = ( + w13_weight_scale * input_scale * (1.0 / activation_scale) + ) + output1_scales_gate_scalar = w13_weight_scale * input_scale + output2_scales_scalar = activation_scale * w2_weight_scale + + layer.output1_scales_scalar = Parameter( + output1_scales_scalar, requires_grad=False + ) + layer.output1_scales_gate_scalar = Parameter( + output1_scales_gate_scalar, requires_grad=False + ) + layer.output2_scales_scalar = Parameter( + output2_scales_scalar, requires_grad=False + ) return def process_weights_hip_int4(self, layer: Module): @@ -1225,7 +1305,10 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer.fused_moe import trtllm_fp8_block_scale_moe + from flashinfer.fused_moe import ( + trtllm_fp8_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe, + ) from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.utils import RoutingMethodType @@ -1236,9 +1319,15 @@ def apply_with_router_logits( assert ( activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" - a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() + + if self.block_quant: + a_q, a_sf = per_token_group_quant_fp8( + x, self.quant_config.weight_block_size[1] + ) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + else: + a_q, _ = scaled_fp8_quant(x, layer.w13_input_scale) correction_bias = ( None @@ -1246,43 +1335,79 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - routing_method_type = getattr(layer, "routing_method_type") + routing_method_type = getattr( + layer, "routing_method_type", RoutingMethodType.DeepSeekV3 + ) with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): - # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. - # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 - # so we put the whole function under the ``use_symmetric_memory`` context manager. - # If the bug is fixed, we can only put the output tensor allocation under the context manager. - return trtllm_fp8_block_scale_moe( - routing_logits=( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ), - tile_tokens_dim=None, - routing_method_type=routing_method_type, - use_shuffled_weight=False, - ) + if self.block_quant: + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + return trtllm_fp8_block_scale_moe( + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor + if routed_scaling_factor is not None + else 1.0 + ), + tile_tokens_dim=None, + routing_method_type=routing_method_type, + use_shuffled_weight=False, + ) + else: + routing_bias_cast = ( + None + if correction_bias is None + else correction_bias.to(torch.bfloat16) + ) + + return trtllm_fp8_per_tensor_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=routing_bias_cast, + hidden_states=a_q, + gemm1_weights=layer.w13_weight, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + gemm2_weights=layer.w2_weight, + output2_scales_scalar=layer.output2_scales_scalar, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor + if routed_scaling_factor is not None + else 1.0 + ), + use_routing_scales_on_input=False, + routing_method_type=routing_method_type, + ) def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 30d7224dc3e1..d928e237ef94 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -658,7 +658,9 @@ def __init__( layer_id=self.layer_id, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, - routing_method_type=RoutingMethodType.DeepSeekV3, + routing_method_type=getattr( + config, "routing_method_type", RoutingMethodType.DeepSeekV3 + ), prefix=add_prefix("experts", prefix), ) diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py new file mode 100644 index 000000000000..e85a629e76f0 --- /dev/null +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from python.sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import RowParallelLinear +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV2Model +from sglang.srt.models.mistral_large_3 import MistralLarge3ForCausalLM +from sglang.srt.utils import add_prefix + + +class MistralLarge3Model(DeepseekV2Model): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + + self.config = config + self.vocab_size = config.vocab_size + assert get_pp_group().world_size == 1 + self.pp_group = get_pp_group() + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + config=config, + prefix=add_prefix(prefix, f"layers.{i}"), + quant_config=quant_config, + layer_id=i, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.start_layer = 0 + self.end_layer = self.config.num_hidden_layers + + self.fc = RowParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix(prefix, "fc"), + input_is_parallel=False, + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers_to_capture = [] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + input_embeds, _ = self.fc( + torch.cat((input_embeds, forward_batch.spec_info.hidden_states), dim=-1) + ) + output = super().forward( + input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors + ) + assert isinstance(output, torch.Tensor) + return output + + +class MistralLarge3ForCausalLMEagle(MistralLarge3ForCausalLM): + remapping = MistralLarge3ForCausalLM.remapping | { + r"eagle_linear\.weight": r"model.fc.weight", + r"eagle_linear\.qscale_act": r"model.fc.input_scale", + r"eagle_linear\.qscale_weight": r"model.fc.weight_scale", + } + + def __init__( + self, + *, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + config.quant_config = quant_config + self.model_cls = MistralLarge3Model + super().__init__(config=config, quant_config=quant_config, prefix=prefix) + + +EntryClass = [MistralLarge3ForCausalLMEagle] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a07b71c5a02e..82032f8edbb2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1711,9 +1711,13 @@ def _handle_speculative_decoding(self): self.speculative_draft_model_path = self.model_path self.speculative_draft_model_revision = self.revision else: - logger.warning( - "DeepSeek MTP does not require setting speculative_draft_model_path." - ) + if model_arch not in [ + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", + ]: + logger.warning( + "DeepSeek MTP does not require setting speculative_draft_model_path." + ) if self.speculative_num_steps is None: assert ( diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index e23abb53c5fd..ecce3042df02 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -22,11 +22,15 @@ def adapt_config_dict( is_mistral_large_3 = ( is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 ) + is_eagle = "eagle" in model.lower() if is_moe: if is_mistral_large_3: config_dict = _remap_moe_args(config_dict) config_dict["model_type"] = "deepseek_v3" - config_dict["architectures"] = ["MistralLarge3ForCausalLM"] + if is_eagle: + config_dict["architectures"] = ["MistralLarge3ForCausalLMEagle"] + else: + config_dict["architectures"] = ["MistralLarge3ForCausalLM"] assert ( "llama_4_scaling" in config_dict @@ -77,6 +81,8 @@ def adapt_config_dict( config_dict = _remap_mistral_vision_args(config_dict) if is_audio: config_dict = _remap_mistral_audio_args(config_dict) + if is_eagle: + config_dict["routing_method_type"] = 1 # RoutingMethodType.Renormalize config = PretrainedConfig.from_dict(config_dict) @@ -227,7 +233,6 @@ def _remap_moe_args(config: dict) -> dict: config[new_name] = value config["topk_method"] = None - config["routing_method_type"] = 1 # RoutingMethodType.Renormalize config["scoring_func"] = "softmax" return config From 0c9e202396d6374a30e1f63ad83df97effeaa5aa Mon Sep 17 00:00:00 2001 From: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Date: Thu, 4 Dec 2025 08:51:50 -0800 Subject: [PATCH 02/16] Fixes for running eagle Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> --- .../schemes/compressed_tensors_w8a8_fp8.py | 16 ++++++---------- python/sglang/srt/models/deepseek_v2.py | 16 +++++++++++----- python/sglang/srt/models/mistral_large_3.py | 3 --- .../sglang/srt/models/mistral_large_3_eagle.py | 1 + 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 2effad0360f9..423e74132610 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -106,14 +106,12 @@ def create_weights( weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) elif self.strategy == QuantizationStrategy.TENSOR: weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) elif self.strategy == QuantizationStrategy.BLOCK: assert layer.weight_block_size is not None block_n, block_k = layer.weight_block_size[0], layer.weight_block_size[1] @@ -130,8 +128,8 @@ def create_weights( ) weight_scale.format_ue8m0 = False weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale_inv", weight_scale) + layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: input_scale = PerTensorScaleParameter( @@ -190,16 +188,14 @@ def process_weights_after_loading(self, layer) -> None: elif self.strategy == QuantizationStrategy.BLOCK: assert self.is_static_input_scheme is False weight = layer.weight - weight_scale_inv = layer.weight_scale_inv + weight_scale = layer.weight_scale if is_fp8_fnuz(): - weight, weight_scale_inv, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale_inv + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale ) layer.weight = Parameter(weight.data, requires_grad=False) - layer.weight_scale_inv = Parameter( - weight_scale_inv.data, requires_grad=False - ) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) else: raise ValueError(f"Unknown quantization strategy {self.strategy}") @@ -221,7 +217,7 @@ def apply_weights( input=x, weight=layer.weight, block_size=self.weight_block_size, - weight_scale=layer.weight_scale_inv, + weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d928e237ef94..c74ba98cedde 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -3182,6 +3182,7 @@ def forward( class DeepseekV2ForCausalLM(nn.Module): # for quark model load packed_modules_mapping = {} + model_cls = DeepseekV2Model def __init__( self, @@ -3190,7 +3191,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - # for quark model load # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None self.fuse_qkv_a_proj = ( @@ -3208,7 +3208,7 @@ def __init__( self.quant_config = quant_config self.determine_num_fused_shared_experts() self.use_nsa = is_deepseek_nsa(config) - self.model = DeepseekV2Model( + self.model = self.model_cls( config, quant_config, prefix=add_prefix("model", prefix) ) if self.pp_group.is_last_rank: @@ -3400,16 +3400,22 @@ def post_load_weights(self, is_nextn=False, weight_names=None): selected_quant_config, "weight_block_size", None ) if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") or hasattr( + self_attn.kv_b_proj, "weight_scale" + ) + weight_scale = ( + self_attn.kv_b_proj.weight_scale + if hasattr(self_attn.kv_b_proj, "weight_scale") + else self_attn.kv_b_proj.weight_scale_inv + ) if _is_fp8_fnuz: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, + weight_scale=weight_scale, input_scale=None, ) else: weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv # In multiple weight loading scenarios (e.g. RL), we need to inverse the scale of the weights after the requantization happened at the first loading. if ( diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py index 719d47915407..fd60ef61f7c0 100644 --- a/python/sglang/srt/models/mistral_large_3.py +++ b/python/sglang/srt/models/mistral_large_3.py @@ -72,9 +72,6 @@ def _iterable_remap_mistral_to_ds( elif name.endswith(".qscale_weight"): name = re.sub(r"\.qscale_weight$", ".weight_scale", name) - if name.endswith(".weight_scale") and ".experts." not in name: - name = re.sub(r"\.weight_scale$", ".weight_scale_inv", name) - yield name, loaded_weight diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py index e85a629e76f0..f136640fde78 100644 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -61,6 +61,7 @@ def __init__( ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layers_to_capture = [] + self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None) def forward( self, From 1402ae439897d496988c6e14b279b87b8f29d217 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Sat, 29 Nov 2025 14:30:50 -0800 Subject: [PATCH 03/16] Added w4a16 loading support. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors.py | 33 ++ .../compressed_tensors/schemes/__init__.py | 2 + .../schemes/compressed_tensors_w4a16_nvfp4.py | 125 ++++++ .../layers/quantization/marlin_utils_fp4.py | 404 ++++++++++++++++++ 4 files changed, 564 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py create mode 100644 python/sglang/srt/layers/quantization/marlin_utils_fp4.py diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 31f47b88bc2f..01893f7406d2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -34,6 +34,7 @@ CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, + CompressedTensorsW4A16Fp4, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -388,11 +389,43 @@ def _is_wNa16_group_channel( is_static = not weight_quant.dynamic return is_channel_group and input_quant_none and is_symmetric and is_static + + def _is_fp4a16_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): + is_weight_only = weight_quant is not None and input_quant is None + is_tensor_group_quant = ( + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) + is_symmetric = weight_quant.symmetric + + is_group_size_16 = weight_quant.group_size == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT + is_4_bits = weight_quant.num_bits == 4 + + print("is_weight_only", is_weight_only) + print("is_tensor_group_quant", is_tensor_group_quant) + print("is_float_type", is_float_type) + print("is_4_bits", is_4_bits) + print("is_group_size_16", is_group_size_16) + print("is_symmetric", is_symmetric) + + return ( + is_weight_only + and is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel ) -> CompressedTensorsScheme: + if self._is_fp4a16_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A16Fp4() + # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): if ( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index 6d9871917bbb..70d6f5889c2b 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -5,6 +5,7 @@ from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 +from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 __all__ = [ "CompressedTensorsScheme", @@ -13,4 +14,5 @@ "CompressedTensorsW8A8Int8", "CompressedTensorsWNA16", "WNA16_SUPPORTED_BITS", + "CompressedTensorsW4A16Fp4", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py new file mode 100644 index 000000000000..54e16ac1fbe8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) + +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) +from sglang.srt.layers.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) + +__all__ = ["CompressedTensorsW4A16Fp4"] + + +class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): + def __init__(self, has_input_global_scale: bool = False): + self.has_input_global_scale = has_input_global_scale + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + if self.has_input_global_scale: + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_global_scale", input_global_scale) + + def process_weights_after_loading(self, layer) -> None: + # Process parameters for marlin repacking + + # Rename weight_packed to weight that marlin expects + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + del layer.weight_packed + # Rename weight_global_scale to weight_scale_2 that marlin expects + # Note: ct stores the inverse of what is expected by the marlin kernel + layer.weight_scale_2 = Parameter( + 1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) + del layer.weight_global_scale + + if self.has_input_global_scale: + layer.input_global_scale = torch.nn.Parameter( + layer.input_global_scale.data, requires_grad=False + ) + + prepare_fp4_layer_for_marlin(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..44fe090ce93b --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,404 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import logging +from typing import Optional + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) +from sglang.srt.layers.quantization.utils import get_scalar_types +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack + +ScalarType, scalar_types = get_scalar_types() + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + +logger = logging.getLogger(__name__) + + +def is_fp4_marlin_supported(): + return True + + +def nvfp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization." + ) + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1 + ) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def mxfp4_marlin_process_scales(marlin_scales): + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1 + ) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) + return global_scale * (2.0 ** (exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor | None, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: torch.Tensor | None = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype + ) + + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) + + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + is_nvfp4 = hasattr(layer, "w13_weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 + ) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale") + if not is_nvfp4: + scales = scales.view(torch.float8_e8m0fnu) + scales = scales.to(param_dtype) + if is_nvfp4: + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + scale = scales[i].T + + marlin_scales = marlin_permute_scales( + s=scale, size_k=size_k, size_n=size_n, group_size=group_size + ) + if is_nvfp4: + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + else: + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + if is_nvfp4: + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(param_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) + + +def rand_marlin_weight_nvfp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device + ) + fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 + ).view(size_n, size_k) + weight_ref = ( + weight_ref + * global_scale.to(weight.dtype) + * scales.repeat_interleave(group_size, 1).to(weight.dtype) + ) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales( + s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + ) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + + global_scale = nvfp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale + + +def rand_marlin_weight_mxfp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = torch.randint( + 100, + 125, + (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device, + ) + scales = scales.view(torch.float8_e8m0fnu) + + fp4_weight = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device + ) + fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 + ).view(size_n, size_k) + weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales( + s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + ) + + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) From b8b4cc614479737431c42f773d9524dbe18942be Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Thu, 4 Dec 2025 20:29:03 -0800 Subject: [PATCH 04/16] Adding w4a4 support for compressed tensors. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors.py | 47 ++++- .../compressed_tensors/schemes/__init__.py | 2 + .../schemes/compressed_tensors_w4a4_nvfp4.py | 188 ++++++++++++++++++ .../quantization/compressed_tensors/utils.py | 1 + .../srt/layers/quantization/modelopt_quant.py | 36 +--- .../sglang/srt/layers/quantization/utils.py | 25 +++ python/sglang/srt/utils/common.py | 4 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/gemm.py | 45 +++++ 9 files changed, 312 insertions(+), 37 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 01893f7406d2..87b6cc010611 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -35,6 +35,7 @@ CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A4Fp4, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -43,6 +44,7 @@ ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import cutlass_fp4_supported logger = logging.getLogger(__name__) @@ -377,6 +379,33 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool # All conditions satisfied. return True + def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs): + if weight_quant is None or input_quant is None: + return False + + is_tensor_group_quant = ( + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + is_group_size_16 = ( + weight_quant.group_size == 16 and input_quant.group_size == 16 + ) + is_float_type = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT + ) + is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 + + return ( + is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) + def _is_wNa16_group_channel( self, weight_quant: BaseModel, input_quant: BaseModel ) -> bool: @@ -403,13 +432,6 @@ def _is_fp4a16_nvfp4( is_float_type = weight_quant.type == QuantizationType.FLOAT is_4_bits = weight_quant.num_bits == 4 - print("is_weight_only", is_weight_only) - print("is_tensor_group_quant", is_tensor_group_quant) - print("is_float_type", is_float_type) - print("is_4_bits", is_4_bits) - print("is_group_size_16", is_group_size_16) - print("is_symmetric", is_symmetric) - return ( is_weight_only and is_tensor_group_quant @@ -444,6 +466,17 @@ def _get_scheme_from_parts( ) if is_activation_quantization_format(self.quant_format): + if self._is_fp4a4_nvfp4(weight_quant, input_quant): + if cutlass_fp4_supported(): + return CompressedTensorsW4A4Fp4() + else: + logger.warning_once( + "Current platform does not support cutlass NVFP4." + " Running CompressedTensorsW4A16Fp4." + ) + return CompressedTensorsW4A16Fp4(has_input_global_scale=True) + + if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( CompressedTensorsW8A8Fp8.get_min_capability(), error=False diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index 70d6f5889c2b..c147ef45a38c 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -6,6 +6,7 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 __all__ = [ "CompressedTensorsScheme", @@ -15,4 +16,5 @@ "CompressedTensorsWNA16", "WNA16_SUPPORTED_BITS", "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A4Fp4", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 000000000000..efbeed6157c0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) + +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) +from sglang.srt.layers.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from sglang.srt.utils import is_flashinfer_available, cutlass_fp4_supported +import logging +from sglang.srt.layers.quantization.utils import swizzle_blockscale +from sgl_kernel import scaled_fp4_quant, cutlass_scaled_fp4_mm, flashinfer_scaled_fp4_mm + +logger = logging.getLogger(__name__) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + def __init__(self): + self.backend = "none" + if is_flashinfer_available(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" + else: + self.backend = "none" + + if self.backend == "none": + raise ValueError( + "No valid NVFP4 GEMM backend found. " + "Please check your platform capability." + ) + + logger.info_once(f"Using {self.backend} for NVFP4 GEMM") + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_global_scale", input_global_scale) + + def process_weights_after_loading(self, layer) -> None: + global_input_scale = layer.input_global_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) + + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) + + if self.backend == "flashinfer-trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight_packed.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) + + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.weight_packed = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + if self.backend == "fbgemm": + swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) + layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) + layer.weight_packed = Parameter( + layer.weight_packed.data, requires_grad=False + ) + + layer.alpha = Parameter( + 1 / (layer.input_global_scale * layer.weight_global_scale), + requires_grad=False, + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight_packed.shape[0]] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + + mm_args = ( + x_fp4, + layer.weight_packed, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + ) + if self.backend.startswith("flashinfer-"): + backend_name = self.backend[len("flashinfer-") :] + out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) + elif self.backend == "fbgemm": + out = torch.ops.fbgemm.f4f4bf16( + x_fp4, + layer.weight_packed, + x_blockscale.view(-1).view(torch.uint8), + layer.weight_scale, + layer.alpha, + use_mx=False, + ).to(output_dtype) + else: + assert self.backend == "cutlass" + out = cutlass_scaled_fp4_mm(*mm_args) + + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/utils.py b/python/sglang/srt/layers/quantization/compressed_tensors/utils.py index ddefa5ea3de5..0ea3c73d2e6b 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/utils.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/utils.py @@ -14,6 +14,7 @@ def is_activation_quantization_format(format: str) -> bool: CompressionFormat.naive_quantized.value, CompressionFormat.int_quantized.value, CompressionFormat.float_quantized.value, + CompressionFormat.nvfp4_pack_quantized.value, ] return format in _ACTIVATION_QUANTIZATION_FORMATS diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 0da0aba82203..506c6815250d 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -42,6 +42,7 @@ is_layer_skipped, per_tensor_dequantize, requantize_with_max_scale, + swizzle_blockscale, ) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils.common import ( @@ -1333,7 +1334,7 @@ def create_weights( # Only use `swizzle_blockscale` for shapes, not for real content layer.w13_blockscale_swizzled = Parameter( - self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False ) w2_weight_scale = ModelWeightParameter( @@ -1350,7 +1351,7 @@ def create_weights( layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.w2_blockscale_swizzled = Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported @@ -1395,31 +1396,6 @@ def create_weights( w2_input_scale._sglang_require_global_experts = True layer.register_parameter("w2_input_scale", w2_input_scale) - def swizzle_blockscale(self, scale: torch.Tensor): - assert scale.dtype == torch.float8_e4m3fn - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return ( - swizzled_scale.reshape(M_padded, K_padded) - if scale_ndim == 2 - else swizzled_scale.reshape(B, M_padded, K_padded) - ) - def prepare_static_weights_for_kernel( self, # args_dequant, @@ -1694,7 +1670,7 @@ def _slice_scale(w): # CUTLASS processing - handle w13 and w2 separately # Process w13 weights - w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) del layer.w13_weight_scale layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled) @@ -1727,13 +1703,13 @@ def _slice_scale(w): requires_grad=False, ) layer.w2_blockscale_swizzled = Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False ) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # Process w2 weights - w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) del layer.w2_weight_scale layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index fc81f3140660..589331520ee9 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -563,3 +563,28 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): g_idx.to(device=orig_device), sort_indices.to(device=orig_device), ) + +def swizzle_blockscale(scale: torch.Tensor): + assert scale.dtype == torch.float8_e4m3fn + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return ( + swizzled_scale.reshape(M_padded, K_padded) + if scale_ndim == 2 + else swizzled_scale.reshape(B, M_padded, K_padded) + ) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 08a13f604f81..e34cf8a11ee8 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -351,6 +351,10 @@ def is_flashinfer_available(): return importlib.util.find_spec("flashinfer") is not None and is_cuda() +def cutlass_fp4_supported() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0) and torch.version.cuda >= "12.8" + + def is_nvidia_cublas_cu12_version_ge_12_9(): """ temporary fix for issue #11272 diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 8e8994e04c95..75364aec3e9e 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -43,6 +43,7 @@ awq_dequantize, bmm_fp8, cutlass_scaled_fp4_mm, + flashinfer_scaled_fp4_mm, dsv3_fused_a_gemm, dsv3_router_gemm, fp8_blockwise_scaled_mm, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 4e23ebdd9fc2..4f47fabd07eb 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -5,6 +5,22 @@ from sgl_kernel.utils import _get_cache_buf +def flashinfer_mm_fp4( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, +) -> torch.Tensor: + from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_ + + return flashinfer_mm_fp4_( + A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + ) + + def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.ByteTensor: @@ -177,6 +193,35 @@ def cutlass_scaled_fp4_mm( return out +def flashinfer_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 and b.stride(-1) == 1 + assert a.shape[1] == b.shape[1] + + if backend == "cutlass": + block_scale_a = block_scale_a.view(torch.uint8) + block_scale_b = block_scale_b.view(torch.uint8) + + return flashinfer_mm_fp4( + a, + b.t(), + block_scale_a, + block_scale_b.t(), + alpha, + out_dtype, + backend=backend, + ) + + def scaled_fp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: From c54fc5282312b1d25ef880ea8bb83325bdada2c3 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Thu, 4 Dec 2025 20:36:44 -0800 Subject: [PATCH 05/16] Do not change sgl kernel. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../srt/layers/attention/flashinfer_ops.py | 68 +++++++++++++++++++ .../schemes/compressed_tensors_w4a4_nvfp4.py | 3 +- sgl-kernel/python/sgl_kernel/__init__.py | 1 - sgl-kernel/python/sgl_kernel/gemm.py | 45 ------------ 4 files changed, 70 insertions(+), 47 deletions(-) create mode 100644 python/sglang/srt/layers/attention/flashinfer_ops.py diff --git a/python/sglang/srt/layers/attention/flashinfer_ops.py b/python/sglang/srt/layers/attention/flashinfer_ops.py new file mode 100644 index 000000000000..0cd7f1684604 --- /dev/null +++ b/python/sglang/srt/layers/attention/flashinfer_ops.py @@ -0,0 +1,68 @@ +import torch + +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + + @torch.library.custom_op( + "sglang::flashinfer_mm_fp4", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mm_fp4( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_ + + return flashinfer_mm_fp4_( + A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + ) + + @torch.library.register_fake( + "sglang::flashinfer_mm_fp4", + ) + def flashinfer_mm_fp4_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) + + +def flashinfer_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 and b.stride(-1) == 1 + assert a.shape[1] == b.shape[1] + + if backend == "cutlass": + block_scale_a = block_scale_a.view(torch.uint8) + block_scale_b = block_scale_b.view(torch.uint8) + + return flashinfer_mm_fp4( + a, + b.t(), + block_scale_a, + block_scale_b.t(), + alpha, + out_dtype, + backend=backend, + ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index efbeed6157c0..621553506e68 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -21,7 +21,8 @@ from sglang.srt.utils import is_flashinfer_available, cutlass_fp4_supported import logging from sglang.srt.layers.quantization.utils import swizzle_blockscale -from sgl_kernel import scaled_fp4_quant, cutlass_scaled_fp4_mm, flashinfer_scaled_fp4_mm +from sgl_kernel import scaled_fp4_quant, cutlass_scaled_fp4_mm +from sglang.srt.layers.attention.flashinfer_ops import flashinfer_scaled_fp4_mm logger = logging.getLogger(__name__) diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 75364aec3e9e..8e8994e04c95 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -43,7 +43,6 @@ awq_dequantize, bmm_fp8, cutlass_scaled_fp4_mm, - flashinfer_scaled_fp4_mm, dsv3_fused_a_gemm, dsv3_router_gemm, fp8_blockwise_scaled_mm, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 4f47fabd07eb..4e23ebdd9fc2 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -5,22 +5,6 @@ from sgl_kernel.utils import _get_cache_buf -def flashinfer_mm_fp4( - A: torch.Tensor, - B: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - g_scale: torch.Tensor, - dtype: torch.dtype, - backend: str, -) -> torch.Tensor: - from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_ - - return flashinfer_mm_fp4_( - A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend - ) - - def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.ByteTensor: @@ -193,35 +177,6 @@ def cutlass_scaled_fp4_mm( return out -def flashinfer_scaled_fp4_mm( - a: torch.Tensor, - b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, - alpha: torch.Tensor, - out_dtype: torch.dtype, - backend: str, -) -> torch.Tensor: - assert a.ndim == 2 and b.ndim == 2 - assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 - assert a.stride(-1) == 1 and b.stride(-1) == 1 - assert a.shape[1] == b.shape[1] - - if backend == "cutlass": - block_scale_a = block_scale_a.view(torch.uint8) - block_scale_b = block_scale_b.view(torch.uint8) - - return flashinfer_mm_fp4( - a, - b.t(), - block_scale_a, - block_scale_b.t(), - alpha, - out_dtype, - backend=backend, - ) - - def scaled_fp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: From ab9fe7a9153564662e5e6b42a29e1e89732e3e6b Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:33:57 -0800 Subject: [PATCH 06/16] add compressed tensors w4a4 nvfp4 moe support --- .../compressed_tensors_moe.py | 245 +++++++++++++++++- python/sglang/srt/models/deepseek_v2.py | 4 + 2 files changed, 248 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8d0d475ac1f2..a36127e7d52c 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,19 +13,24 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS, ) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + is_blackwell_supported, + normalize_e4m3fn_to_e4m3fnuz, +) from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales from sglang.srt.layers.quantization.utils import ( all_close_1d, per_tensor_dequantize, replace_parameter, + swizzle_blockscale, ) from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs @@ -60,6 +65,7 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", + "CompressedTensorsW4A4Nvfp4MoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsWNA16MoEMethod", ] @@ -86,7 +92,11 @@ def get_moe_method( if quant_config._is_wNa16_group_channel(weight_quant, input_quant): logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW4A4Nvfp4MoEMethod") + return CompressedTensorsW4A4Nvfp4MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW8A8Fp8MoEMethod") return CompressedTensorsW8A8Fp8MoEMethod(quant_config) else: raise RuntimeError( @@ -94,6 +104,239 @@ def get_moe_method( ) +class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig): + if not is_blackwell_supported(): + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) + self.quant_config = quant_config + self.group_size = 16 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # From packed to weight + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) + delattr(layer, "w13_weight_packed") + + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) + delattr(layer, "w2_weight_packed") + + if not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected." + ) + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False + ) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False + ) + + # w13 + w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( + torch.float32 + ) + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False, + ) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False + ) + + # w2 + w2_input_global_scale = layer.w2_input_global_scale + + layer.g2_alphas = torch.nn.Parameter( + ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (w2_input_global_scale), requires_grad=False + ) + + # swizzle weight scales + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) + + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) + + layer.cutlass_moe_params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + layer.w13_weight.device, + num_experts=layer.num_experts, + intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, + hidden_size=layer.w13_weight.shape[2] * 2, + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + output = cutlass_moe_fp4( + a=x, + a1_gscale=layer.w13_input_scale_quant, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_weight_scale, + w1_alphas=layer.g1_alphas, + a2_gscale=layer.w2_input_scale_quant, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_weight_scale, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=layer.cutlass_moe_params, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + ).to(x.dtype) + + return StandardCombineInput(hidden_states=output) + + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__(self, quant_config: CompressedTensorsConfig): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c74ba98cedde..dcc326d80043 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -482,6 +482,10 @@ def __init__( tp_rank=tp_rank, tp_size=tp_size, ) + if not hasattr(self.gate_up_proj, "weight"): + self.gate_up_proj.weight = getattr(self.gate_up_proj, "weight_packed") + if not hasattr(self.down_proj, "weight"): + self.down_proj.weight = getattr(self.down_proj, "weight_packed") if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. " From fdf7a4e76ba9e8d9703a13b9408d913b32b60c95 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Fri, 5 Dec 2025 15:49:26 +0000 Subject: [PATCH 07/16] lint Signed-off-by: Xinyuan Tong --- .../compressed_tensors/compressed_tensors.py | 11 ++++++----- .../compressed_tensors/schemes/__init__.py | 4 ++-- .../schemes/compressed_tensors_w4a16_nvfp4.py | 11 +++++------ .../schemes/compressed_tensors_w4a4_nvfp4.py | 19 +++++++------------ .../layers/quantization/marlin_utils_fp4.py | 1 - .../sglang/srt/layers/quantization/utils.py | 1 + python/sglang/srt/utils/common.py | 6 +++++- 7 files changed, 26 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 87b6cc010611..d985737d3781 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -30,12 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS, CompressedTensorsScheme, + CompressedTensorsW4A4Fp4, + CompressedTensorsW4A16Fp4, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, - CompressedTensorsW4A16Fp4, - CompressedTensorsW4A4Fp4, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -379,7 +379,9 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool # All conditions satisfied. return True - def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs): + def _is_fp4a4_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): if weight_quant is None or input_quant is None: return False @@ -418,7 +420,7 @@ def _is_wNa16_group_channel( is_static = not weight_quant.dynamic return is_channel_group and input_quant_none and is_symmetric and is_static - + def _is_fp4a16_nvfp4( self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs ): @@ -476,7 +478,6 @@ def _get_scheme_from_parts( ) return CompressedTensorsW4A16Fp4(has_input_global_scale=True) - if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( CompressedTensorsW8A8Fp8.get_min_capability(), error=False diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index c147ef45a38c..0da6d20434aa 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 -from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 -from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 __all__ = [ "CompressedTensorsScheme", diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index 54e16ac1fbe8..42e4b32e1dba 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -5,19 +5,18 @@ import torch from torch.nn.parameter import Parameter +from sglang.srt.layers.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) - from sglang.srt.layers.quantization.marlin_utils_fp4 import ( apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin, ) -from sglang.srt.layers.parameter import ( - GroupQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter, -) __all__ = ["CompressedTensorsW4A16Fp4"] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 621553506e68..219af5fb0d5f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -1,28 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging from collections.abc import Callable import torch +from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from torch.nn.parameter import Parameter -from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, -) - -from sglang.srt.layers.quantization.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - prepare_fp4_layer_for_marlin, -) +from sglang.srt.layers.attention.flashinfer_ops import flashinfer_scaled_fp4_mm from sglang.srt.layers.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) -from sglang.srt.utils import is_flashinfer_available, cutlass_fp4_supported -import logging +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) from sglang.srt.layers.quantization.utils import swizzle_blockscale -from sgl_kernel import scaled_fp4_quant, cutlass_scaled_fp4_mm -from sglang.srt.layers.attention.flashinfer_ops import flashinfer_scaled_fp4_mm +from sglang.srt.utils import cutlass_fp4_supported, is_flashinfer_available logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py index 44fe090ce93b..ccd11265d939 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -3,7 +3,6 @@ import logging -from typing import Optional import torch diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 589331520ee9..aab350bcffe4 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -564,6 +564,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): sort_indices.to(device=orig_device), ) + def swizzle_blockscale(scale: torch.Tensor): assert scale.dtype == torch.float8_e4m3fn # Pad and blockwise interleave weight_scale diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index e34cf8a11ee8..a903f0f1632e 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -352,7 +352,11 @@ def is_flashinfer_available(): def cutlass_fp4_supported() -> bool: - return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0) and torch.version.cuda >= "12.8" + return ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (10, 0) + and torch.version.cuda >= "12.8" + ) def is_nvidia_cublas_cu12_version_ge_12_9(): From 6b67c6c8ecdf76f453bf8d85a42f99cdf625a7b4 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Fri, 5 Dec 2025 16:42:29 +0000 Subject: [PATCH 08/16] fix marlin undefined name Signed-off-by: Xinyuan Tong --- python/sglang/srt/layers/quantization/marlin_utils_fp4.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py index ccd11265d939..eba307acf634 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -224,7 +224,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WORKSPACE device = layer.w13_weight.device param_dtype = layer.params_dtype - layer.workspace = marlin_make_workspace_new(device, 4) + layer.workspace = marlin_make_workspace(device, 4) perm = torch.empty(0, dtype=torch.int, device=device) # WEIGHT @@ -242,7 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: for i in range(e): qweight = weight[i].view(torch.int32).T.contiguous() - marlin_qweight = ops.gptq_marlin_repack( + marlin_qweight = gptq_marlin_repack( b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 ) tensor_list.append(marlin_qweight) @@ -337,7 +337,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): * scales.repeat_interleave(group_size, 1).to(weight.dtype) ) - marlin_qweight = ops.gptq_marlin_repack( + marlin_qweight = gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), perm=torch.empty(0, dtype=torch.int, device=device), size_k=size_k, @@ -386,7 +386,7 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): ).view(size_n, size_k) weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) - marlin_qweight = ops.gptq_marlin_repack( + marlin_qweight = gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), perm=torch.empty(0, dtype=torch.int, device=device), size_k=size_k, From 3ae63da49c5b09a7099fb3df62e5c3dd31034f1e Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 8 Dec 2025 18:11:05 -0800 Subject: [PATCH 09/16] clean up --- .../srt/layers/attention/flashinfer_ops.py | 68 ------------------ .../compressed_tensors/compressed_tensors.py | 12 ++-- .../schemes/compressed_tensors_w4a4_nvfp4.py | 70 +++++++------------ .../sglang/srt/layers/quantization/utils.py | 3 + python/sglang/srt/utils/common.py | 8 --- 5 files changed, 36 insertions(+), 125 deletions(-) delete mode 100644 python/sglang/srt/layers/attention/flashinfer_ops.py diff --git a/python/sglang/srt/layers/attention/flashinfer_ops.py b/python/sglang/srt/layers/attention/flashinfer_ops.py deleted file mode 100644 index 0cd7f1684604..000000000000 --- a/python/sglang/srt/layers/attention/flashinfer_ops.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch - -from sglang.srt.utils import is_flashinfer_available - -if is_flashinfer_available(): - - @torch.library.custom_op( - "sglang::flashinfer_mm_fp4", - mutates_args=[], - device_types="cuda", - ) - def flashinfer_mm_fp4( - A: torch.Tensor, - B: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - g_scale: torch.Tensor, - dtype: torch.dtype, - backend: str, - ) -> torch.Tensor: - from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_ - - return flashinfer_mm_fp4_( - A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend - ) - - @torch.library.register_fake( - "sglang::flashinfer_mm_fp4", - ) - def flashinfer_mm_fp4_fake( - A: torch.Tensor, - B: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - g_scale: torch.Tensor, - dtype: torch.dtype, - backend: str, - ) -> torch.Tensor: - return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) - - -def flashinfer_scaled_fp4_mm( - a: torch.Tensor, - b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, - alpha: torch.Tensor, - out_dtype: torch.dtype, - backend: str, -) -> torch.Tensor: - assert a.ndim == 2 and b.ndim == 2 - assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 - assert a.stride(-1) == 1 and b.stride(-1) == 1 - assert a.shape[1] == b.shape[1] - - if backend == "cutlass": - block_scale_a = block_scale_a.view(torch.uint8) - block_scale_b = block_scale_b.view(torch.uint8) - - return flashinfer_mm_fp4( - a, - b.t(), - block_scale_a, - block_scale_b.t(), - alpha, - out_dtype, - backend=backend, - ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index d985737d3781..b3d80c9df8d4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -44,7 +44,6 @@ ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod -from sglang.srt.utils import cutlass_fp4_supported logger = logging.getLogger(__name__) @@ -469,14 +468,15 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if cutlass_fp4_supported(): + is_fp4a4_nvfp4_supported = self._check_scheme_supported( + CompressedTensorsW4A4Fp4.get_min_capability(), error=False + ) + if is_fp4a4_nvfp4_supported: return CompressedTensorsW4A4Fp4() else: - logger.warning_once( - "Current platform does not support cutlass NVFP4." - " Running CompressedTensorsW4A16Fp4." + raise NotImplementedError( + "Current platform does not support w4a4 nvfp4 quantization." ) - return CompressedTensorsW4A16Fp4(has_input_global_scale=True) if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 219af5fb0d5f..9ad06bb5603c 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -2,12 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging from collections.abc import Callable +from typing import Optional import torch -from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from torch.nn.parameter import Parameter -from sglang.srt.layers.attention.flashinfer_ops import flashinfer_scaled_fp4_mm from sglang.srt.layers.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, @@ -16,8 +15,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from sglang.srt.layers.quantization.modelopt_quant import ( + FLASHINFER_FP4_GEMM_BACKEND, + _sglang_fp4_gemm, + enable_flashinfer_fp4_gemm, + fp4_quantize, +) from sglang.srt.layers.quantization.utils import swizzle_blockscale -from sglang.srt.utils import cutlass_fp4_supported, is_flashinfer_available logger = logging.getLogger(__name__) @@ -26,21 +30,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): - self.backend = "none" - if is_flashinfer_available(): - self.backend = "flashinfer-cutlass" - elif cutlass_fp4_supported(): - self.backend = "cutlass" - else: - self.backend = "none" - - if self.backend == "none": - raise ValueError( - "No valid NVFP4 GEMM backend found. " - "Please check your platform capability." - ) - - logger.info_once(f"Using {self.backend} for NVFP4 GEMM") self.group_size = 16 @classmethod @@ -109,7 +98,7 @@ def process_weights_after_loading(self, layer) -> None: layer.weight_global_scale.max().to(torch.float32), requires_grad=False ) - if self.backend == "flashinfer-trtllm": + if FLASHINFER_FP4_GEMM_BACKEND == "trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -131,8 +120,6 @@ def process_weights_after_loading(self, layer) -> None: layer.weight_packed = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - if self.backend == "fbgemm": - swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_packed = Parameter( layer.weight_packed.data, requires_grad=False @@ -147,38 +134,35 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: torch.Tensor | None = None, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: output_dtype = x.dtype - output_shape = [x.shape[0], layer.weight_packed.shape[0]] + w_n, _ = layer.weight_packed.shape + output_shape = [x.shape[0], w_n] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + x_fp4, x_blockscale = fp4_quantize(x, layer.input_global_scale) - mm_args = ( + assert x_fp4.dtype == torch.uint8 + assert layer.weight_packed.dtype == torch.uint8 + assert layer.weight_scale.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 + + w = layer.weight_packed + w_blockscale = layer.weight_scale + if enable_flashinfer_fp4_gemm: + w = layer.weight_packed.T + w_blockscale = layer.weight_scale.T + + out = _sglang_fp4_gemm( x_fp4, - layer.weight_packed, + w, x_blockscale, - layer.weight_scale, + w_blockscale, layer.alpha, output_dtype, + w_n, ) - if self.backend.startswith("flashinfer-"): - backend_name = self.backend[len("flashinfer-") :] - out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) - elif self.backend == "fbgemm": - out = torch.ops.fbgemm.f4f4bf16( - x_fp4, - layer.weight_packed, - x_blockscale.view(-1).view(torch.uint8), - layer.weight_scale, - layer.alpha, - use_mx=False, - ).to(output_dtype) - else: - assert self.backend == "cutlass" - out = cutlass_scaled_fp4_mm(*mm_args) - if bias is not None: out = out + bias return out.view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index aab350bcffe4..a2da44d00b69 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -566,6 +566,9 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): def swizzle_blockscale(scale: torch.Tensor): + """ + Swizzle the scale tensor into a blockwise interleaved format for NVFP4 quantization. + """ assert scale.dtype == torch.float8_e4m3fn # Pad and blockwise interleave weight_scale scale_ndim = scale.ndim diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 666582d356b4..6fa0b2404ba0 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -351,14 +351,6 @@ def is_flashinfer_available(): return importlib.util.find_spec("flashinfer") is not None and is_cuda() -def cutlass_fp4_supported() -> bool: - return ( - torch.cuda.is_available() - and torch.cuda.get_device_capability() >= (10, 0) - and torch.version.cuda >= "12.8" - ) - - def is_nvidia_cublas_cu12_version_ge_12_9(): """ temporary fix for issue #11272 From c82ae89d73334d6d353a1f72f7703bd649a09c34 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 9 Dec 2025 11:11:40 +0800 Subject: [PATCH 10/16] clean up w4a16 nvfp4 (#5) --- .../compressed_tensors/compressed_tensors.py | 26 -- .../compressed_tensors/schemes/__init__.py | 2 - .../schemes/compressed_tensors_w4a16_nvfp4.py | 124 ------ .../layers/quantization/marlin_utils_fp4.py | 403 ------------------ 4 files changed, 555 deletions(-) delete mode 100644 python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py delete mode 100644 python/sglang/srt/layers/quantization/marlin_utils_fp4.py diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index b3d80c9df8d4..933c27c059c3 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -31,7 +31,6 @@ WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A16Fp4, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, @@ -420,35 +419,10 @@ def _is_wNa16_group_channel( return is_channel_group and input_quant_none and is_symmetric and is_static - def _is_fp4a16_nvfp4( - self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs - ): - is_weight_only = weight_quant is not None and input_quant is None - is_tensor_group_quant = ( - weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value - ) - is_symmetric = weight_quant.symmetric - - is_group_size_16 = weight_quant.group_size == 16 - is_float_type = weight_quant.type == QuantizationType.FLOAT - is_4_bits = weight_quant.num_bits == 4 - - return ( - is_weight_only - and is_tensor_group_quant - and is_float_type - and is_4_bits - and is_group_size_16 - and is_symmetric - ) - def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel ) -> CompressedTensorsScheme: - if self._is_fp4a16_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A16Fp4() - # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): if ( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index 0da6d20434aa..d5f4146e606f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,7 +2,6 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 -from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 @@ -15,6 +14,5 @@ "CompressedTensorsW8A8Int8", "CompressedTensorsWNA16", "WNA16_SUPPORTED_BITS", - "CompressedTensorsW4A16Fp4", "CompressedTensorsW4A4Fp4", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py deleted file mode 100644 index 42e4b32e1dba..000000000000 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ /dev/null @@ -1,124 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - -import torch -from torch.nn.parameter import Parameter - -from sglang.srt.layers.parameter import ( - GroupQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter, -) -from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, -) -from sglang.srt.layers.quantization.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - prepare_fp4_layer_for_marlin, -) - -__all__ = ["CompressedTensorsW4A16Fp4"] - - -class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): - def __init__(self, has_input_global_scale: bool = False): - self.has_input_global_scale = has_input_global_scale - self.group_size = 16 - - @classmethod - def get_min_capability(cls) -> int: - # dont restrict as emulations - return 80 - - def create_weights( - self, - layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - - # Weight - weight = ModelWeightParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_packed", weight) - - # Global Weight Scale - weight_global_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("weight_global_scale", weight_global_scale) - - # Per Group Weight Scale - weight_scale = GroupQuantScaleParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - - layer.register_parameter("weight_scale", weight_scale) - - if self.has_input_global_scale: - input_global_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("input_global_scale", input_global_scale) - - def process_weights_after_loading(self, layer) -> None: - # Process parameters for marlin repacking - - # Rename weight_packed to weight that marlin expects - layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) - del layer.weight_packed - # Rename weight_global_scale to weight_scale_2 that marlin expects - # Note: ct stores the inverse of what is expected by the marlin kernel - layer.weight_scale_2 = Parameter( - 1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False - ) - del layer.weight_global_scale - - if self.has_input_global_scale: - layer.input_global_scale = torch.nn.Parameter( - layer.input_global_scale.data, requires_grad=False - ) - - prepare_fp4_layer_for_marlin(layer) - - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_scale_2=layer.weight_scale_2, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py deleted file mode 100644 index eba307acf634..000000000000 --- a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py +++ /dev/null @@ -1,403 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -import logging - -import torch - -from sglang.srt.layers.quantization.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, - marlin_make_workspace, - marlin_permute_bias, - marlin_permute_scales, - should_use_atomic_add_reduce, -) -from sglang.srt.layers.quantization.utils import get_scalar_types -from sglang.srt.utils import is_cuda - -_is_cuda = is_cuda() -if _is_cuda: - from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack - -ScalarType, scalar_types = get_scalar_types() - -FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] - -logger = logging.getLogger(__name__) - - -def is_fp4_marlin_supported(): - return True - - -def nvfp4_marlin_process_scales(marlin_scales): - if not (marlin_scales >= 0).all(): - logger.warning_once( - "NVFP4 Marlin assumes the scales to be >=0, but has encountered " - "negative scales. Accuracy will likely be degraded. This is " - "because it changes the scales from FP8-S1E4M3 to a special " - "FP8-S0E5M3 format to speedup the dequantization." - ) - - # convert to half first, we would convert to fp8 later - marlin_scales = marlin_scales.to(torch.half) - - # 8 is the number of scale number using by one thread - marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) - marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1 - ) - - # fit the layout of fp8 dequantization - marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1 - ) - - # We assume that weight_scale (FP8-S1E4M3) is always greater - # than or equal to 0. So we can convert - # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. - # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 - # when weight_scale > 0. This allows us to have an exponent bias - # closer to zero after dequantization. - - marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 - marlin_scales = marlin_scales.view(torch.float8_e4m3fn) - marlin_scales = marlin_scales[:, 1::2].contiguous() - - return marlin_scales - - -def mxfp4_marlin_process_scales(marlin_scales): - # 8 is the number of scale number using by one thread - marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) - marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1 - ) - - # fit the layout of fp8 dequantization - marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1 - ) - marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) - return marlin_scales - - -def nvfp4_marlin_process_global_scale(global_scale): - assert global_scale.dtype in [torch.half, torch.bfloat16] - fp4_exponent = 2 - if global_scale.dtype == torch.half: - target_exponent = 5 - elif global_scale.dtype == torch.bfloat16: - target_exponent = 8 - # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 - # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 - exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) - return global_scale * (2.0 ** (exponent_bias - 7)) - - -def apply_fp4_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_scale_2: torch.Tensor | None, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: torch.Tensor | None = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - # For GPUs that lack FP4 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP4 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - use_atomic_add = should_use_atomic_add_reduce( - m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype - ) - - output = gptq_marlin_gemm( - a=reshaped_x, - c=None, - b_q_weight=weight, - b_bias=bias, - b_scales=weight_scale, - global_scale=weight_scale_2, - b_zeros=None, - g_idx=None, - perm=None, - workspace=workspace, - b_q_type=scalar_types.float4_e2m1f, - size_m=reshaped_x.size(0), - size_n=size_n, - size_k=size_k, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce, - ) - - return output.reshape(out_shape) - - -def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: - logger.warning_once( - "Your GPU does not have native support for FP4 computation but " - "FP4 quantization is being used. Weight-only FP4 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads." - ) - - is_nvfp4 = hasattr(layer, "weight_scale_2") - group_size = 16 if is_nvfp4 else 32 - - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - param_dtype = layer.params_dtype - - assert layer.weight.shape == (part_size_n, part_size_k // 2) - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(device) - - # WEIGHT - # Repack weights to marlin format - perm = torch.empty(0, dtype=torch.int, device=device) - qweight = layer.weight.view(torch.int32).T.contiguous() - - marlin_qweight = gptq_marlin_repack( - b_q_weight=qweight, - perm=perm, - size_k=part_size_k, - size_n=part_size_n, - num_bits=4, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - # Permute scales - weight_scale = layer.weight_scale.T.contiguous() - - if not is_nvfp4: - weight_scale = weight_scale.view(torch.float8_e8m0fnu) - - weight_scale = weight_scale.to(param_dtype) - weight_scale = marlin_permute_scales( - s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size - ) - - if is_nvfp4: - weight_scale = nvfp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - - weight_scale_2 = layer.weight_scale_2.to(param_dtype) - weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) - else: - weight_scale = mxfp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - - if hasattr(layer, "bias") and layer.bias is not None: - assert layer.bias.shape == (part_size_n,) - bias = marlin_permute_bias(layer.bias) - layer.bias = torch.nn.Parameter(bias, requires_grad=False) - - return - - -def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: - logger.warning_once( - "Your GPU does not have native support for FP4 computation but " - "FP4 quantization is being used. Weight-only FP4 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads." - ) - - is_nvfp4 = hasattr(layer, "w13_weight_scale_2") - group_size = 16 if is_nvfp4 else 32 - - e = layer.num_experts - k = layer.hidden_size - n = layer.intermediate_size_per_partition - - # WORKSPACE - device = layer.w13_weight.device - param_dtype = layer.params_dtype - layer.workspace = marlin_make_workspace(device, 4) - perm = torch.empty(0, dtype=torch.int, device=device) - - # WEIGHT - # Repack weights to marlin format - for name in ["w13_weight", "w2_weight"]: - weight = getattr(layer, name) - tensor_list = [] - if "w13" in name: - size_n, size_k = n * 2, k - else: - size_n, size_k = k, n - - assert weight.shape == (e, size_n, size_k // 2) - - for i in range(e): - qweight = weight[i].view(torch.int32).T.contiguous() - - marlin_qweight = gptq_marlin_repack( - b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 - ) - tensor_list.append(marlin_qweight) - - weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - weight = torch.nn.Parameter(weight, requires_grad=False) - - setattr(layer, name, weight) - - # WEIGHT SCALES - # Permute scales - for name in ["w13", "w2"]: - scales = getattr(layer, name + "_weight_scale") - if not is_nvfp4: - scales = scales.view(torch.float8_e8m0fnu) - scales = scales.to(param_dtype) - if is_nvfp4: - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) - - tensor_list = [] - if "w13" in name: - size_n, size_k = n * 2, k - else: - size_n, size_k = k, n - - for i in range(e): - scale = scales[i].T - - marlin_scales = marlin_permute_scales( - s=scale, size_k=size_k, size_n=size_n, group_size=group_size - ) - if is_nvfp4: - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - else: - marlin_scales = mxfp4_marlin_process_scales(marlin_scales) - tensor_list.append(marlin_scales) - - scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - scales = torch.nn.Parameter(scales, requires_grad=False) - setattr(layer, name + "_weight_scale", scales) - - if is_nvfp4: - global_scale = nvfp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, requires_grad=False) - setattr(layer, name + "_weight_scale_2", global_scale) - - # BIAS - # Permute bias - for name in ["w13_bias", "w2_bias"]: - if not hasattr(layer, name): - continue - bias = getattr(layer, name).to(param_dtype) - - tensor_list = [] - for i in range(e): - expert_bias = bias[i] - - tensor_list.append(marlin_permute_bias(expert_bias)) - - bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - bias = torch.nn.Parameter(bias, requires_grad=False) - setattr(layer, name, bias) - - -def rand_marlin_weight_nvfp4_like(weight, group_size): - assert group_size > 0 - size_n, size_k = weight.shape - device = weight.device - - scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 - global_scale = scales.max() / 448 - scales = (scales / global_scale).to(torch.float8_e4m3fn) - - fp4_weight = torch.randint( - 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device - ) - fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) - fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) - fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) - - fp4_weight2 = fp4_weight << 4 - fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) - fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) - fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) - - weight_ref = torch.cat( - [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 - ).view(size_n, size_k) - weight_ref = ( - weight_ref - * global_scale.to(weight.dtype) - * scales.repeat_interleave(group_size, 1).to(weight.dtype) - ) - - marlin_qweight = gptq_marlin_repack( - b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=size_k, - size_n=size_n, - num_bits=4, - ) - - marlin_scales = marlin_permute_scales( - s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size - ) - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - - global_scale = nvfp4_marlin_process_global_scale(global_scale) - - return weight_ref.T, marlin_qweight, marlin_scales, global_scale - - -def rand_marlin_weight_mxfp4_like(weight, group_size): - assert group_size > 0 - size_n, size_k = weight.shape - device = weight.device - - scales = torch.randint( - 100, - 125, - (size_n, size_k // group_size), - dtype=torch.uint8, - device=weight.device, - ) - scales = scales.view(torch.float8_e8m0fnu) - - fp4_weight = torch.randint( - 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device - ) - fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) - fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) - fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) - - fp4_weight2 = fp4_weight << 4 - fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) - fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) - fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) - - weight_ref = torch.cat( - [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 - ).view(size_n, size_k) - weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) - - marlin_qweight = gptq_marlin_repack( - b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=size_k, - size_n=size_n, - num_bits=4, - ) - - marlin_scales = marlin_permute_scales( - s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size - ) - - marlin_scales = mxfp4_marlin_process_scales(marlin_scales) - - return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) From 527e3994af7f2665a71281e6a81861042e7d1d2b Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 8 Dec 2025 23:16:13 -0800 Subject: [PATCH 11/16] fix rope issue --- python/sglang/srt/configs/model_config.py | 22 ++++++++++++++-------- python/sglang/srt/models/pixtral.py | 4 ++++ python/sglang/srt/utils/mistral_utils.py | 4 ++-- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 33938ff481cf..12f9e0a8150d 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -354,14 +354,20 @@ def _derive_model_shapes(self): # Handle rope scaling with yarn self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) - if self.hf_config.rope_scaling: - mscale_all_dim = self.hf_config.rope_scaling.get( - "mscale_all_dim", False - ) - scaling_factor = self.hf_config.rope_scaling.get("factor") - if scaling_factor is not None: - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale + + if ( + "MistralLarge3ForCausalLM" in self.hf_config.architectures + or "PixtralForConditionalGeneration" in self.hf_config.architectures + or "MistralLarge3ForCausalLMEagle" in self.hf_config.architectures + ): + rope_scaling = self.hf_text_config.rope_scaling + else: + rope_scaling = self.hf_config.rope_scaling + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: self.head_dim = 128 diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index c59770f4578a..55ce2114851a 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -83,6 +83,10 @@ def __init__(self, *, config, prefix: str = "", **kwargs): for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields } + if "rope_theta" not in vision_args: + vision_args["rope_theta"] = self.config.vision_config.rope_scaling[ + "rope_theta" + ] self.vision_args = VisionEncoderArgs(**vision_args) diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index ecce3042df02..fd43ed0b882b 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -11,7 +11,7 @@ def adapt_config_dict( config_dict: dict[str, Any], model: str, **kwargs -) -> PretrainedConfig: +) -> tuple[dict, PretrainedConfig]: config_dict.update(kwargs) config_dict = _remap_general_mistral_args(config_dict) @@ -127,7 +127,7 @@ def _remap_mistral_yarn_args(config: dict) -> dict: if old_name in yarn_config: value = yarn_config.pop(old_name) if new_name is not None: - config["rope_scaling"][new_name] = value + config["rope_scaling"][new_name] = float(value) assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" From 022553dbf0db1a83c8be7cbcda6ac515c3734ae8 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 8 Dec 2025 23:29:42 -0800 Subject: [PATCH 12/16] add assertion for yarn --- python/sglang/srt/configs/model_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 12f9e0a8150d..fb4fdb6bcbca 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -364,6 +364,7 @@ def _derive_model_shapes(self): else: rope_scaling = self.hf_config.rope_scaling if rope_scaling: + assert rope_scaling["rope_type"] == "yarn" mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) From bc1167736d1be420bb400fe90e87b38d1e671a0c Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 9 Dec 2025 17:25:41 +0800 Subject: [PATCH 13/16] Remove assertion --- python/sglang/srt/configs/model_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index fb4fdb6bcbca..12f9e0a8150d 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -364,7 +364,6 @@ def _derive_model_shapes(self): else: rope_scaling = self.hf_config.rope_scaling if rope_scaling: - assert rope_scaling["rope_type"] == "yarn" mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) From e2c42cbc3f35f9637671bfdfe05d2c4ac259cfb3 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:21:59 +0800 Subject: [PATCH 14/16] Revert rope fix --- python/sglang/srt/utils/mistral_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index fd43ed0b882b..3aed7a652350 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -127,7 +127,7 @@ def _remap_mistral_yarn_args(config: dict) -> dict: if old_name in yarn_config: value = yarn_config.pop(old_name) if new_name is not None: - config["rope_scaling"][new_name] = float(value) + config["rope_scaling"][new_name] = value assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" From b8a26437aeb1309c7faaf4782f5ae7fa1e8aab02 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:23:51 +0800 Subject: [PATCH 15/16] Revert rope fix --- python/sglang/srt/models/pixtral.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 55ce2114851a..c59770f4578a 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -83,10 +83,6 @@ def __init__(self, *, config, prefix: str = "", **kwargs): for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields } - if "rope_theta" not in vision_args: - vision_args["rope_theta"] = self.config.vision_config.rope_scaling[ - "rope_theta" - ] self.vision_args = VisionEncoderArgs(**vision_args) From 1fcb827eb0b891c5d4f4a59e63f155ce92db7c5a Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Fri, 12 Dec 2025 06:29:05 -0800 Subject: [PATCH 16/16] update copyright --- .../quantization/compressed_tensors/compressed_tensors.py | 2 +- .../quantization/compressed_tensors/compressed_tensors_moe.py | 2 +- .../compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py | 2 +- python/sglang/srt/models/mistral_large_3.py | 2 +- python/sglang/srt/models/mistral_large_3_eagle.py | 2 ++ python/sglang/srt/utils/mistral_utils.py | 2 +- 6 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 933c27c059c3..14f4e5f86904 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1db5009ccff5..b5e3964c85f4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 9ad06bb5603c..a155b160422e 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -1,5 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging from collections.abc import Callable from typing import Optional diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py index fd60ef61f7c0..cbbf893b6e9d 100644 --- a/python/sglang/srt/models/mistral_large_3.py +++ b/python/sglang/srt/models/mistral_large_3.py @@ -1,5 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mistral_large_3.py # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable import regex as re diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py index 08f7271fde6c..a5ce7b6aabb6 100644 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mistral_large_3_eagle.py +# SPDX-License-Identifier: Apache-2.0 from typing import Optional import torch diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index 3aed7a652350..7be1ba5bc6f0 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -1,5 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/mistral.py # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from pathlib import Path from typing import Any