Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl] bestrq support #1750

Merged
merged 2 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions wenet/ssl/bestrq/bestqr_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Optional, Tuple
import torch

from wenet.ssl.bestrq.mask import compute_mask_indices
from wenet.utils.mask import make_pad_mask


class BestRQModel(torch.nn.Module):

def __init__(
self,
encoder: torch.nn.Module,
input_dim: int = 256,
embedding_dim: int = 256,
num_embeddings: int = 8192,
dropout_rate: float = 0.1,
mask_prob: float = 0.01,
mask_length: int = 10,
min_masks: int = 2,
layer_norm_epsilon=1e-5,
) -> None:
super().__init__()

assert mask_prob > 0.0

self.mask_prob = mask_prob
# NOTE: should filter audio less than mask_length
self.mask_length = mask_length
self.min_masks = min_masks

self.input_dropout = torch.nn.Dropout(dropout_rate)

# [embedding_dim, num_embeddings]
random_embedding_weight = torch.empty(embedding_dim,
num_embeddings,
requires_grad=False)
self.embeddings = torch.nn.init.normal_(random_embedding_weight)

random_projection_weight = torch.empty(input_dim,
embedding_dim,
requires_grad=False)
self.projection = torch.nn.init.xavier_normal_(
random_projection_weight)

mask_emb_weight = torch.Tensor(input_dim)
mask_emb_weight.requires_grad = True
self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1)

self.input_layer_norm = torch.nn.LayerNorm(input_dim,
layer_norm_epsilon)
self.encoder = encoder
self.encoder_top_linear = torch.nn.Linear(self.encoder.output_size(),
num_embeddings)

def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
text: Optional[torch.Tensor] = None,
text_length: Optional[torch.Tensor] = None,
):
# should support nonstreamming and streamming
# TODO(Mddct): streamming future
# eg: full attenton and chunk or dynamic chunk training
# 1 forward subsampling
xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens)
unmasked_xs = xs
# 2 mask features
# 2.0 apply mask
masked_xs, masked_masks = self._apply_mask(xs)
# 2.1 get nearest embedding
target_ids = self._nearest_embedding_idx(unmasked_xs)
# 3 forward xxx-formaer block
out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb,
masks)
# 4 get logits
out = self.encoder_top_linear(out) # [B, T', num_embedding]

# 5 compute loss
loss = self._compute_loss(out, target_ids,
out_mask.squeeze(1) * masked_masks)
return {"loss": loss}

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor):
input = input.transpose(1, 2) # [B,C,T]
entropy = torch.nn.functional.cross_entropy(input,
target,
reduction='none') # [B,T]
# stop gradient for non mask area
loss = entropy * mask
return loss.sum() / loss.size(0)

def _forward_encoder_blocks(self, xs: torch.Tensor, xs_masks: torch.Tensor,
pos_emb: torch.Tensor, mask_pad: torch.Tensor):

masks = xs_masks
for layer in self.encoder.encoders:
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
if self.encoder.normalize_before:
xs = self.encoder.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.input_layer_norm(xs)
xs = self.input_dropout(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))

B, T, C = xs.size()
flattened_input = xs.view(-1, C)
embeddings = self.embeddings.to(xs.device)
distance = (torch.sum(flattened_input**2, dim=1, keepdim=True) +
torch.sum(embeddings**2, dim=0, keepdim=False) -
2 * torch.matmul(flattened_input, embeddings))

out = torch.argmin(distance, dim=-1)
return out.reshape(B, T)

def _apply_mask(self,
xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
masks = compute_mask_indices(xs.size()[:-1],
self.mask_prob,
self.mask_length,
self.min_masks,
device=xs.device)
masks_expand = masks.unsqueeze(-1) # [B, T, 1]

mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1)
xs = torch.where(masks_expand, mask_emb, xs)
return xs, masks

def _forward_subsampling(
self, xs: torch.Tensor, xs_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.encoder.global_cmvn is not None:
xs = self.encoder.global_cmvn(xs)
xs, pos_emb, masks = self.encoder.embed(xs, masks)
return xs, pos_emb, masks
54 changes: 54 additions & 0 deletions wenet/ssl/bestrq/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch


def _sampler(pdf: torch.Tensor, num_samples: int,
device=torch.device('cpu')) -> torch.Tensor:
size = pdf.size()
z = -torch.log(torch.rand(size, device=device))
_, indices = torch.topk(pdf + z, num_samples)
return indices


def compute_mask_indices(
size: torch.Size,
mask_prob: float,
mask_length: int,
min_masks: int = 0,
device=torch.device('cpu'),
) -> torch.Tensor:

assert len(size) == 2
batch_size, seq_length = size

# compute number of masked span in batch
num_masked_spans = mask_prob * float(seq_length) / float(
mask_length) + torch.rand(1)[0]
num_masked_spans = int(num_masked_spans)
num_masked_spans = max(num_masked_spans, min_masks)

# num_masked <= seq_length
if num_masked_spans * mask_length > seq_length:
num_masked_spans = seq_length // mask_length

pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device)
mask_idxs = _sampler(pdf, num_masked_spans, device=device)

mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view(
batch_size,
num_masked_spans * mask_length) # [B,num_masked_spans*mask_length]

offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat(
1, num_masked_spans, 1) # [1,num_masked_spans,mask_length]
offset = offset.view(1, num_masked_spans * mask_length)

mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length]

ones = torch.ones(batch_size,
seq_length,
dtype=torch.bool,
device=mask_idxs.device)
# masks to fill
full_mask = torch.zeros_like(ones,
dtype=torch.bool,
device=mask_idxs.device)
return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones)