Skip to content

Commit

Permalink
Fast Conformer global token fix (#7085)
Browse files Browse the repository at this point in the history
* old way

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* remove extra

Signed-off-by: sam1373 <[email protected]>

* clean

Signed-off-by: sam1373 <[email protected]>

* clean

Signed-off-by: sam1373 <[email protected]>

* clean

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* fix

Signed-off-by: sam1373 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: sam1373 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: jubick1337 <[email protected]>
  • Loading branch information
2 people authored and jubick1337 committed Aug 8, 2023
1 parent 103d94f commit 2156be4
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,6 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):

scores += d_mask

attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn = self.dropout(attn)
# (batch, head, time, 2w + 1)

out = self.sliding_chunks_matmul_pv(attn, v, w).reshape(n_batch, -1, self.h * self.d_k)

if self.global_tokens > 0:

# create q, k, v for global attn
Expand Down Expand Up @@ -426,21 +420,34 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
).transpose(1, 2)

global_key_attn = torch.softmax(global_key_attn, dim=-1).masked_fill(mask, 0.0)
global_key_attn = self.dropout(global_key_attn)
# concat to local_attn_probs
# (batch, time, head, max_num_global_attn_indices + 2*w)
scores = torch.cat((global_key_attn, scores), dim=-1)

# compute outputs for global attention from all tokens to global
# (batch, time, head x head_dim)
out_all_to_global = self._compute_out_all_to_global(
value=global_v,
attn_probs=global_key_attn,
# free memory
del global_key_attn

attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
p_attn = self.dropout(attn)
# (batch, head, time, 2w + 1)

if self.global_tokens > 0:
# compute sum of global and local attn
out = self._compute_attn_output_with_global_indices(
value=v,
attn_probs=p_attn,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
w=w,
)
else:
# compute local attn only
out = self.sliding_chunks_matmul_pv(p_attn, v, w)

out = out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]

# compute outputs for global attention from global tokens to all
# (batch, max_num_global_attn_indices, head x head_dim)
if self.global_tokens > 0:
out_global_to_all = self._compute_out_global_to_all(
query=global_q,
key=global_k,
Expand All @@ -452,11 +459,11 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):
is_index_masked=mask,
)

out += out_all_to_global
# overwrite values with global attention
out[is_index_global_attn_nonzero] = out_global_to_all

out[is_index_global_attn_nonzero] += out_global_to_all
ret = self.linear_out(out)

ret = self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T])
if cache is None:
return ret
else:
Expand Down Expand Up @@ -544,24 +551,25 @@ def _compute_global_key_attn(

return attn_probs_from_global_key

def _compute_out_all_to_global(
def _compute_attn_output_with_global_indices(
self,
value: torch.Tensor,
attn_probs: torch.Tensor,
max_num_global_attn_indices: int,
is_index_global_attn_nonzero: tuple,
is_local_index_global_attn_nonzero: tuple,
w: int,
) -> torch.Tensor:
"""
Compute the attention output of all tokens attending to global.
Compute the attention output with global indices.
Args:
value (torch.Tensor): (batch, head, time, head_dim) The value vectors for global attention.
attn_probs (torch.Tensor): (batch, time, head, 2w) The attention probabilities.
max_num_global_attn_indices (int): Maximum number of global attention indices in the batch.
is_index_global_attn_nonzero (tuple): Indices of global attention (non-zero elements).
is_local_index_global_attn_nonzero (tuple): Non-padding values within global attention indices.
w (int): Local context size
Returns:
torch.Tensor: (batch, time, head x head_dim) The attention output of all tokens attending to global.
"""
Expand All @@ -573,12 +581,22 @@ def _compute_out_all_to_global(
value_vectors_only_global = value.new_zeros(batch_size, max_num_global_attn_indices, self.h, self.d_k)
value_vectors_only_global[is_local_index_global_attn_nonzero] = value[is_index_global_attn_nonzero]

# cut local attn probs to global only
attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
# compute attn output only global
out_all_to_global = torch.matmul(attn_probs, value_vectors_only_global.transpose(1, 2)).transpose(1, 2)
attn_output_only_global = torch.matmul(
attn_probs_only_global.clone(), value_vectors_only_global.transpose(1, 2).clone()
).transpose(1, 2)

# reshape attn probs
attn_probs_without_global = attn_probs.narrow(
-1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
).contiguous()

out_all_to_global = out_all_to_global.reshape(batch_size, time, -1)
# compute attn output with global
attn_output_without_global = self.sliding_chunks_matmul_pv(attn_probs_without_global, value.transpose(1, 2), w)

return out_all_to_global
return attn_output_only_global + attn_output_without_global

def _compute_out_global_to_all(
self,
Expand Down

0 comments on commit 2156be4

Please sign in to comment.