Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
3d3dc09
Compressed tensors MoE schemes initial implementation
TamirBaydasov Jan 21, 2026
cb87af5
Apply suggestion from @gemini-code-assist[bot]
TamirBaydasov Jan 21, 2026
db90ec7
Apply suggestion from @gemini-code-assist[bot]
TamirBaydasov Jan 21, 2026
f6ddc08
Apply suggestion from @gemini-code-assist[bot]
TamirBaydasov Jan 21, 2026
d94e307
Apply suggestion from @gemini-code-assist[bot]
TamirBaydasov Jan 21, 2026
b6c554e
Merge branch 'sgl-project:main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 26, 2026
1e57ef6
Update layer.py
TamirBaydasov Jan 27, 2026
d7351c0
Update kt_ep_wrapper.py
TamirBaydasov Jan 27, 2026
8d5b3ee
Update compressed_tensors.py
TamirBaydasov Jan 27, 2026
96e49df
Update __init__.py
TamirBaydasov Jan 27, 2026
e4219c6
Update compressed_tensors_w4a4_nvfp4_moe.py
TamirBaydasov Jan 27, 2026
47b93b1
Update compressed_tensors_w4a8_int8_moe.py
TamirBaydasov Jan 27, 2026
8b2364a
Update compressed_tensors_w8a8_fp8_moe.py
TamirBaydasov Jan 27, 2026
845a724
Update compressed_tensors_w8a8_int8.py
TamirBaydasov Jan 27, 2026
c144c81
Update compressed_tensors_w8a8_int8_moe.py
TamirBaydasov Jan 27, 2026
c411d95
Update compressed_tensors_wNa16_moe.py
TamirBaydasov Jan 27, 2026
b7d9194
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 27, 2026
25b3fd4
Update layer.py
TamirBaydasov Jan 27, 2026
e439a63
Update compressed_tensors.py
TamirBaydasov Jan 27, 2026
620e43b
Delete python/sglang/srt/layers/quantization/compressed_tensors/compr…
TamirBaydasov Jan 27, 2026
00dcfc1
Update __init__.py
TamirBaydasov Jan 27, 2026
3d15808
Update compressed_tensors_w4a4_nvfp4_moe.py
TamirBaydasov Jan 27, 2026
6b8fc97
Update compressed_tensors_w4a8_int8_moe.py
TamirBaydasov Jan 27, 2026
a50a58a
Update compressed_tensors_w8a8_fp8_moe.py
TamirBaydasov Jan 27, 2026
2b0aae2
Update compressed_tensors_w8a8_int8.py
TamirBaydasov Jan 27, 2026
dc0e31a
Update compressed_tensors_w8a8_int8_moe.py
TamirBaydasov Jan 27, 2026
5d76dd1
Update compressed_tensors_wNa16_moe.py
TamirBaydasov Jan 27, 2026
4da8049
Update compressed_tensors_w4a8_int8_moe.py
TamirBaydasov Jan 27, 2026
1e275bb
Update compressed_tensors_w8a8_fp8_moe.py
TamirBaydasov Jan 27, 2026
5f70889
Update compressed_tensors_w8a8_int8_moe.py
TamirBaydasov Jan 27, 2026
fa62578
Update compressed_tensors_wNa16_moe.py
TamirBaydasov Jan 27, 2026
81886cb
Update layer.py
TamirBaydasov Jan 27, 2026
c2da0ea
Update compressed_tensors.py
TamirBaydasov Jan 27, 2026
d9b7829
Update __init__.py
TamirBaydasov Jan 27, 2026
ae85653
Create compressed_tensors_w4a4_mxint4_moe.py
TamirBaydasov Jan 27, 2026
691bcf3
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 27, 2026
886a1eb
Update compressed_tensors_w4a4_mxint4_moe.py
TamirBaydasov Jan 28, 2026
edeb1fd
Update layer.py
TamirBaydasov Jan 28, 2026
a3cc9e4
Merge branch 'sgl-project:main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 28, 2026
5bb3f99
Update compressed_tensors.py
TamirBaydasov Jan 28, 2026
5c71144
Update compressed_tensors_w4a4_mxint4_moe.py
TamirBaydasov Jan 28, 2026
3ee1e8b
Update compressed_tensors_w4a4_nvfp4_moe.py
TamirBaydasov Jan 28, 2026
1bd3b72
Update compressed_tensors_w4a8_int8_moe.py
TamirBaydasov Jan 28, 2026
33acfb0
Update compressed_tensors_w8a8_fp8_moe.py
TamirBaydasov Jan 28, 2026
bceb9e5
Update compressed_tensors_w8a8_int8_moe.py
TamirBaydasov Jan 28, 2026
08d4324
Update compressed_tensors_wNa16_moe.py
TamirBaydasov Jan 28, 2026
dc9db4d
Update compressed_tensors_w8a8_int8_moe.py
TamirBaydasov Jan 28, 2026
71ac841
Update compressed_tensors.py
TamirBaydasov Jan 28, 2026
b0b0622
Update compressed_tensors_w4a4_mxint4_moe.py
TamirBaydasov Jan 28, 2026
e0f6d14
Update compressed_tensors_w4a4_nvfp4_moe.py
TamirBaydasov Jan 28, 2026
a5a0a5e
Update compressed_tensors_w4a8_int8_moe.py
TamirBaydasov Jan 28, 2026
04ba6c3
Update compressed_tensors_w8a8_fp8_moe.py
TamirBaydasov Jan 28, 2026
75f9f44
Update compressed_tensors_w8a8_int8_moe.py
TamirBaydasov Jan 28, 2026
e6adca3
Update compressed_tensors_wNa16_moe.py
TamirBaydasov Jan 28, 2026
33a6f60
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 30, 2026
872e917
Merge branch 'sgl-project:main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 30, 2026
e77a4c1
slight fixes
TamirBaydasov Jan 30, 2026
320fe57
apply_without_routing_weights addition to compressedtensorsfusedmoe
TamirBaydasov Jan 30, 2026
779244d
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Jan 30, 2026
74c946d
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Feb 2, 2026
a5aada3
Merge branch 'main' into compressed_tensors_moe_schemes
ping1jing2 Feb 2, 2026
45d1daa
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Feb 3, 2026
8ad6243
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Feb 3, 2026
3480aa3
Merge branch 'main' into compressed_tensors_moe_schemes
ping1jing2 Feb 4, 2026
2043bf3
Merge branch 'main' into compressed_tensors_moe_schemes
ping1jing2 Feb 4, 2026
85d2c80
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Feb 10, 2026
7e67240
Base scheme implementation
TamirBaydasov Feb 10, 2026
6a7cc1f
Remove device capability from base schemes
TamirBaydasov Feb 10, 2026
3633b77
DOcsting fix
TamirBaydasov Feb 10, 2026
ac55ea5
Merge branch 'main' into compressed_tensors_moe_schemes
TamirBaydasov Feb 11, 2026
639d2c9
Merge branch 'main' into compressed_tensors_moe_schemes
ping1jing2 Feb 14, 2026
2ff3ac5
Merge branch 'main' into compressed_tensors_moe_schemes
AniZpZ Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from sglang.srt.layers.moe.token_dispatcher.moriep import MoriEPNormalCombineInput
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
NPUCompressedTensorsW4A16Int4DynamicMoEMethod,
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
NPUCompressedTensorsW4A16Int4DynamicMoE,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
Expand Down Expand Up @@ -377,7 +377,7 @@ def forward_npu(
else:
input_quant = get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT")
if not input_quant and not isinstance(
self.quant_method, NPUCompressedTensorsW4A16Int4DynamicMoEMethod
self.quant_method, NPUCompressedTensorsW4A16Int4DynamicMoE
):
hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
hidden_states
Expand Down
35 changes: 20 additions & 15 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
FusedMoEMethodBase,
QuantizationConfig,
)
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsMxInt4MoEMethod,
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsMxInt4MoE,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
Expand Down Expand Up @@ -682,6 +682,8 @@ def _weight_loader_impl(
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
Comment thread
TamirBaydasov marked this conversation as resolved.
Comment thread
TamirBaydasov marked this conversation as resolved.
method = self.quant_method
if hasattr(self, "scheme"):
method = self.scheme
if method.__class__.__name__ == "KTEPWrapperMethod":
method = method.gpu_method

Expand All @@ -690,9 +692,9 @@ def _weight_loader_impl(
if (
method.__class__.__name__
in [
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16TritonMoEMethod",
"CompressedTensorsWNA16MarlinMoE",
"CompressedTensorsWNA16MoE",
"CompressedTensorsWNA16TritonMoE",
]
)
else loaded_weight
Expand All @@ -703,10 +705,10 @@ def _weight_loader_impl(

# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if self.use_flashinfer_trtllm_moe and (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
or isinstance(self.quant_method, Fp8MoEMethod)
or isinstance(self.quant_method, UnquantizedFusedMoEMethod)
or isinstance(self.quant_method, CompressedTensorsMxInt4MoEMethod)
isinstance(method, ModelOptNvFp4FusedMoEMethod)
or isinstance(method, Fp8MoEMethod)
or isinstance(method, UnquantizedFusedMoEMethod)
or isinstance(method, CompressedTensorsMxInt4MoE)
):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]

Expand Down Expand Up @@ -739,7 +741,7 @@ def _weight_loader_impl(

if (
(
"compressed" in self.quant_method.__class__.__name__.lower()
"compressed" in method.__class__.__name__.lower()
or "w4afp8" in self.quant_config.get_name()
)
and (param.data[expert_id] != 1).any()
Expand Down Expand Up @@ -767,9 +769,9 @@ def _weight_loader_impl(
)
return

if "ModelOpt" in self.quant_method.__class__.__name__:
if "ModelOpt" in method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
is_fp4_variant = isinstance(method, ModelOptNvFp4FusedMoEMethod)

# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
Expand Down Expand Up @@ -902,13 +904,16 @@ def weight_loader_fused(
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
method = self.quant_method
if hasattr(self, "scheme"):
method = self.scheme
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
method.__class__.__name__
in [
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16TritonMoEMethod",
"CompressedTensorsWNA16MoE",
"CompressedTensorsWNA16TritonMoE",
]
)
else loaded_weight
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/kt_ep_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class KTEPWrapperMethod(FusedMoEMethodBase):

Example:
# Wrap any GPU method with AMX/AVX CPU expert support
gpu_method = CompressedTensorsWNA16MoEMethod(quant_config, prefix)
gpu_method = CompressedTensorsWNA16MoE(quant_config, prefix)
kt_config = KTConfig(layer_idx=0, num_gpu_experts=4, ...)
method = KTEPWrapperMethod(gpu_method, kt_config)
"""
Expand Down
99 changes: 99 additions & 0 deletions python/sglang/srt/layers/quantization/base_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

import torch

from sglang.srt.layers.moe import MoeRunnerConfig

if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput

__all__ = ["BaseLinearScheme", "BaseMoEScheme"]


class BaseLinearScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes.
"""

@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function

"""
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

@abstractmethod
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.

:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter

"""
raise NotImplementedError


class BaseMoEScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes.
"""

@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function

"""
raise NotImplementedError

@abstractmethod
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
dispatch_output: "StandardDispatchOutput",
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.

:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter

"""
raise NotImplementedError
Loading
Loading