Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 63 additions & 25 deletions egs/wsj/s5/steps/libs/nnet3/xconfig/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def set_default_configs(self):
'self-repair-scale-nonlinearity' : 0.00001,
'zeroing-interval' : 20,
'zeroing-threshold' : 15.0,
'dropout-proportion' : -1.0 # -1.0 stands for no dropout will be added
'dropout-proportion' : -1.0, # -1.0 stands for no dropout will be added
'dropout-per-frame' : 'false' # default normal dropout mode
}

def set_derived_configs(self):
Expand Down Expand Up @@ -286,6 +287,10 @@ def check_configs(self):
raise RuntimeError("dropout-proportion has invalid value {0}."
"".format(self.config['dropout-proportion']))

if (self.config['dropout-per-frame'] != 'false' and
self.config['dropout-per-frame'] != 'true'):
raise xparser_error("dropout-per-frame has invalid value {0}.".format(self.config['dropout-per-frame']))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change all xparser_errors to be RuntimeError. IIRC this exception type has been removed.


def auxiliary_outputs(self):
return ['c_t']

Expand Down Expand Up @@ -347,6 +352,8 @@ def generate_lstm_config(self):
pes_str = self.config['ng-per-element-scale-options']
lstm_dropout_value = self.config['dropout-proportion']
lstm_dropout_str = 'dropout-proportion='+str(self.config['dropout-proportion'])
lstm_dropout_per_frame_value = self.config['dropout-per-frame']
lstm_dropout_per_frame_str = 'dropout-per-frame='+str(self.config['dropout-per-frame'])

