diff --git a/test/wenet/transformer/test_attention.py b/test/wenet/transformer/test_attention.py index e1821cf3c..255f0352d 100644 --- a/test/wenet/transformer/test_attention.py +++ b/test/wenet/transformer/test_attention.py @@ -1,7 +1,8 @@ import torch import pytest -from wenet.transformer.attention import (MultiHeadedAttention, - RelPositionMultiHeadedAttention) +from wenet.transformer.attention import ( + IndicesRelPositionMultiHeadedAttention, MultiHeadedAttention, + RelPositionMultiHeadedAttention) from wenet.transformer.embedding import RelPositionalEncoding from wenet.transformer.encoder_layer import (ConformerEncoderLayer, TransformerEncoderLayer) @@ -221,3 +222,26 @@ def test_rel_position_multi_head_attention_sdpa(args): ) assert torch.allclose(cache, cache_with_sdpa) q = output + + +def test_indices_rel_position_multihead_attention(): + torch.manual_seed(777) + module = IndicesRelPositionMultiHeadedAttention(8, + 256, + 0.0, + use_sdpa=False) + + torch.manual_seed(777) + module_sdpa = IndicesRelPositionMultiHeadedAttention(8, + 256, + 0.0, + use_sdpa=True) + q = torch.rand(2, 10, 256) + k = torch.rand(2, 10, 256) + v = torch.rand(2, 10, 256) + pos_emb = torch.zeros(0, 0, 0) + mask = torch.ones(2, 10, 10) + out, _ = module(q, k, v, mask, pos_emb) + out_sdpa, _ = module_sdpa(q, k, v, mask, pos_emb) + + torch.allclose(out, out_sdpa) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index d644b98f2..15fbc4782 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -443,3 +443,79 @@ def forward( output_shape = torch.Size([B * Beams]) + output.size()[2:] output = output.view(output_shape) return output, new_cache + + +class IndicesRelPositionMultiHeadedAttention(MultiHeadedAttention): + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True, + use_sdpa: bool = False): + super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa) + # TODO(Mddct): 64 8 1 as args + self.max_right_rel_pos = 64 + self.max_left_rel_pos = 8 + self.rel_k_embed = torch.nn.Embedding( + self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k) + + def _relative_indices(self, length: int, device: torch.device): + indices = torch.arange(length, device=device).unsqueeze(0) + rel_indices = indices - indices.transpose(0, 1) + rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos, + self.max_right_rel_pos) + return rel_indices + self.max_left_rel_pos + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + del pos_emb + q, k, v = self.forward_qkv(query, key, value) + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + new_cache = torch.cat((k, v), dim=-1) + + rel_k = self.rel_k_embed( + self._relative_indices(key.size(1), query.device)) # (t2, t2, d_k) + rel_k = rel_k[-q.size(2):] # (t1, t2, d_k) + # b,h,t1,dk + rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k) + q_expand = q.unsqueeze(3) # (batch, h, t1, 1, d_k) + rel_att_weights = (rel_k * q_expand).sum(-1).squeeze( + -1) # (batch, h, t1, t2) + print(rel_att_weights.size()) + + if not self.use_sdpa: + scores = (torch.matmul(q, k.transpose(-2, -1)) + + rel_att_weights) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + else: + # NOTE(Mddct): we need mask bias, not boolean mask + assert mask.dtype != torch.bool + mask = mask.unsqueeze(1) + # matrix_bd as a mask bias + mask = torch.where(mask == get_dtype_min(mask.dtype), mask, + rel_att_weights / math.sqrt(self.d_k)) + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout_rate, + scale=1 / math.sqrt(self.d_k), + ) + output = (output.transpose(1, 2).contiguous().view( + query.size(0), -1, + self.h * self.d_k)) # (batch, time1, d_model) + return self.linear_out(output), new_cache diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py index 37588476a..f87a1f034 100644 --- a/wenet/transformer/subsampling.py +++ b/wenet/transformer/subsampling.py @@ -18,6 +18,8 @@ import torch +from wenet.utils.mask import make_non_pad_mask + class BaseSubsampling(torch.nn.Module): @@ -332,3 +334,56 @@ def forward( x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] + + +class StackNFramesSubsampling(BaseSubsampling): + + def __init__(self, + idim: int, + odim: int, + dropout_rate: float, + pos_enc_class: torch.nn.Module, + stride: int = 2): + + super().__init__() + del dropout_rate + self.pos_enc_class = pos_enc_class + self.stride = stride + self.idim = idim + + self.layer_norm = torch.nn.LayerNorm(idim * stride, eps=1e-5) + self.out = torch.nn.Linear(idim * stride, odim) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // stride. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // stride. + torch.Tensor: positional encoding + """ + b, s, _ = x.size() + seq_len = x_mask.sum(-1).view(b) + r = s % self.stride + s -= r + x = x[:, :s, :] + seq_len = torch.where(seq_len > s, s, seq_len) + num_frames = s // self.stride + x = x.view(b, num_frames, self.idim * self.stride) + new_mask = make_non_pad_mask(seq_len) + _, pos_emb = self.pos_enc(x, offset) + x = self.layer_norm(x) + _, pos_emb = self.pos_enc_class(x, offset) + x = self.out(x) + return x, pos_emb, new_mask