Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dd54bcf
Use fused_experts_cpu and add weight packing (#8)
chunyuan-w Mar 12, 2025
c1781be
switch to weight_packed_linear if cpu_has_amx_support (#11)
chunyuan-w Mar 14, 2025
1e4e60b
Switch to weight_packed_linear for MoEGate and lm_head (#16)
chunyuan-w Mar 19, 2025
238d29f
Replace torch.bmm in forward_absorb with sgl_kernel.cpu.bmm (#21)
chunyuan-w Mar 26, 2025
1f77ae6
don't use c++ kernel if apply_router_weight_on_input is True
chunyuan-w Jun 5, 2025
27e6501
Integrate qkv_proj_with_rope (#34)
chunyuan-w Apr 2, 2025
16157b3
update API for fused_qkv_a_proj_with_mqa
chunyuan-w Jun 6, 2025
146470f
revert changes to bmm
chunyuan-w Jun 6, 2025
a6253a9
update qkv_proj OP name
chunyuan-w Jun 6, 2025
6c98ba8
refine comment
chunyuan-w Jun 6, 2025
f687f62
only pack weight is using the fused_qkv_proj_with_rope kernel
chunyuan-w Jun 6, 2025
21b04a6
remove dead code
chunyuan-w Jun 11, 2025
d170e11
Merge branch 'main' into chunyuan/pr_frontend_moe
zhyncs Jun 12, 2025
7e43b73
update qkv_proj OP name
chunyuan-w Jun 12, 2025
19078ee
fix if the obj does not have the use_intel_amx_backend attr
chunyuan-w Jun 12, 2025
db7fcc9
cast bias to FP32 in process weight after load
gau-nernst Apr 9, 2025
cea74fd
fix when module.bias is None
chunyuan-w Jun 17, 2025
7411cf1
Merge branch 'main' into chunyuan/pr_frontend_moe
chunyuan-w Jun 17, 2025
6b1500b
Merge branch 'main' into chunyuan/pr_frontend_moe
chunyuan-w Jun 17, 2025
fb15b38
copy __dict__ from original weight param to packed weight param
chunyuan-w Jun 18, 2025
88426bb
only pack LMHead weight if it's not quantized
chunyuan-w Jun 18, 2025
e5aa0ca
Only pack w_kc and w_vc for CPU
chunyuan-w Jun 18, 2025
9542064
Merge branch 'main' into chunyuan/pr_frontend_moe
chunyuan-w Jun 18, 2025
08d3140
revert the debug change
chunyuan-w Jun 18, 2025
726289d
Merge branch 'main' into chunyuan/pr_frontend_moe
zhyncs Jun 18, 2025
410aa32
use _is_cpu and _is_cpu_amx_available to check device
chunyuan-w Jun 18, 2025
056b7ef
Merge branch 'main' into chunyuan/pr_frontend_moe
zhyncs Jun 19, 2025
ea8fd2c
cast tok_weights to fp32 for llama4
chunyuan-w Jun 19, 2025
39c684c
add comment for topk_weights
chunyuan-w Jun 19, 2025
c3990ec
fix gemm kernel when N is small
chunyuan-w Jun 19, 2025
0c4c6b5
Merge branch 'main' into chunyuan/pr_frontend_moe
zhyncs Jun 20, 2025
14908f1
move import to the top
chunyuan-w Jun 20, 2025
04d5e39
only set self.quant_method or call _process_weight_after_loading if _…
chunyuan-w Jun 20, 2025
4526d13
Merge branch 'main' into chunyuan/pr_frontend_moe
zhyncs Jun 20, 2025
f5ef0f2
don't pack weight or use intel amx backend if any weight of this modu…
chunyuan-w Jun 24, 2025
71c48a4
Merge branch 'main' into chunyuan/pr_frontend_moe
zhyncs Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
)

logger = logging.getLogger(__name__)

Expand All @@ -52,6 +57,9 @@
"IPEXAWQLinearMethod",
]

_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()


def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
Expand Down Expand Up @@ -165,13 +173,22 @@ def create_weights(
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["weight"])

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if getattr(layer, "use_intel_amx_backend", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make getattr(sth, "use_intel_amx_backend", False) as a method in utils.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #7647

return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
)

return F.linear(x, layer.weight, bias)


Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,20 @@ def _get_logits(
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

if hasattr(lm_head, "weight"):
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
if getattr(lm_head, "use_intel_amx_backend", False):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
None, # bias
True, # is_vnni
)
else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)

if self.logit_scale is not None:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,15 @@ def moe_forward_native(
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:

if apply_router_weight_on_input:
raise NotImplementedError()

topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
Expand Down
87 changes: 73 additions & 14 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_hip,
set_weight_attrs,
)

if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
Expand All @@ -28,6 +35,8 @@
import logging

_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if _use_aiter:
Expand Down Expand Up @@ -117,6 +126,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False,
)
torch.cuda.empty_cache()

# Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

return

def apply(
Expand Down Expand Up @@ -248,19 +262,64 @@ def forward_cpu(
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
)
assert activation == "silu", f"activation = {activation} is not supported."

if (
getattr(layer, "use_intel_amx_backend", False)
and not apply_router_weight_on_input
):
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)

# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights.to(
torch.float
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
topk_ids,
True, # inplace
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
None, # w2_scale
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
else:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)

def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
Expand Down
15 changes: 14 additions & 1 deletion python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@
QuantizeMethodBase,
method_has_implemented_embedding,
)
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import (
PackWeightMethod,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
)

DEFAULT_VOCAB_PADDING_SIZE = 64

_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()


class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
Expand Down Expand Up @@ -549,6 +557,11 @@ def __init__(
use_presharded_weights=use_presharded_weights,
)
self.quant_config = quant_config

# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"])

if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
Expand Down
Loading
Loading