-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight #6265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
44a270c
3b6c997
76ee056
15751b6
1d1bf24
8a36ff6
af7a806
d6b2c62
f3fa6de
7b6f462
d8806a6
32f2e8d
1a26b13
00c9722
fa8236f
9499f94
194b671
d8c047f
1d68017
3bed57a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1533,14 +1533,23 @@ def forward( | |
| input_ids, hidden_states, self.lm_head, forward_batch | ||
| ) | ||
|
|
||
| def post_load_weights(self, is_nextn=False): | ||
| def post_load_weights(self, is_nextn=False, weight_names=None): | ||
|
|
||
| # Perform post-processing after loading weights | ||
| layer_ids = ( | ||
| range(self.config.num_hidden_layers) | ||
| if not is_nextn | ||
| else [self.config.num_hidden_layers] | ||
| ) | ||
| if weight_names is None: | ||
| layer_ids = ( | ||
| range(self.config.num_hidden_layers) | ||
| if not is_nextn | ||
| else [self.config.num_hidden_layers] | ||
| ) | ||
| else: | ||
| layer_ids = [] | ||
| for name in weight_names: | ||
| if "kv_b_proj" in name: | ||
| layer_id = int(name.split(".")[2]) | ||
| if layer_id not in layer_ids: | ||
| layer_ids.append(layer_id) | ||
|
|
||
| for layer_id in layer_ids: | ||
| self_attn = ( | ||
| self.model.layers[layer_id].self_attn | ||
|
|
@@ -1640,13 +1649,22 @@ def post_load_weights(self, is_nextn=False): | |
| 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) | ||
| ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) | ||
| if not use_deep_gemm_bmm: | ||
| self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) | ||
| self_attn.w_vc = w_vc.contiguous().transpose(1, 2) | ||
| if self_attn.w_kc is None: | ||
| self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) | ||
| self_attn.w_vc = w_vc.contiguous().transpose(1, 2) | ||
| else: | ||
| self_attn.w_kc.copy_( | ||
| w_kc.transpose(1, 2).contiguous().transpose(1, 2) | ||
| ) | ||
| self_attn.w_vc.copy_(w_vc.contiguous().transpose(1, 2)) | ||
|
||
| if ( | ||
| hasattr(self_attn.kv_b_proj, "weight_scale") | ||
| and self_attn.w_scale is None | ||
| ): | ||
| self_attn.w_scale = self_attn.kv_b_proj.weight_scale | ||
| if self_attn.w_scale is None: | ||
| self_attn.w_scale = self_attn.kv_b_proj.weight_scale | ||
| else: | ||
| self_attn.w_scale.copy_(self_attn.kv_b_proj.weight_scale) | ||
|
||
| if _is_hip: | ||
| self_attn.w_scale *= 2.0 | ||
| else: | ||
|
|
@@ -1655,10 +1673,16 @@ def post_load_weights(self, is_nextn=False): | |
| ws_kc, ws_vc = block_scale.unflatten( | ||
| 0, (-1, (num_tiles_k + num_tiles_n)) | ||
| ).split([num_tiles_k, num_tiles_n], dim=1) | ||
| self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous() | ||
| self_attn.w_scale_v = ws_vc.contiguous() | ||
| self_attn.w_kc = w_kc.transpose(1, 2).contiguous() | ||
| self_attn.w_vc = w_vc.contiguous() | ||
| if self_attn.w_kc is None: | ||
| self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous() | ||
| self_attn.w_scale_v = ws_vc.contiguous() | ||
| self_attn.w_kc = w_kc.transpose(1, 2).contiguous() | ||
| self_attn.w_vc = w_vc.contiguous() | ||
| else: | ||
| self_attn.w_scale_k.copy_(ws_kc.transpose(1, 2).contiguous()) | ||
| self_attn.w_scale_v.copy_(ws_vc.contiguous()) | ||
| self_attn.w_kc.copy_(w_kc.transpose(1, 2).contiguous()) | ||
| self_attn.w_vc.copy_(w_vc.contiguous()) | ||
| self_attn.use_deep_gemm_bmm = True | ||
|
|
||
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): | ||
|
|
@@ -1765,7 +1789,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal | |
| ] | ||
|
|
||
| params_dict = dict(self.named_parameters()) | ||
| weight_names = [] | ||
| for name, loaded_weight in weights: | ||
| weight_names.append(name) | ||
|
|
||
| if not is_nextn: | ||
| if hasattr(self.config, "num_nextn_predict_layers"): | ||
| num_nextn_layers = self.config.num_nextn_predict_layers | ||
|
|
@@ -1883,7 +1910,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal | |
| ) | ||
| weight_loader(param, loaded_weight) | ||
|
|
||
| self.post_load_weights(is_nextn=is_nextn) | ||
| self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) | ||
|
|
||
| def get_embed_and_head(self): | ||
| return self.model.embed_tokens.weight, self.lm_head.weight | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layer_ids = set()?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.