Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
44a270c
Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight
zhuzilin May 13, 2025
3b6c997
use set instead of list
zhuzilin May 14, 2025
76ee056
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 17, 2025
15751b6
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 20, 2025
1d1bf24
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 20, 2025
8a36ff6
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 20, 2025
af7a806
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 21, 2025
d6b2c62
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 21, 2025
f3fa6de
add util function
zhuzilin May 23, 2025
7b6f462
fix bug with nextn layer
zhuzilin May 23, 2025
d8806a6
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 23, 2025
32f2e8d
bugfix
zhuzilin May 23, 2025
1a26b13
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 23, 2025
00c9722
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 23, 2025
fa8236f
bugfix
zhuzilin May 25, 2025
9499f94
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 25, 2025
194b671
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 26, 2025
d8c047f
bugfix
zhuzilin May 27, 2025
1d68017
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhuzilin May 27, 2025
3bed57a
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
BumpAllocator,
DeepEPMode,
add_prefix,
bind_or_assign,
get_bool_env_var,
get_int_env_var,
is_cuda,
Expand Down Expand Up @@ -1713,14 +1714,23 @@ def forward(
input_ids, hidden_states, self.lm_head, forward_batch
)

def post_load_weights(self, is_nextn=False):
def post_load_weights(self, is_nextn=False, weight_names=None):

# Perform post-processing after loading weights
layer_ids = (
range(self.config.num_hidden_layers)
if not is_nextn
else [self.config.num_hidden_layers]
)
if is_nextn:
layer_ids = [self.config.num_hidden_layers]
else:
if weight_names is None:
layer_ids = range(self.config.num_hidden_layers)
else:
layer_ids = set()
for name in weight_names:
if "kv_b_proj" in name:
layer_id = int(name.split(".")[2])
# filter the nextn layer.
if layer_id != self.config.num_hidden_layers:
layer_ids.add(layer_id)

for layer_id in layer_ids:
self_attn = (
self.model.layers[layer_id].self_attn
Expand Down Expand Up @@ -1830,13 +1840,19 @@ def post_load_weights(self, is_nextn=False):
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
self_attn.w_kc = bind_or_assign(
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
)
self_attn.w_vc = bind_or_assign(
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
self_attn.w_scale = bind_or_assign(
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
)
if _is_hip:
self_attn.w_scale *= 2.0
else:
Expand All @@ -1845,10 +1861,16 @@ def post_load_weights(self, is_nextn=False):
ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
self_attn.w_scale_k = bind_or_assign(
self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
)
self_attn.w_scale_v = bind_or_assign(
self_attn.w_scale_v, ws_vc.contiguous()
)
self_attn.w_kc = bind_or_assign(
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
)
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True

# TODO support nextn later
Expand Down Expand Up @@ -1958,7 +1980,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
]

params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
weight_names.append(name)

if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
Expand Down Expand Up @@ -2075,7 +2100,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
)
weight_loader(param, loaded_weight)

self.post_load_weights(is_nextn=is_nextn)
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)

def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
except Exception:
# If anything fails, return empty string
return ""


def bind_or_assign(target, source):
if target is not None:
target.copy_(source)
return target
else:
return source
Loading