From b2458ba007d7601bcd951210d5d39f30d523c7b8 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 14 Mar 2023 23:15:35 +0800 Subject: [PATCH 1/2] [ssl] bestrq support --- wenet/ssl/bestrq/bestqr_model.py | 144 +++++++++++++++++++++++++++++++ wenet/ssl/bestrq/mask.py | 53 ++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 wenet/ssl/bestrq/bestqr_model.py create mode 100644 wenet/ssl/bestrq/mask.py diff --git a/wenet/ssl/bestrq/bestqr_model.py b/wenet/ssl/bestrq/bestqr_model.py new file mode 100644 index 000000000..7ecf3c642 --- /dev/null +++ b/wenet/ssl/bestrq/bestqr_model.py @@ -0,0 +1,144 @@ +from typing import Optional, Tuple +import torch + +from wenet.ssl.bestrq.mask import compute_mask_indices +from wenet.utils.mask import make_pad_mask + + +class BestRQModel(torch.nn.Module): + + def __init__( + self, + encoder: torch.nn.Module, + input_dim: int = 256, + embedding_dim: int = 256, + num_embeddings: int = 8192, + dropout_rate: float = 0.1, + mask_prob: float = 0.01, + mask_length: int = 10, + min_masks: int = 2, + layer_norm_epsilon=1e-5, + ) -> None: + super().__init__() + + assert mask_prob > 0.0 + + self.mask_prob = mask_prob + # NOTE: should filter audio less than mask_length + self.mask_length = mask_length + self.min_masks = min_masks + + self.input_dropout = torch.nn.Dropout(dropout_rate) + + # [embedding_dim, num_embeddings] + random_embedding_weight = torch.empty(embedding_dim, + num_embeddings, + requires_grad=False) + self.embeddings = torch.nn.init.normal_(random_embedding_weight) + + random_projection_weight = torch.empty(input_dim, + embedding_dim, + requires_grad=False) + self.projection = torch.nn.init.xavier_normal_( + random_projection_weight) + + mask_emb_weight = torch.Tensor(input_dim) + mask_emb_weight.requires_grad = True + self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1) + + self.input_layer_norm = torch.nn.LayerNorm(input_dim, + layer_norm_epsilon) + self.encoder = encoder + self.encoder_top_linear = torch.nn.Linear(self.encoder.output_size(), + num_embeddings) + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + text: Optional[torch.Tensor] = None, + text_length: Optional[torch.Tensor] = None, + ): + # should support nonstreamming and streamming + # TODO(Mddct): streamming future + # eg: full attenton and chunk or dynamic chunk training + # 1 forward subsampling + xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens) + unmasked_xs = xs + # 2 mask features + # 2.0 apply mask + masked_xs, masked_masks = self._apply_mask(xs) + # 2.1 get nearest embedding + target_ids = self._nearest_embedding_idx(unmasked_xs) + # 3 forward xxx-formaer block + out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, + masks) + # 4 get logits + out = self.encoder_top_linear(out) # [B, T', num_embedding] + + # 5 compute loss + loss = self._compute_loss(out, target_ids, + out_mask.squeeze(1) * masked_masks) + return {"loss": loss} + + def _compute_loss(self, input: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor): + input = input.transpose(1, 2) # [B,C,T] + entropy = torch.nn.functional.cross_entropy(input, + target, + reduction='none') # [B,T] + # stop gradient for non mask area + loss = entropy * mask + return loss.sum() / loss.size(0) + + def _forward_encoder_blocks(self, xs: torch.Tensor, xs_masks: torch.Tensor, + pos_emb: torch.Tensor, mask_pad: torch.Tensor): + + masks = xs_masks + for layer in self.encoder.encoders: + xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad) + if self.encoder.normalize_before: + xs = self.encoder.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: + xs = self.input_layer_norm(xs) + xs = self.input_dropout(xs) + xs = torch.matmul(xs, self.projection.to(xs.device)) + + B, T, C = xs.size() + flattened_input = xs.view(-1, C) + embeddings = self.embeddings.to(xs.device) + distance = (torch.sum(flattened_input**2, dim=1, keepdim=True) + + torch.sum(embeddings**2, dim=0, keepdim=False) - + 2 * torch.matmul(flattened_input, embeddings)) + + out = torch.argmin(distance, dim=-1) + return out.reshape(B, T) + + def _apply_mask(self, + xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + masks = compute_mask_indices(xs.size()[:-1], + self.mask_prob, + self.mask_length, + self.min_masks, + device=xs.device) + masks_expand = masks.unsqueeze(-1) # [B, T, 1] + + mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1) + xs = torch.where(masks_expand, mask_emb, xs) + return xs, masks + + def _forward_subsampling( + self, xs: torch.Tensor, xs_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.encoder.global_cmvn is not None: + xs = self.encoder.global_cmvn(xs) + xs, pos_emb, masks = self.encoder.embed(xs, masks) + return xs, pos_emb, masks diff --git a/wenet/ssl/bestrq/mask.py b/wenet/ssl/bestrq/mask.py new file mode 100644 index 000000000..f5e62eea0 --- /dev/null +++ b/wenet/ssl/bestrq/mask.py @@ -0,0 +1,53 @@ +import torch + + +def _sampler(pdf: torch.Tensor, num_samples: int, + device=torch.device('cpu')) -> torch.Tensor: + size = pdf.size() + z = -torch.log(torch.rand(size, device=device)) + _, indices = torch.topk(pdf + z, num_samples) + return indices + + +def compute_mask_indices( + size: torch.Size, + mask_prob: float, + mask_length: int, + min_masks: int = 0, + device=torch.device('cpu')) -> torch.Tensor: + + assert len(size) == 2 + batch_size, seq_length = size + + # compute number of masked span in batch + num_masked_spans = mask_prob * float(seq_length) / float( + mask_length) + torch.rand(1)[0] + num_masked_spans = int(num_masked_spans) + num_masked_spans = max(num_masked_spans, min_masks) + + # num_masked <= seq_length + if num_masked_spans * mask_length > seq_length: + num_masked_spans = seq_length // mask_length + + pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) + mask_idxs = _sampler(pdf, num_masked_spans, device=device) + + mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( + batch_size, + num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] + + offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( + 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] + offset = offset.view(1, num_masked_spans * mask_length) + + mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] + + ones = torch.ones(batch_size, + seq_length, + dtype=torch.bool, + device=mask_idxs.device) + # masks to fill + full_mask = torch.zeros_like(ones, + dtype=torch.bool, + device=mask_idxs.device) + return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones) From 32fdf4e07c24560b0156eff258b9d14512d32c29 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 15 Mar 2023 10:59:36 +0800 Subject: [PATCH 2/2] fix lint --- wenet/ssl/bestrq/mask.py | 107 ++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 53 deletions(-) diff --git a/wenet/ssl/bestrq/mask.py b/wenet/ssl/bestrq/mask.py index f5e62eea0..c8905953e 100644 --- a/wenet/ssl/bestrq/mask.py +++ b/wenet/ssl/bestrq/mask.py @@ -1,53 +1,54 @@ -import torch - - -def _sampler(pdf: torch.Tensor, num_samples: int, - device=torch.device('cpu')) -> torch.Tensor: - size = pdf.size() - z = -torch.log(torch.rand(size, device=device)) - _, indices = torch.topk(pdf + z, num_samples) - return indices - - -def compute_mask_indices( - size: torch.Size, - mask_prob: float, - mask_length: int, - min_masks: int = 0, - device=torch.device('cpu')) -> torch.Tensor: - - assert len(size) == 2 - batch_size, seq_length = size - - # compute number of masked span in batch - num_masked_spans = mask_prob * float(seq_length) / float( - mask_length) + torch.rand(1)[0] - num_masked_spans = int(num_masked_spans) - num_masked_spans = max(num_masked_spans, min_masks) - - # num_masked <= seq_length - if num_masked_spans * mask_length > seq_length: - num_masked_spans = seq_length // mask_length - - pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) - mask_idxs = _sampler(pdf, num_masked_spans, device=device) - - mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( - batch_size, - num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] - - offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( - 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] - offset = offset.view(1, num_masked_spans * mask_length) - - mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] - - ones = torch.ones(batch_size, - seq_length, - dtype=torch.bool, - device=mask_idxs.device) - # masks to fill - full_mask = torch.zeros_like(ones, - dtype=torch.bool, - device=mask_idxs.device) - return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones) +import torch + + +def _sampler(pdf: torch.Tensor, num_samples: int, + device=torch.device('cpu')) -> torch.Tensor: + size = pdf.size() + z = -torch.log(torch.rand(size, device=device)) + _, indices = torch.topk(pdf + z, num_samples) + return indices + + +def compute_mask_indices( + size: torch.Size, + mask_prob: float, + mask_length: int, + min_masks: int = 0, + device=torch.device('cpu'), +) -> torch.Tensor: + + assert len(size) == 2 + batch_size, seq_length = size + + # compute number of masked span in batch + num_masked_spans = mask_prob * float(seq_length) / float( + mask_length) + torch.rand(1)[0] + num_masked_spans = int(num_masked_spans) + num_masked_spans = max(num_masked_spans, min_masks) + + # num_masked <= seq_length + if num_masked_spans * mask_length > seq_length: + num_masked_spans = seq_length // mask_length + + pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) + mask_idxs = _sampler(pdf, num_masked_spans, device=device) + + mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( + batch_size, + num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] + + offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( + 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] + offset = offset.view(1, num_masked_spans * mask_length) + + mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] + + ones = torch.ones(batch_size, + seq_length, + dtype=torch.bool, + device=mask_idxs.device) + # masks to fill + full_mask = torch.zeros_like(ones, + dtype=torch.bool, + device=mask_idxs.device) + return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones)