From 626ec0a7c5c1ccf4ca86cdbefa00193e052ac0ee Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Sun, 2 Nov 2025 21:40:24 -0500 Subject: [PATCH 01/20] add CPU optimized frontend for qwen3-next --- python/sglang/srt/layers/amx_utils.py | 20 +++++++--- .../attention/hybrid_linear_attn_backend.py | 10 ++++- python/sglang/srt/layers/layernorm.py | 24 +++++++++++ python/sglang/srt/models/qwen3_next.py | 40 ++++++++++++++----- 4 files changed, 77 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/amx_utils.py b/python/sglang/srt/layers/amx_utils.py index df2a05ba538a..078bf2298cc4 100644 --- a/python/sglang/srt/layers/amx_utils.py +++ b/python/sglang/srt/layers/amx_utils.py @@ -7,13 +7,17 @@ logger = logging.getLogger(__name__) -def amx_process_weight_after_loading(weight): +def amx_process_weight_after_loading(weight, is_conv=False): if weight.device != torch.device("cpu"): return weight if not cpu_has_amx_support(): return weight - - return torch.ops.sgl_kernel.convert_weight_packed(weight) + if is_conv: + return torch.ops.sgl_kernel.causal_conv1d_weight_pack( + weight.view(-1, weight.size(-1)) + ) + else: + return torch.ops.sgl_kernel.convert_weight_packed(weight) # TODO: currently gemm kernel has the below requirements: @@ -28,6 +32,10 @@ def dim_is_supported(weight): return OC % TILE_N == 0 and IC % TILE_K == 0 +def is_dim_conv_weight(weight): + weight.dim() == 3 and weight.size(1) == 1 + + def _amx_process_weight_after_loading( module, weight_names, transpose_dims=None ) -> None: @@ -46,9 +54,9 @@ def _amx_process_weight_after_loading( if transpose_dims and transpose_dims[i]: weight_tensor = weight_tensor.transpose(*transpose_dims[i]) - + is_conv_weight = is_dim_conv_weight(weight_tensor) # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim. - if not dim_is_supported(weight_tensor): + if (not dim_is_supported(weight_tensor)) and (not is_conv_weight): logger.warning( f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. " f"The derived (OC, IC) dimensions must be divisible by (16, 32). " @@ -57,7 +65,7 @@ def _amx_process_weight_after_loading( return packed_weight = torch.nn.Parameter( - amx_process_weight_after_loading(weight_tensor), + amx_process_weight_after_loading(weight_tensor, is_conv_weight), requires_grad=False, ) packed_weight.__dict__ = weight_tensor.__dict__ diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 016b803ffbbb..6c08facf6009 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -33,7 +33,7 @@ from sglang.srt.models.qwen3_next import fused_gdn_gating from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import is_cuda, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu if is_cuda(): from sglang.srt.layers.attention.mamba.causal_conv1d import ( @@ -55,6 +55,14 @@ fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu causal_conv1d_fn = causal_conv1d_fn_npu causal_conv1d_update = causal_conv1d_update_npu +elif is_cpu() and cpu_has_amx_support(): + chunk_gated_delta_rule = torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu + causal_conv1d_fn = torch.ops.sgl_kernel.causal_conv1d_fwd_cpu + causal_conv1d_update = torch.ops.sgl_kernel.causal_conv1d_update_cpu + fused_sigmoid_gating_delta_rule_update = ( + torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu + ) + fused_gdn_gating = torch.ops.sgl_kernel.fused_gdn_gating_cpu class MambaAttnBackendBase(AttentionBackend): diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3be20f2ea5c7..6eb311bc7818 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from packaging.version import Version from sglang.srt.batch_invariant_ops import ( @@ -388,6 +389,29 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class Qwen3NextRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward_native(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + # Norm before gate + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + def forward_cpu(self, hidden_states, gate=None): + out = torch.ops.sgl_kernel.qwen3_next_rmsnorm_gated_cpu( + hidden_states, self.weight, gate, self.variance_epsilon + ) + return out + + if not ( _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu ): diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index d817076fbfad..0aa32648e69a 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -17,7 +17,7 @@ get_attention_tp_size, is_dp_attention_enabled, ) -from sglang.srt.layers.layernorm import GemmaRMSNorm +from sglang.srt.layers.layernorm import GemmaRMSNorm, Qwen3NextRMSNormGated from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -43,6 +43,8 @@ from sglang.srt.utils import ( LazyValue, add_prefix, + cpu_has_amx_support, + is_cpu, is_cuda, is_npu, make_layers, @@ -52,6 +54,10 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu = is_cpu() +_is_amx_available = cpu_has_amx_support() +if _is_cpu and _is_amx_available: + pass import triton import triton.language as tl @@ -326,15 +332,18 @@ def __init__( set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=torch.get_device_module().current_device(), - dtype=config.torch_dtype, - ) + if _is_cpu: + self.norm = Qwen3NextRMSNormGated( + self.head_v_dim, eps=self.layer_norm_epsilon + ) + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.get_device_module().current_device(), + dtype=config.torch_dtype, + ) self.out_proj = RowParallelLinear( self.value_dim, @@ -429,6 +438,17 @@ def forward( self.head_k_dim, self.head_v_dim, ) + elif _is_cpu and _is_amx_available: + mixed_qkv, z, b, a = ( + torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( + projected_states_qkvz, + projected_states_ba, + self.num_k_heads // self.attn_tp_size, + self.num_v_heads // self.attn_tp_size, + self.head_k_dim, + self.head_v_dim, + ) + ) else: query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba From 7d7fa12e6673fdbefa8bffcea1b5cd179e846ffd Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 3 Nov 2025 10:42:27 +0800 Subject: [PATCH 02/20] minor fix --- python/sglang/srt/models/qwen3_next.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 0aa32648e69a..d5e380d78ce3 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -336,6 +336,7 @@ def __init__( self.norm = Qwen3NextRMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon ) + else: self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, From b1472a19c1816d5f2be564a3a6d3373e63de9a2a Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Sun, 2 Nov 2025 21:48:58 -0500 Subject: [PATCH 03/20] memory pool changes for amx conv --- python/sglang/srt/mem_cache/memory_pool.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index d9047d1ce47d..33177ef9de55 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -49,7 +49,9 @@ set_mla_kv_scale_buffer_triton, ) from sglang.srt.utils import ( + cpu_has_amx_support, get_bool_env_var, + is_cpu, is_cuda, is_float4_e2m1fn_x2, is_npu, @@ -66,6 +68,8 @@ GB = 1024 * 1024 * 1024 _is_cuda = is_cuda() _is_npu = is_npu() +_is_cpu = is_cpu() +_cpu_has_amx_support = cpu_has_amx_support() if _is_npu: import torch_npu @@ -192,12 +196,23 @@ def __init__( ] else: # assume conv_state = (dim, state_len) + assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.zeros( size=(num_mamba_layers, size + 1) + conv_state_shape, dtype=conv_dtype, device=device, ) + if _is_cpu and _cpu_has_amx_support: + conv_state = conv_state.as_strided_( + conv_state.size(), + ( + conv_state.stride(0), + conv_state.stride(1), + 1, + conv_state.size(2), + ), + ) temporal_state = torch.zeros( size=(num_mamba_layers, size + 1) + temporal_state_shape, dtype=ssm_dtype, From 6be8b132db6bdb8feb8a8edf80f1d58fb7ebb515 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 31 Oct 2025 03:45:32 -0400 Subject: [PATCH 04/20] add TP padding for qwen3-next on CPU --- python/sglang/srt/configs/qwen3_next.py | 13 ++++++ python/sglang/srt/configs/update_config.py | 25 +++++++++++ .../srt/layers/attention/mamba/mamba.py | 41 ++++++++++++++++++- .../sglang/srt/model_loader/weight_utils.py | 24 +++++++++-- python/sglang/srt/models/qwen3_next.py | 12 +++++- 5 files changed, 109 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py index cd1b6f1ea59a..347b5ee015c6 100644 --- a/python/sglang/srt/configs/qwen3_next.py +++ b/python/sglang/srt/configs/qwen3_next.py @@ -21,6 +21,7 @@ from transformers.utils import logging from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape +from sglang.srt.utils import is_cpu logger = logging.get_logger(__name__) @@ -278,6 +279,18 @@ def full_attention_layer_ids(self): def mamba2_cache_params(self) -> Mamba2CacheParams: from sglang.srt.layers.dp_attention import get_attention_tp_size + world_size = get_attention_tp_size() + if is_cpu() and ( + self.linear_num_key_heads % world_size != 0 + or self.linear_num_value_heads % world_size != 0 + ): + pad_size = world_size + groups = self.linear_num_value_heads // self.linear_num_key_heads + self.linear_num_key_heads = ( + (self.linear_num_key_heads + pad_size - 1) // pad_size + ) * pad_size + self.linear_num_value_heads = self.linear_num_key_heads * groups + shape = Mamba2StateShape.create( tp_world_size=get_attention_tp_size(), intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index abbd724fb141..aa259b6375f0 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -119,6 +119,28 @@ def adjust_config_with_unaligned_cpu_tp( model_config.hf_config.num_attention_heads = num_attention_heads model_config.hf_text_config.num_attention_heads = num_attention_heads + # Linear attn padding logic + if ( + model_config.hf_config.linear_num_key_heads % tp_size != 0 + or model_config.hf_config.linear_num_value_heads % tp_size != 0 + ): + pad_size = get_num_heads_padding_size(tp_size, weight_block_size) + model_config.hf_config.linear_num_key_heads_cpu = pad_vocab_size( + model_config.hf_config.linear_num_key_heads, pad_size + ) + model_config.hf_config.linear_num_value_heads_cpu = ( + model_config.hf_config.linear_num_key_heads_cpu + * model_config.hf_config.linear_num_value_heads + // model_config.hf_config.linear_num_key_heads + ) + else: + model_config.hf_config.linear_num_key_heads_cpu = ( + model_config.hf_config.linear_num_key_heads + ) + model_config.hf_config.linear_num_value_heads_cpu = ( + model_config.hf_config.linear_num_value_heads + ) + intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size) model_config = update_intermediate_size( model_config, "moe_intermediate_size", intermediate_padding_size @@ -129,6 +151,9 @@ def adjust_config_with_unaligned_cpu_tp( model_config = update_intermediate_size( model_config, "intermediate_size_mlp", intermediate_padding_size ) + model_config = update_intermediate_size( + model_config, "shared_expert_intermediate_size", intermediate_padding_size + ) if ( hasattr(model_config.hf_config, "vision_config") and model_config.hf_config.vision_config.model_type == "siglip_vision_model" diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 2520284899ee..7954dc8c9e20 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -30,7 +30,7 @@ composed_weight_loader, sharded_weight_loader, ) -from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs if is_cuda(): from sglang.srt.layers.attention.mamba.causal_conv1d import ( @@ -70,6 +70,18 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - track boundary of (sharded) param, and loaded_weight, respectively boundary, loaded_boundary = 0, 0 + # Calculate padding size for CPU when TP odd size + full_dim_sum = 0 + full_dim_list = [] + weight_full_dim_list = [] + for full_dim, _, _ in shard_spec: + full_dim_sum = full_dim_sum + full_dim + full_dim_list.append(full_dim) + for full_dim in full_dim_list: + weight_full_dim_list.append( + int(full_dim / full_dim_sum * loaded_weight.size(0)) + ) + # - iterate over the shard specs for full_dim, extra, duplicate_groups in shard_spec: # - full dim is the model dim (before TP). @@ -96,6 +108,33 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - take these many dims from the loaded weight. take = min(shard_size, full_dim - extra - loaded_skip) + # CPU logic of padding size for qwen3-next + # TODO : make this common for all mamba. + if is_cpu() and loaded_weight.size(0) % tp_size != 0: + import copy + + loaded_weight_ = copy.deepcopy(loaded_weight) + q, k, v = torch.split( + loaded_weight_, + weight_full_dim_list, + dim=0, + ) + pad_qk = torch.zeros( + full_dim_list[0] - weight_full_dim_list[0], + loaded_weight.size(1), + loaded_weight.size(2), + ).to(loaded_weight.dtype) + pad_v = torch.zeros( + full_dim_list[2] - weight_full_dim_list[2], + loaded_weight.size(1), + loaded_weight.size(2), + ).to(loaded_weight.dtype) + q = torch.cat((q, pad_qk), dim=0) + k = torch.cat((k, pad_qk), dim=0) + v = torch.cat((v, pad_v), dim=0) + loaded_weight_qk = torch.cat((q, k), dim=0) + loaded_weight = torch.cat((loaded_weight_qk, v), dim=0) + # - always shard on dim 0 # - the ignore is for a mundane mypy error as it does not # seem to handle slices well. diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index a7b987e110f4..bdfdfb852d0a 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -35,13 +35,18 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.layers.dp_attention import get_attention_tp_rank +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, ModelOptFp8Config, ) -from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once +from sglang.srt.utils import ( + find_local_repo_dir, + is_cpu, + log_info_on_rank0, + print_warning_once, +) from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -855,10 +860,23 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction: def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: tp_rank = get_attention_tp_rank() + tp_size = get_attention_tp_size() shard_size = param.data.shape[shard_axis] start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + if is_cpu() and loaded_weight.size(0) % tp_size != 0: + param.data, loaded_weight = narrow_padded_param_and_loaded_weight( + param.data, + loaded_weight, + 0, # param_data_start + start_idx, + shard_axis, + shard_size, + ) + if loaded_weight.size(0) == 32: + print(loaded_weight) + else: + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) return default_weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index d5e380d78ce3..0922a2771231 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -254,8 +254,16 @@ def __init__( self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_size = get_attention_tp_size() self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads + self.num_v_heads = ( + config.linear_num_value_heads + if not _is_cpu + else config.linear_num_value_heads_cpu + ) + self.num_k_heads = ( + config.linear_num_key_heads + if not _is_cpu + else config.linear_num_key_heads_cpu + ) self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads From 13571bd0202038fd6d21749b3e576a06a7ba0586 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 18 Nov 2025 21:39:45 -0500 Subject: [PATCH 05/20] fix lint --- python/sglang/srt/mem_cache/memory_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 447f302b5ea3..d7d659a38008 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -192,7 +192,7 @@ def __init__( ) for conv_shape in conv_state_shape ] - + if _is_cpu and _cpu_has_amx_support: conv_state = conv_state.as_strided_( conv_state.size(), From 0d1559d1796167487b1572cc35058b0c103f9638 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 5 Dec 2025 02:30:01 -0500 Subject: [PATCH 06/20] rebase with latest kernels --- python/sglang/srt/configs/update_config.py | 7 +- python/sglang/srt/layers/amx_utils.py | 18 ++- .../attention/hybrid_linear_attn_backend.py | 123 ++++++++++++------ .../srt/layers/attention/intel_amx_backend.py | 11 +- python/sglang/srt/layers/layernorm.py | 4 +- python/sglang/srt/mem_cache/memory_pool.py | 22 ++-- .../sglang/srt/model_loader/weight_utils.py | 18 ++- python/sglang/srt/models/qwen3_next.py | 21 +-- sgl-kernel/csrc/cpu/topk.cpp | 6 + 9 files changed, 159 insertions(+), 71 deletions(-) diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index 200c07a2e2dc..d35d5de85ffe 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -142,7 +142,12 @@ def adjust_config_with_unaligned_cpu_tp( model_config.hf_config.linear_num_key_heads % tp_size != 0 or model_config.hf_config.linear_num_value_heads % tp_size != 0 ): - pad_size = get_num_heads_padding_size(tp_size, weight_block_size) + head_dim = ( + model_config.hf_config.qk_head_dim + if hasattr(model_config.hf_config, "qk_head_dim") + else model_config.hf_config.head_dim + ) + pad_size = get_num_heads_padding_size(tp_size, weight_block_size, head_dim) model_config.hf_config.linear_num_key_heads_cpu = pad_vocab_size( model_config.hf_config.linear_num_key_heads, pad_size ) diff --git a/python/sglang/srt/layers/amx_utils.py b/python/sglang/srt/layers/amx_utils.py index 95b50d64ca0c..29c7129001f1 100644 --- a/python/sglang/srt/layers/amx_utils.py +++ b/python/sglang/srt/layers/amx_utils.py @@ -34,8 +34,17 @@ def dim_is_supported(weight): return is_oc_support and is_ic_support +def dtype_is_supported(weight): + return weight.dtype in [ + torch.float16, + torch.bfloat16, + torch.int8, + torch.float8_e4m3fn, + ] + + def is_dim_conv_weight(weight): - weight.dim() == 3 and weight.size(1) == 1 + return weight.dim() == 3 and weight.size(1) == 1 def _amx_process_weight_after_loading( @@ -58,9 +67,12 @@ def _amx_process_weight_after_loading( weight_tensor = weight_tensor.transpose(*transpose_dims[i]) is_conv_weight = is_dim_conv_weight(weight_tensor) # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim. - if (not dim_is_supported(weight_tensor)) and (not is_conv_weight): + if ( + (not dim_is_supported(weight_tensor)) + or not dtype_is_supported(weight_tensor) + ) and (not is_conv_weight): logger.warning( - f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. " + f"Unsupported dimension or dtype for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} and dtype {weight_tensor.dtype} in {module}. " f"The derived (OC, IC) dimensions must be divisible by (16, 32). " ) module.use_intel_amx_backend = False diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 4323ac24f851..67267b9c5f83 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -4,7 +4,6 @@ from einops import rearrange from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating from sglang.srt.layers.attention.fla.fused_recurrent import ( fused_recurrent_gated_delta_rule_update, @@ -12,11 +11,6 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update, ) -from sglang.srt.layers.attention.fla.kda import ( - chunk_kda, - fused_kda_gate, - fused_recurrent_kda, -) from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( PAD_SLOT_ID, causal_conv1d_fn, @@ -35,6 +29,15 @@ from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu +if not is_cpu(): + # fix import error on CPU device, no impacts when non-CPU path + from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule + from sglang.srt.layers.attention.fla.kda import ( + chunk_kda, + fused_kda_gate, + fused_recurrent_kda, + ) + if is_cuda(): from sglang.srt.layers.attention.mamba.causal_conv1d import ( causal_conv1d_fn as causal_conv1d_fn_cuda, @@ -560,14 +563,27 @@ def forward_decode( query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices - mixed_qkv = causal_conv1d_update( - mixed_qkv, - conv_states, - conv_weights, - bias, - activation, - conv_state_indices=cache_indices, - ) + if is_cpu() and cpu_has_amx_support(): + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation == "silu", + None, + cache_indices, + self.pad_slot_id, + True, + ) + else: + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation, + conv_state_indices=cache_indices, + ) query, key, value = torch.split( mixed_qkv, @@ -674,17 +690,31 @@ def forward_extend( ) mixed_qkv = mixed_qkv_processed.transpose(1, 2).view(seq_len, -1) else: - mixed_qkv = causal_conv1d_fn( - mixed_qkv.transpose(0, 1), - conv_weights, - bias, - activation=activation, - conv_states=conv_states, - has_initial_state=has_initial_states, - cache_indices=cache_indices, - query_start_loc=query_start_loc, - seq_lens_cpu=forward_batch.extend_seq_lens_cpu, - ).transpose(0, 1)[:seq_len] + if is_cpu() and cpu_has_amx_support(): + mixed_qkv = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + conv_weights, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_states, + activation == "silu", + self.pad_slot_id, + True, + ).transpose(0, 1)[:seq_len] + else: + mixed_qkv = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + conv_weights, + bias, + activation=activation, + conv_states=conv_states, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + ).transpose(0, 1)[:seq_len] key_split_dim = key_dim // attn_tp_size value_split_dim = value_dim // attn_tp_size @@ -702,7 +732,6 @@ def forward_extend( query = query.view(1, actual_seq_len, num_heads, head_k_dim) key = key.view(1, actual_seq_len, num_heads, head_k_dim) value = value.view(1, actual_seq_len, num_value_heads, head_v_dim) - g, beta = fused_gdn_gating(A_log, a, b, dt_bias) if is_target_verify: @@ -723,19 +752,35 @@ def forward_extend( ) else: recurrent_state = ssm_states[cache_indices] - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - q=query, - k=key, - v=value, - g=g, - beta=beta, - initial_state=recurrent_state, - output_final_state=True, - cu_seqlens=query_start_loc, - head_first=False, - use_qk_l2norm_in_kernel=True, - ) - last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False) + if is_cpu() and cpu_has_amx_support(): + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + query=query, + key=key, + value=value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + cu_seqlens=query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + cu_seqlens=query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + last_recurrent_state = last_recurrent_state.to( + ssm_states.dtype, copy=False + ) ssm_states[cache_indices] = last_recurrent_state return core_attn_out diff --git a/python/sglang/srt/layers/attention/intel_amx_backend.py b/python/sglang/srt/layers/attention/intel_amx_backend.py index 4b2974c44e0d..8a1a8ec3b317 100644 --- a/python/sglang/srt/layers/attention/intel_amx_backend.py +++ b/python/sglang/srt/layers/attention/intel_amx_backend.py @@ -24,8 +24,15 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] - + # self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + layer_id = 0 + if hasattr(model_runner.token_to_kv_pool, "full_attention_layer_id_mapping"): + layer_id = [*model_runner.token_to_kv_pool.full_attention_layer_id_mapping][ + 0 + ] + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( + layer_id + ).shape[-1] self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index a6bf4fb9ddaf..131147b3069c 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -463,7 +463,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -class Qwen3NextRMSNormGated(nn.Module): +class Qwen3NextRMSNormGated(CustomOp): def __init__(self, hidden_size, eps=1e-6, **kwargs): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -480,7 +480,7 @@ def forward_native(self, hidden_states, gate=None): return hidden_states.to(input_dtype) def forward_cpu(self, hidden_states, gate=None): - out = torch.ops.sgl_kernel.qwen3_next_rmsnorm_gated_cpu( + out = torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( hidden_states, self.weight, gate, self.variance_epsilon ) return out diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index f2a9fd2019cc..6e236ccd86d0 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -195,15 +195,19 @@ def __init__( ] if _is_cpu and _cpu_has_amx_support: - conv_state = conv_state.as_strided_( - conv_state.size(), - ( - conv_state.stride(0), - conv_state.stride(1), - 1, - conv_state.size(2), - ), - ) + conv_state_cpu = [] + for conv_shape_t in conv_state: + conv_shape_new = conv_shape_t.as_strided_( + conv_shape_t.size(), + ( + conv_shape_t.stride(0), + conv_shape_t.stride(1), + 1, + conv_shape_t.size(2), + ), + ) + conv_state_cpu.append(conv_shape_new) + conv_state = conv_state_cpu temporal_state = torch.zeros( size=(num_mamba_layers, size + 1) + temporal_state_shape, dtype=ssm_dtype, diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 8d2408c5ef42..4b17ebad25af 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -894,21 +894,25 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: shard_size = param.data.shape[shard_axis] start_idx = tp_rank * shard_size - if is_cpu() and loaded_weight.size(0) % tp_size != 0: - param.data, loaded_weight = narrow_padded_param_and_loaded_weight( - param.data, + + if ( + is_cpu() + and loaded_weight.size(0) % tp_size != 0 + and loaded_weight.dim() == 1 + ): + param_data = param.data # view copy on param for uneven padding + param_data, loaded_weight = narrow_padded_param_and_loaded_weight( + param_data, loaded_weight, 0, # param_data_start start_idx, shard_axis, shard_size, ) - if loaded_weight.size(0) == 32: - print(loaded_weight) + return default_weight_loader(param_data, loaded_weight) else: loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) - - return default_weight_loader(param, loaded_weight) + return default_weight_loader(param, loaded_weight) return loader diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index b4561f542aa7..3c848f692588 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -297,7 +297,6 @@ def __init__( ).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) if _is_cpu: @@ -378,7 +377,7 @@ def _forward_input_proj(self, hidden_states: torch.Tensor): DUAL_STREAM_TOKEN_THRESHOLD = 1024 seq_len, _ = hidden_states.shape - if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: + if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and not _is_cpu: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) @@ -418,7 +417,11 @@ def _forward( hidden_states ) - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph: + if ( + self.num_v_heads // self.num_k_heads in [1, 2, 4] + and is_cuda_graph + and not _is_cpu + ): mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( projected_states_qkvz, projected_states_ba, @@ -447,11 +450,13 @@ def _forward( ) mixed_qkv = torch.cat((query, key, value), dim=-1) # mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l") - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) + if _is_cpu: + conv_weights = self.conv1d.weight + else: + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) kwargs = { "mixed_qkv": mixed_qkv, diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index b4bdcd0b7b37..0471661e58a7 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -525,6 +525,12 @@ topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t t case 256: LAUNCH_TOPK_SOFTMAX_KERNEL(256); break; + case 384: + LAUNCH_TOPK_SOFTMAX_KERNEL(384); + break; + case 512: + LAUNCH_TOPK_SOFTMAX_KERNEL(512); + break; default: TORCH_CHECK(false, "Unexpected num_experts: ", num_experts); } From eeebbb2c2bb98ffc458f1cd1334e15d54e4f6606 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 8 Dec 2025 11:06:04 +0800 Subject: [PATCH 07/20] Update python/sglang/srt/layers/attention/intel_amx_backend.py Co-authored-by: Ma Mingfei --- python/sglang/srt/layers/attention/intel_amx_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/intel_amx_backend.py b/python/sglang/srt/layers/attention/intel_amx_backend.py index 8a1a8ec3b317..7a183dc1578d 100644 --- a/python/sglang/srt/layers/attention/intel_amx_backend.py +++ b/python/sglang/srt/layers/attention/intel_amx_backend.py @@ -24,7 +24,7 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - # self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + # [NB]: `layer_id` set to 0 for xxx models, and blah blah blah layer_id = 0 if hasattr(model_runner.token_to_kv_pool, "full_attention_layer_id_mapping"): layer_id = [*model_runner.token_to_kv_pool.full_attention_layer_id_mapping][ From c032b1ed58ac8598826b97c1d68be6a110e7ff71 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 8 Dec 2025 00:34:02 -0500 Subject: [PATCH 08/20] refine codes --- python/sglang/srt/configs/qwen3_next.py | 16 ++--- python/sglang/srt/configs/update_config.py | 65 +++++++++++-------- .../layers/attention/fla/layernorm_gated.py | 31 +++++---- .../attention/hybrid_linear_attn_backend.py | 12 ++-- .../srt/layers/attention/intel_amx_backend.py | 3 +- .../srt/layers/attention/mamba/mamba.py | 21 +++--- python/sglang/srt/layers/layernorm.py | 23 ------- python/sglang/srt/models/qwen3_next.py | 28 ++++---- 8 files changed, 97 insertions(+), 102 deletions(-) diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py index 6fe705a357eb..94b74cfcc7a2 100644 --- a/python/sglang/srt/configs/qwen3_next.py +++ b/python/sglang/srt/configs/qwen3_next.py @@ -20,9 +20,11 @@ from transformers.utils import logging from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape +from sglang.srt.configs.update_config import adjust_tp_num_heads_if_necessary from sglang.srt.utils import is_cpu logger = logging.get_logger(__name__) +_is_cpu = is_cpu() class HybridLayerType(enum.Enum): @@ -277,17 +279,9 @@ def full_attention_layer_ids(self): def mamba2_cache_params(self) -> Mamba2CacheParams: from sglang.srt.layers.dp_attention import get_attention_tp_size - world_size = get_attention_tp_size() - if is_cpu() and ( - self.linear_num_key_heads % world_size != 0 - or self.linear_num_value_heads % world_size != 0 - ): - pad_size = world_size - groups = self.linear_num_value_heads // self.linear_num_key_heads - self.linear_num_key_heads = ( - (self.linear_num_key_heads + pad_size - 1) // pad_size - ) * pad_size - self.linear_num_value_heads = self.linear_num_key_heads * groups + if _is_cpu: + world_size = get_attention_tp_size() + adjust_tp_num_heads_if_necessary(self, world_size, False) shape = Mamba2StateShape.create( tp_world_size=get_attention_tp_size(), diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index d35d5de85ffe..55bda7787af7 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -54,6 +54,44 @@ def get_num_heads_padding_size(tp_size, weight_block_size, head_dim): return pad_size +def adjust_tp_num_heads_if_necessary(model_config, tp_size, is_post_update): + # is_post_update: whether to update an existing config + from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size + + # Linear attn check logic + if hasattr(model_config, "linear_num_key_heads") and hasattr( + model_config, "linear_num_value_heads" + ): + if ( + model_config.linear_num_key_heads % tp_size != 0 + or model_config.linear_num_value_heads % tp_size != 0 + ): + pad_size = tp_size + linear_num_key_heads_cpu = pad_vocab_size( + model_config.linear_num_key_heads, pad_size + ) + linear_num_value_heads_cpu = ( + linear_num_key_heads_cpu + * model_config.linear_num_value_heads + // model_config.linear_num_key_heads + ) + if is_post_update: + model_config.linear_num_key_heads_cpu = linear_num_key_heads_cpu + model_config.linear_num_value_heads_cpu = linear_num_value_heads_cpu + else: + model_config.linear_num_key_heads = linear_num_key_heads_cpu + model_config.linear_num_value_heads = linear_num_value_heads_cpu + + else: + if is_post_update: + model_config.linear_num_key_heads_cpu = ( + model_config.hf_config.linear_num_key_heads + ) + model_config.linear_num_value_heads_cpu = ( + model_config.linear_num_value_heads + ) + + def update_intermediate_size(model_config, attr_name, intermediate_padding_size): attr_value = intermediate_padding_size if hasattr(model_config, "hf_config") and hasattr( @@ -137,32 +175,7 @@ def adjust_config_with_unaligned_cpu_tp( model_config.hf_config.num_attention_heads = num_attention_heads model_config.hf_text_config.num_attention_heads = num_attention_heads - # Linear attn padding logic - if ( - model_config.hf_config.linear_num_key_heads % tp_size != 0 - or model_config.hf_config.linear_num_value_heads % tp_size != 0 - ): - head_dim = ( - model_config.hf_config.qk_head_dim - if hasattr(model_config.hf_config, "qk_head_dim") - else model_config.hf_config.head_dim - ) - pad_size = get_num_heads_padding_size(tp_size, weight_block_size, head_dim) - model_config.hf_config.linear_num_key_heads_cpu = pad_vocab_size( - model_config.hf_config.linear_num_key_heads, pad_size - ) - model_config.hf_config.linear_num_value_heads_cpu = ( - model_config.hf_config.linear_num_key_heads_cpu - * model_config.hf_config.linear_num_value_heads - // model_config.hf_config.linear_num_key_heads - ) - else: - model_config.hf_config.linear_num_key_heads_cpu = ( - model_config.hf_config.linear_num_key_heads - ) - model_config.hf_config.linear_num_value_heads_cpu = ( - model_config.hf_config.linear_num_value_heads - ) + adjust_tp_num_heads_if_necessary(model_config.hf_config, tp_size, True) intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size) model_config = update_intermediate_size( diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index ed64cc438d00..69387629fe9a 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -12,9 +12,10 @@ import triton.language as tl from einops import rearrange -from sglang.srt.utils import device_context, is_npu +from sglang.srt.utils import cpu_has_amx_support, device_context, is_cpu, is_npu _is_npu = is_npu() +_use_cpu = is_cpu() and cpu_has_amx_support() def rms_norm_ref( @@ -338,13 +339,21 @@ def reset_parameters(self): def forward(self, x, z=None): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return layernorm_fn( - x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate, - is_rms_norm=True, - ) + if _use_cpu: + assert ( + self.norm_before_gate and self.group_size is None + ), "CPU rmsnorm_gated currently only supports norm before gate without group size" + return torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( + x, self.weight, z, self.eps + ) + else: + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + is_rms_norm=True, + ) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 4c312793d5dd..ed970620dcc8 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -58,7 +58,11 @@ fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu causal_conv1d_fn = causal_conv1d_fn_npu causal_conv1d_update = causal_conv1d_update_npu -elif is_cpu() and cpu_has_amx_support(): +elif is_cpu(): + assert ( + cpu_has_amx_support() + ), "CPU requires AMX support for hybrid linear attn backend" + _use_cpu = True chunk_gated_delta_rule = torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu causal_conv1d_fn = torch.ops.sgl_kernel.causal_conv1d_fwd_cpu causal_conv1d_update = torch.ops.sgl_kernel.causal_conv1d_update_cpu @@ -563,7 +567,7 @@ def forward_decode( query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices - if is_cpu() and cpu_has_amx_support(): + if _use_cpu: mixed_qkv = causal_conv1d_update( mixed_qkv, conv_states, @@ -690,7 +694,7 @@ def forward_extend( ) mixed_qkv = mixed_qkv_processed.transpose(1, 2).view(seq_len, -1) else: - if is_cpu() and cpu_has_amx_support(): + if _use_cpu: mixed_qkv = causal_conv1d_fn( mixed_qkv.transpose(0, 1), conv_weights, @@ -752,7 +756,7 @@ def forward_extend( ) else: recurrent_state = ssm_states[cache_indices] - if is_cpu() and cpu_has_amx_support(): + if _use_cpu: core_attn_out, last_recurrent_state = chunk_gated_delta_rule( query=query, key=key, diff --git a/python/sglang/srt/layers/attention/intel_amx_backend.py b/python/sglang/srt/layers/attention/intel_amx_backend.py index 7a183dc1578d..7ab2753741c3 100644 --- a/python/sglang/srt/layers/attention/intel_amx_backend.py +++ b/python/sglang/srt/layers/attention/intel_amx_backend.py @@ -24,7 +24,8 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - # [NB]: `layer_id` set to 0 for xxx models, and blah blah blah + # [NB]: `layer_id` set to 0 for qwen3-next models, as not all attn layers require kv pool + # using "full_attention_layer_id_mapping" to map which layer needs kv pool layer_id = 0 if hasattr(model_runner.token_to_kv_pool, "full_attention_layer_id_mapping"): layer_id = [*model_runner.token_to_kv_pool.full_attention_layer_id_mapping][ diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 4529b1682f88..550ee171e18e 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -70,16 +70,17 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: boundary, loaded_boundary = 0, 0 # Calculate padding size for CPU when TP odd size - full_dim_sum = 0 - full_dim_list = [] - weight_full_dim_list = [] - for full_dim, _, _ in shard_spec: - full_dim_sum = full_dim_sum + full_dim - full_dim_list.append(full_dim) - for full_dim in full_dim_list: - weight_full_dim_list.append( - int(full_dim / full_dim_sum * loaded_weight.size(0)) - ) + if is_cpu(): + full_dim_sum = 0 + full_dim_list = [] + weight_full_dim_list = [] + for full_dim, _, _ in shard_spec: + full_dim_sum = full_dim_sum + full_dim + full_dim_list.append(full_dim) + for full_dim in full_dim_list: + weight_full_dim_list.append( + int(full_dim / full_dim_sum * loaded_weight.size(0)) + ) # - iterate over the shard specs for full_dim, extra, duplicate_groups in shard_spec: diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 131147b3069c..3293a8a59b50 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -463,29 +463,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -class Qwen3NextRMSNormGated(CustomOp): - def __init__(self, hidden_size, eps=1e-6, **kwargs): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward_native(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - # Norm before gate - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight * hidden_states.to(input_dtype) - hidden_states = hidden_states * F.silu(gate.to(torch.float32)) - return hidden_states.to(input_dtype) - - def forward_cpu(self, hidden_states, gate=None): - out = torch.ops.sgl_kernel.fused_rmsnorm_gated_cpu( - hidden_states, self.weight, gate, self.variance_epsilon - ) - return out - - if not ( _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu ): diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 3c848f692588..43c07cbfdd27 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -18,7 +18,7 @@ get_attention_tp_size, is_dp_attention_enabled, ) -from sglang.srt.layers.layernorm import GemmaRMSNorm, Qwen3NextRMSNormGated +from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -299,19 +299,15 @@ def __init__( self.A_log._no_weight_decay = True set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - if _is_cpu: - self.norm = Qwen3NextRMSNormGated( - self.head_v_dim, eps=self.layer_norm_epsilon - ) - else: - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=torch.get_device_module().current_device(), - dtype=config.torch_dtype, - ) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.get_device_module().current_device(), + dtype=config.torch_dtype, + ) self.out_proj = RowParallelLinear( self.value_dim, @@ -371,13 +367,13 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): return query, key, value, z, b, a def _forward_input_proj(self, hidden_states: torch.Tensor): - if _is_npu or get_global_server_args().enable_piecewise_cuda_graph: + if _is_cpu or _is_npu or get_global_server_args().enable_piecewise_cuda_graph: DUAL_STREAM_TOKEN_THRESHOLD = 0 else: DUAL_STREAM_TOKEN_THRESHOLD = 1024 seq_len, _ = hidden_states.shape - if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and not _is_cpu: + if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) From 8a6dae6bce68378fb5a6e53866640a160b896d66 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 8 Dec 2025 19:58:33 -0500 Subject: [PATCH 09/20] minor fix after rebase --- python/sglang/srt/configs/update_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index 55bda7787af7..f357a7635608 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -85,7 +85,7 @@ def adjust_tp_num_heads_if_necessary(model_config, tp_size, is_post_update): else: if is_post_update: model_config.linear_num_key_heads_cpu = ( - model_config.hf_config.linear_num_key_heads + model_config.linear_num_key_heads ) model_config.linear_num_value_heads_cpu = ( model_config.linear_num_value_heads From 742ea2660b6535f2c9a9e56ae82d6fd9cfc76fc5 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 9 Dec 2025 07:46:10 -0500 Subject: [PATCH 10/20] refine mamba apis --- .../attention/hybrid_linear_attn_backend.py | 122 ++++++------------ python/sglang/srt/mem_cache/memory_pool.py | 1 + sgl-kernel/python/sgl_kernel/__init__.py | 8 +- sgl-kernel/python/sgl_kernel/mamba.py | 70 ++++++++++ 4 files changed, 121 insertions(+), 80 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index ed970620dcc8..2d78e7cbe5bc 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -62,10 +62,15 @@ assert ( cpu_has_amx_support() ), "CPU requires AMX support for hybrid linear attn backend" - _use_cpu = True - chunk_gated_delta_rule = torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu - causal_conv1d_fn = torch.ops.sgl_kernel.causal_conv1d_fwd_cpu - causal_conv1d_update = torch.ops.sgl_kernel.causal_conv1d_update_cpu + from sgl_kernel.mamba import ( + causal_conv1d_fn_cpu, + causal_conv1d_update_cpu, + chunk_gated_delta_rule_cpu, + ) + + chunk_gated_delta_rule = chunk_gated_delta_rule_cpu + causal_conv1d_fn = causal_conv1d_fn_cpu + causal_conv1d_update = causal_conv1d_update_cpu fused_sigmoid_gating_delta_rule_update = ( torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu ) @@ -567,27 +572,14 @@ def forward_decode( query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices - if _use_cpu: - mixed_qkv = causal_conv1d_update( - mixed_qkv, - conv_states, - conv_weights, - bias, - activation == "silu", - None, - cache_indices, - self.pad_slot_id, - True, - ) - else: - mixed_qkv = causal_conv1d_update( - mixed_qkv, - conv_states, - conv_weights, - bias, - activation, - conv_state_indices=cache_indices, - ) + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation, + conv_state_indices=cache_indices, + ) query, key, value = torch.split( mixed_qkv, @@ -694,31 +686,18 @@ def forward_extend( ) mixed_qkv = mixed_qkv_processed.transpose(1, 2).view(seq_len, -1) else: - if _use_cpu: - mixed_qkv = causal_conv1d_fn( - mixed_qkv.transpose(0, 1), - conv_weights, - bias, - conv_states, - query_start_loc, - cache_indices, - has_initial_states, - activation == "silu", - self.pad_slot_id, - True, - ).transpose(0, 1)[:seq_len] - else: - mixed_qkv = causal_conv1d_fn( - mixed_qkv.transpose(0, 1), - conv_weights, - bias, - activation=activation, - conv_states=conv_states, - has_initial_state=has_initial_states, - cache_indices=cache_indices, - query_start_loc=query_start_loc, - seq_lens_cpu=forward_batch.extend_seq_lens_cpu, - ).transpose(0, 1)[:seq_len] + + mixed_qkv = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + conv_weights, + bias, + activation=activation, + conv_states=conv_states, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + ).transpose(0, 1)[:seq_len] key_split_dim = key_dim // attn_tp_size value_split_dim = value_dim // attn_tp_size @@ -756,35 +735,20 @@ def forward_extend( ) else: recurrent_state = ssm_states[cache_indices] - if _use_cpu: - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - query=query, - key=key, - value=value, - g=g, - beta=beta, - initial_state=recurrent_state, - output_final_state=True, - cu_seqlens=query_start_loc, - head_first=False, - use_qk_l2norm_in_kernel=True, - ) - else: - core_attn_out, last_recurrent_state = chunk_gated_delta_rule( - q=query, - k=key, - v=value, - g=g, - beta=beta, - initial_state=recurrent_state, - output_final_state=True, - cu_seqlens=query_start_loc, - head_first=False, - use_qk_l2norm_in_kernel=True, - ) - last_recurrent_state = last_recurrent_state.to( - ssm_states.dtype, copy=False - ) + + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + cu_seqlens=query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False) ssm_states[cache_indices] = last_recurrent_state return core_attn_out diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 6e236ccd86d0..c02037a672d3 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -195,6 +195,7 @@ def __init__( ] if _is_cpu and _cpu_has_amx_support: + # CPU uses a different layout of conv_state for kernel optimization conv_state_cpu = [] for conv_shape_t in conv_state: conv_shape_new = conv_shape_t.as_strided_( diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 8e8994e04c95..a24d3573be4a 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -78,7 +78,13 @@ transfer_kv_per_layer, transfer_kv_per_layer_mla, ) -from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update +from sgl_kernel.mamba import ( + causal_conv1d_fn_cpu, + causal_conv1d_fwd, + causal_conv1d_update, + causal_conv1d_update_cpu, + chunk_gated_delta_rule_cpu, +) from sgl_kernel.marlin import ( awq_marlin_moe_repack, awq_marlin_repack, diff --git a/sgl-kernel/python/sgl_kernel/mamba.py b/sgl-kernel/python/sgl_kernel/mamba.py index 85aa5b9479e1..a3bb0c613832 100644 --- a/sgl-kernel/python/sgl_kernel/mamba.py +++ b/sgl-kernel/python/sgl_kernel/mamba.py @@ -48,3 +48,73 @@ def causal_conv1d_update( conv_state_indices, pad_slot_id, ) + + +def causal_conv1d_fn_cpu( + mixed_qkv_transposed, + conv_weights, + bias, + activation, + conv_states, + has_initial_state, + cache_indices, + query_start_loc, + seq_lens_cpu, +): + return torch.ops.sgl_kernel.causal_conv1d_fwd_cpu( + mixed_qkv_transposed, + conv_weights, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation == "silu", + -1, + True, + ) + + +def causal_conv1d_update_cpu( + mixed_qkv, conv_states, conv_weights, bias, activation, conv_state_indices +): + return torch.ops.sgl_kernel.causal_conv1d_update_cpu( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation == "silu", + None, + conv_state_indices, + -1, + True, + ) + + +def chunk_gated_delta_rule_cpu( + q, + k, + v, + g, + beta, + initial_state, + output_final_state, + cu_seqlens, + head_first, + use_qk_l2norm_in_kernel, +): + core_attn_out, last_recurrent_state = ( + torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( + q, + k, + v, + g, + beta, + initial_state, + output_final_state, + cu_seqlens, + head_first, + use_qk_l2norm_in_kernel, + ) + ) + return core_attn_out, last_recurrent_state From c40ae9d15053f796dc2dbc74873fd6c8c87b6f7b Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 15 Dec 2025 02:32:02 -0500 Subject: [PATCH 11/20] adjust mamba cache after rebase --- .../layers/attention/hybrid_linear_attn_backend.py | 14 +++++++++----- python/sglang/srt/utils/common.py | 3 +++ sgl-kernel/python/sgl_kernel/mamba.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index f08f37990ade..571e26f6efd3 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -34,7 +34,9 @@ if not is_cpu(): # fix import error on CPU device, no impacts when non-CPU path from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule - from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE + from sglang.srt.layers.attention.fla.chunk_delta_h import ( + CHUNK_SIZE as FLA_CHUNK_SIZE, + ) from sglang.srt.layers.attention.fla.kda import ( chunk_kda, fused_kda_gate, @@ -65,6 +67,7 @@ assert ( cpu_has_amx_support() ), "CPU requires AMX support for hybrid linear attn backend" + _is_cpu = True from sgl_kernel.mamba import ( causal_conv1d_fn_cpu, causal_conv1d_update_cpu, @@ -862,9 +865,10 @@ def __init__(self, model_runner: ModelRunner): self.conv_states_shape = ( model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape ) - assert ( - self.conv_states_shape[-1] < FLA_CHUNK_SIZE - ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}" + if not _is_cpu: + assert ( + self.conv_states_shape[-1] < FLA_CHUNK_SIZE + ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}" def forward_decode( self, @@ -1098,7 +1102,7 @@ def forward_extend( ) last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False) ssm_states[cache_indices] = last_recurrent_state - + self._track_mamba_state_extend( forward_batch, h, ssm_states, forward_metadata ) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index b4fef6ed61d7..8642ab432a43 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1826,6 +1826,9 @@ def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "npu") and torch.npu.is_available(): return torch.npu.get_device_name(device_id) + if hasattr(torch, "cpu") and torch.cpu.is_available(): + return torch.cpu.current_device() + @lru_cache(maxsize=1) def is_habana_available() -> bool: diff --git a/sgl-kernel/python/sgl_kernel/mamba.py b/sgl-kernel/python/sgl_kernel/mamba.py index 01ffa12b2127..14dd75b06dbc 100644 --- a/sgl-kernel/python/sgl_kernel/mamba.py +++ b/sgl-kernel/python/sgl_kernel/mamba.py @@ -117,5 +117,5 @@ def chunk_gated_delta_rule_cpu( use_qk_l2norm_in_kernel, ) ) - h = None # todo: add return h support + h = None # todo: add return h support return core_attn_out, last_recurrent_state, h From c777aa36e96cd3e81fcf4edc02b1a66e071508e0 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 15 Dec 2025 02:35:53 -0500 Subject: [PATCH 12/20] minor refinements --- .../sglang/srt/layers/attention/hybrid_linear_attn_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 571e26f6efd3..cc5b2ed793a2 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -67,7 +67,6 @@ assert ( cpu_has_amx_support() ), "CPU requires AMX support for hybrid linear attn backend" - _is_cpu = True from sgl_kernel.mamba import ( causal_conv1d_fn_cpu, causal_conv1d_update_cpu, @@ -865,7 +864,7 @@ def __init__(self, model_runner: ModelRunner): self.conv_states_shape = ( model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape ) - if not _is_cpu: + if not is_cpu(): assert ( self.conv_states_shape[-1] < FLA_CHUNK_SIZE ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}" From cbf8adb6c886eac67abc211ebbc93033b8dd1849 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 15 Dec 2025 02:48:52 -0500 Subject: [PATCH 13/20] final minor refinements --- python/sglang/srt/server_args.py | 2 +- python/sglang/srt/utils/common.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a8b141477fd3..50890bb66b6c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1429,7 +1429,7 @@ def _handle_model_specific_adjustments(self): # for models with explicit support (DeepseekV3, GptOss, Glm4Moe, Qwen3Moe) # TODO: currently, it is only supported in the single node scenario. https://github.com/flashinfer-ai/flashinfer/issues/2006 # TODO: there is currently a bug on H20 device specifically, https://github.com/flashinfer-ai/flashinfer/issues/2204 - device_name = get_device_name() + device_name = get_device_name() if self.device is not "cpu" else "CPU" is_h20_device = "H20" in device_name and "H200" not in device_name if ( not self.enable_flashinfer_allreduce_fusion diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 8642ab432a43..b4fef6ed61d7 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1826,9 +1826,6 @@ def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "npu") and torch.npu.is_available(): return torch.npu.get_device_name(device_id) - if hasattr(torch, "cpu") and torch.cpu.is_available(): - return torch.cpu.current_device() - @lru_cache(maxsize=1) def is_habana_available() -> bool: From b98753054baef7848d45341dacb3b16d284a47ea Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 15 Dec 2025 02:56:37 -0500 Subject: [PATCH 14/20] format --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 50890bb66b6c..2851c16e7fd7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1429,7 +1429,7 @@ def _handle_model_specific_adjustments(self): # for models with explicit support (DeepseekV3, GptOss, Glm4Moe, Qwen3Moe) # TODO: currently, it is only supported in the single node scenario. https://github.com/flashinfer-ai/flashinfer/issues/2006 # TODO: there is currently a bug on H20 device specifically, https://github.com/flashinfer-ai/flashinfer/issues/2204 - device_name = get_device_name() if self.device is not "cpu" else "CPU" + device_name = get_device_name() if self.device != "cpu" else "CPU" is_h20_device = "H20" in device_name and "H200" not in device_name if ( not self.enable_flashinfer_allreduce_fusion From 9520d90afeb5036b2bd5e037a428e00fe7e0c65c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 12 Jan 2026 21:52:40 -0500 Subject: [PATCH 15/20] rebase api --- sgl-kernel/python/sgl_kernel/mamba.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/mamba.py b/sgl-kernel/python/sgl_kernel/mamba.py index 14dd75b06dbc..4f011c3fe690 100644 --- a/sgl-kernel/python/sgl_kernel/mamba.py +++ b/sgl-kernel/python/sgl_kernel/mamba.py @@ -98,11 +98,12 @@ def chunk_gated_delta_rule_cpu( g, beta, initial_state, - output_final_state, cu_seqlens, head_first, use_qk_l2norm_in_kernel, + initial_state_indices, ): + recurrent_state = initial_state[initial_state_indices] core_attn_out, last_recurrent_state = ( torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( q, @@ -110,12 +111,12 @@ def chunk_gated_delta_rule_cpu( v, g, beta, - initial_state, - output_final_state, + recurrent_state, + True, # output_final_state cu_seqlens, head_first, use_qk_l2norm_in_kernel, ) ) - h = None # todo: add return h support + h = None # Todo: add return h support return core_attn_out, last_recurrent_state, h From e008c0ab13891db9635b6a5c9654529ccf794a43 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 12 Jan 2026 22:18:55 -0500 Subject: [PATCH 16/20] refine api --- .../sglang/srt/layers/attention/hybrid_linear_attn_backend.py | 4 ++-- sgl-kernel/python/sgl_kernel/mamba.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 6b74962ff6f7..f26f3b3b0e16 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -1086,7 +1086,7 @@ def forward_extend( # Only cuda env uses fuse ssm_states update recurrent_state = ssm_states recurrent_state_indices_args = {"initial_state_indices": cache_indices} - if is_npu(): + if is_npu() or is_cpu(): recurrent_state = ssm_states[cache_indices] recurrent_state_indices_args = {} core_attn_out, last_recurrent_state, h = chunk_gated_delta_rule( @@ -1101,7 +1101,7 @@ def forward_extend( use_qk_l2norm_in_kernel=True, **recurrent_state_indices_args, ) - if is_npu(): + if is_npu() or is_cpu(): last_recurrent_state = last_recurrent_state.to( ssm_states.dtype, copy=False ) diff --git a/sgl-kernel/python/sgl_kernel/mamba.py b/sgl-kernel/python/sgl_kernel/mamba.py index 4f011c3fe690..a9ffbfcb5418 100644 --- a/sgl-kernel/python/sgl_kernel/mamba.py +++ b/sgl-kernel/python/sgl_kernel/mamba.py @@ -101,9 +101,7 @@ def chunk_gated_delta_rule_cpu( cu_seqlens, head_first, use_qk_l2norm_in_kernel, - initial_state_indices, ): - recurrent_state = initial_state[initial_state_indices] core_attn_out, last_recurrent_state = ( torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( q, @@ -111,7 +109,7 @@ def chunk_gated_delta_rule_cpu( v, g, beta, - recurrent_state, + initial_state, True, # output_final_state cu_seqlens, head_first, From 9cc7b92ccc9631daa732ddf25553d00ba18c5aa9 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Sun, 18 Jan 2026 20:35:24 -0500 Subject: [PATCH 17/20] format after rebase --- .../sglang/srt/layers/attention/hybrid_linear_attn_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 558283e69615..03134cfa5ba2 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -5,7 +5,6 @@ import triton.language as tl from einops import rearrange -from sglang.jit_kernel.cutedsl_gdn import cutedsl_fused_sigmoid_gating_delta_rule_update from sglang.srt.environ import Envs from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating @@ -37,6 +36,9 @@ if not is_cpu(): # fix import error on CPU device, no impacts when non-CPU path + from sglang.jit_kernel.cutedsl_gdn import ( + cutedsl_fused_sigmoid_gating_delta_rule_update, + ) from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule from sglang.srt.layers.attention.fla.chunk_delta_h import ( CHUNK_SIZE as FLA_CHUNK_SIZE, From 2c5309a3b72df48ac64cccf3bc8d7dc6a9ecc584 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 20 Jan 2026 00:13:18 -0500 Subject: [PATCH 18/20] minor refinements --- python/sglang/srt/model_loader/weight_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 24472c5a2521..e22acc2b2312 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -33,8 +33,11 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, @@ -997,14 +1000,13 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction: def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: tp_rank = get_attention_tp_rank() - tp_size = get_attention_tp_size() shard_size = param.data.shape[shard_axis] start_idx = tp_rank * shard_size if ( is_cpu() - and loaded_weight.size(0) % tp_size != 0 + and loaded_weight.size(0) % get_tensor_model_parallel_world_size() != 0 and loaded_weight.dim() == 1 ): param_data = param.data # view copy on param for uneven padding From 1a97f90d48f7cd6cdae45f5a2f0b069ff2c9ffdc Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 20 Jan 2026 23:54:22 -0500 Subject: [PATCH 19/20] refinements per reviews --- python/sglang/srt/layers/amx_utils.py | 17 +++++++++++++++++ python/sglang/srt/mem_cache/memory_pool.py | 17 ++++------------- python/sglang/srt/models/qwen3_next.py | 2 -- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/amx_utils.py b/python/sglang/srt/layers/amx_utils.py index 29c7129001f1..9fffa7f89670 100644 --- a/python/sglang/srt/layers/amx_utils.py +++ b/python/sglang/srt/layers/amx_utils.py @@ -47,6 +47,23 @@ def is_dim_conv_weight(weight): return weight.dim() == 3 and weight.size(1) == 1 +def _init_amx_conv_state(conv_state): + # CPU AMX layout for conv_state kernel optimization + conv_state_cpu = [] + for conv_shape_t in conv_state: + conv_shape_new = conv_shape_t.as_strided_( + conv_shape_t.size(), + ( + conv_shape_t.stride(0), + conv_shape_t.stride(1), + 1, + conv_shape_t.size(2), + ), + ) + conv_state_cpu.append(conv_shape_new) + return conv_state_cpu + + def _amx_process_weight_after_loading( module, weight_names, transpose_dims=None ) -> None: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index e3f9305055a1..c562c8b185ff 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -241,20 +241,11 @@ def __init__( ] if _is_cpu and _cpu_has_amx_support: + from sglang.srt.layers.amx_utils import _init_amx_conv_state + # CPU uses a different layout of conv_state for kernel optimization - conv_state_cpu = [] - for conv_shape_t in conv_state: - conv_shape_new = conv_shape_t.as_strided_( - conv_shape_t.size(), - ( - conv_shape_t.stride(0), - conv_shape_t.stride(1), - 1, - conv_shape_t.size(2), - ), - ) - conv_state_cpu.append(conv_shape_new) - conv_state = conv_state_cpu + conv_state = _init_amx_conv_state(conv_state) + temporal_state = torch.zeros( size=(num_mamba_layers, size + 1) + temporal_state_shape, dtype=ssm_dtype, diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index e0abe13e1428..3be096b62152 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -59,8 +59,6 @@ _is_npu = is_npu() _is_cpu = is_cpu() _is_amx_available = cpu_has_amx_support() -if _is_cpu and _is_amx_available: - pass import triton From 7be31cb2108333db3c752a6d42eef6a890967a33 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 22 Jan 2026 09:44:53 -0500 Subject: [PATCH 20/20] minor refine after rebase --- python/sglang/srt/layers/amx_utils.py | 5 +++++ python/sglang/srt/models/qwen3_next.py | 4 +--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/amx_utils.py b/python/sglang/srt/layers/amx_utils.py index 9fffa7f89670..1fb704f9517a 100644 --- a/python/sglang/srt/layers/amx_utils.py +++ b/python/sglang/srt/layers/amx_utils.py @@ -101,6 +101,11 @@ def _amx_process_weight_after_loading( ) packed_weight.__dict__ = weight_tensor.__dict__ setattr(module, weight_name, packed_weight) + if is_conv_weight: + # need to use inplace copy for conv weight amx packing, + # as its usage in radix_linear_attention will use the original conv weight. + weight_tensor = weight_tensor.view(-1, weight_tensor.size(-1)) + weight_tensor.copy_(packed_weight) module.use_intel_amx_backend = ( device == torch.device("cpu") and cpu_has_amx_support() diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index f5b22ae8261b..329678f0f868 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -323,9 +323,7 @@ def __init__( head_qk_dim=self.head_k_dim, head_v_dim=self.head_v_dim, attention_tp_size=self.attn_tp_size, - conv_weights=( - self.conv1d.weight if _is_cpu else self.conv1d.weight.squeeze(1) - ), + conv_weights=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, A_log=self.A_log,