Skip to content

Commit

Permalink
[wenet/squeezeformer] compat with websocket server (#1544)
Browse files Browse the repository at this point in the history
* [update] compat with websocket server

* compat with decoder chunk size -1

* fix format issues
  • Loading branch information
yygle authored Nov 8, 2022
1 parent 693a216 commit 1d9b0bf
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions wenet/squeezeformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from wenet.squeezeformer.convolution import ConvolutionModule
from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask
from wenet.utils.common import get_activation
import math


class SqueezeformerEncoder(nn.Module):
Expand Down Expand Up @@ -381,16 +380,16 @@ def forward_chunk(

xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i + 1][:, :, ::factor, :] if
elayers > 0 and att_cache.size(2) != 0 else
att_cache[:, :, ::factor, :],
att_cache=att_cache[i:i + 1][:, :, ::factor, :]
[:, :, :pos_emb.size(1) - xs.size(1), :] if
elayers > 0 else att_cache[:, :, ::factor, :],
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
cached_att \
= new_att_cache[:, :, math.ceil(next_cache_start / factor):, :]
= new_att_cache[:, :, next_cache_start // factor:, :]
cached_cnn = new_cnn_cache.unsqueeze(0)
cached_att = cached_att.repeat_interleave(repeats=factor, dim=2)
if i == 0:
Expand Down

0 comments on commit 1d9b0bf

Please sign in to comment.