Skip to content

Commit

Permalink
[ssl/wav2vec2] fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 9, 2023
1 parent d616dfe commit 4db7a2a
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions wenet/ssl/wav2vec2/wav2vec2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,47 @@
from wenet.utils.mask import make_non_pad_mask


def _sample_negative_indices(features_shape: torch.Size,
def _sample_negative_indices(features_shape: Tuple,
num_negatives: int,
mask: Optional[torch.Tensor] = None):
device: torch.device,
mask_time_indices: Optional[torch.Tensor] = None):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, _ = features_shape
assert sequence_length > 1
batch_size, sequence_length = features_shape

sequence_length_range = torch.arange(sequence_length)

# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = []
sampled_negative_indices = torch.zeros(
(batch_size, sequence_length, num_negatives),
dtype=sequence_length_range.dtype,
device=device)

mask_time_indices = (mask_time_indices.bool() if mask_time_indices
is not None else torch.ones(features_shape,
dtype=torch.bool))

for batch_idx in range(batch_size):
high = mask[batch_idx].sum(
) - 1 if mask is not None else sequence_length - 1
sampled_indices_slice = torch.randint(0,
high,
size=(num_negatives *
sequence_length, ))
sampled_negative_indices.append(sampled_indices_slice)

sampled_negative_indices = torch.stack(sampled_negative_indices, dim=0).to(
torch.int32) # [B, num_negatives * sequence_length]

# generate indices of the positive vectors themselves,
# repeat them `num_negatives` times
feature_indices = torch.arange(sequence_length)[:, None].repeat(
1, num_negatives).flatten() # [B x num_negatives x sequence_length]

# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1

# correct for batch size
for batch_idx in range(1, batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[
mask_time_indices[batch_idx]]

feature_indices = torch.arange(high + 1).unsqueeze(1).expand(
high + 1, num_negatives)
sampled_indices = torch.randint(0,
high,
size=(high + 1, num_negatives))
sampled_indices[sampled_indices >= feature_indices] += 1

# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[
batch_idx]] = mapped_masked_indices[sampled_indices]

# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length

return sampled_negative_indices
return sampled_negative_indices.reshape(batch_size, -1)


def _compute_contrastive_loss(quantized_features: torch.Tensor,
Expand Down Expand Up @@ -241,7 +247,8 @@ def forward(
unmasked_xs, masks.squeeze(1), gumbel_temperature)

sampled_negative_indices = _sample_negative_indices(
xs.size(), self.num_negatives, masks.squeeze(1))
xs.size()[:-1], self.num_negatives, masked_masks.device,
masked_masks)

loss_contrastive = _compute_contrastive_loss(
quantized_features, out, sampled_negative_indices, masked_masks,
Expand Down

0 comments on commit 4db7a2a

Please sign in to comment.