Skip to content

Commit

Permalink
Fix caching bug in causal convolutions for cache-aware ASR models (#7034
Browse files Browse the repository at this point in the history
) (#7082)

Co-authored-by: Vahid Noroozi <[email protected]>
Signed-off-by: jubick1337 <[email protected]>
  • Loading branch information
2 people authored and jubick1337 committed Aug 8, 2023
1 parent 4957058 commit 103d94f
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions nemo/collections/asr/parts/submodules/causal_convs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,16 @@ def __init__(
def update_cache(self, x, cache=None):
if cache is None:
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
next_cache = cache
else:
new_x = F.pad(x, pad=(0, self._right_padding))
new_x = torch.cat([cache, new_x], dim=-1)
if self.cache_drop_size > 0:
x = x[:, :, : -self.cache_drop_size]
cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1)
return new_x, cache
next_cache = new_x[:, :, : -self.cache_drop_size]
else:
next_cache = new_x
next_cache = next_cache[:, :, -cache.size(-1) :]
return new_x, next_cache

def forward(self, x, cache=None):
x, cache = self.update_cache(x, cache=cache)
Expand Down

0 comments on commit 103d94f

Please sign in to comment.