diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py index 630227a2c625..684939776490 100644 --- a/python/sglang/srt/configs/qwen3_next.py +++ b/python/sglang/srt/configs/qwen3_next.py @@ -22,6 +22,7 @@ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.utils import is_cpu logger = logging.get_logger(__name__) @@ -277,6 +278,17 @@ def full_attention_layer_ids(self): @property def mamba2_cache_params(self) -> Mamba2CacheParams: + world_size = get_attention_tp_size() + if is_cpu() and ( + self.linear_num_key_heads % world_size != 0 + or self.linear_num_value_heads % world_size != 0 + ): + pad_size = world_size + groups = self.linear_num_value_heads // self.linear_num_key_heads + self.linear_num_key_heads = ( + (self.linear_num_key_heads + pad_size - 1) // pad_size + ) * pad_size + self.linear_num_value_heads = self.linear_num_key_heads * groups shape = Mamba2StateShape.create( tp_world_size=get_attention_tp_size(), intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index abbd724fb141..aa259b6375f0 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -119,6 +119,28 @@ def adjust_config_with_unaligned_cpu_tp( model_config.hf_config.num_attention_heads = num_attention_heads model_config.hf_text_config.num_attention_heads = num_attention_heads + # Linear attn padding logic + if ( + model_config.hf_config.linear_num_key_heads % tp_size != 0 + or model_config.hf_config.linear_num_value_heads % tp_size != 0 + ): + pad_size = get_num_heads_padding_size(tp_size, weight_block_size) + model_config.hf_config.linear_num_key_heads_cpu = pad_vocab_size( + model_config.hf_config.linear_num_key_heads, pad_size + ) + model_config.hf_config.linear_num_value_heads_cpu = ( + model_config.hf_config.linear_num_key_heads_cpu + * model_config.hf_config.linear_num_value_heads + // model_config.hf_config.linear_num_key_heads + ) + else: + model_config.hf_config.linear_num_key_heads_cpu = ( + model_config.hf_config.linear_num_key_heads + ) + model_config.hf_config.linear_num_value_heads_cpu = ( + model_config.hf_config.linear_num_value_heads + ) + intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size) model_config = update_intermediate_size( model_config, "moe_intermediate_size", intermediate_padding_size @@ -129,6 +151,9 @@ def adjust_config_with_unaligned_cpu_tp( model_config = update_intermediate_size( model_config, "intermediate_size_mlp", intermediate_padding_size ) + model_config = update_intermediate_size( + model_config, "shared_expert_intermediate_size", intermediate_padding_size + ) if ( hasattr(model_config.hf_config, "vision_config") and model_config.hf_config.vision_config.model_type == "siglip_vision_model" diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 2520284899ee..7954dc8c9e20 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -30,7 +30,7 @@ composed_weight_loader, sharded_weight_loader, ) -from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs if is_cuda(): from sglang.srt.layers.attention.mamba.causal_conv1d import ( @@ -70,6 +70,18 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - track boundary of (sharded) param, and loaded_weight, respectively boundary, loaded_boundary = 0, 0 + # Calculate padding size for CPU when TP odd size + full_dim_sum = 0 + full_dim_list = [] + weight_full_dim_list = [] + for full_dim, _, _ in shard_spec: + full_dim_sum = full_dim_sum + full_dim + full_dim_list.append(full_dim) + for full_dim in full_dim_list: + weight_full_dim_list.append( + int(full_dim / full_dim_sum * loaded_weight.size(0)) + ) + # - iterate over the shard specs for full_dim, extra, duplicate_groups in shard_spec: # - full dim is the model dim (before TP). @@ -96,6 +108,33 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # - take these many dims from the loaded weight. take = min(shard_size, full_dim - extra - loaded_skip) + # CPU logic of padding size for qwen3-next + # TODO : make this common for all mamba. + if is_cpu() and loaded_weight.size(0) % tp_size != 0: + import copy + + loaded_weight_ = copy.deepcopy(loaded_weight) + q, k, v = torch.split( + loaded_weight_, + weight_full_dim_list, + dim=0, + ) + pad_qk = torch.zeros( + full_dim_list[0] - weight_full_dim_list[0], + loaded_weight.size(1), + loaded_weight.size(2), + ).to(loaded_weight.dtype) + pad_v = torch.zeros( + full_dim_list[2] - weight_full_dim_list[2], + loaded_weight.size(1), + loaded_weight.size(2), + ).to(loaded_weight.dtype) + q = torch.cat((q, pad_qk), dim=0) + k = torch.cat((k, pad_qk), dim=0) + v = torch.cat((v, pad_v), dim=0) + loaded_weight_qk = torch.cat((q, k), dim=0) + loaded_weight = torch.cat((loaded_weight_qk, v), dim=0) + # - always shard on dim 0 # - the ignore is for a mundane mypy error as it does not # seem to handle slices well. diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index a7b987e110f4..bdfdfb852d0a 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -35,13 +35,18 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.layers.dp_attention import get_attention_tp_rank +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, ModelOptFp8Config, ) -from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once +from sglang.srt.utils import ( + find_local_repo_dir, + is_cpu, + log_info_on_rank0, + print_warning_once, +) from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -855,10 +860,23 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction: def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: tp_rank = get_attention_tp_rank() + tp_size = get_attention_tp_size() shard_size = param.data.shape[shard_axis] start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + if is_cpu() and loaded_weight.size(0) % tp_size != 0: + param.data, loaded_weight = narrow_padded_param_and_loaded_weight( + param.data, + loaded_weight, + 0, # param_data_start + start_idx, + shard_axis, + shard_size, + ) + if loaded_weight.size(0) == 32: + print(loaded_weight) + else: + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) return default_weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index d817076fbfad..1827961f371d 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -43,6 +43,7 @@ from sglang.srt.utils import ( LazyValue, add_prefix, + is_cpu, is_cuda, is_npu, make_layers, @@ -52,7 +53,7 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() - +_is_cpu = is_cpu() import triton import triton.language as tl @@ -248,8 +249,16 @@ def __init__( self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_size = get_attention_tp_size() self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads + self.num_v_heads = ( + config.linear_num_value_heads + if not _is_cpu + else config.linear_num_value_heads_cpu + ) + self.num_k_heads = ( + config.linear_num_key_heads + if not _is_cpu + else config.linear_num_key_heads_cpu + ) self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads