Skip to content
86 changes: 86 additions & 0 deletions python/sglang/srt/layers/amx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging

import torch

from sglang.srt.utils import cpu_has_amx_support

logger = logging.getLogger(__name__)


def amx_process_weight_after_loading(weight):
if weight.device != torch.device("cpu"):
return weight
if not cpu_has_amx_support():
return weight

return torch.ops.sgl_kernel.convert_weight_packed(weight)


# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def dim_is_supported(weight):
TILE_N = 16
TILE_K = 32
ndim = weight.ndim
OC = weight.size(1) if ndim == 3 else weight.size(0)
IC = weight.size(2) if ndim == 3 else weight.size(1)
return OC % TILE_N == 0 and IC % TILE_K == 0


def _amx_process_weight_after_loading(
module, weight_names, transpose_dims=None
) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
assert len(devices) == 1, f"Expects all weights to be on the same device"
device = devices.pop()

if transpose_dims:
assert len(weight_names) == len(
transpose_dims
), "len(weight_names) should be equal to len(transpose_dims)"

for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)

if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])

# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
logger.warning(
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
)
module.use_intel_amx_backend = False
return

packed_weight = torch.nn.Parameter(
amx_process_weight_after_loading(weight_tensor),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)

module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)

if (
module.use_intel_amx_backend
and hasattr(module, "bias")
and module.bias is not None
):
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)


class PackWeightMethod:
def __init__(self, weight_names, transpose_dims=None):
self.weight_names = weight_names
self.transpose_dims = transpose_dims

def process_weights_after_loading(self, module) -> None:
_amx_process_weight_after_loading(
module, self.weight_names, self.transpose_dims
)
7 changes: 4 additions & 3 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
Expand All @@ -31,10 +32,10 @@
QuantizeMethodBase,
)
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
use_intel_amx_backend,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -175,7 +176,7 @@ def create_weights(

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["weight"])
_amx_process_weight_after_loading(layer, ["weight"])

def apply(
self,
Expand All @@ -184,7 +185,7 @@ def apply(
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if getattr(layer, "use_intel_amx_backend", False):
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import dump_to_file
from sglang.srt.utils import dump_to_file, use_intel_amx_backend

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -442,7 +442,7 @@ def _get_logits(
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)

if hasattr(lm_head, "weight"):
if getattr(lm_head, "use_intel_amx_backend", False):
if use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
Expand Down
10 changes: 4 additions & 6 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_hip,
set_weight_attrs,
use_intel_amx_backend,
)

if torch.cuda.is_available():
Expand Down Expand Up @@ -129,7 +130,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

# Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available:
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

return

Expand Down Expand Up @@ -264,10 +265,7 @@ def forward_cpu(
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."

if (
getattr(layer, "use_intel_amx_backend", False)
and not apply_router_weight_on_input
):
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def dummy_func(*args, **kwargs):


from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
Expand Down Expand Up @@ -64,7 +65,6 @@ def dummy_func(*args, **kwargs):
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
Expand All @@ -74,6 +74,7 @@ def dummy_func(*args, **kwargs):
log_info_on_rank0,
print_warning_once,
set_weight_attrs,
use_intel_amx_backend,
)

_is_hip = is_hip()
Expand Down Expand Up @@ -335,7 +336,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
assert (
_is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["weight"])
_amx_process_weight_after_loading(layer, ["weight"])
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
Expand Down Expand Up @@ -433,7 +434,7 @@ def apply(
)

if self.block_quant:
if getattr(layer, "use_intel_amx_backend", False):
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
x,
layer.weight,
Expand Down Expand Up @@ -769,7 +770,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
assert (
_is_cpu_amx_available
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

return

Expand Down Expand Up @@ -996,7 +997,7 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if getattr(layer, "use_intel_amx_backend", False):
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.parameter import Parameter

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
Expand All @@ -12,11 +13,11 @@
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import (
_process_weight_after_loading,
cpu_has_amx_support,
is_cpu,
is_cuda,
set_weight_attrs,
use_intel_amx_backend,
)

_is_cuda = is_cuda()
Expand Down Expand Up @@ -84,7 +85,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert (
_is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["weight"])
_amx_process_weight_after_loading(layer, ["weight"])
return

layer.weight = Parameter(layer.weight.t(), requires_grad=False)
Expand Down Expand Up @@ -127,7 +128,7 @@ def apply(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
if getattr(layer, "use_intel_amx_backend", False):
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
x,
layer.weight,
Expand Down Expand Up @@ -235,7 +236,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert (
_is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return

layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
Expand Down Expand Up @@ -284,7 +285,7 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
)

if getattr(layer, "use_intel_amx_backend", False):
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
Expand Down
8 changes: 2 additions & 6 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,15 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
method_has_implemented_embedding,
)
from sglang.srt.utils import (
PackWeightMethod,
cpu_has_amx_support,
is_cpu,
set_weight_attrs,
)
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs

DEFAULT_VOCAB_PADDING_SIZE = 64

Expand Down
Loading
Loading