Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ccae96b
support W8A8 for DeepSeek-R1 on CPU (#19)
chunyuan-w Mar 24, 2025
da85df4
Use shared_expert if use_intel_amx_backend and n_shared_experts is no…
chunyuan-w Mar 28, 2025
05df591
Use int8_scaled_mm_with_quant (#28)
chunyuan-w Mar 28, 2025
ae7d006
Use sgl_kernel.cpu.fp8_scaled_mm in Fp8LinearMethod (#50)
chunyuan-w Apr 11, 2025
01cb26f
Enable FP8 DeepSeek R1 (#52)
chunyuan-w Apr 14, 2025
8c5dc85
Integrate FP8 shared_expert (#59)
chunyuan-w Apr 17, 2025
4159a92
Integrate FP8 fused_moe (#62)
chunyuan-w Apr 17, 2025
575bc3c
fix capability check
chunyuan-w May 28, 2025
2a4513a
fix FP8 linear after rebase
chunyuan-w May 28, 2025
9793c80
convert bmm weight from fp8 to bf16
blzheng Apr 11, 2025
b61007d
fix format
chunyuan-w Jun 3, 2025
5ff3d9a
refine frontend change in deepseek_v2.py
chunyuan-w Jun 5, 2025
1c5be08
simplify shared_experts_is_int8 and shared_experts_is_fp8
chunyuan-w Jun 5, 2025
331ba0e
remove unnecessary view
chunyuan-w Jun 5, 2025
8723032
fix if the obj does not have the use_intel_amx_backend attr
chunyuan-w Jun 12, 2025
3ef5f18
use _is_cpu and _is_cpu_amx_available to check when converting bmm FP…
chunyuan-w Jun 25, 2025
b253b91
Merge branch 'main' into chunyuan/pr_int8_fp8
chunyuan-w Jun 25, 2025
3b778ee
use _is_cpu for device check
chunyuan-w Jun 26, 2025
4742008
Merge branch 'main' into chunyuan/pr_int8_fp8
chunyuan-w Jun 26, 2025
ac476cf
Merge branch 'main' into chunyuan/pr_int8_fp8
zhyncs Jun 26, 2025
63a0197
Merge branch 'main' into chunyuan/pr_int8_fp8
zhyncs Jun 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def forward_cpu(
torch.float
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
topk_ids,
True, # inplace
False, # inplace # See [Note] inplace should be False in fused_experts.
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
Expand Down
43 changes: 43 additions & 0 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def dummy_func(*args, **kwargs):
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import (
_process_weight_after_loading,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refine this method after addressing #6641 (review)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #7647

cpu_has_amx_support,
get_bool_env_var,
is_cpu,
Expand Down Expand Up @@ -330,6 +331,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
)

layer.input_scale = None
elif layer.weight.device.type == "cpu":
assert (
cpu_has_amx_support()
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["weight"])
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
Expand Down Expand Up @@ -426,6 +433,17 @@ def apply(
)

if self.block_quant:
if getattr(layer, "use_intel_amx_backend", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refine this line after addressing #6641 (review)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #7647

return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
x,
layer.weight,
layer.weight_scale_inv,
self.quant_config.weight_block_size,
bias,
x.dtype,
True, # is_vnni
)

return self.w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
Expand Down Expand Up @@ -746,6 +764,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)

if all(w.device.type == "cpu" for w in [layer.w13_weight, layer.w2_weight]):
assert (
cpu_has_amx_support()
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

return

# If checkpoint is fp16 or bfloat16, quantize in place.
Expand Down Expand Up @@ -971,6 +996,24 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if getattr(layer, "use_intel_amx_backend", False):
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
False, # inplace See [Note] inplace should be False in fused_experts.
False, # use_int8_w8a8
True, # use_fp8_w8a16
layer.w13_weight_scale_inv, # w1_scale
layer.w2_weight_scale_inv, # w2_scale
self.quant_config.weight_block_size, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)

if _is_hip:
ret = self.maybe_apply_hip_fused_experts(
layer,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
capability_tuple = get_device_capability()
device_capability = (
-1
if capability_tuple is None
if all(capability is None for capability in capability_tuple)
else capability_tuple[0] * 10 + capability_tuple[1]
)
# Avoid circular import
Expand Down
49 changes: 48 additions & 1 deletion python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
is_cuda,
set_weight_attrs,
)

_is_cuda = is_cuda()
if _is_cuda:
Expand Down Expand Up @@ -72,6 +77,13 @@ def __init__(self, quantization_config: W8A8Int8Config):
self.quantization_config = quantization_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.weight.device == torch.device("cpu"):
assert (
cpu_has_amx_support()
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["weight"])
return

layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)

Expand Down Expand Up @@ -112,6 +124,16 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
if getattr(layer, "use_intel_amx_backend", False):
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
x,
layer.weight,
layer.weight_scale,
bias,
x.dtype,
True, # is_vnni
)

x_q, x_scale = per_token_quant_int8(x)

return int8_scaled_mm(
Expand Down Expand Up @@ -206,6 +228,13 @@ def create_weights(
layer.register_parameter("w2_input_scale", w2_input_scale)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if all(w.device.type == "cpu" for w in [layer.w13_weight, layer.w2_weight]):
assert (
cpu_has_amx_support()
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return

layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
Expand Down Expand Up @@ -252,6 +281,24 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if getattr(layer, "use_intel_amx_backend", False):
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
False, # inplace See [Note] inplace should be False in fused_experts.
True, # use_int8_w8a8
False, # use_fp8_w8a16
layer.w13_weight_scale, # w1_scale
layer.w2_weight_scale, # w2_scale
None, # block_size
layer.w13_input_scale, # a1_scale
layer.w2_input_scale, # a2_scale
True, # is_vnni
)

return fused_experts(
x,
layer.w13_weight,
Expand Down
83 changes: 83 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def __init__(
),
)

self.shared_experts_is_int8 = False
self.shared_experts_is_fp8 = False
self.shared_experts_weight_block_size = None
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe
Expand All @@ -316,6 +319,20 @@ def __init__(
else {}
),
)
self.shared_experts_is_int8 = (
self.shared_experts.gate_up_proj.weight.dtype == torch.int8
)
self.shared_experts_is_fp8 = (
self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
)
if self.shared_experts_is_fp8:
assert (
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
== self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
)
self.shared_experts_weight_block_size = (
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
)

self.top_k = config.num_experts_per_tok

Expand Down Expand Up @@ -394,6 +411,11 @@ def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tenso
return final_hidden_states

def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hasattr(self, "shared_experts") and getattr(
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False
):
return self.forward_cpu(hidden_states)

shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
Expand All @@ -409,6 +431,59 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
fused_experts_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)

assert getattr(
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False
) == getattr(self.shared_experts.down_proj, "use_intel_amx_backend", False)
# [Note] inplace should be False in fused_experts.
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
# While hidden_states is still needed in shared_expert.
final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
self.shared_experts.gate_up_proj.weight,
self.shared_experts.down_proj.weight,
fused_experts_out,
self.routed_scaling_factor,
True, # inplace
self.shared_experts_is_int8, # use_int8_w8a8
self.shared_experts_is_fp8, # use_fp8_w8a16
(
self.shared_experts.gate_up_proj.weight_scale
if self.shared_experts_is_int8
else (
self.shared_experts.gate_up_proj.weight_scale_inv
if self.shared_experts_is_fp8
else None
)
), # w1_scale
(
self.shared_experts.down_proj.weight_scale
if self.shared_experts_is_int8
else (
self.shared_experts.down_proj.weight_scale_inv
if self.shared_experts_is_fp8
else None
)
), # w2_scale
(
self.shared_experts_weight_block_size
if self.shared_experts_is_fp8
else None
), # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states

def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
Expand Down Expand Up @@ -2107,6 +2182,14 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
)
if _is_hip:
self_attn.w_scale *= 2.0
# TODO: remove this after adding FP8 support in bmm cpu kernel
if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
self_attn.w_kc = (
self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
)
self_attn.w_vc = (
self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
)
else:
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
Expand Down
Loading