Skip to content

Commit

Permalink
[whisper] fix decoding maxlen (#2380)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Mar 4, 2024
1 parent 8179fe1 commit 8f8cedc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def attention_beam_search(
cache: Optional[List[torch.Tensor]] = None
if model.decoder.use_sdpa:
encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype)
if hasattr(model, 'decode_maxlen'):
maxlen = model.decode_maxlen
# 2. Decoder forward step by step
for i in range(prefix_len, maxlen + 1):
# Stop if all batch and all beam produce eos
Expand Down
1 change: 1 addition & 0 deletions wenet/whisper/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
assert reverse_weight == 0.0
self.sos = special_tokens["sot"]
self.eos = special_tokens["eot"]
self.decode_maxlen = self.decoder.embed[1].max_len

# TODO(xcsong): time align
def set_alignment_heads(self, dump: bytes):
Expand Down

0 comments on commit 8f8cedc

Please sign in to comment.