Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/configs/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.utils import is_cpu

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -277,6 +278,17 @@ def full_attention_layer_ids(self):

@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
world_size = get_attention_tp_size()
if is_cpu() and (
self.linear_num_key_heads % world_size != 0
or self.linear_num_value_heads % world_size != 0
):
pad_size = world_size
groups = self.linear_num_value_heads // self.linear_num_key_heads
self.linear_num_key_heads = (
(self.linear_num_key_heads + pad_size - 1) // pad_size
) * pad_size
self.linear_num_value_heads = self.linear_num_key_heads * groups
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
Expand Down
25 changes: 25 additions & 0 deletions python/sglang/srt/configs/update_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,28 @@ def adjust_config_with_unaligned_cpu_tp(
model_config.hf_config.num_attention_heads = num_attention_heads
model_config.hf_text_config.num_attention_heads = num_attention_heads

# Linear attn padding logic
if (
model_config.hf_config.linear_num_key_heads % tp_size != 0
or model_config.hf_config.linear_num_value_heads % tp_size != 0
):
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
model_config.hf_config.linear_num_key_heads_cpu = pad_vocab_size(
model_config.hf_config.linear_num_key_heads, pad_size
)
model_config.hf_config.linear_num_value_heads_cpu = (
model_config.hf_config.linear_num_key_heads_cpu
* model_config.hf_config.linear_num_value_heads
// model_config.hf_config.linear_num_key_heads
)
else:
model_config.hf_config.linear_num_key_heads_cpu = (
model_config.hf_config.linear_num_key_heads
)
model_config.hf_config.linear_num_value_heads_cpu = (
model_config.hf_config.linear_num_value_heads
)

intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
model_config = update_intermediate_size(
model_config, "moe_intermediate_size", intermediate_padding_size
Expand All @@ -129,6 +151,9 @@ def adjust_config_with_unaligned_cpu_tp(
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "shared_expert_intermediate_size", intermediate_padding_size
)
if (
hasattr(model_config.hf_config, "vision_config")
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
Expand Down
41 changes: 40 additions & 1 deletion python/sglang/srt/layers/attention/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
composed_weight_loader,
sharded_weight_loader,
)
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs

if is_cuda():
from sglang.srt.layers.attention.mamba.causal_conv1d import (
Expand Down Expand Up @@ -70,6 +70,18 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0

# Calculate padding size for CPU when TP odd size
full_dim_sum = 0
full_dim_list = []
weight_full_dim_list = []
for full_dim, _, _ in shard_spec:
full_dim_sum = full_dim_sum + full_dim
full_dim_list.append(full_dim)
for full_dim in full_dim_list:
weight_full_dim_list.append(
int(full_dim / full_dim_sum * loaded_weight.size(0))
)

# - iterate over the shard specs
for full_dim, extra, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP).
Expand All @@ -96,6 +108,33 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - take these many dims from the loaded weight.
take = min(shard_size, full_dim - extra - loaded_skip)

# CPU logic of padding size for qwen3-next
# TODO : make this common for all mamba.
if is_cpu() and loaded_weight.size(0) % tp_size != 0:
import copy

loaded_weight_ = copy.deepcopy(loaded_weight)
q, k, v = torch.split(
loaded_weight_,
weight_full_dim_list,
dim=0,
)
pad_qk = torch.zeros(
full_dim_list[0] - weight_full_dim_list[0],
loaded_weight.size(1),
loaded_weight.size(2),
).to(loaded_weight.dtype)
pad_v = torch.zeros(
full_dim_list[2] - weight_full_dim_list[2],
loaded_weight.size(1),
loaded_weight.size(2),
).to(loaded_weight.dtype)
q = torch.cat((q, pad_qk), dim=0)
k = torch.cat((k, pad_qk), dim=0)
v = torch.cat((v, pad_v), dim=0)
loaded_weight_qk = torch.cat((q, k), dim=0)
loaded_weight = torch.cat((loaded_weight_qk, v), dim=0)

# - always shard on dim 0
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
Expand Down
24 changes: 21 additions & 3 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config,
ModelOptFp8Config,
)
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
from sglang.srt.utils import (
find_local_repo_dir,
is_cpu,
log_info_on_rank0,
print_warning_once,
)
from sglang.utils import is_in_ci

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -855,10 +860,23 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:

def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_attention_tp_rank()
tp_size = get_attention_tp_size()

shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
if is_cpu() and loaded_weight.size(0) % tp_size != 0:
param.data, loaded_weight = narrow_padded_param_and_loaded_weight(
param.data,
loaded_weight,
0, # param_data_start
start_idx,
shard_axis,
shard_size,
)
if loaded_weight.size(0) == 32:
print(loaded_weight)
else:
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)

return default_weight_loader(param, loaded_weight)

Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cpu,
is_cuda,
is_npu,
make_layers,
Expand All @@ -52,7 +53,7 @@
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()

_is_cpu = is_cpu()
import triton
import triton.language as tl

Expand Down Expand Up @@ -248,8 +249,16 @@ def __init__(
self.attn_tp_rank = get_attention_tp_rank()
self.attn_tp_size = get_attention_tp_size()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.num_v_heads = (
config.linear_num_value_heads
if not _is_cpu
else config.linear_num_value_heads_cpu
)
self.num_k_heads = (
config.linear_num_key_heads
if not _is_cpu
else config.linear_num_key_heads_cpu
)
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
Expand Down