Skip to content

Commit

Permalink
Merge pull request #303 from hirofumi0810/safeguard
Browse files Browse the repository at this point in the history
Add safeguard for state reset during streaming inference
  • Loading branch information
hirofumi0810 authored Mar 26, 2021
2 parents ff940a9 + 61ba30c commit 2b10b9c
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 60 deletions.
21 changes: 15 additions & 6 deletions neural_sp/models/seq2seq/decoders/las.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,7 @@ def batchfy_beam(self, hyps, i):

def beam_search_block_sync(self, eouts, params, helper, idx2token,
hyps, hyps_nobd, lm, ctc_log_probs=None, speaker=None,
ignore_eos=False):
ignore_eos=False, dualhyp=True):
assert eouts.size(0) == 1
assert self.attn_type == 'mocha'

Expand Down Expand Up @@ -1594,13 +1594,19 @@ def beam_search_block_sync(self, eouts, params, helper, idx2token,
'no_boundary': no_boundary})

# Local pruning
new_hyps += hyps_nobd
new_hyps = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width]
if not dualhyp:
new_hyps += hyps_nobd
new_hyps = sorted(new_hyps, key=lambda x: x['score'], reverse=True)

# Remove complete hypotheses
new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(new_hyps, end_hyps)
hyps_nobd = [beam for beam in new_hyps if beam['no_boundary']]
hyps = [beam for beam in new_hyps if not beam['no_boundary']]

if dualhyp:
hyps = new_hyps[:]
else:
hyps_nobd = [beam for beam in new_hyps if beam['no_boundary']]
hyps = [beam for beam in new_hyps if not beam['no_boundary']]

if is_finish:
break

