From 232bed8a41379a6b76cae8b5fa5e2486f866e65b Mon Sep 17 00:00:00 2001 From: Louis Kirsch Date: Tue, 2 May 2017 21:53:05 +0200 Subject: [PATCH] Added option to export weights and biases as numpy arrays --- speecht-cli | 17 +++++++++++++++- speecht/exporting.py | 44 +++++++++++++++++++++++++++++++++++++++++ speecht/speech_model.py | 3 +++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 speecht/exporting.py diff --git a/speecht-cli b/speecht-cli index a2661f3..3a068aa 100755 --- a/speecht-cli +++ b/speecht-cli @@ -32,6 +32,7 @@ class CLI: self._add_recording_parser() self._add_parameter_search_parser() self._add_preprocess_parser() + self._add_export_parser() def _create_base_parser(self): base_parser = argparse.ArgumentParser(add_help=False) @@ -52,6 +53,14 @@ class CLI: base_parser.set_defaults(feature_type='power') return base_parser + def _add_export_parser(self): + export_parser = self.subparsers.add_parser('export', help='Export network details', + parents=[self.base_parser]) + export_parser.add_argument('--weights', dest='export_weights_dir', type=str, + help='Store the weights in numpy arrays') + export_parser.add_argument('--input-size', dest='input_size', type=int, default=128, + help='The input size of each sample, depending on what preprocessing was used') + def _add_training_parser(self): training_parser = self.subparsers.add_parser('train', help='Train the wav2letter weights.', parents=[self.base_parser]) @@ -174,6 +183,11 @@ class CLI: import speecht.preprocessing return speecht.preprocessing.Preprocessing(flags) + @staticmethod + def _get_export_executor(flags): + import speecht.exporting + return speecht.exporting.Exporting(flags) + @lazy def command_executor(self): return { @@ -181,7 +195,8 @@ class CLI: 'evaluate': self._get_evaluation_executor, 'record': self._get_recording_executor, 'search': self._get_search_executor, - 'preprocess': self._get_preprocessing_executor + 'preprocess': self._get_preprocessing_executor, + 'export': self._get_export_executor }[self.parsed.command](self.parsed) def run(self): diff --git a/speecht/exporting.py b/speecht/exporting.py new file mode 100644 index 0000000..320e2b9 --- /dev/null +++ b/speecht/exporting.py @@ -0,0 +1,44 @@ +from pathlib import Path + +from speecht.speech_input import SingleInputLoader +from speecht.speech_model import create_default_model + +import tensorflow as tf +import numpy as np + + +class Exporting: + + def __init__(self, flags): + self.flags = flags + + def create_model(self, sess: tf.Session): + input_loader = SingleInputLoader(self.flags.input_size) + model = create_default_model(self.flags, self.flags.input_size, input_loader) + model.restore(sess, self.flags.run_train_dir) + return model + + def run(self): + with tf.Session() as sess: + self.create_model(sess) + + if self.flags.export_weights_dir: + path = Path(self.flags.export_weights_dir) + if not path.exists(): + path.mkdir() + + variables = tf.trainable_variables() + values = sess.run(variables) + + for variable, value in zip(variables, values): + file_path = path / variable.name + parent_dir = Path(file_path.parent) + if not parent_dir.exists(): + parent_dir.mkdir() + + # noinspection PyTypeChecker + np.save(file_path, value) + + return + + print('Nothing to do.') diff --git a/speecht/speech_model.py b/speecht/speech_model.py index f4691e6..bef11e6 100644 --- a/speecht/speech_model.py +++ b/speecht/speech_model.py @@ -307,6 +307,9 @@ def create_default_model(flags, input_size: int, speech_input: BaseInputLoader) max_gradient_norm=flags.max_gradient_norm, momentum=flags.momentum) model.add_decoding_ops() + elif flags.command == 'export': + model.add_training_ops() + model.add_decoding_ops() else: model.add_training_ops() model.add_decoding_ops(language_model=flags.language_model,