diff --git a/torchaudio/models/wav2vec2/wavlm_attention.py b/torchaudio/models/wav2vec2/wavlm_attention.py index 4fc723f78a..fafddfeb95 100644 --- a/torchaudio/models/wav2vec2/wavlm_attention.py +++ b/torchaudio/models/wav2vec2/wavlm_attention.py @@ -73,6 +73,7 @@ def __init__( self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.dropout = dropout self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True) self.gru_rel_pos = gru_rel_pos @@ -165,7 +166,7 @@ def forward( if self.rel_attn_embed is not None and position_bias is None: position_bias = self.compute_bias(seq_len, seq_len) - position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, seq_len, seq_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1) attn_mask_rel_pos: Optional[Tensor] = None if position_bias is not None: @@ -178,11 +179,36 @@ def forward( self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False) ).chunk(2, dim=-1) gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 - attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias - - attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) - - attn_output, _ = self.attention( - query, query, query, key_padding_mask=key_padding_mask, attn_mask=attn_mask_rel_pos, need_weights=False + attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len)) + + if attn_mask_rel_pos is not None and key_padding_mask is not None: + key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1) + key_padding_mask = torch.nn.functional._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos), + other_name="", + target_type=query.dtype, + ) + if attn_mask_rel_pos is not None and key_padding_mask is not None: + attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask + query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias) + query, key, value = query_projected.chunk(3, -1) + shape = (bsz, seq_len, self.num_heads, self.head_dim) + query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim) + dropout = self.dropout if self.training else 0.0 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask_rel_pos, + dropout_p=dropout, + is_causal=False, ) + attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim) + attn_output = self.attention.out_proj(attn_output) return attn_output, position_bias