-
Notifications
You must be signed in to change notification settings - Fork 179
/
checkpoint.py
66 lines (55 loc) · 2.73 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright 2018 Dong-Hyun Lee, Kakao Brain.
""" Load a checkpoint file of pretrained transformer to a model in pytorch """
import numpy as np
import tensorflow as tf
import torch
#import ipdb
#from models import *
def load_param(checkpoint_file, conversion_table):
"""
Load parameters in pytorch model from checkpoint file according to conversion_table
checkpoint_file : pretrained checkpoint model file in tensorflow
conversion_table : { pytorch tensor in a model : checkpoint variable name }
"""
for pyt_param, tf_param_name in conversion_table.items():
tf_param = tf.train.load_variable(checkpoint_file, tf_param_name)
# for weight(kernel), we should do transpose
if tf_param_name.endswith('kernel'):
tf_param = np.transpose(tf_param)
assert pyt_param.size() == tf_param.shape, \
'Dim Mismatch: %s vs %s ; %s' % \
(tuple(pyt_param.size()), tf_param.shape, tf_param_name)
# assign pytorch tensor from tensorflow param
pyt_param.data = torch.from_numpy(tf_param)
def load_model(model, checkpoint_file):
""" Load the pytorch model from checkpoint file """
# Embedding layer
e, p = model.embed, 'bert/embeddings/'
load_param(checkpoint_file, {
e.tok_embed.weight: p+"word_embeddings",
e.pos_embed.weight: p+"position_embeddings",
e.seg_embed.weight: p+"token_type_embeddings",
e.norm.gamma: p+"LayerNorm/gamma",
e.norm.beta: p+"LayerNorm/beta"
})
# Transformer blocks
for i in range(len(model.blocks)):
b, p = model.blocks[i], "bert/encoder/layer_%d/"%i
load_param(checkpoint_file, {
b.attn.proj_q.weight: p+"attention/self/query/kernel",
b.attn.proj_q.bias: p+"attention/self/query/bias",
b.attn.proj_k.weight: p+"attention/self/key/kernel",
b.attn.proj_k.bias: p+"attention/self/key/bias",
b.attn.proj_v.weight: p+"attention/self/value/kernel",
b.attn.proj_v.bias: p+"attention/self/value/bias",
b.proj.weight: p+"attention/output/dense/kernel",
b.proj.bias: p+"attention/output/dense/bias",
b.pwff.fc1.weight: p+"intermediate/dense/kernel",
b.pwff.fc1.bias: p+"intermediate/dense/bias",
b.pwff.fc2.weight: p+"output/dense/kernel",
b.pwff.fc2.bias: p+"output/dense/bias",
b.norm1.gamma: p+"attention/output/LayerNorm/gamma",
b.norm1.beta: p+"attention/output/LayerNorm/beta",
b.norm2.gamma: p+"output/LayerNorm/gamma",
b.norm2.beta: p+"output/LayerNorm/beta",
})