diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py index bf2a90916ae..928ca445ccc 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/composite_layers.py @@ -80,7 +80,8 @@ def set_default_configs(self): 'time-stride':1, 'l2-regularize':0.0, 'max-change': 0.75, - 'self-repair-scale': 1.0e-05} + 'self-repair-scale': 1.0e-05, + 'context': 'default'} def set_derived_configs(self): pass @@ -104,6 +105,10 @@ def check_configs(self): raise RuntimeError('bypass-scale is nonzero but output-dim != input-dim: {0} != {1}' ''.format(output_dim, input_dim)) + if not self.config['context'] in ['default', 'left-only', 'shift-left', 'none']: + raise RuntimeError('context must be default, left-only shift-left or none, got {}'.format( + self.config['context'])) + def output_name(self, auxiliary_output=None): assert auxiliary_output is None @@ -142,9 +147,16 @@ def _generate_config(self): bypass_scale = self.config['bypass-scale'] dropout_proportion = self.config['dropout-proportion'] time_stride = self.config['time-stride'] - if time_stride != 0: + context = self.config['context'] + if time_stride != 0 and context != 'none': time_offsets1 = '{0},0'.format(-time_stride) - time_offsets2 = '0,{0}'.format(time_stride) + if context == 'default': + time_offsets2 = '0,{0}'.format(time_stride) + elif context == 'shift-left': + time_offsets2 = '{0},0'.format(-time_stride) + else: + assert context == 'left-only' + time_offsets2 = '0' else: time_offsets1 = '0' time_offsets2 = '0'