Skip to content

Commit

Permalink
Cache handling without input tensors mutation (NVIDIA#6980)
Browse files Browse the repository at this point in the history
* Cache handling without input tensors mutation

Signed-off-by: Boris Fomitchev <[email protected]>

* Cleanup

Signed-off-by: Boris Fomitchev <[email protected]>

* Cleanup#2

Signed-off-by: Boris Fomitchev <[email protected]>

* Cleanup#3

Signed-off-by: Boris Fomitchev <[email protected]>

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: zhehuaichen <[email protected]>
  • Loading branch information
2 people authored and zhehuaichen committed Oct 4, 2023
1 parent 0632f9f commit 91fb316
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 161 deletions.
64 changes: 20 additions & 44 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def output_module(self):
@property
def output_names(self):
otypes = self.output_module.output_types
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
Expand All @@ -174,7 +174,6 @@ def forward_for_export(
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
Args:
input: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps.
Expand All @@ -187,49 +186,26 @@ def forward_for_export(
Returns:
the output of the model
"""
if hasattr(self.input_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
else:
encoder_output = self.input_module.forward_for_export(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(audio_signal=input, length=length)
else:
encoder_output = self.input_module(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
decoder_input = encoder_output
if hasattr(self.output_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(encoder_output=decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
else:
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)

dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward)
ret = dec_fun(encoder_output=encoder_output)
if isinstance(ret, tuple):
ret = ret[0]
if cache_last_channel is not None:
ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len)
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)

@property
Expand Down
48 changes: 21 additions & 27 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,6 @@ def forward_internal(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
)

if cache_last_time is not None:
cache_last_time_next = torch.zeros_like(cache_last_time)
else:
cache_last_time_next = None

# select a random att_context_size with the distribution specified by att_context_probs during training
# for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
if self.training and len(self.att_context_size_all) > 1:
Expand All @@ -537,7 +532,6 @@ def forward_internal(
if cache_last_channel is not None:
cache_len = self.streaming_cfg.last_channel_cache_size
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
cache_last_channel_next = torch.zeros_like(cache_last_channel)
max_audio_length = max_audio_length + cache_len
padding_length = length + cache_len
offset = torch.neg(cache_last_channel_len) + cache_len
Expand All @@ -562,19 +556,32 @@ def forward_internal(
pad_mask = pad_mask[:, cache_len:]
if att_mask is not None:
att_mask = att_mask[:, cache_len:]
# Convert caches from the tensor to list
cache_last_time_next = []
cache_last_channel_next = []

for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
original_signal = audio_signal
if cache_last_channel is not None:
cache_last_channel_cur = cache_last_channel[lth]
cache_last_time_cur = cache_last_time[lth]
else:
cache_last_channel_cur = None
cache_last_time_cur = None
audio_signal = layer(
x=audio_signal,
att_mask=att_mask,
pos_emb=pos_emb,
pad_mask=pad_mask,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_next=cache_last_channel_next,
cache_last_time_next=cache_last_time_next,
cache_last_channel=cache_last_channel_cur,
cache_last_time=cache_last_time_cur,
)

if cache_last_channel_cur is not None:
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
cache_last_channel_next.append(cache_last_channel_cur)
cache_last_time_next.append(cache_last_time_cur)

# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
if self.training and drop_prob > 0.0:
should_drop = torch.rand(1) < drop_prob
Expand Down Expand Up @@ -627,6 +634,8 @@ def forward_internal(
length = length.to(dtype=torch.int64)

if cache_last_channel is not None:
cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
return (
audio_signal,
length,
Expand Down Expand Up @@ -861,20 +870,12 @@ def setup_streaming_params(
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor

# counting the number of the layers need caching
streaming_cfg.last_channel_num = 0
streaming_cfg.last_time_num = 0
for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
if isinstance(m, MultiHeadAttention):
m._cache_id = streaming_cfg.last_channel_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_channel_num += 1

if isinstance(m, CausalConv1D):
m._cache_id = streaming_cfg.last_time_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_time_num += 1

self.streaming_cfg = streaming_cfg

Expand All @@ -887,19 +888,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None
create_tensor = torch.zeros
last_time_cache_size = self.conv_context_size[0]
cache_last_channel = create_tensor(
(
self.streaming_cfg.last_channel_num,
batch_size,
self.streaming_cfg.last_channel_cache_size,
self.d_model,
),
(len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,),
device=device,
dtype=dtype,
)
cache_last_time = create_tensor(
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size),
device=device,
dtype=dtype,
(len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype,
)
if max_dim > 0:
cache_last_channel_len = torch.randint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,26 @@ def __init__(
# reset parameters for Q to be identity operation
self.reset_parameters()

def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None):
def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""Compute 'Scaled Dot Product Attention'.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
cache (torch.Tensor) : (batch, time_cache, size)
returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
# Need to perform duplicate computations as at this point the tensors have been
# separated by the adapter forward
query = self.pre_norm(query)
key = self.pre_norm(key)
value = self.pre_norm(value)

return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
return super().forward(query, key, value, mask, pos_emb, cache=cache)

def reset_parameters(self):
with torch.no_grad():
Expand Down Expand Up @@ -242,26 +242,26 @@ def __init__(
# reset parameters for Q to be identity operation
self.reset_parameters()

def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None):
def forward(self, query, key, value, mask, pos_emb, cache=None):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
cache (torch.Tensor) : (batch, time_cache, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache_next (torch.Tensor) : (batch, time_cache_next, size)
"""
# Need to perform duplicate computations as at this point the tensors have been
# separated by the adapter forward
query = self.pre_norm(query)
key = self.pre_norm(key)
value = self.pre_norm(value)

return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
return super().forward(query, key, value, mask, pos_emb, cache=cache)

def reset_parameters(self):
with torch.no_grad():
Expand Down
28 changes: 13 additions & 15 deletions nemo/collections/asr/parts/submodules/causal_convs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
raise ValueError("Argument padding should be set to None for CausalConv2D.")
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
self._cache_id = None

padding = 0
super(CausalConv2D, self).__init__(
Expand Down Expand Up @@ -113,7 +112,6 @@ def __init__(
raise ValueError(f"Invalid padding param: {padding}!")

self._max_cache_len = self._left_padding
self._cache_id = None

super(CausalConv1D, self).__init__(
in_channels=in_channels,
Expand All @@ -129,21 +127,21 @@ def __init__(
dtype=dtype,
)

def update_cache(self, x, cache=None, cache_next=None):
def update_cache(self, x, cache=None):
if cache is None:
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
else:
new_x = F.pad(x, pad=(0, self._right_padding))
new_x = torch.cat([cache[self._cache_id], new_x], dim=-1)
# todo: we should know input_x.size(-1) at config time
if cache_next is not None:
cache_keep_size = torch.tensor(x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device)
cache_keep_size = torch.clip(cache_keep_size, min=1, max=cache_next.size(-1))
cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:]
cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size]
return new_x

def forward(self, x, cache=None, cache_next=None):
x = self.update_cache(x, cache=cache, cache_next=cache_next)
new_x = torch.cat([cache, new_x], dim=-1)
if self.cache_drop_size > 0:
x = x[:, :, : -self.cache_drop_size]
cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1)
return new_x, cache

def forward(self, x, cache=None):
x, cache = self.update_cache(x, cache=cache)
x = super().forward(x)
return x
if cache is None:
return x
else:
return x, cache
Loading

0 comments on commit 91fb316

Please sign in to comment.