Skip to content

Commit

Permalink
Added option to export weights and biases as numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
louiskirsch committed May 2, 2017
1 parent e9b9d80 commit 232bed8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
17 changes: 16 additions & 1 deletion speecht-cli
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CLI:
self._add_recording_parser()
self._add_parameter_search_parser()
self._add_preprocess_parser()
self._add_export_parser()

def _create_base_parser(self):
base_parser = argparse.ArgumentParser(add_help=False)
Expand All @@ -52,6 +53,14 @@ class CLI:
base_parser.set_defaults(feature_type='power')
return base_parser

def _add_export_parser(self):
export_parser = self.subparsers.add_parser('export', help='Export network details',
parents=[self.base_parser])
export_parser.add_argument('--weights', dest='export_weights_dir', type=str,
help='Store the weights in numpy arrays')
export_parser.add_argument('--input-size', dest='input_size', type=int, default=128,
help='The input size of each sample, depending on what preprocessing was used')

def _add_training_parser(self):
training_parser = self.subparsers.add_parser('train', help='Train the wav2letter weights.',
parents=[self.base_parser])
Expand Down Expand Up @@ -174,14 +183,20 @@ class CLI:
import speecht.preprocessing
return speecht.preprocessing.Preprocessing(flags)

@staticmethod
def _get_export_executor(flags):
import speecht.exporting
return speecht.exporting.Exporting(flags)

@lazy
def command_executor(self):
return {
'train': self._get_training_executor,
'evaluate': self._get_evaluation_executor,
'record': self._get_recording_executor,
'search': self._get_search_executor,
'preprocess': self._get_preprocessing_executor
'preprocess': self._get_preprocessing_executor,
'export': self._get_export_executor
}[self.parsed.command](self.parsed)

def run(self):
Expand Down
44 changes: 44 additions & 0 deletions speecht/exporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pathlib import Path

from speecht.speech_input import SingleInputLoader
from speecht.speech_model import create_default_model

import tensorflow as tf
import numpy as np


class Exporting:

def __init__(self, flags):
self.flags = flags

def create_model(self, sess: tf.Session):
input_loader = SingleInputLoader(self.flags.input_size)
model = create_default_model(self.flags, self.flags.input_size, input_loader)
model.restore(sess, self.flags.run_train_dir)
return model

def run(self):
with tf.Session() as sess:
self.create_model(sess)

if self.flags.export_weights_dir:
path = Path(self.flags.export_weights_dir)
if not path.exists():
path.mkdir()

variables = tf.trainable_variables()
values = sess.run(variables)

for variable, value in zip(variables, values):
file_path = path / variable.name
parent_dir = Path(file_path.parent)
if not parent_dir.exists():
parent_dir.mkdir()

# noinspection PyTypeChecker
np.save(file_path, value)

return

print('Nothing to do.')
3 changes: 3 additions & 0 deletions speecht/speech_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def create_default_model(flags, input_size: int, speech_input: BaseInputLoader)
max_gradient_norm=flags.max_gradient_norm,
momentum=flags.momentum)
model.add_decoding_ops()
elif flags.command == 'export':
model.add_training_ops()
model.add_decoding_ops()
else:
model.add_training_ops()
model.add_decoding_ops(language_model=flags.language_model,
Expand Down

0 comments on commit 232bed8

Please sign in to comment.