Skip to content

Commit

Permalink
chore: minor cleanup (NVIDIA#6311)
Browse files Browse the repository at this point in the history
* chore: minor cleanup

Signed-off-by: Greg Clark <[email protected]>

* Adding docs and missing param

Signed-off-by: Greg Clark <[email protected]>

* Clip cache_keep_size in causal_convs

Signed-off-by: Greg Clark <[email protected]>

---------

Signed-off-by: Greg Clark <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
messiaen authored and hsiehjackson committed Jun 2, 2023
1 parent fb78e79 commit b1b4752
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
You may drop the '--debug_mode' and '--compare_vs_offline' to speedup the streaming evaluation.
If compare_vs_offline is not used, then significantly larger batch_size can be used.
To best compare output with offline output (i.e. `--compare_vs_offline` is set) `--pad-and-drop-preencoded` should also be set.
## Evaluate a model trained with full context for offline mode
You may try the cache-aware streaming with a model trained with full context in offline mode.
Expand Down Expand Up @@ -252,7 +254,9 @@ def main():
"--output_path", type=str, help="path to output file when manifest is used as input", default=None
)
parser.add_argument(
"--pad-and-drop-preencoded", action="store_true", help="pad first audio chunk and always drop preencoded"
"--pad-and-drop-preencoded",
action="store_true",
help="Enables padding the audio input and then dropping the extra steps after the pre-encoding for the first step. It makes the outputs of the downsampling exactly as the offline mode for some techniques like striding.",
)

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/mixins/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setup_streaming_params(
pass

@abstractmethod
def get_initial_cache_state(self, batch_size, dtype, device):
def get_initial_cache_state(self, batch_size, dtype, device, max_dim):
pass

@staticmethod
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/asr/parts/submodules/causal_convs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,11 @@ def update_cache(self, x, cache=None, cache_next=None):
new_x = F.pad(x, pad=(0, self._right_padding))
new_x = torch.cat([cache[self._cache_id], new_x], dim=-1)
# todo: we should know input_x.size(-1) at config time
cache_keep_size = x.size(-1) - self.cache_drop_size
cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:]
# print("self._max_cache_len:", self._max_cache_len, "cache: size", cache.size(), "x:", x.size(), " new_x:", new_x.size(), ", cache_keep_size:", cache_keep_size)
cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size]
if cache_next is not None:
cache_keep_size = torch.tensor(x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device)
cache_keep_size = torch.clip(cache_keep_size, min=1, max=cache_next.size(-1))
cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:]
cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size]
return new_x

def forward(self, x, cache=None, cache_next=None):
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=
def update_cache(self, key, value, query, cache, cache_next):
if cache is not None:
key = value = torch.cat([cache[self._cache_id], key], dim=1)
# query.shape[1] is constant, should save it at init()
q_keep_size = query.shape[1] - self.cache_drop_size
cache_next[self._cache_id, :, :-q_keep_size, :] = cache[self._cache_id, :, q_keep_size:, :]
cache_next[self._cache_id, :, -q_keep_size:, :] = query[:, :q_keep_size, :]
if cache_next is not None:
cache_next[self._cache_id, :, :-q_keep_size, :] = cache[self._cache_id, :, q_keep_size:, :]
cache_next[self._cache_id, :, -q_keep_size:, :] = query[:, :q_keep_size, :]
return key, value, query


Expand Down
21 changes: 8 additions & 13 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,10 +1351,9 @@ def __iter__(self):
)

if self.buffer_idx == 0 and isinstance(self.streaming_cfg.shift_size, list):
shift_size = self.streaming_cfg.shift_size[0]
if self.pad_and_drop_preencoded:
shift_size = self.streaming_cfg.shift_size[1]
else:
shift_size = self.streaming_cfg.shift_size[0]
else:
shift_size = (
self.streaming_cfg.shift_size[1]
Expand All @@ -1379,18 +1378,14 @@ def __iter__(self):
# if there is not enough frames to be used as the pre-encoding cache, zeros would be added
zeros_pads = None
if self.buffer_idx == 0 and isinstance(self.streaming_cfg.pre_encode_cache_size, list):
cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[0]
if self.pad_and_drop_preencoded:
cache_pre_encode = torch.zeros(
(audio_chunk.size(0), self.input_features, self.streaming_cfg.pre_encode_cache_size[1]),
device=audio_chunk.device,
dtype=audio_chunk.dtype,
)
else:
cache_pre_encode = torch.zeros(
(audio_chunk.size(0), self.input_features, self.streaming_cfg.pre_encode_cache_size[0]),
device=audio_chunk.device,
dtype=audio_chunk.dtype,
)
cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[1]
cache_pre_encode = torch.zeros(
(audio_chunk.size(0), self.input_features, cache_pre_encode_num_frames),
device=audio_chunk.device,
dtype=audio_chunk.dtype,
)
else:
if isinstance(self.streaming_cfg.pre_encode_cache_size, list):
pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size[1]
Expand Down

0 comments on commit b1b4752

Please sign in to comment.