Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d6322e0
Support eagle
elvischenv Dec 4, 2025
0c9e202
Fixes for running eagle
Linda-Stadter Dec 4, 2025
0575231
Merge branch 'main' into elvis/eagle
elvischenv Dec 5, 2025
1402ae4
Added w4a16 loading support.
dcampora Nov 29, 2025
b8b4cc6
Adding w4a4 support for compressed tensors.
dcampora Dec 5, 2025
c54fc52
Do not change sgl kernel.
dcampora Dec 5, 2025
ab9fe7a
add compressed tensors w4a4 nvfp4 moe support
elvischenv Dec 5, 2025
09d4eba
Merge branch 'main' into dcampora/nvfp4_support
JustinTong0323 Dec 5, 2025
fdf7a4e
lint
JustinTong0323 Dec 5, 2025
6b67c6c
fix marlin undefined name
JustinTong0323 Dec 5, 2025
1c5b478
Merge branch 'main' into dcampora/nvfp4_support
JustinTong0323 Dec 5, 2025
c4b3399
Merge branch 'main' into dcampora/nvfp4_support
JustinTong0323 Dec 5, 2025
da20243
Merge branch 'main' into dcampora/nvfp4_support
JustinTong0323 Dec 6, 2025
854bfa5
Merge branch 'main' into dcampora/nvfp4_support
Fridge003 Dec 6, 2025
3ae63da
clean up
elvischenv Dec 9, 2025
c82ae89
clean up w4a16 nvfp4 (#5)
elvischenv Dec 9, 2025
6f47107
Merge branch 'main' into dcampora/nvfp4_support
elvischenv Dec 9, 2025
527e399
fix rope issue
elvischenv Dec 9, 2025
a54c7f5
Merge remote-tracking branch 'origin/main' into dcampora/nvfp4_support
elvischenv Dec 9, 2025
022553d
add assertion for yarn
elvischenv Dec 9, 2025
bc11677
Remove assertion
elvischenv Dec 9, 2025
e044562
Merge branch 'main' into dcampora/nvfp4_support
Fridge003 Dec 10, 2025
b2683a0
Merge branch 'main' into dcampora/nvfp4_support
Fridge003 Dec 10, 2025
e2c42cb
Revert rope fix
elvischenv Dec 11, 2025
b8a2643
Revert rope fix
elvischenv Dec 11, 2025
8879db9
Merge branch 'main' into dcampora/nvfp4_support
elvischenv Dec 11, 2025
1e24516
Merge branch 'main' into dcampora/nvfp4_support
Fridge003 Dec 11, 2025
2e724b2
Merge branch 'main' into dcampora/nvfp4_support
ispobock Dec 12, 2025
1fcb827
update copyright
elvischenv Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Expand Down Expand Up @@ -30,6 +30,7 @@
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS,
CompressedTensorsScheme,
CompressedTensorsW4A4Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
Expand Down Expand Up @@ -376,6 +377,35 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool
# All conditions satisfied.
return True

def _is_fp4a4_nvfp4(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
):
if weight_quant is None or input_quant is None:
return False

is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric and input_quant.symmetric

is_group_size_16 = (
weight_quant.group_size == 16 and input_quant.group_size == 16
)
is_float_type = (
weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT
)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4

return (
is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)

def _is_wNa16_group_channel(
self, weight_quant: BaseModel, input_quant: BaseModel
) -> bool:
Expand Down Expand Up @@ -411,6 +441,17 @@ def _get_scheme_from_parts(
)

if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
is_fp4a4_nvfp4_supported = self._check_scheme_supported(
CompressedTensorsW4A4Fp4.get_min_capability(), error=False
)
if is_fp4a4_nvfp4_supported:
return CompressedTensorsW4A4Fp4()
else:
raise NotImplementedError(
"Current platform does not support w4a4 nvfp4 quantization."
)

if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Expand All @@ -13,19 +13,24 @@

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.fp8_utils import (
is_blackwell_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
from sglang.srt.layers.quantization.utils import (
all_close_1d,
per_tensor_dequantize,
replace_parameter,
swizzle_blockscale,
)
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs

Expand Down Expand Up @@ -60,6 +65,7 @@ class GPTQMarlinState(Enum):

__all__ = [
"CompressedTensorsMoEMethod",
"CompressedTensorsW4A4Nvfp4MoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
Expand All @@ -86,14 +92,251 @@ def get_moe_method(
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsW4A4Nvfp4MoEMethod")
return CompressedTensorsW4A4Nvfp4MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsW8A8Fp8MoEMethod")
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
)


class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):

def __init__(self, quant_config: CompressedTensorsConfig):
if not is_blackwell_supported():
raise ValueError(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
self.quant_config = quant_config
self.group_size = 16

def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

layer.num_experts = num_experts
layer.params_dtype = params_dtype

w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

# Weight Scales
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)

w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)

w2_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)

# Input Global Scales
w13_input_scale = torch.nn.Parameter(
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_input_scale, extra_weight_attrs)

w2_input_scale = torch.nn.Parameter(
torch.empty(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_global_scale", w2_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w2_input_scale, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# From packed to weight
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
delattr(layer, "w13_weight_packed")

layer.w2_weight = torch.nn.Parameter(
layer.w2_weight_packed.data, requires_grad=False
)
delattr(layer, "w2_weight_packed")

if not torch.allclose(
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected."
)

# Take inverse of global scale saved to disk
layer.w13_weight_scale_2 = torch.nn.Parameter(
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False
)

layer.w2_weight_scale_2 = torch.nn.Parameter(
1 / layer.w2_weight_global_scale.data, requires_grad=False
)

# w13
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
torch.float32
)
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False,
)

layer.w13_input_scale_quant = torch.nn.Parameter(
(w13_input_global_scale), requires_grad=False
)

# w2
w2_input_global_scale = layer.w2_input_global_scale

layer.g2_alphas = torch.nn.Parameter(
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)

layer.w2_input_scale_quant = torch.nn.Parameter(
(w2_input_global_scale), requires_grad=False
)

# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)

layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)

layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
layer.w13_weight.device,
num_experts=layer.num_experts,
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2,
hidden_size=layer.w13_weight.shape[2] * 2,
)

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)

def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:

from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids

output = cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_weight_scale,
w1_alphas=layer.g1_alphas,
a2_gscale=layer.w2_input_scale_quant,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_weight_scale,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)

return StandardCombineInput(hidden_states=output)


class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

def __init__(self, quant_config: CompressedTensorsConfig):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
Expand All @@ -13,4 +14,5 @@
"CompressedTensorsW8A8Int8",
"CompressedTensorsWNA16",
"WNA16_SUPPORTED_BITS",
"CompressedTensorsW4A4Fp4",
]
Loading
Loading