From 91fb3164a2898e4292023784c3b47b064f7915d8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 7 Jul 2023 14:34:55 -0700 Subject: [PATCH] Cache handling without input tensors mutation (#6980) * Cache handling without input tensors mutation Signed-off-by: Boris Fomitchev * Cleanup Signed-off-by: Boris Fomitchev * Cleanup#2 Signed-off-by: Boris Fomitchev * Cleanup#3 Signed-off-by: Boris Fomitchev --------- Signed-off-by: Boris Fomitchev Co-authored-by: Somshubra Majumdar Signed-off-by: zhehuaichen --- nemo/collections/asr/models/asr_model.py | 64 ++++++----------- .../asr/modules/conformer_encoder.py | 48 ++++++------- .../multi_head_attention_adapter_module.py | 16 ++--- .../asr/parts/submodules/causal_convs.py | 28 ++++---- .../asr/parts/submodules/conformer_modules.py | 70 +++++++------------ .../parts/submodules/multi_head_attention.py | 53 ++++++++------ 6 files changed, 118 insertions(+), 161 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 20be6cc162037..7e03d587139f1 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -161,7 +161,7 @@ def output_module(self): @property def output_names(self): otypes = self.output_module.output_types - if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support: + if getattr(self.input_module, 'export_cache_support', False): in_types = self.input_module.output_types otypes = {n: t for (n, t) in list(otypes.items())[:1]} for (n, t) in list(in_types.items())[1:]: @@ -174,7 +174,6 @@ def forward_for_export( """ This forward is used when we need to export the model to ONNX format. Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models. - When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case. Args: input: Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps. @@ -187,49 +186,26 @@ def forward_for_export( Returns: the output of the model """ - if hasattr(self.input_module, 'forward_for_export'): - if cache_last_channel is None and cache_last_time is None: - encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length) - else: - encoder_output = self.input_module.forward_for_export( - audio_signal=input, - length=length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - ) + enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) + if cache_last_channel is None: + encoder_output = enc_fun(audio_signal=input, length=length) + if isinstance(encoder_output, tuple): + encoder_output = encoder_output[0] else: - if cache_last_channel is None and cache_last_time is None: - encoder_output = self.input_module(audio_signal=input, length=length) - else: - encoder_output = self.input_module( - audio_signal=input, - length=length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - ) - if isinstance(encoder_output, tuple): - decoder_input = encoder_output[0] - else: - decoder_input = encoder_output - if hasattr(self.output_module, 'forward_for_export'): - if cache_last_channel is None and cache_last_time is None: - ret = self.output_module.forward_for_export(encoder_output=decoder_input) - else: - ret = self.output_module.forward_for_export(encoder_output=decoder_input) - else: - if cache_last_channel is None and cache_last_time is None: - ret = self.output_module(encoder_output=decoder_input) - else: - ret = self.output_module(encoder_output=decoder_input) - if cache_last_channel is None and cache_last_time is None: - pass - else: - if isinstance(ret, tuple): - ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4]) - else: - ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4]) + encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( + audio_signal=input, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward) + ret = dec_fun(encoder_output=encoder_output) + if isinstance(ret, tuple): + ret = ret[0] + if cache_last_channel is not None: + ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len) return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32) @property diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index ba73f39a25fb6..27b2ea6e4d45b 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -506,11 +506,6 @@ def forward_internal( (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device ) - if cache_last_time is not None: - cache_last_time_next = torch.zeros_like(cache_last_time) - else: - cache_last_time_next = None - # select a random att_context_size with the distribution specified by att_context_probs during training # for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size if self.training and len(self.att_context_size_all) > 1: @@ -537,7 +532,6 @@ def forward_internal( if cache_last_channel is not None: cache_len = self.streaming_cfg.last_channel_cache_size cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size - cache_last_channel_next = torch.zeros_like(cache_last_channel) max_audio_length = max_audio_length + cache_len padding_length = length + cache_len offset = torch.neg(cache_last_channel_len) + cache_len @@ -562,19 +556,32 @@ def forward_internal( pad_mask = pad_mask[:, cache_len:] if att_mask is not None: att_mask = att_mask[:, cache_len:] + # Convert caches from the tensor to list + cache_last_time_next = [] + cache_last_channel_next = [] for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): original_signal = audio_signal + if cache_last_channel is not None: + cache_last_channel_cur = cache_last_channel[lth] + cache_last_time_cur = cache_last_time[lth] + else: + cache_last_channel_cur = None + cache_last_time_cur = None audio_signal = layer( x=audio_signal, att_mask=att_mask, pos_emb=pos_emb, pad_mask=pad_mask, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_next=cache_last_channel_next, - cache_last_time_next=cache_last_time_next, + cache_last_channel=cache_last_channel_cur, + cache_last_time=cache_last_time_cur, ) + + if cache_last_channel_cur is not None: + (audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal + cache_last_channel_next.append(cache_last_channel_cur) + cache_last_time_next.append(cache_last_time_cur) + # applying stochastic depth logic from https://arxiv.org/abs/2102.03216 if self.training and drop_prob > 0.0: should_drop = torch.rand(1) < drop_prob @@ -627,6 +634,8 @@ def forward_internal( length = length.to(dtype=torch.int64) if cache_last_channel is not None: + cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0) + cache_last_time_next = torch.stack(cache_last_time_next, dim=0) return ( audio_signal, length, @@ -861,20 +870,12 @@ def setup_streaming_params( else: streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor - # counting the number of the layers need caching - streaming_cfg.last_channel_num = 0 - streaming_cfg.last_time_num = 0 for m in self.layers.modules(): if hasattr(m, "_max_cache_len"): if isinstance(m, MultiHeadAttention): - m._cache_id = streaming_cfg.last_channel_num m.cache_drop_size = streaming_cfg.cache_drop_size - streaming_cfg.last_channel_num += 1 - if isinstance(m, CausalConv1D): - m._cache_id = streaming_cfg.last_time_num m.cache_drop_size = streaming_cfg.cache_drop_size - streaming_cfg.last_time_num += 1 self.streaming_cfg = streaming_cfg @@ -887,19 +888,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None create_tensor = torch.zeros last_time_cache_size = self.conv_context_size[0] cache_last_channel = create_tensor( - ( - self.streaming_cfg.last_channel_num, - batch_size, - self.streaming_cfg.last_channel_cache_size, - self.d_model, - ), + (len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,), device=device, dtype=dtype, ) cache_last_time = create_tensor( - (self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size), - device=device, - dtype=dtype, + (len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype, ) if max_dim > 0: cache_last_channel_len = torch.randint( diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 169dde48602f0..563d4219baa78 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -147,18 +147,18 @@ def __init__( # reset parameters for Q to be identity operation self.reset_parameters() - def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb=None, cache=None): """Compute 'Scaled Dot Product Attention'. Args: query (torch.Tensor): (batch, time1, size) key (torch.Tensor): (batch, time2, size) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ # Need to perform duplicate computations as at this point the tensors have been # separated by the adapter forward @@ -166,7 +166,7 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= key = self.pre_norm(key) value = self.pre_norm(value) - return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next) + return super().forward(query, key, value, mask, pos_emb, cache=cache) def reset_parameters(self): with torch.no_grad(): @@ -242,7 +242,7 @@ def __init__( # reset parameters for Q to be identity operation self.reset_parameters() - def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb, cache=None): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): (batch, time1, size) @@ -250,10 +250,10 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) pos_emb (torch.Tensor) : (batch, time1, size) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) Returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache_next (torch.Tensor) : (batch, time_cache_next, size) """ # Need to perform duplicate computations as at this point the tensors have been # separated by the adapter forward @@ -261,7 +261,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) key = self.pre_norm(key) value = self.pre_norm(value) - return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next) + return super().forward(query, key, value, mask, pos_emb, cache=cache) def reset_parameters(self): with torch.no_grad(): diff --git a/nemo/collections/asr/parts/submodules/causal_convs.py b/nemo/collections/asr/parts/submodules/causal_convs.py index 25f8418021541..c6251690b1b12 100644 --- a/nemo/collections/asr/parts/submodules/causal_convs.py +++ b/nemo/collections/asr/parts/submodules/causal_convs.py @@ -45,7 +45,6 @@ def __init__( raise ValueError("Argument padding should be set to None for CausalConv2D.") self._left_padding = kernel_size - 1 self._right_padding = stride - 1 - self._cache_id = None padding = 0 super(CausalConv2D, self).__init__( @@ -113,7 +112,6 @@ def __init__( raise ValueError(f"Invalid padding param: {padding}!") self._max_cache_len = self._left_padding - self._cache_id = None super(CausalConv1D, self).__init__( in_channels=in_channels, @@ -129,21 +127,21 @@ def __init__( dtype=dtype, ) - def update_cache(self, x, cache=None, cache_next=None): + def update_cache(self, x, cache=None): if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) else: new_x = F.pad(x, pad=(0, self._right_padding)) - new_x = torch.cat([cache[self._cache_id], new_x], dim=-1) - # todo: we should know input_x.size(-1) at config time - if cache_next is not None: - cache_keep_size = torch.tensor(x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device) - cache_keep_size = torch.clip(cache_keep_size, min=1, max=cache_next.size(-1)) - cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:] - cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size] - return new_x - - def forward(self, x, cache=None, cache_next=None): - x = self.update_cache(x, cache=cache, cache_next=cache_next) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + x = x[:, :, : -self.cache_drop_size] + cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1) + return new_x, cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) x = super().forward(x) - return x + if cache is None: + return x + else: + return x, cache diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 579b78a8f5a89..677d2acd9f2e7 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -138,29 +138,19 @@ def __init__( self.dropout = nn.Dropout(dropout) self.norm_out = LayerNorm(d_model) - def forward( - self, - x, - att_mask=None, - pos_emb=None, - pad_mask=None, - cache_last_channel=None, - cache_last_time=None, - cache_last_channel_next=None, - cache_last_time_next=None, - ): + def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None): """ Args: x (torch.Tensor): input signals (B, T, d_model) att_mask (torch.Tensor): attention masks(B, T, T) pos_emb (torch.Tensor): (L, 1, d_model) pad_mask (torch.tensor): padding mask - cache_last_channel (torch.tensor) : cache for MHA layers (N, B, T_cache, d_model) - cache_last_time (torch.tensor) : cache for convolutional layers (N, B, d_model, T_cache) - cache_last_channel_next (torch.tensor) : next cache for MHA layers (N, B, T_cache, d_model) - cache_last_time_next (torch.tensor) : next cache for convolutional layers (N, B, d_model, T_cache) + cache_last_channel (torch.tensor) : cache for MHA layers (B, T_cache, d_model) + cache_last_time (torch.tensor) : cache for convolutional layers (B, d_model, T_cache) Returns: x (torch.Tensor): (B, T, d_model) + cache_last_channel (torch.tensor) : next cache for MHA layers (B, T_cache, d_model) + cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache) """ residual = x x = self.norm_feed_forward1(x) @@ -169,31 +159,17 @@ def forward( x = self.norm_self_att(residual) if self.self_attention_model == 'rel_pos': - x = self.self_attn( - query=x, - key=x, - value=x, - mask=att_mask, - pos_emb=pos_emb, - cache=cache_last_channel, - cache_next=cache_last_channel_next, - ) + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel) elif self.self_attention_model == 'rel_pos_local_attn': - x = self.self_attn( - query=x, - key=x, - value=x, - pad_mask=pad_mask, - pos_emb=pos_emb, - cache=cache_last_channel, - cache_next=cache_last_channel_next, - ) + x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel) elif self.self_attention_model == 'abs_pos': - x = self.self_attn( - query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel, cache_next=cache_last_channel_next - ) + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel) else: x = None + + if x is not None and cache_last_channel is not None: + (x, cache_last_channel) = x + residual = residual + self.dropout(x) if self.is_adapter_available(): @@ -208,7 +184,9 @@ def forward( residual = pack_ip['x'] x = self.norm_conv(residual) - x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time, cache_next=cache_last_time_next) + x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time) + if cache_last_time is not None: + (x, cache_last_time) = x residual = residual + self.dropout(x) x = self.norm_feed_forward2(residual) @@ -228,8 +206,10 @@ def forward( if self.is_access_enabled() and self.access_cfg.get('save_encoder_tensors', False): self.register_accessible_tensor(name='encoder', tensor=x) - - return x + if cache_last_channel is None: + return x + else: + return x, cache_last_channel, cache_last_time def forward_single_enabled_adapter_( self, @@ -355,7 +335,7 @@ def __init__( in_channels=dw_conv_input_dim, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True ) - def forward(self, x, pad_mask=None, cache=None, cache_next=None): + def forward(self, x, pad_mask=None, cache=None): x = x.transpose(1, 2) x = self.pointwise_conv1(x) @@ -368,10 +348,9 @@ def forward(self, x, pad_mask=None, cache=None, cache_next=None): if pad_mask is not None: x = x.float().masked_fill(pad_mask.unsqueeze(1), 0.0) + x = self.depthwise_conv(x, cache=cache) if cache is not None: - x = self.depthwise_conv(x, cache=cache, cache_next=cache_next) - else: - x = self.depthwise_conv(x) + x, cache = x if self.norm_type == "layer_norm": x = x.transpose(1, 2) @@ -383,7 +362,10 @@ def forward(self, x, pad_mask=None, cache=None, cache_next=None): x = self.activation(x) x = self.pointwise_conv2(x) x = x.transpose(1, 2) - return x + if cache is None: + return x + else: + return x, cache def reset_parameters_conv(self): pw1_max = pw2_max = self.d_model ** -0.5 diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index b7356ffe87e4b..a0253524419ef 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -73,7 +73,6 @@ def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0): self.dropout = nn.Dropout(p=dropout_rate) self._max_cache_len = max_cache_len - self._cache_id = None def forward_qkv(self, query, key, value): """Transforms query, key and value. @@ -119,20 +118,20 @@ def forward_attention(self, value, scores, mask): return self.linear_out(x) # (batch, time1, d_model) - def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb=None, cache=None): """Compute 'Scaled Dot Product Attention'. Args: query (torch.Tensor): (batch, time1, size) key (torch.Tensor): (batch, time2, size) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ - key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) @@ -142,17 +141,17 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= q, k, v = self.forward_qkv(query, key, value) scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k out = self.forward_attention(v, scores, mask) + if cache is None: + return out + else: + return out, cache - return out - - def update_cache(self, key, value, query, cache, cache_next): + def update_cache(self, key, value, query, cache): if cache is not None: - key = value = torch.cat([cache[self._cache_id], key], dim=1) + key = value = torch.cat([cache, key], dim=1) q_keep_size = query.shape[1] - self.cache_drop_size - if cache_next is not None: - cache_next[self._cache_id, :, :-q_keep_size, :] = cache[self._cache_id, :, q_keep_size:, :] - cache_next[self._cache_id, :, -q_keep_size:, :] = query[:, :q_keep_size, :] - return key, value, query + cache = torch.cat([cache[:, q_keep_size:, :], query[:, :q_keep_size, :]], dim=1) + return key, value, query, cache class RelPositionMultiHeadAttention(MultiHeadAttention): @@ -195,7 +194,7 @@ def rel_shift(self, x): x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2) return x - def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb, cache=None): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): (batch, time1, size) @@ -203,12 +202,13 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) pos_emb (torch.Tensor) : (batch, time1, size) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) + Returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ - key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) @@ -244,7 +244,10 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) out = self.forward_attention(v, scores, mask) - return out + if cache is None: + return out + else: + return out, cache class RelPositionMultiHeadAttentionLongformer(RelPositionMultiHeadAttention): @@ -298,7 +301,7 @@ def __init__( self.global_k = nn.Linear(n_feat, n_feat) self.global_v = nn.Linear(n_feat, n_feat) - def forward(self, query, key, value, pad_mask, pos_emb, cache=None, cache_next=None): + def forward(self, query, key, value, pad_mask, pos_emb, cache=None): """Compute Scaled Dot Product Local Attention with rel. positional encoding. using overlapping chunks Args: query (torch.Tensor): (batch, time, size) @@ -306,13 +309,13 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None, cache_next=N value(torch.Tensor): (batch, time, size) pad_mask (torch.Tensor): (batch, time) pos_emb (torch.Tensor) : (batch, 2w + 1, size) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) Returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ - key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) @@ -453,7 +456,11 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None, cache_next=N out[is_index_global_attn_nonzero] += out_global_to_all - return self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]) + ret = self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]) + if cache is None: + return ret + else: + return ret, cache def _get_global_attn_indices(self, is_index_global_attn: torch.Tensor) -> Tuple: """