Skip to content

Commit e6eb6ec

Browse files
borisfomtitu1994
authored andcommitted
Cache handling without input tensors mutation (NVIDIA#6980)
* Cache handling without input tensors mutation Signed-off-by: Boris Fomitchev <[email protected]> * Cleanup Signed-off-by: Boris Fomitchev <[email protected]> * Cleanup#2 Signed-off-by: Boris Fomitchev <[email protected]> * Cleanup#3 Signed-off-by: Boris Fomitchev <[email protected]> --------- Signed-off-by: Boris Fomitchev <[email protected]> Co-authored-by: Somshubra Majumdar <[email protected]> Signed-off-by: zhehuaichen <[email protected]>
1 parent 367db81 commit e6eb6ec

File tree

6 files changed

+118
-161
lines changed

6 files changed

+118
-161
lines changed

nemo/collections/asr/models/asr_model.py

+20-44
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def output_module(self):
161161
@property
162162
def output_names(self):
163163
otypes = self.output_module.output_types
164-
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
164+
if getattr(self.input_module, 'export_cache_support', False):
165165
in_types = self.input_module.output_types
166166
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
167167
for (n, t) in list(in_types.items())[1:]:
@@ -174,7 +174,6 @@ def forward_for_export(
174174
"""
175175
This forward is used when we need to export the model to ONNX format.
176176
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
177-
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
178177
Args:
179178
input: Tensor that represents a batch of raw audio signals,
180179
of shape [B, T]. T here represents timesteps.
@@ -187,49 +186,26 @@ def forward_for_export(
187186
Returns:
188187
the output of the model
189188
"""
190-
if hasattr(self.input_module, 'forward_for_export'):
191-
if cache_last_channel is None and cache_last_time is None:
192-
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
193-
else:
194-
encoder_output = self.input_module.forward_for_export(
195-
audio_signal=input,
196-
length=length,
197-
cache_last_channel=cache_last_channel,
198-
cache_last_time=cache_last_time,
199-
cache_last_channel_len=cache_last_channel_len,
200-
)
189+
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
190+
if cache_last_channel is None:
191+
encoder_output = enc_fun(audio_signal=input, length=length)
192+
if isinstance(encoder_output, tuple):
193+
encoder_output = encoder_output[0]
201194
else:
202-
if cache_last_channel is None and cache_last_time is None:
203-
encoder_output = self.input_module(audio_signal=input, length=length)
204-
else:
205-
encoder_output = self.input_module(
206-
audio_signal=input,
207-
length=length,
208-
cache_last_channel=cache_last_channel,
209-
cache_last_time=cache_last_time,
210-
cache_last_channel_len=cache_last_channel_len,
211-
)
212-
if isinstance(encoder_output, tuple):
213-
decoder_input = encoder_output[0]
214-
else:
215-
decoder_input = encoder_output
216-
if hasattr(self.output_module, 'forward_for_export'):
217-
if cache_last_channel is None and cache_last_time is None:
218-
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
219-
else:
220-
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
221-
else:
222-
if cache_last_channel is None and cache_last_time is None:
223-
ret = self.output_module(encoder_output=decoder_input)
224-
else:
225-
ret = self.output_module(encoder_output=decoder_input)
226-
if cache_last_channel is None and cache_last_time is None:
227-
pass
228-
else:
229-
if isinstance(ret, tuple):
230-
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
231-
else:
232-
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
195+
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
196+
audio_signal=input,
197+
length=length,
198+
cache_last_channel=cache_last_channel,
199+
cache_last_time=cache_last_time,
200+
cache_last_channel_len=cache_last_channel_len,
201+
)
202+
203+
dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward)
204+
ret = dec_fun(encoder_output=encoder_output)
205+
if isinstance(ret, tuple):
206+
ret = ret[0]
207+
if cache_last_channel is not None:
208+
ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len)
233209
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)
234210

235211
@property

nemo/collections/asr/modules/conformer_encoder.py

+21-27
Original file line numberDiff line numberDiff line change
@@ -506,11 +506,6 @@ def forward_internal(
506506
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
507507
)
508508

