Skip to content

Commit

Permalink
Use scaled_dot_product_attention in Wav2vec2/HuBERT's SelfAttention (#…
Browse files Browse the repository at this point in the history
…3253) (#3261)

Summary:
Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency.

Pull Request resolved: #3253

Reviewed By: mthrok

Differential Revision: D44800353

Pulled By: nateanl

fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
  • Loading branch information
nateanl authored Apr 11, 2023
1 parent d92216d commit e99de15
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def __init__(

self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = torch.nn.Dropout(dropout)
self.dropout = dropout
self.head_dim = head_dim

self.scaling = self.head_dim**-0.5
Expand Down Expand Up @@ -304,25 +304,14 @@ def forward(

shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
k = self.k_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd

# scale down q to avoid value overflow.
weights = (self.scaling * q) @ k # B, nH, L, L
if attention_mask is not None:
weights += attention_mask
# subtracting a constant value from the tensor won't change the output of softmax.
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
weights = weights - weights.max(dim=-1, keepdim=True)[0]

weights = torch.nn.functional.softmax(weights, dim=-1)
weights = self.dropout(weights)

output = weights @ v # B, nH, L, Hd
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)

output = self.out_proj(output)
dropout = self.dropout if self.training else 0.0
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
output = self.out_proj(attn_output)
return output, None # Necessary for compatibility with WavLMSelAttention


Expand Down

0 comments on commit e99de15

Please sign in to comment.