Skip to content

Commit

Permalink
[transformer] support multi query attention && multi goruped (wenet-e…
Browse files Browse the repository at this point in the history
…2e#2403)

* [transformer] support multi query attention

* fix dim

* fix dim

* fix comment and fix kv_head
  • Loading branch information
Mddct authored and srdfjy committed Jul 5, 2024
1 parent 9c1835f commit b6b2a40
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 21 deletions.
185 changes: 169 additions & 16 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Multi-Head Attention layer definition."""

import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -26,6 +26,14 @@

class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
if n_kv_head != None and n_kv_head != n_head
see: https://arxiv.org/pdf/1911.02150.pdf
https://arxiv.org/pdf/2305.13245.pdf
Example:
case 1: n_kv_head == None, head_dim == None, MultiHead attention (MHSA)
case 2: n_kv_head=1, n_head = 16, MultiQuery attention (MQA)
case 3: nv_kv_head=2, n_head = 16, GroupedQuery attention (GQA)
Args:
n_head (int): The number of heads.
Expand All @@ -41,17 +49,30 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0

self.inner_dim = n_feat if head_dim is None else head_dim * n_head
if n_kv_head is not None:
assert head_dim is not None
self.inner_kv_dim = head_dim * n_kv_head
n_kv_head = n_kv_head
else:
self.inner_kv_dim = self.inner_dim
n_kv_head = n_head
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
assert self.d_k == self.inner_kv_dim // n_kv_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat, bias=query_bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat, bias=value_bias)
self.linear_out = nn.Linear(n_feat, n_feat)
self.h_kv = n_kv_head

self.linear_q = nn.Linear(n_feat, self.inner_dim, bias=query_bias)
self.linear_k = nn.Linear(n_feat, self.inner_kv_dim, bias=key_bias)
self.linear_v = nn.Linear(n_feat, self.inner_kv_dim, bias=value_bias)
self.linear_out = nn.Linear(self.inner_dim, n_feat, bias=query_bias)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
Expand All @@ -61,16 +82,21 @@ def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor:
assert x.ndim >= 3
if name == 'query':
x = self.linear_q(x)
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h, self.d_k])
elif name == 'key':
x = self.linear_k(x)
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h_kv, self.d_k])
else:
assert name == 'value'
x = self.linear_v(x)
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h_kv, self.d_k])

# split last dim
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h, self.d_k])
x = x.view(x_shape)
x = x.transpose(-3, -2) # (batch, ..., head, time, d_k)
x = x.transpose(-3, -2) # (batch, ..., head or head_kv, time, d_k)
return x

def forward_qkv(
Expand All @@ -87,9 +113,9 @@ def forward_qkv(
torch.Tensor: Transformed query tensor, size
(#batch, ..., n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, ..., n_head, time2, d_k).
(#batch, ..., n_head_kv, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, ..., n_head, time2, d_k).
(#batch, ..., n_head_kv, time2, d_k).
"""
q = self._forward_linearx('query', query)
Expand Down Expand Up @@ -210,6 +236,19 @@ def forward(
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

# for multi query or multi group attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down Expand Up @@ -244,10 +283,12 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
value_bias, use_sdpa, n_kv_head, head_dim)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
Expand Down Expand Up @@ -335,10 +376,24 @@ def forward(
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)

# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
Expand Down Expand Up @@ -395,9 +450,11 @@ def __init__(self,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False):
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa)
value_bias, use_sdpa, n_kv_head, head_dim)

def forward(
self,
Expand All @@ -418,6 +475,19 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1)

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

B = query.size(0)
Beams = 1
if B != k.size(0):
Expand Down Expand Up @@ -451,3 +521,86 @@ def forward(
output_shape = torch.Size([B * Beams]) + output.size()[2:]
output = output.view(output_shape)
return output, new_cache


class ShawRelPositionMultiHeadedAttention(MultiHeadedAttention):
""" https://arxiv.org/pdf/1803.02155.pdf
"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
del n_kv_head, head_dim
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, None, None)
# TODO(Mddct): 64 8 1 as args
self.max_right_rel_pos = 64
self.max_left_rel_pos = 8
self.rel_k_embed = torch.nn.Embedding(
self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k)

def _relative_indices(self, length: int, device: torch.device):
indices = torch.arange(length, device=device).unsqueeze(0)
rel_indices = indices - indices.transpose(0, 1)
rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos,
self.max_right_rel_pos)
return rel_indices + self.max_left_rel_pos

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
del pos_emb
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)

rel_k = self.rel_k_embed(
self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k)
rel_k = rel_k[-q.size(2):] # (t1, t2, d_k)
# b,h,t1,dk
rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k)
q_expand = q.unsqueeze(3) # (batch, h, t1, 1, d_k)
rel_att_weights = (rel_k * q_expand).sum(-1).squeeze(
-1) # (batch, h, t1, t2)

if not self.use_sdpa:
scores = (torch.matmul(q, k.transpose(-2, -1)) +
rel_att_weights) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
else:
# NOTE(Mddct): we need mask bias, not boolean mask
assert mask.dtype != torch.bool
mask = mask.unsqueeze(1)
# matrix_bd as a mask bias
mask = torch.where(mask == get_dtype_min(mask.dtype), mask,
rel_att_weights / math.sqrt(self.d_k))
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
14 changes: 11 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -114,11 +116,11 @@ def __init__(
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, query_bias, key_bias,
value_bias, use_sdpa),
value_bias, use_sdpa, n_kv_head, head_dim),
WENET_ATTENTION_CLASSES["crossattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa)
if src_attention else None,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim) if src_attention else None,
mlp_class(attention_dim, linear_units, dropout_rate,
activation, mlp_bias),
dropout_rate,
Expand Down Expand Up @@ -334,6 +336,8 @@ def __init__(
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):

super().__init__()
Expand All @@ -360,6 +364,8 @@ def __init__(
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)

self.right_decoder = TransformerDecoder(
Expand All @@ -384,6 +390,8 @@ def __init__(
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)

def forward(
Expand Down
11 changes: 9 additions & 2 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.utils.checkpoint as ckpt
Expand Down Expand Up @@ -375,6 +375,8 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
""" Construct TransformerEncoder
Expand All @@ -396,7 +398,8 @@ def __init__(
output_size,
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa),
value_bias, use_sdpa,
n_kv_head, head_dim),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
Expand Down Expand Up @@ -445,6 +448,8 @@ def __init__(
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
"""Construct ConformerEncoder
Expand Down Expand Up @@ -481,6 +486,8 @@ def __init__(
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)
# feed-forward module definition
positionwise_layer_args = (
Expand Down

0 comments on commit b6b2a40

Please sign in to comment.