-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[CPU] Use weight_packed_linear kernel for linear, MoEGate and lm_head #6657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -454,11 +454,20 @@ def _get_logits( | |
| dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) | ||
|
|
||
| if hasattr(lm_head, "weight"): | ||
| logits = torch.matmul( | ||
| hidden_states.to(lm_head.weight.dtype), lm_head.weight.T | ||
| ) | ||
| if lm_head.use_intel_amx_backend: | ||
| logits = torch.ops.sgl_kernel.weight_packed_linear( | ||
| hidden_states.to(lm_head.weight.dtype), | ||
| lm_head.weight, | ||
| None, # bias | ||
| True, # is_vnni | ||
| ) | ||
|
Comment on lines
+457
to
+463
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| logits = torch.matmul( | ||
| hidden_states.to(lm_head.weight.dtype), lm_head.weight.T | ||
| ) | ||
| else: | ||
| # GGUF models | ||
| # TODO: use weight_packed_linear for GGUF models | ||
| logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) | ||
|
|
||
| if self.logit_scale is not None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,12 @@ | |
| QuantizationConfig, | ||
| QuantizeMethodBase, | ||
| ) | ||
| from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs | ||
| from sglang.srt.utils import ( | ||
| _process_weight_after_loading, | ||
| get_bool_env_var, | ||
| is_hip, | ||
| set_weight_attrs, | ||
| ) | ||
|
|
||
| if torch.cuda.is_available(): | ||
| from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts | ||
|
|
@@ -115,6 +120,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| requires_grad=False, | ||
| ) | ||
| torch.cuda.empty_cache() | ||
|
|
||
| _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) | ||
|
|
||
| return | ||
|
|
||
| def apply( | ||
|
|
@@ -236,18 +244,58 @@ def forward_cpu( | |
| correction_bias: Optional[torch.Tensor] = None, | ||
| inplace: bool = True, | ||
| ) -> torch.Tensor: | ||
| return moe_forward_native( | ||
| layer, | ||
| x, | ||
| use_grouped_topk, | ||
| top_k, | ||
| router_logits, | ||
| renormalize, | ||
| topk_group, | ||
| num_expert_group, | ||
| custom_routing_function, | ||
| correction_bias, | ||
| ) | ||
| assert activation == "silu", f"activation = {activation} is not supported." | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # TODO: rebase after #6441 lands | ||
| if layer.use_intel_amx_backend: | ||
| # if cpu_has_amx_support(): ---> #6441 | ||
| topk_weights, topk_ids = select_experts( | ||
| hidden_states=x, | ||
| router_logits=router_logits, | ||
| use_grouped_topk=use_grouped_topk, | ||
| top_k=top_k, | ||
| renormalize=renormalize, | ||
| topk_group=topk_group, | ||
| num_expert_group=num_expert_group, | ||
| custom_routing_function=custom_routing_function, | ||
| correction_bias=correction_bias, | ||
| routed_scaling_factor=routed_scaling_factor, | ||
| ) | ||
|
|
||
| return torch.ops.sgl_kernel.fused_experts_cpu( | ||
| x, | ||
| layer.w13_weight, | ||
| layer.w2_weight, | ||
| topk_weights, | ||
| topk_ids, | ||
| True, # inplace | ||
| False, # use_int8_w8a8 | ||
| False, # use_fp8_w8a16 | ||
| None, # w1_scale | ||
| None, # w2_scale | ||
| None, # block_size | ||
| None, # a1_scale | ||
| None, # a2_scale | ||
| True, # is_vnni | ||
|
Comment on lines
+265
to
+279
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
| else: | ||
| return moe_forward_native( | ||
| layer, | ||
| x, | ||
| use_grouped_topk, | ||
| top_k, | ||
| router_logits, | ||
| renormalize, | ||
| topk_group, | ||
| num_expert_group, | ||
| custom_routing_function, | ||
| correction_bias, | ||
| activation, | ||
| apply_router_weight_on_input, | ||
| inplace, | ||
| no_combine, | ||
| routed_scaling_factor, | ||
| ) | ||
|
Comment on lines
+282
to
+298
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The call to Its signature is: def moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
# ... other params ...
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:Could you clarify if |
||
|
|
||
| def forward_tpu(self, *args, **kwargs) -> torch.Tensor: | ||
| raise NotImplementedError("The TPU backend currently does not support MoE.") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,7 @@ | |
| from sglang.srt.utils import ( | ||
| BumpAllocator, | ||
| DeepEPMode, | ||
| PackWeightMethod, | ||
| add_prefix, | ||
| get_bool_env_var, | ||
| get_int_env_var, | ||
|
|
@@ -201,8 +202,17 @@ def __init__( | |
| ) | ||
| else: | ||
| self.e_score_correction_bias = None | ||
| self.quant_method = PackWeightMethod(weight_names=["weight"]) | ||
|
|
||
| def forward(self, hidden_states): | ||
| if self.use_intel_amx_backend: | ||
| return torch.ops.sgl_kernel.weight_packed_linear( | ||
| hidden_states, | ||
| self.weight, | ||
| None, # bias | ||
| True, # is_vnni | ||
| ) | ||
|
Comment on lines
+208
to
+214
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| logits = F.linear(hidden_states, self.weight, None) | ||
| return logits | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
is_vnniparameter forweight_packed_linearis hardcoded toTrue. Could you confirm if this is always the case whenuse_intel_amx_backendis true? It's likely correct given that AMX usage often implies VNNI-packed weights, but a confirmation or a brief comment explaining this assumption would be helpful for future maintainability.