Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
prepack_weight_if_needed,
set_weight_attrs,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,13 +170,21 @@ def create_weights(
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
_process_weight_after_loading(layer, ["weight"])

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if layer.use_intel_amx_backend:
return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
)
Comment on lines +183 to +186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The is_vnni parameter for weight_packed_linear is hardcoded to True. Could you confirm if this is always the case when use_intel_amx_backend is true? It's likely correct given that AMX usage often implies VNNI-packed weights, but a confirmation or a brief comment explaining this assumption would be helpful for future maintainability.


return F.linear(x, layer.weight, bias)


Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,20 @@ def _get_logits(
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

if hasattr(lm_head, "weight"):
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
if lm_head.use_intel_amx_backend:
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
None, # bias
True, # is_vnni
)
Comment on lines +457 to +463
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to other weight_packed_linear calls, is_vnni is hardcoded to True. Is this assumption universally valid when lm_head.use_intel_amx_backend is true? A brief comment clarifying this would be beneficial.

else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)

if self.logit_scale is not None:
Expand Down
74 changes: 61 additions & 13 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
from sglang.srt.utils import (
_process_weight_after_loading,
get_bool_env_var,
is_hip,
set_weight_attrs,
)

if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
Expand Down Expand Up @@ -115,6 +120,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False,
)
torch.cuda.empty_cache()

_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

return

def apply(
Expand Down Expand Up @@ -236,18 +244,58 @@ def forward_cpu(
correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
correction_bias,
)
assert activation == "silu", f"activation = {activation} is not supported."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assert activation == "silu" restricts this optimized forward_cpu path (and consequently the AMX path for MoE) to SiLU activation. Is this an intended limitation for the initial AMX support, perhaps due to the fused_experts_cpu kernel's capabilities? If so, it might be worth a comment.


# TODO: rebase after #6441 lands
if layer.use_intel_amx_backend:
# if cpu_has_amx_support(): ---> #6441
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)

return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
True, # inplace
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
None, # w2_scale
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
Comment on lines +265 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The is_vnni parameter for fused_experts_cpu is hardcoded to True. Is this always the correct setting when use_intel_amx_backend is active? This seems consistent with the other AMX kernel calls.

)
else:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
Comment on lines +282 to +298
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to moe_forward_native here includes apply_router_weight_on_input, inplace, and no_combine as arguments. However, the moe_forward_native function defined in python/sglang/srt/layers/moe/fused_moe_native.py (as per the full file context) does not seem to accept these parameters.

Its signature is:

def moe_forward_native(
    layer: torch.nn.Module,
    x: torch.Tensor,
    use_grouped_topk: bool,
    # ... other params ...
    activation: str = "silu",
    routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:

Could you clarify if moe_forward_native's signature is expected to be updated in one of the prerequisite PRs (e.g., #6641)? If not, this call would lead to a runtime error.


def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,10 @@ def __init__(
use_presharded_weights=use_presharded_weights,
)
self.quant_config = quant_config

from sglang.srt.utils import PackWeightMethod

self.quant_method = PackWeightMethod(weight_names=["weight"])
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
PackWeightMethod,
add_prefix,
get_bool_env_var,
get_int_env_var,
Expand Down Expand Up @@ -201,8 +202,17 @@ def __init__(
)
else:
self.e_score_correction_bias = None
self.quant_method = PackWeightMethod(weight_names=["weight"])

def forward(self, hidden_states):
if self.use_intel_amx_backend:
return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states,
self.weight,
None, # bias
True, # is_vnni
)
Comment on lines +208 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The is_vnni parameter is hardcoded to True for the weight_packed_linear call. Is this always the case for DeepSeekV2MoEGate when AMX is used? A clarifying comment could be helpful.


logits = F.linear(hidden_states, self.weight, None)
return logits

Expand Down
29 changes: 29 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,3 +2169,32 @@ def with_value(self, new_value: T):
finally:
assert self._value is new_value
self._value = None


def _process_weight_after_loading(module, weight_names) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
assert len(devices) == 1, f"Expects all weights to be on the same device"
device = devices.pop()

for weight_name in weight_names:
setattr(
module,
weight_name,
torch.nn.Parameter(
prepack_weight_if_needed(getattr(module, weight_name)),
requires_grad=False,
),
)

module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)


class PackWeightMethod:
def __init__(self, weight_names):
self.weight_names = weight_names

def process_weights_after_loading(self, module) -> None:
_process_weight_after_loading(module, self.weight_names)