# Natural gradient per element scale parameters
# TODO: decide if we want to keep exposing these options
Expand Down Expand Up @@ -383,6 +390,8 @@ def generate_lstm_config(self):
configs.append("component name={0}.o type=SigmoidComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str))
configs.append("component name={0}.g type=TanhComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str))
configs.append("component name={0}.h type=TanhComponent dim={1} {2}".format(name, cell_dim, repair_nonlin_str))
if lstm_dropout_value != -1.0:
configs.append("component name={0}.dropout type=DropoutComponent dim={1} {2} {3}".format(name, cell_dim, lstm_dropout_str, lstm_dropout_per_frame_str))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

watch line length. we're not being strict, but don't be ridiculous.


configs.append("# Defining the components for other cell computations")
configs.append("component name={0}.c1 type=ElementwiseProductComponent input-dim={1} output-dim={2}".format(name, 2 * cell_dim, cell_dim))
Expand All @@ -398,17 +407,29 @@ def generate_lstm_config(self):
configs.append("# i_t")
configs.append("component-node name={0}.i1_t component={0}.W_i.xr input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay))
configs.append("component-node name={0}.i2_t component={0}.w_i.c input={1}".format(name, delayed_c_t_descriptor))
configs.append("component-node name={0}.i_t component={0}.i input=Sum({0}.i1_t, {0}.i2_t)".format(name))
if lstm_dropout_value != -1.0:
configs.append("component-node name={0}.i_t_predrop component={0}.i input=Sum({0}.i1_t, {0}.i2_t)".format(name))
configs.append("component-node name={0}.i_t component={0}.dropout input={0}.i_t_predrop".format(name))
else:
configs.append("component-node name={0}.i_t component={0}.i input=Sum({0}.i1_t, {0}.i2_t)".format(name))

configs.append("# f_t")
configs.append("component-node name={0}.f1_t component={0}.W_f.xr input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay))
configs.append("component-node name={0}.f2_t component={0}.w_f.c input={1}".format(name, delayed_c_t_descriptor))
configs.append("component-node name={0}.f_t component={0}.f input=Sum({0}.f1_t, {0}.f2_t)".format(name))
if lstm_dropout_value != -1.0:
configs.append("component-node name={0}.f_t_predrop component={0}.f input=Sum({0}.f1_t, {0}.f2_t)".format(name))
configs.append("component-node name={0}.f_t component={0}.dropout input={0}.f_t_predrop".format(name))
else:
configs.append("component-node name={0}.f_t component={0}.f input=Sum({0}.f1_t, {0}.f2_t)".format(name))

configs.append("# o_t")
configs.append("component-node name={0}.o1_t component={0}.W_o.xr input=Append({1}, IfDefined(Offset({2}, {3})))".format(name, input_descriptor, recurrent_connection, delay))
configs.append("component-node name={0}.o2_t component={0}.w_o.c input={0}.c_t".format(name))
configs.append("component-node name={0}.o_t component={0}.o input=Sum({0}.o1_t, {0}.o2_t)".format(name))
if lstm_dropout_value != -1.0:
configs.append("component-node name={0}.o_t_predrop component={0}.o input=Sum({0}.o1_t, {0}.o2_t)".format(name))
configs.append("component-node name={0}.o_t component={0}.dropout input={0}.o_t_predrop".format(name))
else:
configs.append("component-node name={0}.o_t component={0}.o input=Sum({0}.o1_t, {0}.o2_t)".format(name))

configs.append("# h_t")
configs.append("component-node name={0}.h_t component={0}.h input={0}.c_t".format(name))
Expand All @@ -426,21 +447,13 @@ def generate_lstm_config(self):

# add the recurrent connections
configs.append("# projection matrices : Wrm and Wpm")
if lstm_dropout_value != -1.0:
configs.append("component name={0}.W_rp.m.dropout type=DropoutComponent dim={1} {2}".format(name, cell_dim, lstm_dropout_str))
configs.append("component name={0}.W_rp.m type=NaturalGradientAffineComponent input-dim={1} output-dim={2} {3}".format(name, cell_dim, rec_proj_dim + nonrec_proj_dim, affine_str))
configs.append("component name={0}.r type=BackpropTruncationComponent dim={1} {2}".format(name, rec_proj_dim, bptrunc_str))

configs.append("# r_t and p_t : rp_t will be the output")
if lstm_dropout_value != -1.0:
configs.append("component-node name={0}.rp_t.dropout component={0}.W_rp.m.dropout input={0}.m_t".format(name))
configs.append("component-node name={0}.rp_t component={0}.W_rp.m input={0}.rp_t.dropout".format(name))
configs.append("dim-range-node name={0}.r_t_preclip input-node={0}.rp_t dim-offset=0 dim={1}".format(name, rec_proj_dim))
configs.append("component-node name={0}.r_t component={0}.r input={0}.r_t_preclip".format(name))
else:
configs.append("component-node name={0}.rp_t component={0}.W_rp.m input={0}.m_t".format(name))
configs.append("dim-range-node name={0}.r_t_preclip input-node={0}.rp_t dim-offset=0 dim={1}".format(name, rec_proj_dim))
configs.append("component-node name={0}.r_t component={0}.r input={0}.r_t_preclip".format(name))
configs.append("component-node name={0}.rp_t component={0}.W_rp.m input={0}.m_t".format(name))
configs.append("dim-range-node name={0}.r_t_preclip input-node={0}.rp_t dim-offset=0 dim={1}".format(name, rec_proj_dim))
configs.append("component-node name={0}.r_t component={0}.r input={0}.r_t_preclip".format(name))

return configs

Expand Down Expand Up @@ -760,8 +773,9 @@ def set_default_configs(self):
# larger max-change than the normal value of 0.75.
'ng-affine-options' : ' max-change=1.5',
'zeroing-interval' : 20,
'zeroing-threshold' : 15.0

'zeroing-threshold' : 15.0,
'dropout-proportion' : -1.0 ,# -1.0 stands for no dropout will be added
'dropout-per-frame' : 'false' # default normal dropout mode
}

def set_derived_configs(self):
Expand All @@ -775,6 +789,15 @@ def set_derived_configs(self):
self.config['non-recurrent-projection-dim'] = \
self.config['recurrent-projection-dim']

if ((self.config['dropout-proportion'] > 1.0 or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this checking code belongs in check_configs.

self.config['dropout-proportion'] < 0.0) and
self.config['dropout-proportion'] != -1.0 ):
raise xparser_error("dropout-proportion has invalid value {0}.".format(self.config['dropout-proportion']))

if (self.config['dropout-per-frame'] != 'false' and
self.config['dropout-per-frame'] != 'true'):
raise xparser_error("dropout-per-frame has invalid value {0}.".format(self.config['dropout-per-frame']))

def check_configs(self):
for key in ['cell-dim', 'recurrent-projection-dim',
'non-recurrent-projection-dim']:
Expand Down Expand Up @@ -846,7 +869,10 @@ def generate_lstm_config(self):
abs(delay)))
affine_str = self.config['ng-affine-options']
lstm_str = self.config['lstm-nonlinearity-options']

lstm_dropout_value = self.config['dropout-proportion']
lstm_dropout_str = 'dropout-proportion='+str(self.config['dropout-proportion'])
lstm_dropout_per_frame_value = self.config['dropout-per-frame']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this variable is unused, an since the name is longer than the rhs of the expression, there is not much need for it.

lstm_dropout_per_frame_str = 'dropout-per-frame='+str(self.config['dropout-per-frame'])

configs = []

Expand All @@ -865,6 +891,8 @@ def generate_lstm_config(self):
configs.append("# Component for backprop truncation, to avoid gradient blowup in long training examples.")
configs.append("component name={0}.cr_trunc type=BackpropTruncationComponent "
"dim={1} {2}".format(name, cell_dim + rec_proj_dim, bptrunc_str))
if lstm_dropout_value != -1.0:
configs.append("component name={0}.cr_trunc.dropout type=DropoutComponent dim={1} {2} {3}".format(name, cell_dim + rec_proj_dim, lstm_dropout_str, lstm_dropout_per_frame_str))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should split this line.

configs.append("# Component specific to 'projected' LSTM (LSTMP), contains both recurrent");
configs.append("# and non-recurrent projections")
configs.append("component name={0}.W_rp type=NaturalGradientAffineComponent input-dim={1} "
Expand All @@ -886,11 +914,21 @@ def generate_lstm_config(self):
configs.append("# Note: it's not 100% efficient that we have to stitch the c")
configs.append("# and r back together to truncate them but it probably");
configs.append("# makes the deriv truncation more accurate .")
configs.append("component-node name={0}.cr_trunc component={0}.cr_trunc "
"input=Append({0}.c, {0}.r)".format(name))
configs.append("dim-range-node name={0}.c_trunc input-node={0}.cr_trunc "
"dim-offset=0 dim={1}".format(name, cell_dim))
configs.append("dim-range-node name={0}.r_trunc input-node={0}.cr_trunc "
"dim-offset={1} dim={2}".format(name, cell_dim, rec_proj_dim))
configs.append("### End LSTM Layer '{0}'".format(name))
if lstm_dropout_value != -1.0:
configs.append("component-node name={0}.cr_trunc component={0}.cr_trunc "
"input=Append({0}.c, {0}.r)".format(name))
configs.append("component-node name={0}.cr_trunc.dropout component={0}.cr_trunc.dropout input={0}.cr_trunc".format(name))
configs.append("dim-range-node name={0}.c_trunc input-node={0}.cr_trunc.dropout "
"dim-offset=0 dim={1}".format(name, cell_dim))
configs.append("dim-range-node name={0}.r_trunc input-node={0}.cr_trunc.dropout "
"dim-offset={1} dim={2}".format(name, cell_dim, rec_proj_dim))
configs.append("### End LSTM Layer '{0}'".format(name))
else:
configs.append("component-node name={0}.cr_trunc component={0}.cr_trunc "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could move the first and last of these lines out of the if-statement.

"input=Append({0}.c, {0}.r)".format(name))
configs.append("dim-range-node name={0}.c_trunc input-node={0}.cr_trunc "
"dim-offset=0 dim={1}".format(name, cell_dim))
configs.append("dim-range-node name={0}.r_trunc input-node={0}.cr_trunc "
"dim-offset={1} dim={2}".format(name, cell_dim, rec_proj_dim))
configs.append("### End LSTM Layer '{0}'".format(name))
return configs