Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
44a270c
Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight
zhuzilin May 13, 2025
3b6c997
use set instead of list
zhuzilin May 14, 2025
76ee056
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 17, 2025
15751b6
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 20, 2025
1d1bf24
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 20, 2025
8a36ff6
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 20, 2025
af7a806
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 21, 2025
d6b2c62
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 21, 2025
f3fa6de
add util function
zhuzilin May 23, 2025
7b6f462
fix bug with nextn layer
zhuzilin May 23, 2025
d8806a6
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 23, 2025
32f2e8d
bugfix
zhuzilin May 23, 2025
1a26b13
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 23, 2025
00c9722
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 23, 2025
fa8236f
bugfix
zhuzilin May 25, 2025
9499f94
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 25, 2025
194b671
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhyncs May 26, 2025
d8c047f
bugfix
zhuzilin May 27, 2025
1d68017
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhuzilin May 27, 2025
3bed57a
Merge branch 'main' into feature/fix_deepseek_v2_loader
zhaochenyang20 May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,14 +1533,23 @@ def forward(
input_ids, hidden_states, self.lm_head, forward_batch
)

def post_load_weights(self, is_nextn=False):
def post_load_weights(self, is_nextn=False, weight_names=None):

# Perform post-processing after loading weights
layer_ids = (
range(self.config.num_hidden_layers)
if not is_nextn
else [self.config.num_hidden_layers]
)
if weight_names is None:
layer_ids = (
range(self.config.num_hidden_layers)
if not is_nextn
else [self.config.num_hidden_layers]
)
else:
layer_ids = []
Copy link
Contributor

@Edenzzzz Edenzzzz May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer_ids = set()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

for name in weight_names:
if "kv_b_proj" in name:
layer_id = int(name.split(".")[2])
if layer_id not in layer_ids:
layer_ids.append(layer_id)

for layer_id in layer_ids:
self_attn = (
self.model.layers[layer_id].self_attn
Expand Down Expand Up @@ -1640,13 +1649,22 @@ def post_load_weights(self, is_nextn=False):
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if self_attn.w_kc is None:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
else:
self_attn.w_kc.copy_(
w_kc.transpose(1, 2).contiguous().transpose(1, 2)
)
self_attn.w_vc.copy_(w_vc.contiguous().transpose(1, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By corrupt do you mean the weight tensor is incorrect or the output?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output will be incorrect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if weight is not assigned or deepgemm is not using the new weights. Perhaps check the w_kc.data_ptr()here if you want

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of suspect, it's a issue with cuda graph, as cuda graph will remember the pointer of the tensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if self_attn.w_scale is None:
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
else:
self_attn.w_scale.copy_(self_attn.kv_b_proj.weight_scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: handle these tensor with a new method

def bind_or_assign(target, source):
    if target is not None:
           target.copy_(source)
           return target
    else:
           return source

in srt/utils.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

if _is_hip:
self_attn.w_scale *= 2.0
else:
Expand All @@ -1655,10 +1673,16 @@ def post_load_weights(self, is_nextn=False):
ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
if self_attn.w_kc is None:
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
else:
self_attn.w_scale_k.copy_(ws_kc.transpose(1, 2).contiguous())
self_attn.w_scale_v.copy_(ws_vc.contiguous())
self_attn.w_kc.copy_(w_kc.transpose(1, 2).contiguous())
self_attn.w_vc.copy_(w_vc.contiguous())
self_attn.use_deep_gemm_bmm = True

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
Expand Down Expand Up @@ -1765,7 +1789,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
]

params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
weight_names.append(name)

if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
Expand Down Expand Up @@ -1883,7 +1910,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
)
weight_loader(param, loaded_weight)

self.post_load_weights(is_nextn=is_nextn)
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)

def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
Expand Down
Loading