Skip to content

Commit

Permalink
ADD load function for bucketing module
Browse files Browse the repository at this point in the history
  • Loading branch information
Soonhwan-Kwon committed Jul 6, 2017
1 parent 47b55a0 commit 58e9707
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
17 changes: 7 additions & 10 deletions arch_deepspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def arch(args, seq_len=None):
"""
if isinstance(args, argparse.Namespace):
mode = args.config.get("common", "mode")
if mode == "train":
is_bucketing = args.config.getboolean("arch", "is_bucketing")
if mode == "train" or is_bucketing:
channel_num = args.config.getint("arch", "channel_num")
conv_layer1_filter_dim = \
tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
Expand All @@ -83,7 +84,6 @@ def arch(args, seq_len=None):
num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))

is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
is_bucketing = args.config.getboolean("arch", "is_bucketing")

if seq_len is None:
seq_len = args.config.getint('arch', 'max_t_count')
Expand All @@ -106,18 +106,18 @@ def arch(args, seq_len=None):
no_bias=is_batchnorm,
name='conv1')
if is_batchnorm:
# batch norm normalizes axis 1
net = batchnorm(net, name="conv1_batchnorm")
# batch norm normalizes axis 1
net = batchnorm(net, name="conv1_batchnorm")

net = conv(net=net,
channels=channel_num,
filter_dimension=conv_layer2_filter_dim,
stride=conv_layer2_stride,
no_bias=is_batchnorm,
name='conv2')
if is_batchnorm:
# batch norm normalizes axis 1
net = batchnorm(net, name="conv2_batchnorm")
# if is_batchnorm:
# # batch norm normalizes axis 1
# net = batchnorm(net, name="conv2_batchnorm")

net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
Expand Down Expand Up @@ -187,9 +187,6 @@ def arch(args, seq_len=None):
args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
else:
raise Exception('mode must be the one of the followings - train,predict,load')
else:
raise Exception('type of args should be one of the argparse.' +
'Namespace for fixed length model or integer for variable length model')


class BucketingArch(object):
Expand Down
12 changes: 7 additions & 5 deletions deepspeech.cfg
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
[common]
# method can be one of the followings - train,predict,load
mode = train
mode = load
#ex: gpu0,gpu1,gpu2,gpu3
context = gpu0,gpu1,gpu2
#context = gpu0
# checkpoint prefix, check point will be saved under checkpoints folder with prefix
prefix = deep_bucket
# when mode is load or predict, model will be loaded from the file name with model_file under checkpoints
model_file = deep_bucket-0001
model_file = deep_bucketn_epoch0n_batch-0018
batch_size = 12
#batch_size=4
# log will be saved by the log_filename
log_filename = deep_bucket.log
# checkpoint set n to save checkpoints after n epoch
save_checkpoint_every_n_epoch = 1
save_checkpoint_every_n_batch = 1000
is_bi_graphemes = True
tensorboard_log_dir = tblog/deep
tensorboard_log_dir = tblog/deep_bucket
# if random_seed is -1 then it gets random seed from timestamp
mx_random_seed = -1
random_seed = -1
Expand Down Expand Up @@ -66,9 +68,9 @@ factor_type = in
# show progress every how nth batches
show_every = 100
save_optimizer_states = True
normalize_target_k = 230000
normalize_target_k = 13000
# overwrite meta files(feats_mean,feats_std,unicode_en_baidu_bi_graphemes.csv)
overwrite_meta_files = True
overwrite_meta_files = False
enable_logging_train_metric = True
enable_logging_validation_metric = True

Expand Down
34 changes: 24 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#tensorboard setting
from tensorboard import SummaryWriter
import json
from stt_bucketing_module import STTBucketingModule



Expand Down Expand Up @@ -63,11 +64,31 @@ def do_training(args, module, data_train, data_val, begin_epoch=0):
optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary'))
kvstore_option = args.config.get('common', 'kvstore_option')
n_epoch=begin_epoch
is_bucketing = args.config.getboolean('arch', 'is_bucketing')

if clip_gradient == 0:
clip_gradient = None
if is_bucketing and mode == 'load':
model_file = args.config.get('common', 'model_file')
model_name = os.path.splitext(model_file)[0]
model_num_epoch = int(model_name[-4:])

model_path = 'checkpoints/' + str(model_name[:-5])
symbol, data_names, label_names = module(1600)
model = STTBucketingModule(
sym_gen=module,
default_bucket_key=data_train.default_bucket_key,
context=contexts)
data_train.reset()

module.bind(data_shapes=data_train.provide_data,
model.bind(data_shapes=data_train.provide_data,
label_shapes=data_train.provide_label,
for_training=True)
_, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
model.set_params(arg_params, aux_params)
module = model
else:
module.bind(data_shapes=data_train.provide_data,
label_shapes=data_train.provide_label,
for_training=True)

Expand All @@ -90,20 +111,13 @@ def reset_optimizer(force_init=False):
reset_optimizer(force_init=True)
else:
reset_optimizer(force_init=False)
data_train.reset()
data_train.is_first_epoch = True

#tensorboard setting
tblog_dir = args.config.get('common', 'tensorboard_log_dir')
summary_writer = SummaryWriter(tblog_dir)


if mode == "train":
sort_by_duration = True
else:
sort_by_duration = False

if not sort_by_duration:
data_train.reset()

while True:

if n_epoch >= num_epoch:
Expand Down

0 comments on commit 58e9707

Please sign in to comment.