|
1 | 1 | #!/usr/bin/env python |
2 | | -# coding=utf-8 |
3 | | -# Copyright 2017 The Tensor2Tensor Authors. |
4 | | -# |
5 | | -# Licensed under the Apache License, Version 2.0 (the "License"); |
6 | | -# you may not use this file except in compliance with the License. |
7 | | -# You may obtain a copy of the License at |
8 | | -# |
9 | | -# http://www.apache.org/licenses/LICENSE-2.0 |
10 | | -# |
11 | | -# Unless required by applicable law or agreed to in writing, software |
12 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
13 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | -# See the License for the specific language governing permissions and |
15 | | -# limitations under the License. |
16 | | - |
17 | | -"""Script to continously average last N checkpoints in a given directory.""" |
| 2 | +"""t2t-avg-all.""" |
18 | 3 | from __future__ import absolute_import |
19 | 4 | from __future__ import division |
20 | 5 | from __future__ import print_function |
21 | 6 |
|
22 | | -import os |
23 | | -import logging |
24 | | - |
25 | | -# Dependency imports |
| 7 | +from tensor2tensor.bin import t2t_avg_all |
26 | 8 |
|
27 | | -import numpy as np |
28 | | -import six |
29 | | -from six.moves import zip # pylint: disable=redefined-builtin |
30 | | -from collections import deque |
31 | | -import shutil |
32 | 9 | import tensorflow as tf |
33 | | -from tensor2tensor.utils import bleu_hook |
34 | | - |
35 | | -flags = tf.flags |
36 | | -FLAGS = flags.FLAGS |
37 | | - |
38 | | -flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.") |
39 | | -flags.DEFINE_string("output_dir", "avg/", "Directory to output the averaged checkpoints to.") |
40 | | -flags.DEFINE_integer("n", 8, "How many checkpoints should be averaged?") |
41 | | -flags.DEFINE_integer("min_steps", 0, "Ignore checkpoints with less steps.") |
42 | | -flags.DEFINE_integer("wait_minutes", 0, "Wait upto N minutes for a new checkpoint.") |
43 | | - |
44 | | - |
45 | | -def main(_): |
46 | | - tf.logging._handler.setFormatter(logging.Formatter("%(asctime)s:" + logging.BASIC_FORMAT, None)) |
47 | | - tf.logging.set_verbosity(tf.logging.INFO) |
48 | | - |
49 | | - model_dir = os.path.expanduser(FLAGS.model_dir) |
50 | | - output_dir = os.path.expanduser(FLAGS.output_dir) |
51 | | - out_base_file = os.path.join(output_dir, 'model.ckpt') |
52 | | - |
53 | | - # Copy flags.txt with the original time, so t2t-bleu can report correct relative time. |
54 | | - os.makedirs(FLAGS.output_dir, exist_ok=True) |
55 | | - if not os.path.exists(os.path.join(output_dir, 'flags.txt')): |
56 | | - shutil.copy2(os.path.join(model_dir, 'flags.txt'), os.path.join(output_dir, 'flags.txt')) |
57 | | - |
58 | | - models_processed = 0 |
59 | | - queue = deque() |
60 | | - for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes, FLAGS.min_steps): |
61 | | - if models_processed == 0: |
62 | | - var_list = tf.contrib.framework.list_variables(model.filename) |
63 | | - avg_values = {} |
64 | | - for (name, shape) in var_list: |
65 | | - if not name.startswith("global_step"): |
66 | | - avg_values[name] = np.zeros(shape) |
67 | | - models_processed += 1 |
68 | | - |
69 | | - tf.logging.info("Loading [%d]: %s" % (models_processed, model.filename)) |
70 | | - reader = tf.contrib.framework.load_checkpoint(model.filename) |
71 | | - for name in avg_values: |
72 | | - avg_values[name] += reader.get_tensor(name) / FLAGS.n |
73 | | - queue.append(model) |
74 | | - if len(queue) < FLAGS.n: |
75 | | - continue |
76 | | - |
77 | | - out_file = "%s-%d" % (out_base_file, model.steps) |
78 | | - tf_vars = [] |
79 | | - tf.logging.info("Averaging %s" % (out_file)) |
80 | | - for (name, value) in six.iteritems(avg_values): |
81 | | - tf_vars.append(tf.get_variable(name, shape=value.shape)) # TODO , dtype=var_dtypes[name] |
82 | | - placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] |
83 | | - assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] |
84 | | - |
85 | | - global_step = tf.Variable(model.steps, name="global_step", trainable=False, dtype=tf.int64) |
86 | | - saver = tf.train.Saver(tf.global_variables()) |
87 | | - |
88 | | - tf.logging.info("Running session for %s" % (out_file)) |
89 | | - with tf.Session() as sess: |
90 | | - sess.run(tf.global_variables_initializer()) |
91 | | - for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(avg_values)): |
92 | | - sess.run(assign_op, {p: value}) |
93 | | - tf.logging.info("Storing to %s" % out_file) |
94 | | - saver.save(sess, out_base_file, global_step=global_step) |
95 | | - os.utime(out_file + '.index', (model.mtime, model.mtime)) |
96 | | - |
97 | | - tf.reset_default_graph() |
98 | | - first_model = queue.popleft() |
99 | 10 |
|
100 | | - reader = tf.contrib.framework.load_checkpoint(first_model.filename) |
101 | | - for name in avg_values: |
102 | | - avg_values[name] -= reader.get_tensor(name) / FLAGS.n |
| 11 | +def main(argv): |
| 12 | + t2t_avg_all.main(argv) |
103 | 13 |
|
104 | 14 |
|
105 | 15 | if __name__ == "__main__": |
|
0 commit comments