Skip to content

Commit

Permalink
Fix speech recognition example (apache#12291)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk authored and anirudh2290 committed Sep 19, 2018
1 parent f6aa533 commit 77e173f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 38 deletions.
4 changes: 2 additions & 2 deletions example/speech_recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ With rich functionalities and convenience explained above, you can build your ow
## **Environments**
- MXNet version: 0.9.5+
- GPU memory size: 2.4GB+
- Install tensorboard for logging
- Install mxboard for logging
<pre>
<code>pip install tensorboard</code>
<code>pip install mxboard</code>
</pre>

- [SoundFile](https://pypi.python.org/pypi/SoundFile/0.8.1) for audio preprocessing (If encounter errors about libsndfile, follow [this tutorial](http://www.linuxfromscratch.org/blfs/view/svn/multimedia/libsndfile.html).)
Expand Down
5 changes: 3 additions & 2 deletions example/speech_recognition/deepspeech.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ prefix = deep_bucket
# when mode is load or predict, model will be loaded from the file name with model_file under checkpoints
model_file = deep_bucketn_epoch0n_batch-0018
batch_size = 12
#batch_size=4
#use batch_size 4 with single GPU
#batch_size = 4
# log will be saved by the log_filename
log_filename = deep_bucket.log
# checkpoint set n to save checkpoints after n epoch
save_checkpoint_every_n_epoch = 1
save_checkpoint_every_n_batch = 3000
is_bi_graphemes = True
tensorboard_log_dir = tblog/deep_bucket
mxboard_log_dir = mxlog/deep_bucket
# if random_seed is -1 then it gets random seed from timestamp
mx_random_seed = -1
random_seed = -1
Expand Down
2 changes: 1 addition & 1 deletion example/speech_recognition/default.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ log_filename = test.log
save_checkpoint_every_n_epoch = 20
save_checkpoint_every_n_batch = 1000
is_bi_graphemes = False
tensorboard_log_dir = tblog/libri_sample
mxboard_log_dir = mxlog/libri_sample
# if random_seed is -1 then it gets random seed from timestamp
mx_random_seed = 1234
random_seed = 1234
Expand Down
26 changes: 6 additions & 20 deletions example/speech_recognition/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import logging as log

class Singleton:
def __init__(self, decrated):
log.debug("Singleton Init %s" % decrated)
self._decorated = decrated
def __init__(self, decorated):
log.debug("Singleton Init %s" % decorated)
self._decorated = decorated

def getInstance(self):
try:
Expand All @@ -30,25 +30,11 @@ def getInstance(self):
self._instance = self._decorated()
return self._instance

def __new__(class_, *args, **kwargs):
def __new__(cls, *args, **kwargs):
print("__new__")
class_.instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs)
return class_.instances[class_]
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
return cls._instance

def __call__(self):
raise TypeError("Singletons must be accessed through 'getInstance()'")


class SingletonInstane:
__instance = None

@classmethod
def __getInstance(cls):
return cls.__instance

@classmethod
def instance(cls, *args, **kargs):
cls.__instance = cls(*args, **kargs)
cls.instance = cls.__getInstance
return cls.__instance

9 changes: 7 additions & 2 deletions example/speech_recognition/stt_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, batch_size, num_gpu, is_epoch_end=False, is_logging=True):
self.total_ctc_loss = 0.
self.batch_loss = 0.
self.is_logging = is_logging

def update(self, labels, preds):
check_label_shapes(labels, preds)
if self.is_logging:
Expand Down Expand Up @@ -83,10 +84,15 @@ def update(self, labels, preds):
if self.is_logging:
log.info("loss: %f " % loss)
self.total_ctc_loss += self.batch_loss

def get_batch_loss(self):
return self.batch_loss

def get_name_value(self):
total_cer = float(self.total_l_dist) / float(self.total_n_label)
try:
total_cer = float(self.total_l_dist) / float(self.total_n_label)
except ZeroDivisionError:
total_cer = float('inf')

return total_cer, self.total_n_label, self.total_l_dist, self.total_ctc_loss

Expand Down Expand Up @@ -244,4 +250,3 @@ def char_match_2way(label, pred):
val = val1_max if val1_max > val2_max else val2_max
val_matched = val1_max_matched if val1_max > val2_max else val2_max_matched
return val, val_matched, n_whole_label

21 changes: 10 additions & 11 deletions example/speech_recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
# under the License.

import sys

import json
sys.path.insert(0, "../../python")
import os.path
#mxboard setting
from mxboard import SummaryWriter
import mxnet as mx
from config_util import get_checkpoint_path, parse_contexts
from stt_metric import STTMetric
#tensorboard setting
from tensorboard import SummaryWriter
import json
from stt_bucketing_module import STTBucketingModule


Expand Down Expand Up @@ -65,7 +64,7 @@ def do_training(args, module, data_train, data_val, begin_epoch=0):
contexts = parse_contexts(args)
num_gpu = len(contexts)
eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True)
# tensorboard setting
# mxboard setting
loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False)

optimizer = args.config.get('optimizer', 'optimizer')
Expand Down Expand Up @@ -131,9 +130,9 @@ def reset_optimizer(force_init=False):
data_train.reset()
data_train.is_first_epoch = True

#tensorboard setting
tblog_dir = args.config.get('common', 'tensorboard_log_dir')
summary_writer = SummaryWriter(tblog_dir)
#mxboard setting
mxlog_dir = args.config.get('common', 'mxboard_log_dir')
summary_writer = SummaryWriter(mxlog_dir)

while True:

Expand All @@ -144,7 +143,7 @@ def reset_optimizer(force_init=False):
for nbatch, data_batch in enumerate(data_train):
module.forward_backward(data_batch)
module.update()
# tensorboard setting
# mxboard setting
if (nbatch + 1) % show_every == 0:
module.update_metric(loss_metric, data_batch.label)
#summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
Expand All @@ -160,7 +159,7 @@ def reset_optimizer(force_init=False):
module.forward(data_batch, is_train=True)
module.update_metric(eval_metric, data_batch.label)

# tensorboard setting
# mxboard setting
val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
curr_acc = val_cer
Expand All @@ -170,7 +169,7 @@ def reset_optimizer(force_init=False):
data_train.reset()
data_train.is_first_epoch = False

# tensorboard setting
# mxboard setting
train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
summary_writer.add_scalar('CER train', train_cer, n_epoch)
Expand Down

0 comments on commit 77e173f

Please sign in to comment.