Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache handling without input tensors mutation #6980

Merged
merged 7 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 20 additions & 44 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def output_module(self):
@property
def output_names(self):
otypes = self.output_module.output_types
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
Expand All @@ -174,7 +174,6 @@ def forward_for_export(
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
Args:
input: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps.
Expand All @@ -187,49 +186,26 @@ def forward_for_export(
Returns:
the output of the model
"""
if hasattr(self.input_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
else:
encoder_output = self.input_module.forward_for_export(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(audio_signal=input, length=length)
else:
encoder_output = self.input_module(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
decoder_input = encoder_output
if hasattr(self.output_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(encoder_output=decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
else:
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)

dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward)
ret = dec_fun(encoder_output=encoder_output)
if isinstance(ret, tuple):
ret = ret[0]
if cache_last_channel is not None:
ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len)
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)

@property
Expand Down
48 changes: 21 additions & 27 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,6 @@ def forward_internal(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
)

if cache_last_time is not None:
cache_last_time_next = torch.zeros_like(cache_last_time)
else:
cache_last_time_next = None

# select a random att_context_size with the distribution specified by att_context_probs during training
# for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
if self.training and len(self.att_context_size_all) > 1:
Expand All @@ -536,7 +531,6 @@ def forward_internal(
if cache_last_channel is not None:
cache_len = self.streaming_cfg.last_channel_cache_size
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
cache_last_channel_next = torch.zeros_like(cache_last_channel)
max_audio_length = max_audio_length + cache_len
padding_length = length + cache_len
offset = torch.neg(cache_last_channel_len) + cache_len
Expand All @@ -561,19 +555,32 @@ def forward_internal(
pad_mask = pad_mask[:, cache_len:]
if att_mask is not None:
att_mask = att_mask[:, cache_len:]
# Convert caches from the tensor to list
cache_last_time_next = []
cache_last_channel_next = []

for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
original_signal = audio_signal
if cache_last_channel is not None:
cache_last_channel_cur = cache_last_channel[lth]
cache_last_time_cur = cache_last_time[lth]
else:
cache_last_channel_cur = None
cache_last_time_cur = None
audio_signal = layer(
x=audio_signal,
att_mask=att_mask,
pos_emb=pos_emb,
pad_mask=pad_mask,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_next=cache_last_channel_next,
cache_last_time_next=cache_last_time_next,
cache_last_channel=cache_last_channel_cur,
cache_last_time=cache_last_time_cur,
)

if cache_last_channel_cur is not None:
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
cache_last_channel_next.append(cache_last_channel_cur)
cache_last_time_next.append(cache_last_time_cur)

# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
if self.training and drop_prob > 0.0:
should_drop = torch.rand(1) < drop_prob
Expand Down Expand Up @@ -626,6 +633,8 @@ def forward_internal(
length = length.to(dtype=torch.int64)

if cache_last_channel is not None:
cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
return (
audio_signal,
length,
Expand Down Expand Up @@ -860,20 +869,12 @@ def setup_streaming_params(
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor

# counting the number of the layers need caching
streaming_cfg.last_channel_num = 0
streaming_cfg.last_time_num = 0
for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
if isinstance(m, MultiHeadAttention):
m._cache_id = streaming_cfg.last_channel_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_channel_num += 1

if isinstance(m, CausalConv1D):
m._cache_id = streaming_cfg.last_time_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_time_num += 1

self.streaming_cfg = streaming_cfg

Expand All @@ -886,19 +887,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None
create_tensor = torch.zeros
last_time_cache_size = self.conv_context_size[0]
cache_last_channel = create_tensor(
(
self.streaming_cfg.last_channel_num,
batch_size,
self.streaming_cfg.last_channel_cache_size,
self.d_model,
),
(len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,),
device=device,
dtype=dtype,
)
cache_last_time = create_tensor(
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size),
device=device,
dtype=dtype,
(len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype,
)
if max_dim > 0:
cache_last_channel_len = torch.randint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,26 @@ def __init__(
# reset parameters for Q to be identity operation
self.reset_parameters()

def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None):
def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""Compute 'Scaled Dot Product Attention'.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
cache (torch.Tensor) : (batch, time_cache, size)

returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
# Need to perform duplicate computations as at this point the tensors have been
# separated by the adapter forward
query = self.pre_norm(query)
key = self.pre_norm(key)
value = self.pre_norm(value)

return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
return super().forward(query, key, value, mask, pos_emb, cache=cache)

def reset_parameters(self):
with torch.no_grad():
Expand Down Expand Up @@ -242,26 +242,26 @@ def __init__(
# reset parameters for Q to be identity operation
self.reset_parameters()

def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None):
def forward(self, query, key, value, mask, pos_emb, cache=None):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
cache (torch.Tensor) : (batch, time_cache, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache_next (torch.Tensor) : (batch, time_cache_next, size)
"""
# Need to perform duplicate computations as at this point the tensors have been
# separated by the adapter forward
query = self.pre_norm(query)
key = self.pre_norm(key)
value = self.pre_norm(value)

return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
return super().forward(query, key, value, mask, pos_emb, cache=cache)

def reset_parameters(self):
with torch.no_grad():
Expand Down
28 changes: 13 additions & 15 deletions nemo/collections/asr/parts/submodules/causal_convs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
raise ValueError("Argument padding should be set to None for CausalConv2D.")
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
self._cache_id = None

padding = 0
super(CausalConv2D, self).__init__(
Expand Down Expand Up @@ -113,7 +112,6 @@ def __init__(
raise ValueError(f"Invalid padding param: {padding}!")

self._max_cache_len = self._left_padding
self._cache_id = None

super(CausalConv1D, self).__init__(
in_channels=in_channels,
Expand All @@ -129,21 +127,21 @@ def __init__(
dtype=dtype,
)

def update_cache(self, x, cache=None, cache_next=None):
def update_cache(self, x, cache=None):
if cache is None:
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
else:
new_x = F.pad(x, pad=(0, self._right_padding))
new_x = torch.cat([cache[self._cache_id], new_x], dim=-1)
# todo: we should know input_x.size(-1) at config time
if cache_next is not None:
cache_keep_size = torch.tensor(x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device)
cache_keep_size = torch.clip(cache_keep_size, min=1, max=cache_next.size(-1))
cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:]
cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size]
return new_x

def forward(self, x, cache=None, cache_next=None):
x = self.update_cache(x, cache=cache, cache_next=cache_next)
new_x = torch.cat([cache, new_x], dim=-1)
if self.cache_drop_size > 0:
x = x[:, :, : -self.cache_drop_size]
cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1)
return new_x, cache

def forward(self, x, cache=None):
x, cache = self.update_cache(x, cache=cache)
x = super().forward(x)
return x
if cache is None:
return x
else:
return x, cache
Loading