From 4db7a2a4747b861b780e5659fe80027e0651f18f Mon Sep 17 00:00:00 2001 From: Mddct Date: Sun, 8 Oct 2023 20:09:01 +0800 Subject: [PATCH] [ssl/wav2vec2] fix lint --- wenet/ssl/wav2vec2/wav2vec2_model.py | 63 +++++++++++++++------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/wenet/ssl/wav2vec2/wav2vec2_model.py b/wenet/ssl/wav2vec2/wav2vec2_model.py index 80347529c..09d059e0c 100644 --- a/wenet/ssl/wav2vec2/wav2vec2_model.py +++ b/wenet/ssl/wav2vec2/wav2vec2_model.py @@ -12,41 +12,47 @@ from wenet.utils.mask import make_non_pad_mask -def _sample_negative_indices(features_shape: torch.Size, +def _sample_negative_indices(features_shape: Tuple, num_negatives: int, - mask: Optional[torch.Tensor] = None): + device: torch.device, + mask_time_indices: Optional[torch.Tensor] = None): """ Sample `num_negatives` vectors from feature vectors. """ - batch_size, sequence_length, _ = features_shape - assert sequence_length > 1 + batch_size, sequence_length = features_shape + + sequence_length_range = torch.arange(sequence_length) + # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = [] + sampled_negative_indices = torch.zeros( + (batch_size, sequence_length, num_negatives), + dtype=sequence_length_range.dtype, + device=device) + + mask_time_indices = (mask_time_indices.bool() if mask_time_indices + is not None else torch.ones(features_shape, + dtype=torch.bool)) + for batch_idx in range(batch_size): - high = mask[batch_idx].sum( - ) - 1 if mask is not None else sequence_length - 1 - sampled_indices_slice = torch.randint(0, - high, - size=(num_negatives * - sequence_length, )) - sampled_negative_indices.append(sampled_indices_slice) - - sampled_negative_indices = torch.stack(sampled_negative_indices, dim=0).to( - torch.int32) # [B, num_negatives * sequence_length] - - # generate indices of the positive vectors themselves, - # repeat them `num_negatives` times - feature_indices = torch.arange(sequence_length)[:, None].repeat( - 1, num_negatives).flatten() # [B x num_negatives x sequence_length] - - # avoid sampling the same positive vector, but keep the distribution uniform - sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 - - # correct for batch size - for batch_idx in range(1, batch_size): + high = mask_time_indices[batch_idx].sum() - 1 + mapped_masked_indices = sequence_length_range[ + mask_time_indices[batch_idx]] + + feature_indices = torch.arange(high + 1).unsqueeze(1).expand( + high + 1, num_negatives) + sampled_indices = torch.randint(0, + high, + size=(high + 1, num_negatives)) + sampled_indices[sampled_indices >= feature_indices] += 1 + + # remap to actual indices + sampled_negative_indices[batch_idx][mask_time_indices[ + batch_idx]] = mapped_masked_indices[sampled_indices] + + # correct for batch size sampled_negative_indices[batch_idx] += batch_idx * sequence_length - return sampled_negative_indices + return sampled_negative_indices.reshape(batch_size, -1) def _compute_contrastive_loss(quantized_features: torch.Tensor, @@ -241,7 +247,8 @@ def forward( unmasked_xs, masks.squeeze(1), gumbel_temperature) sampled_negative_indices = _sample_negative_indices( - xs.size(), self.num_negatives, masks.squeeze(1)) + xs.size()[:-1], self.num_negatives, masked_masks.device, + masked_masks) loss_contrastive = _compute_contrastive_loss( quantized_features, out, sampled_negative_indices, masked_masks,