509-
if cache_last_time is not None:
510-
cache_last_time_next = torch.zeros_like(cache_last_time)
511-
else:
512-
cache_last_time_next = None
513-
514509
# select a random att_context_size with the distribution specified by att_context_probs during training
515510
# for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
516511
if self.training and len(self.att_context_size_all) > 1:
@@ -537,7 +532,6 @@ def forward_internal(
537532
if cache_last_channel is not None:
538533
cache_len = self.streaming_cfg.last_channel_cache_size
539534
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
540-
cache_last_channel_next = torch.zeros_like(cache_last_channel)
541535
max_audio_length = max_audio_length + cache_len
542536
padding_length = length + cache_len
543537
offset = torch.neg(cache_last_channel_len) + cache_len
@@ -562,19 +556,32 @@ def forward_internal(
562556
pad_mask = pad_mask[:, cache_len:]
563557
if att_mask is not None:
564558
att_mask = att_mask[:, cache_len:]
559+
# Convert caches from the tensor to list
560+
cache_last_time_next = []
561+
cache_last_channel_next = []
565562

566563
for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
567564
original_signal = audio_signal
565+
if cache_last_channel is not None:
566+
cache_last_channel_cur = cache_last_channel[lth]
567+
cache_last_time_cur = cache_last_time[lth]
568+
else:
569+
cache_last_channel_cur = None
570+
cache_last_time_cur = None
568571
audio_signal = layer(
569572
x=audio_signal,
570573
att_mask=att_mask,
571574
pos_emb=pos_emb,
572575
pad_mask=pad_mask,
573-
cache_last_channel=cache_last_channel,
574-
cache_last_time=cache_last_time,
575-
cache_last_channel_next=cache_last_channel_next,
576-
cache_last_time_next=cache_last_time_next,
576+
cache_last_channel=cache_last_channel_cur,
577+
cache_last_time=cache_last_time_cur,
577578
)
579+
580+
if cache_last_channel_cur is not None:
581+
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
582+
cache_last_channel_next.append(cache_last_channel_cur)
583+
cache_last_time_next.append(cache_last_time_cur)
584+
578585
# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
579586
if self.training and drop_prob > 0.0:
580587
should_drop = torch.rand(1) < drop_prob
@@ -627,6 +634,8 @@ def forward_internal(
627634
length = length.to(dtype=torch.int64)
628635

629636
if cache_last_channel is not None:
637+
cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
638+
cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
630639
return (
631640
audio_signal,
632641
length,
@@ -861,20 +870,12 @@ def setup_streaming_params(
861870
else:
862871
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor
863872

864-
# counting the number of the layers need caching
865-
streaming_cfg.last_channel_num = 0
866-
streaming_cfg.last_time_num = 0
867873
for m in self.layers.modules():
868874
if hasattr(m, "_max_cache_len"):
869875
if isinstance(m, MultiHeadAttention):
870-
m._cache_id = streaming_cfg.last_channel_num
871876
m.cache_drop_size = streaming_cfg.cache_drop_size
872-
streaming_cfg.last_channel_num += 1
873-
874877
if isinstance(m, CausalConv1D):
875-
m._cache_id = streaming_cfg.last_time_num
876878
m.cache_drop_size = streaming_cfg.cache_drop_size
877-
streaming_cfg.last_time_num += 1
878879

879880
self.streaming_cfg = streaming_cfg
880881

@@ -887,19 +888,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None
887888
create_tensor = torch.zeros
888889
last_time_cache_size = self.conv_context_size[0]
889890
cache_last_channel = create_tensor(
890-
(
891-
self.streaming_cfg.last_channel_num,
892-
batch_size,
893-
self.streaming_cfg.last_channel_cache_size,
894-
self.d_model,
895-
),
891+
(len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,),
896892
device=device,
897893
dtype=dtype,
898894
)
899895
cache_last_time = create_tensor(
900-
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size),
901-
device=device,
902-
dtype=dtype,
896+
(len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype,
903897
)
904898
if max_dim > 0:
905899
cache_last_channel_len = torch.randint(

nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -147,26 +147,26 @@ def __init__(
147147
# reset parameters for Q to be identity operation
148148
self.reset_parameters()
149149

150-
def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None):
150+
def forward(self, query, key, value, mask, pos_emb=None, cache=None):
151151
"""Compute 'Scaled Dot Product Attention'.
152152
Args:
153153
query (torch.Tensor): (batch, time1, size)
154154
key (torch.Tensor): (batch, time2, size)
155155
value(torch.Tensor): (batch, time2, size)
156156
mask (torch.Tensor): (batch, time1, time2)
157-
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
158-
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
157+
cache (torch.Tensor) : (batch, time_cache, size)
159158
160159
returns:
161160
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
161+
cache (torch.Tensor) : (batch, time_cache_next, size)
162162
"""
163163
# Need to perform duplicate computations as at this point the tensors have been
164164
# separated by the adapter forward
165165
query = self.pre_norm(query)
166166
key = self.pre_norm(key)
167167
value = self.pre_norm(value)
168168

169-
return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
169+
return super().forward(query, key, value, mask, pos_emb, cache=cache)
170170

171171
def reset_parameters(self):
172172
with torch.no_grad():
@@ -242,26 +242,26 @@ def __init__(
242242
# reset parameters for Q to be identity operation
243243
self.reset_parameters()
244244

245-
def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None):
245+
def forward(self, query, key, value, mask, pos_emb, cache=None):
246246
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
247247
Args:
248248
query (torch.Tensor): (batch, time1, size)
249249
key (torch.Tensor): (batch, time2, size)
250250
value(torch.Tensor): (batch, time2, size)
251251
mask (torch.Tensor): (batch, time1, time2)
252252
pos_emb (torch.Tensor) : (batch, time1, size)
253-
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
254-
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
253+
cache (torch.Tensor) : (batch, time_cache, size)
255254
Returns:
256255
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
256+
cache_next (torch.Tensor) : (batch, time_cache_next, size)
257257
"""
258258
# Need to perform duplicate computations as at this point the tensors have been
259259
# separated by the adapter forward
260260
query = self.pre_norm(query)
261261
key = self.pre_norm(key)
262262
value = self.pre_norm(value)
263263

264-
return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
264+
return super().forward(query, key, value, mask, pos_emb, cache=cache)
265265

266266
def reset_parameters(self):
267267
with torch.no_grad():

nemo/collections/asr/parts/submodules/causal_convs.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
raise ValueError("Argument padding should be set to None for CausalConv2D.")
4646
self._left_padding = kernel_size - 1
4747
self._right_padding = stride - 1
48-
self._cache_id = None
4948

5049
padding = 0
5150
super(CausalConv2D, self).__init__(
@@ -113,7 +112,6 @@ def __init__(
113112
raise ValueError(f"Invalid padding param: {padding}!")
114113

115114
self._max_cache_len = self._left_padding
116-
self._cache_id = None
117115

118116
super(CausalConv1D, self).__init__(
119117
in_channels=in_channels,
@@ -129,21 +127,21 @@ def __init__(
129127
dtype=dtype,
130128
)
131129

132-
def update_cache(self, x, cache=None, cache_next=None):
130+
def update_cache(self, x, cache=None):
133131
if cache is None:
134132
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
135133
else:
136134
new_x = F.pad(x, pad=(0, self._right_padding))
137-
new_x = torch.cat([cache[self._cache_id], new_x], dim=-1)
138-
# todo: we should know input_x.size(-1) at config time
139-
if cache_next is not None:
140-
cache_keep_size = torch.tensor(x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device)
141-
cache_keep_size = torch.clip(cache_keep_size, min=1, max=cache_next.size(-1))
142-
cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:]
143-
cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size]
144-
return new_x
145-
146-
def forward(self, x, cache=None, cache_next=None):
147-
x = self.update_cache(x, cache=cache, cache_next=cache_next)
135+
new_x = torch.cat([cache, new_x], dim=-1)
136+
if self.cache_drop_size > 0:
137+
x = x[:, :, : -self.cache_drop_size]
138+
cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1)
139+
return new_x, cache
140+
141+
def forward(self, x, cache=None):
142+
x, cache = self.update_cache(x, cache=cache)
148143
x = super().forward(x)
149-
return x
144+
if cache is None:
145+
return x
146+
else:
147+
return x, cache

0 commit comments

Comments
 (0)