Skip to content

Commit

Permalink
[ssl/w2vbert] weight copy from meta w2vbert-2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 7, 2024
1 parent 1a6dcfe commit f3e9bf1
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 2 deletions.
28 changes: 26 additions & 2 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import pytest
from wenet.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from wenet.transformer.attention import (
IndicesRelPositionMultiHeadedAttention, MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
Expand Down Expand Up @@ -221,3 +222,26 @@ def test_rel_position_multi_head_attention_sdpa(args):
)
assert torch.allclose(cache, cache_with_sdpa)
q = output


def test_indices_rel_position_multihead_attention():
torch.manual_seed(777)
module = IndicesRelPositionMultiHeadedAttention(8,
256,
0.0,
use_sdpa=False)

torch.manual_seed(777)
module_sdpa = IndicesRelPositionMultiHeadedAttention(8,
256,
0.0,
use_sdpa=True)
q = torch.rand(2, 10, 256)
k = torch.rand(2, 10, 256)
v = torch.rand(2, 10, 256)
pos_emb = torch.zeros(0, 0, 0)
mask = torch.ones(2, 10, 10)
out, _ = module(q, k, v, mask, pos_emb)
out_sdpa, _ = module_sdpa(q, k, v, mask, pos_emb)

torch.allclose(out, out_sdpa)
76 changes: 76 additions & 0 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,79 @@ def forward(
output_shape = torch.Size([B * Beams]) + output.size()[2:]
output = output.view(output_shape)
return output, new_cache


class IndicesRelPositionMultiHeadedAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)
# TODO(Mddct): 64 8 1 as args
self.max_right_rel_pos = 64
self.max_left_rel_pos = 8
self.rel_k_embed = torch.nn.Embedding(
self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k)

def _relative_indices(self, length: int, device: torch.device):
indices = torch.arange(length, device=device).unsqueeze(0)
rel_indices = indices - indices.transpose(0, 1)
rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos,
self.max_right_rel_pos)
return rel_indices + self.max_left_rel_pos

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
del pos_emb
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)

rel_k = self.rel_k_embed(
self._relative_indices(key.size(1), query.device)) # (t2, t2, d_k)
rel_k = rel_k[-q.size(2):] # (t1, t2, d_k)
# b,h,t1,dk
rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k)
q_expand = q.unsqueeze(3) # (batch, h, t1, 1, d_k)
rel_att_weights = (rel_k * q_expand).sum(-1).squeeze(
-1) # (batch, h, t1, t2)
print(rel_att_weights.size())

if not self.use_sdpa:
scores = (torch.matmul(q, k.transpose(-2, -1)) +
rel_att_weights) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
else:
# NOTE(Mddct): we need mask bias, not boolean mask
assert mask.dtype != torch.bool
mask = mask.unsqueeze(1)
# matrix_bd as a mask bias
mask = torch.where(mask == get_dtype_min(mask.dtype), mask,
rel_att_weights / math.sqrt(self.d_k))
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
55 changes: 55 additions & 0 deletions wenet/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import torch

from wenet.utils.mask import make_non_pad_mask


class BaseSubsampling(torch.nn.Module):

Expand Down Expand Up @@ -332,3 +334,56 @@ def forward(
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]


class StackNFramesSubsampling(BaseSubsampling):

def __init__(self,
idim: int,
odim: int,
dropout_rate: float,
pos_enc_class: torch.nn.Module,
stride: int = 2):

super().__init__()
del dropout_rate
self.pos_enc_class = pos_enc_class
self.stride = stride
self.idim = idim

self.layer_norm = torch.nn.LayerNorm(idim * stride, eps=1e-5)
self.out = torch.nn.Linear(idim * stride, odim)

def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // stride.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // stride.
torch.Tensor: positional encoding
"""
b, s, _ = x.size()
seq_len = x_mask.sum(-1).view(b)
r = s % self.stride
s -= r
x = x[:, :s, :]
seq_len = torch.where(seq_len > s, s, seq_len)
num_frames = s // self.stride
x = x.view(b, num_frames, self.idim * self.stride)
new_mask = make_non_pad_mask(seq_len)
_, pos_emb = self.pos_enc(x, offset)
x = self.layer_norm(x)
_, pos_emb = self.pos_enc_class(x, offset)
x = self.out(x)
return x, pos_emb, new_mask

0 comments on commit f3e9bf1

Please sign in to comment.