Expand Down Expand Up @@ -1636,6 +1642,9 @@ def beam_search_block_sync(self, eouts, params, helper, idx2token,

self.n_frames += eouts.size(1)
self.score.reset()
self.key_tail = eouts[:, -(self.score.w - 1):]
if eouts.size(1) < self.score.w - 1 and self.key_tail is not None:
self.key_tail = torch.cat([self.key_tail, eouts], dim=1)[:, -(self.score.w - 1):]
else:
self.key_tail = eouts[:, -(self.score.w - 1):]

return end_hyps, hyps, hyps_nobd
40 changes: 19 additions & 21 deletions neural_sp/models/seq2seq/encoders/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim,
chunk_size_right = str(chunk_size_right)

# for latency-controlled
self.chunk_size_current = int(chunk_size_current.split('_')[0]) // n_stacks
self.chunk_size_right = int(chunk_size_right.split('_')[0]) // n_stacks
self.lc_bidir = self.chunk_size_current > 0 or self.chunk_size_right > 0 and self.bidirectional
self.N_c = int(chunk_size_current.split('_')[0]) // n_stacks
self.N_r = int(chunk_size_right.split('_')[0]) // n_stacks
self.lc_bidir = (self.N_c > 0 or self.N_r > 0) and self.bidirectional
if self.lc_bidir:
assert enc_type not in ['lstm', 'gru', 'conv_lstm', 'conv_gru']
assert n_layers_sub2 == 0
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim,
self._odim = input_dim * n_splices * n_stacks
self.cnn_lookahead = cnn_lookahead
if not cnn_lookahead:
assert self.chunk_size_current > 0
assert self.N_c > 0
assert self.lc_bidir

if enc_type != 'conv':
Expand Down Expand Up @@ -217,22 +217,20 @@ def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim,
self._odim = last_proj_dim

# calculate subsampling factor
self._factor = 1
if self.conv is not None:
self._factor *= self.conv.subsampling_factor
self._factor_sub1 = self._factor
if n_layers_sub1 > 0 and np.prod(subsamples[:n_layers_sub1 - 1]) > 1:
self.conv_factor = self.conv.subsampling_factor if self.conv is not None else 1
self._factor = self.conv_factor
self._factor_sub1 = self.conv_factor
self._factor_sub2 = self.conv_factor
if n_layers_sub1 > 1:
self._factor_sub1 *= np.prod(subsamples[:n_layers_sub1 - 1])
self._factor_sub2 = self._factor
if n_layers_sub2 > 0 and np.prod(subsamples[:n_layers_sub2 - 1]) > 1:
if n_layers_sub2 > 1:
self._factor_sub1 *= np.prod(subsamples[:n_layers_sub2 - 1])
if np.prod(subsamples) > 1:
self._factor *= np.prod(subsamples)
self._factor *= np.prod(subsamples)
# NOTE: subsampling factor for frame stacking should not be included here
if self.chunk_size_current > 0:
assert self.chunk_size_current % self._factor == 0
if self.chunk_size_right > 0:
assert self.chunk_size_right % self._factor == 0
if self.N_c > 0:
assert self.N_c % self._factor == 0
if self.N_r > 0:
assert self.N_r % self._factor == 0

self.reset_parameters(param_init)

Expand Down Expand Up @@ -329,7 +327,7 @@ def forward(self, xs, xlens, task, streaming=False,
xs = self.dropout_in(xs)

bs, xmax, idim = xs.size()
N_c, N_r = self.chunk_size_current, self.chunk_size_right
N_c, N_r = self.N_c, self.N_r

if self.lc_bidir and not self.cnn_lookahead:
xs = chunkwise(xs, 0, N_c, 0) # `[B * n_chunks, N_c, idim]`
Expand All @@ -345,8 +343,8 @@ def forward(self, xs, xlens, task, streaming=False,
eouts['ys']['xlens'] = xlens
return eouts
if self.lc_bidir:
N_c = N_c // self.conv.subsampling_factor
N_r = N_r // self.conv.subsampling_factor
N_c = N_c // self.conv_factor
N_r = N_r // self.conv_factor

carry_over = self.rsp_prob > 0 and self.training and random.random() < self.rsp_prob
carry_over = carry_over and (bs == (self.hx_fwd[0][0].size(0) if self.hx_fwd[0] is not None else 0))
Expand All @@ -356,7 +354,7 @@ def forward(self, xs, xlens, task, streaming=False,

if self.lc_bidir:
# Flip the layer and time loop
if self.chunk_size_current <= 0:
if self.N_c <= 0:
xs, xlens, xs_sub1, xlens_sub1 = self._forward_full_context(
xs, xlens)
else:
Expand Down
31 changes: 14 additions & 17 deletions neural_sp/models/seq2seq/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def __init__(self, input_dim, enc_type, n_heads,
self.lookaheads = lookaheads
if sum(lookaheads) > 0:
assert self.unidir
self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks
self.chunk_size_current = int(chunk_size_current.split('_')[-1]) // n_stacks
self.chunk_size_right = int(chunk_size_right.split('_')[-1]) // n_stacks
self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type
self.N_l = int(chunk_size_left.split('_')[-1]) // n_stacks
self.N_c = int(chunk_size_current.split('_')[-1]) // n_stacks
self.N_r = int(chunk_size_right.split('_')[-1]) // n_stacks
self.lc_bidir = self.N_c > 0 and enc_type != 'conv' and 'uni' not in enc_type
self.cnn_lookahead = self.unidir or enc_type == 'conv'
self.streaming_type = streaming_type if self.lc_bidir else ''
# -: past context
Expand All @@ -156,10 +156,10 @@ def __init__(self, input_dim, enc_type, n_heads,
# chunk4: -- --|**|
# chunk5: -- --|**|
if self.unidir:
assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0
assert self.N_l == self.N_c == self.N_r == 0
if self.streaming_type == 'mask':
assert self.chunk_size_right == 0
assert self.chunk_size_left % self.chunk_size_current == 0
assert self.N_r == 0
assert self.N_l % self.N_c == 0
# NOTE: this is important to cache CNN output at each chunk
if self.lc_bidir:
assert n_layers_sub1 == 0
Expand Down Expand Up @@ -228,12 +228,9 @@ def __init__(self, input_dim, enc_type, n_heads,
else:
raise NotImplementedError(subsample_type)

if self.chunk_size_left > 0:
assert self.chunk_size_left % self._factor == 0
if self.chunk_size_current > 0:
assert self.chunk_size_current % self._factor == 0
if self.chunk_size_right > 0:
assert self.chunk_size_right % self._factor == 0
assert self.N_l % self._factor == 0
assert self.N_c % self._factor == 0
assert self.N_r % self._factor == 0

self.pos_enc, self.pos_emb = None, None
self.u_bias, self.v_bias = None, None
Expand Down Expand Up @@ -417,7 +414,7 @@ def truncate_cache(self, cache):
def calculate_cache_size(self):
"""Calculate the maximum cache size per layer."""
cache_size = self._total_chunk_size_left() # after CNN subsampling
N_l = self.chunk_size_left // self.conv_factor
N_l = self.N_l // self.conv_factor
cache_sizes = []
for lth in range(self.n_layers):
cache_sizes.append(cache_size)
Expand All @@ -433,9 +430,9 @@ def _total_chunk_size_left(self):
This corresponds to the frame length after CNN subsampling.
"""
if self.streaming_type == 'reshape':
return self.chunk_size_left // self.conv_factor
return self.N_l // self.conv_factor
elif self.streaming_type == 'mask':
return (self.chunk_size_left // self.conv_factor) * self.n_layers
return (self.N_l // self.conv_factor) * self.n_layers
elif self.unidir:
return 10000 // self.conv_factor
else:
Expand Down Expand Up @@ -466,7 +463,7 @@ def forward(self, xs, xlens, task, streaming=False,
n_chunks = 0
unidir = self.unidir
lc_bidir = self.lc_bidir
N_l, N_c, N_r = self.chunk_size_left, self.chunk_size_current, self.chunk_size_right
N_l, N_c, N_r = self.N_l, self.N_c, self.N_r

if streaming and self.streaming_type == 'mask':
assert xmax <= N_c
Expand Down
30 changes: 16 additions & 14 deletions neural_sp/models/seq2seq/frontends/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ def __init__(self, x_whole, params, encoder, idx2token=None):

# latency related
self._factor = encoder.subsampling_factor
self.N_l = getattr(encoder, 'chunk_size_left', 0) # for LC-Transformer/Conformer
self.N_c = encoder.chunk_size_current
self.N_r = encoder.chunk_size_right
self.N_l = getattr(encoder, 'N_l', 0) # for LC-Transformer/Conformer
self.N_c = encoder.N_c
self.N_r = encoder.N_r
if self.streaming_type == 'mask':
self.N_l = 0
# NOTE: context in previous chunks are cached inside the encoder
if self.N_c <= 0 and self.N_r <= 0:
self.N_c = params['recog_block_sync_size'] # for unidirectional encoder
self.N_c = params.get('recog_block_sync_size') # for unidirectional encoder
assert self.N_c % self._factor == 0
# NOTE: these lengths are the ones before subsampling

# threshold for CTC-VAD
self.blank_id = 0
self.is_ctc_vad = params['recog_ctc_vad']
self.BLANK_THRESHOLD = params['recog_ctc_vad_blank_threshold']
self.SPIKE_THRESHOLD = params['recog_ctc_vad_spike_threshold']
self.MAX_N_ACCUM_FRAMES = params['recog_ctc_vad_n_accum_frames']
assert params['recog_ctc_vad_blank_threshold'] % self._factor == 0
assert params['recog_ctc_vad_n_accum_frames'] % self._factor == 0
self.is_ctc_vad = params.get('recog_ctc_vad')
self.BLANK_THRESHOLD = params.get('recog_ctc_vad_blank_threshold')
self.SPIKE_THRESHOLD = params.get('recog_ctc_vad_spike_threshold')
self.MAX_N_ACCUM_FRAMES = params.get('recog_ctc_vad_n_accum_frames')
assert self.BLANK_THRESHOLD % self._factor == 0
assert self.MAX_N_ACCUM_FRAMES % self._factor == 0
# NOTE: these parameters are based on 10ms/frame

self._offset = 0 # global time offset in the session
Expand Down Expand Up @@ -95,6 +95,10 @@ def bd_offset(self):
def n_cache_block(self):
return len(self._eout_blocks)

@property
def safeguard_reset(self):
return self._n_accum_frames < self.MAX_N_ACCUM_FRAMES

def reset(self, stdout=False):
self._eout_blocks = []
self._n_blanks = 0
Expand Down Expand Up @@ -172,7 +176,7 @@ def ctc_vad(self, ctc_probs_block, stdout=False):
"""
is_reset = False # detect the first boundary in the same block

if self._n_accum_frames < self.MAX_N_ACCUM_FRAMES:
if self.safeguard_reset:
return is_reset

assert ctc_probs_block is not None
Expand All @@ -182,8 +186,7 @@ def ctc_vad(self, ctc_probs_block, stdout=False):
# encoder states will be carried over to the next block.
# Otherwise, the current block is segmented at the point where
# _n_blanks surpasses the threshold.
topk_ids_block = torch.topk(ctc_probs_block, k=1, dim=-1, largest=True, sorted=True)[1]
topk_ids_block = topk_ids_block[0, :, 0] # `[T_block]`
topk_ids_block = ctc_probs_block[0].argmax(-1) # `[T_block]`
bs, xmax_block, vocab = ctc_probs_block.size()

# skip all blank segments
Expand Down Expand Up @@ -213,7 +216,6 @@ def ctc_vad(self, ctc_probs_block, stdout=False):
print('CTC (T:%d): %s' % (self._offset + (j + 1) * self._factor,
self.idx2token([topk_ids_block[j].item()])))

# if not is_reset and (self._n_blanks * self._factor >= self.BLANK_THRESHOLD):# NOTE: select the leftmost blank offset
if self._n_blanks * self._factor >= self.BLANK_THRESHOLD: # NOTE: select the rightmost blank offset
self._bd_offset = j
is_reset = True
Expand Down
14 changes: 12 additions & 2 deletions neural_sp/models/seq2seq/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def decode_streaming(self, xs, params, idx2token, exclude_eos=False,
block_size = params.get('recog_block_sync_size') # before subsampling
cache_emb = params.get('recog_cache_embedding')
ctc_weight = params.get('recog_ctc_weight')
backoff = True

assert task == 'ys'
assert self.input_type == 'speech'
Expand Down Expand Up @@ -593,11 +594,20 @@ def decode_streaming(self, xs, params, idx2token, exclude_eos=False,

self.enc.reset_cache()
eout_block_tail = None
x_block_prev, xlen_block_prev = None, None
while True:
# Encode input features block by block
x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat()
if not is_transformer_enc and is_reset:
self.enc.reset_cache()
if backoff:
self.encode([x_block_prev], 'all',
streaming=True,
cnn_lookback=cnn_lookback,
cnn_lookahead=cnn_lookahead,
xlen_block=xlen_block_prev)
x_block_prev = x_block
xlen_block_prev = xlen_block
eout_block_dict = self.encode([x_block], 'all',
streaming=True,
cnn_lookback=cnn_lookback,
Expand Down Expand Up @@ -657,9 +667,9 @@ def decode_streaming(self, xs, params, idx2token, exclude_eos=False,
# Segmentation strategy 2:
# If <eos> is emitted from the decoder (not CTC),
# the current block is segmented.
if not is_reset:
if (not is_reset) and (not streaming.safeguard_reset):
streaming._bd_offset = eout_block.size(1) - 1 # TODO: fix later
is_reset = True
is_reset = True

if len(best_hyp_id_prefix_viz) > 0:
n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames
Expand Down

0 comments on commit 2b10b9c

Please sign in to comment.