Skip to content

Commit ca2f4d0

Browse files
mikekgfbmalfet
authored andcommitted
add init file, as per Jez (pytorch#201)
* add init file, as per Jez * split fused weights
1 parent f17d2dd commit ca2f4d0

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

build/__init__.py

Whitespace-only changes.

build/model.py

+12
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,18 @@ def load_hook(self, state_dict, prefix, *args):
285285
# wv = state_dict.pop(prefix + "wv.weight")
286286
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
287287

288+
if prefix + "wqkv.weight" in state_dict:
289+
wqkv = state_dict.pop(prefix + "wqkv.weight")
290+
q_size = self.n_heads * self.head_dim
291+
kv_size = self.n_local_heads * self.head_dim
292+
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
293+
state_dict[prefix + "wq.weight"] = wq
294+
state_dict[prefix + "wk.weight"] = wk
295+
state_dict[prefix + "wv.weight"] = wv
296+
297+
return
298+
299+
288300
def _unfuse_wqkv_state_dict(
289301
state_dict: Dict[str, torch.Tensor],
290302
dim: int,

0 commit comments

Comments
 (0)