diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index f664d6bb634..3304340f18d 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -30,7 +30,12 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import ( + _process_weight_after_loading, + cpu_has_amx_support, + is_cpu, + set_weight_attrs, +) logger = logging.getLogger(__name__) @@ -52,6 +57,9 @@ "IPEXAWQLinearMethod", ] +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) @@ -165,6 +173,10 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _is_cpu and _is_cpu_amx_available: + _process_weight_after_loading(layer, ["weight"]) + def apply( self, layer: torch.nn.Module, @@ -172,6 +184,11 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if getattr(layer, "use_intel_amx_backend", False): + return torch.ops.sgl_kernel.weight_packed_linear( + x, layer.weight, bias, True # is_vnni + ) + return F.linear(x, layer.weight, bias) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index c8b4ecd4cc4..e9829518446 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -442,11 +442,20 @@ def _get_logits( dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) if hasattr(lm_head, "weight"): - logits = torch.matmul( - hidden_states.to(lm_head.weight.dtype), lm_head.weight.T - ) + if getattr(lm_head, "use_intel_amx_backend", False): + logits = torch.ops.sgl_kernel.weight_packed_linear( + hidden_states.to(lm_head.weight.dtype), + lm_head.weight, + None, # bias + True, # is_vnni + ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) else: # GGUF models + # TODO: use weight_packed_linear for GGUF models logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) if self.logit_scale is not None: diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 05e8070381a..25645ad00e9 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -77,8 +77,15 @@ def moe_forward_native( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: + + if apply_router_weight_on_input: + raise NotImplementedError() + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 3d9dbf64f28..fd189889184 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -18,7 +18,14 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs +from sglang.srt.utils import ( + _process_weight_after_loading, + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_hip, + set_weight_attrs, +) if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -28,6 +35,8 @@ import logging _is_hip = is_hip() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: @@ -117,6 +126,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) torch.cuda.empty_cache() + + # Pack weight for get better performance on CPU + if _is_cpu and _is_cpu_amx_available: + _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + return def apply( @@ -248,19 +262,64 @@ def forward_cpu( no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - return moe_forward_native( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - ) + assert activation == "silu", f"activation = {activation} is not supported." + + if ( + getattr(layer, "use_intel_amx_backend", False) + and not apply_router_weight_on_input + ): + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel + return torch.ops.sgl_kernel.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights.to( + torch.float + ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32 + topk_ids, + True, # inplace + False, # use_int8_w8a8 + False, # use_fp8_w8a16 + None, # w1_scale + None, # w2_scale + None, # block_size + None, # a1_scale + None, # a2_scale + True, # is_vnni + ) + else: + return moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + inplace, + no_combine, + routed_scaling_factor, + ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("The TPU backend currently does not support MoE.") diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index ec7c140ae01..19a281e481e 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -20,10 +20,18 @@ QuantizeMethodBase, method_has_implemented_embedding, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import ( + PackWeightMethod, + cpu_has_amx_support, + is_cpu, + set_weight_attrs, +) DEFAULT_VOCAB_PADDING_SIZE = 64 +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() + class UnquantizedEmbeddingMethod(QuantizeMethodBase): """Unquantized method for embeddings.""" @@ -549,6 +557,11 @@ def __init__( use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config + + # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight" + if self.quant_config is None and _is_cpu and _is_cpu_amx_available: + self.quant_method = PackWeightMethod(weight_names=["weight"]) + if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, dtype=params_dtype) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 700a8ef6101..10079227d83 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -93,6 +93,7 @@ BumpAllocator, DeepEPMode, LazyValue, + PackWeightMethod, add_prefix, bind_or_assign, cpu_has_amx_support, @@ -144,6 +145,9 @@ class AttnForwardMethod(IntEnum): # Use MLA but with fused RoPE MLA_FUSED_ROPE = auto() + # Use MLA with fused RoPE kernel for CPU + MLA_FUSED_ROPE_CPU = auto() + class DeepseekV2MLP(nn.Module): def __init__( @@ -212,8 +216,18 @@ def __init__( ) else: self.e_score_correction_bias = None + if _is_cpu and _is_cpu_amx_available: + self.quant_method = PackWeightMethod(weight_names=["weight"]) def forward(self, hidden_states): + if getattr(self, "use_intel_amx_backend", False): + return torch.ops.sgl_kernel.weight_packed_linear( + hidden_states, + self.weight, + None, # bias + True, # is_vnni + ) + logits = F.linear(hidden_states, self.weight, None) return logits @@ -778,6 +792,37 @@ def __init__( "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192 ) + # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel + # which requires self.w_kc and self.w_vc to be packed. + # If not, we will use torch.bmm and weight shouldn't be packed in this case + if ( + hasattr(self, "fused_qkv_a_proj_with_mqa") + and _is_cpu + and _is_cpu_amx_available + ): + self.quant_method = PackWeightMethod( + weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]] + ) + + self.qkv_proj_with_rope_is_int8 = ( + hasattr(self, "fused_qkv_a_proj_with_mqa") + and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8 + ) + self.qkv_proj_with_rope_is_fp8 = ( + hasattr(self, "fused_qkv_a_proj_with_mqa") + and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn + ) + + self.weight_block_size = None + if self.qkv_proj_with_rope_is_fp8: + assert ( + self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size + == self.q_b_proj.quant_method.quant_config.weight_block_size + ) + self.weight_block_size = ( + self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size + ) + def dispatch_attn_forward_method( self, forward_batch: ForwardBatch ) -> AttnForwardMethod: @@ -791,7 +836,12 @@ def _dispatch_mla_subtype(): else: return AttnForwardMethod.MLA else: - return AttnForwardMethod.MLA + if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr( + self, "use_intel_amx_backend", False + ): + return AttnForwardMethod.MLA_FUSED_ROPE_CPU + else: + return AttnForwardMethod.MLA if self.attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill @@ -905,6 +955,10 @@ def forward_prepare( inner_state = self.forward_absorb_fused_mla_rope_prepare( positions, hidden_states, forward_batch, zero_allocator ) + elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU: + inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) else: raise NotImplementedError return None, attn_forward_method, forward_batch, inner_state @@ -924,6 +978,8 @@ def forward_core(self, intermediate_state): return self.forward_absorb_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: return self.forward_absorb_fused_mla_rope_core(*inner_state) + elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU: + return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state) else: raise NotImplementedError @@ -1241,6 +1297,57 @@ def forward_absorb_fused_mla_rope_prepare( zero_allocator, ) + def forward_absorb_fused_mla_rope_cpu_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, + ): + assert self.q_lora_rank is not None and getattr( + self, "use_intel_amx_backend", False + ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend" + + q_input, k_input, v_input = ( + torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight( + hidden_states, + self.fused_qkv_a_proj_with_mqa.weight, + self.q_b_proj.weight, + self.w_kc, + self.q_a_layernorm.weight, + self.kv_a_layernorm.weight, + positions, + self.rotary_emb.cos_sin_cache, + self.kv_a_layernorm.variance_epsilon, + self.qkv_proj_with_rope_is_int8, + self.qkv_proj_with_rope_is_fp8, + ( + self.fused_qkv_a_proj_with_mqa.weight_scale + if self.qkv_proj_with_rope_is_int8 + else ( + self.fused_qkv_a_proj_with_mqa.weight_scale_inv + if self.qkv_proj_with_rope_is_fp8 + else None + ) + ), + ( + self.q_b_proj.weight_scale + if self.qkv_proj_with_rope_is_int8 + else ( + self.q_b_proj.weight_scale_inv + if self.qkv_proj_with_rope_is_fp8 + else None + ) + ), + True, # is_vnni + self.weight_block_size, + self.q_lora_rank, + self.kv_lora_rank, + self.qk_rope_head_dim, + ) + ) + return (q_input, k_input, v_input, forward_batch, zero_allocator) + def forward_absorb_fused_mla_rope_core( self, q_input, @@ -1314,6 +1421,43 @@ def forward_absorb_fused_mla_rope_core( return output + def forward_absorb_fused_mla_rope_cpu_core( + self, q_input, k_input, v_input, forward_batch, zero_allocator + ): + assert self.q_lora_rank is not None and getattr( + self, "use_intel_amx_backend", False + ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend" + + attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + + # [Note] Align shapes of bmm inputs. + # Shapes of inputs: + # q_nope: [M, B, K] + # original self.w_kc: [B, K, N] + # current self.w_kc (which has been converted in PackWeightMethod): [B, N, K] + + # Shapes of inputs to sgl_kernel.cpu.bmm: + # out: [B, M, N] + # mat1: [B, M, K] + # mat2: [B, N, K] + B = self.w_vc.size(0) + N = self.w_vc.size(1) + M = attn_output.size(0) + output = torch.empty([M, int(B * N)], dtype=attn_output.dtype) + attn_bmm_output = output.view([M, B, N]).transpose_(0, 1) + torch.ops.sgl_kernel.bmm_cpu( + attn_bmm_output, + attn_output.transpose(0, 1), + self.w_vc, + True, # is_vnni + None, # scale + ) + attn_output = output + output, _ = self.o_proj(attn_output) + + return output + def _chunked_prefix_attn_mha( self, q: torch.Tensor, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ed43e04503f..2c0c86f3704 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2457,6 +2457,77 @@ def cpu_has_amx_support(): return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available +def prepack_weight_if_needed(weight): + if weight.device != torch.device("cpu"): + return weight + if not cpu_has_amx_support(): + return weight + + return torch.ops.sgl_kernel.convert_weight_packed(weight) + + +# TODO: currently gemm kernel has the below requirements: +# OC % TILE_N == 0, where TILE_N = 16 +# IC % TILE_K == 0, where TILE_K = 32 +def dim_is_supported(weight): + return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 + + +def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None: + # Pack weight for get better performance on CPU + devices = {getattr(module, weight_name).device for weight_name in weight_names} + assert len(devices) == 1, f"Expects all weights to be on the same device" + device = devices.pop() + + if transpose_dims: + assert len(weight_names) == len( + transpose_dims + ), "len(weight_names) should be equal to len(transpose_dims)" + + for i, weight_name in enumerate(weight_names): + weight_tensor = getattr(module, weight_name) + + # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim. + if not dim_is_supported(weight_tensor): + logger.warning( + f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 " + f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. " + f"{module} won't use intel amx backend." + ) + module.use_intel_amx_backend = False + return + + if transpose_dims and transpose_dims[i]: + weight_tensor = weight_tensor.transpose(*transpose_dims[i]) + + packed_weight = torch.nn.Parameter( + prepack_weight_if_needed(weight_tensor), + requires_grad=False, + ) + packed_weight.__dict__ = weight_tensor.__dict__ + setattr(module, weight_name, packed_weight) + + module.use_intel_amx_backend = ( + device == torch.device("cpu") and cpu_has_amx_support() + ) + + if ( + module.use_intel_amx_backend + and hasattr(module, "bias") + and module.bias is not None + ): + module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False) + + +class PackWeightMethod: + def __init__(self, weight_names, transpose_dims=None): + self.weight_names = weight_names + self.transpose_dims = transpose_dims + + def process_weights_after_loading(self, module) -> None: + _process_weight_after_loading(module, self.weight_names, self.transpose_dims) + + class LazyValue: def __init__(self, creator: Callable): self._creator = creator diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 8cdebb9a2c9..2cce8ddac5a 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -318,8 +318,8 @@ void weight_packed_linear_kernel_impl( const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx - const bool use_brgemm = (M > 4) || (!std::is_same_v); + // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small + const bool use_brgemm = (M > 4) || (!std::is_same_v) || (N < 64); // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { diff --git a/test/srt/cpu/test_gemm.py b/test/srt/cpu/test_gemm.py index 7404d060e4b..ebadad7a038 100644 --- a/test/srt/cpu/test_gemm.py +++ b/test/srt/cpu/test_gemm.py @@ -28,7 +28,7 @@ def forward(self, x): class TestGemm(CustomTestCase): M = [1, 101] - N = [32 * 13] + N = [16, 32 * 13] K = [32 * 16] has_bias = [False, True]