-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadaptor.py
23 lines (19 loc) · 928 Bytes
/
adaptor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tensorflow as tf
# replace represents the mapping from TF 0.12 to TF 1.0 RNN definitions.
replace = dict()
for index in n_layers:
replace['rnn/rnn/MultiRNNCell/Cell' + str(index) + '/BasicLSTMCell/Linear/Matrix'] = \
'rnn/rnn/multi_rnn_cell/cell_' + str(index) + '/basic_lstm_cell/weights'
replace['rnn/rnn/MultiRNNCell/Cell' + str(index) + '/BasicLSTMCell/Linear/Bias'] = \
'rnn/rnn/multi_rnn_cell/cell_' + str(index) + '/basic_lstm_cell/biases'
with tf.Session() as sess:
saver = tf.train.Saver(max_to_keep=1000)
sess.run(tf.initialize_all_variables())
saver.restore(sess, "old.ckpt")
names_to_vars = {v.op.name: v for v in tf.all_variables()}
for key in replace.keys():
bias_var = names_to_vars[key]
names_to_vars[replace[key]] = bias_var
del names_to_vars[key]
saver = tf.train.Saver(var_list=names_to_vars)
saver.save(sess, 'new.ckpt')