Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c5d2c5e
Support the case where `num_attention_heads` is not divisible by the …
chunyuan-w Mar 11, 2025
0871a3a
pad intermediate_size (#32)
chunyuan-w Apr 2, 2025
d039ce7
FP8: update MoE padding value and support TP=6 (#53)
chunyuan-w Apr 15, 2025
8e34004
fix issue after rebase in commit: not divisible by the TP size
chunyuan-w May 29, 2025
8f2a6ea
fix issue after rebase in commit: MoE padding
chunyuan-w May 29, 2025
3232e35
fix issue after rebase in commit: FP8
chunyuan-w Jun 3, 2025
df59cfa
refactor narrow and reset part into a util function
chunyuan-w Jun 3, 2025
df1a274
use util function in MergedColumnParallelLinear
chunyuan-w Jun 3, 2025
3cc1042
use util function in RowParallelLinear
chunyuan-w Jun 3, 2025
187a099
remove unused import
chunyuan-w Jun 3, 2025
1fb220e
use util func in moe _load_w2
chunyuan-w Jun 3, 2025
407aec8
use util func in moe _load_w13
chunyuan-w Jun 3, 2025
85b7eca
remove unused import in moe
chunyuan-w Jun 3, 2025
91e41b3
use util func in FP8 _ColumnvLLMParameter
chunyuan-w Jun 3, 2025
6c73063
do not overwrite self.data
chunyuan-w Jun 3, 2025
201785f
use util func in FP8 _ColumnvLLMParameter merged column weight
chunyuan-w Jun 3, 2025
ed200db
use util func in FP8 RowvLLMParameter
chunyuan-w Jun 3, 2025
8fa4539
revert unintended change
chunyuan-w Jun 3, 2025
25bb617
fix the case where actual_shard_size = 0
chunyuan-w Jun 3, 2025
0079fcc
add narrow logic in QKVParallelLinear
chunyuan-w Jun 3, 2025
85ea205
only update embedding padding_size for CPU
chunyuan-w Jun 12, 2025
cbe7591
use the use_cpu() API from #6614
chunyuan-w Jun 13, 2025
edc3953
fix FP8 TP=3
chunyuan-w Jun 13, 2025
f8dc8c4
fix the init of weight_block_size
chunyuan-w Jun 16, 2025
e3aefac
change _use_cpu to _is_cpu
chunyuan-w Jun 18, 2025
971db2b
fix head_dim in qwen2 when num_attention_heads is padded (#82)
chunyuan-w Jun 19, 2025
ac3d715
add padding if total_num_kv_heads % tp_size != 0
chunyuan-w Jun 19, 2025
32b711e
llama4: use original num heads and kv heads when permute loaded weight
chunyuan-w Jun 19, 2025
0b6e45d
llama4: add is_cpu() when getting num heads
chunyuan-w Jun 19, 2025
d2f768a
fix format and remove unused import
chunyuan-w Jun 27, 2025
4a0a2cc
move functions to srt/configs/update_config.py and srt/model_loader/w…
chunyuan-w Jul 1, 2025
96f00ba
add update_config.py
chunyuan-w Jul 1, 2025
40be63a
Merge branch 'main' into chunyuan/pr_tp
zhyncs Jul 1, 2025
d09e60f
Merge branch 'main' into chunyuan/pr_tp
zhyncs Jul 1, 2025
21deb40
make original_total_num_kv_heads available if _is_cpu; change attr name
chunyuan-w Jul 1, 2025
ff7a2c6
Merge branch 'main' into chunyuan/pr_tp
zhyncs Jul 3, 2025
a8c32b9
Merge branch 'main' into chunyuan/pr_tp
zhyncs Jul 3, 2025
b4dc9a8
Merge branch 'main' into chunyuan/pr_tp
zhyncs Jul 3, 2025
f5fe517
only call cuda_graph_mem_usage if not _is_cpu
chunyuan-w Jul 3, 2025
4129599
Merge branch 'main' into chunyuan/pr_tp
Alcanderian Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_process_weight_after_loading,
cpu_has_amx_support,
is_cpu,
narrow_padded_param_and_loaded_weight,
set_weight_attrs,
)

Expand Down Expand Up @@ -425,8 +426,22 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
output_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)

# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
Expand Down Expand Up @@ -643,10 +658,25 @@ def weight_loader(

param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = self.tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
output_dim,
shard_size,
not use_bitsandbytes_4bit and not self.use_presharded_weights,
)
else:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)

# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1111,10 +1141,23 @@ def weight_loader(
shard_id = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
output_dim,
shard_size,
not use_bitsandbytes_4bit and not self.use_presharded_weights,
)
else:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)

# Special case for for AQLM codebooks.
elif is_metadata:
Expand Down Expand Up @@ -1256,7 +1299,18 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
):
shard_size = param_data.shape[input_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)

if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
start_idx,
input_dim,
shard_size,
)
else:
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)

# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
Expand Down
42 changes: 33 additions & 9 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_bool_env_var,
is_cpu,
is_hip,
narrow_padded_param_and_loaded_weight,
set_weight_attrs,
)

Expand Down Expand Up @@ -537,11 +538,6 @@ def _load_w13(
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2

if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)

# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
Expand All @@ -552,7 +548,24 @@ def _load_w13(
start = shard_size
else:
start = 0
expert_data = expert_data.narrow(shard_dim, start, shard_size)

if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
start,
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)

expert_data = expert_data.narrow(shard_dim, start, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(
Expand All @@ -569,10 +582,21 @@ def _load_w2(
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]

if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
expert_data,
loaded_weight,
0, # param_data_start
shard_size * tp_rank,
shard_dim,
shard_size,
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)

# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
Expand Down
72 changes: 63 additions & 9 deletions python/sglang/srt/layers/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from torch.nn import Parameter

from sglang.srt.utils import is_cpu

__all__ = [
"BasevLLMParameter",
"PackedvLLMParameter",
Expand All @@ -21,6 +23,8 @@

logger = logging.getLogger(__name__)

_is_cpu = is_cpu()


class BasevLLMParameter(Parameter):
"""
Expand Down Expand Up @@ -93,9 +97,26 @@ def load_column_parallel_weight(
):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)

from sglang.srt.utils import narrow_padded_param_and_loaded_weight

if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
self.data,
loaded_weight,
0, # param_data_start
tp_rank * shard_size,
self.output_dim,
shard_size,
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
else:
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)

assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)

Expand All @@ -116,10 +137,25 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
param_data = self.data

param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size

from sglang.srt.utils import narrow_padded_param_and_loaded_weight

if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
tp_rank * shard_size,
self.output_dim,
shard_size,
not use_presharded_weights,
)
else:
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -182,9 +218,27 @@ def load_row_parallel_weight(
):
if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)

from sglang.srt.utils import narrow_padded_param_and_loaded_weight

if _is_cpu:
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
self.data,
loaded_weight,
0, # param_data_start
tp_rank * shard_size,
self.input_dim,
shard_size,
)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

return
else:
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,16 @@ def __init__(
self.tp_size = 1

self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings

# Support the case where the vocab size is not divisible by the TP size.
if (
_is_cpu
and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
):
padding_size *= self.tp_size
self.padding_size = padding_size

num_added_embeddings = num_embeddings - self.org_vocab_size
self.use_presharded_weights = use_presharded_weights
if use_presharded_weights:
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes,
set_cuda_arch,
update_config,
)

_is_hip = is_hip()
Expand Down Expand Up @@ -159,7 +160,6 @@ def __init__(
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
):
# Parse args
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
Expand All @@ -172,6 +172,7 @@ def __init__(
self.dp_size = server_args.dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.model_config = model_config
self.dist_port = nccl_port
self.server_args = server_args
self.is_draft_worker = is_draft_worker
Expand Down Expand Up @@ -555,6 +556,10 @@ def load_model(self):
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
)
if self.device == "cpu":
self.model_config = update_config(
self.model_config, self.load_config, self.tp_size
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()

Expand Down
28 changes: 21 additions & 7 deletions python/sglang/srt/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, is_cpu

_is_cpu = is_cpu()


class Llama4ForConditionalGeneration(nn.Module):
Expand Down Expand Up @@ -110,13 +112,25 @@ def permute(w: torch.Tensor, n_heads: int):

# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_key_value_heads
)
if _is_cpu:
dim = getattr(
self.language_model.config,
"original_total_num_kv_heads",
self.language_model.config.num_key_value_heads,
)
else:
dim = self.language_model.config.num_key_value_heads
loaded_weight = permute(loaded_weight, dim)
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.language_model.config.num_attention_heads
)
if _is_cpu:
dim = getattr(
self.language_model.config,
"original_num_attention_heads",
self.language_model.config.num_attention_heads,
)
else:
dim = self.language_model.config.num_attention_heads
loaded_weight = permute(loaded_weight, dim)

return name, loaded_weight

Expand Down
Loading