diff --git a/nemo/collections/asr/parts/submodules/causal_convs.py b/nemo/collections/asr/parts/submodules/causal_convs.py index c6251690b1b1..32f08a8d2feb 100644 --- a/nemo/collections/asr/parts/submodules/causal_convs.py +++ b/nemo/collections/asr/parts/submodules/causal_convs.py @@ -130,13 +130,16 @@ def __init__( def update_cache(self, x, cache=None): if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache else: new_x = F.pad(x, pad=(0, self._right_padding)) new_x = torch.cat([cache, new_x], dim=-1) if self.cache_drop_size > 0: - x = x[:, :, : -self.cache_drop_size] - cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1) - return new_x, cache + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache def forward(self, x, cache=None): x, cache = self.update_cache(x, cache=cache)