Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl/wav2vec2] add more info #2035

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions wenet/ssl/wav2vec2/wav2vec2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,47 @@
from wenet.utils.mask import make_non_pad_mask


def _sample_negative_indices(features_shape: torch.Size,
def _sample_negative_indices(features_shape: Tuple,
num_negatives: int,
mask: Optional[torch.Tensor] = None):
device: torch.device,
mask_time_indices: Optional[torch.Tensor] = None):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, _ = features_shape
assert sequence_length > 1
batch_size, sequence_length = features_shape

sequence_length_range = torch.arange(sequence_length)

# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = []
sampled_negative_indices = torch.zeros(
(batch_size, sequence_length, num_negatives),
dtype=sequence_length_range.dtype,
device=device)

mask_time_indices = (mask_time_indices.bool() if mask_time_indices
is not None else torch.ones(features_shape,
dtype=torch.bool))

for batch_idx in range(batch_size):
high = mask[batch_idx].sum(
) - 1 if mask is not None else sequence_length - 1
sampled_indices_slice = torch.randint(0,
high,
size=(num_negatives *
sequence_length, ))
sampled_negative_indices.append(sampled_indices_slice)

sampled_negative_indices = torch.stack(sampled_negative_indices, dim=0).to(
torch.int32) # [B, num_negatives * sequence_length]

# generate indices of the positive vectors themselves,
# repeat them `num_negatives` times
feature_indices = torch.arange(sequence_length)[:, None].repeat(
1, num_negatives).flatten() # [B x num_negatives x sequence_length]

# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1

# correct for batch size
for batch_idx in range(1, batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[
mask_time_indices[batch_idx]]

feature_indices = torch.arange(high + 1).unsqueeze(1).expand(
high + 1, num_negatives)
sampled_indices = torch.randint(0,
high,
size=(high + 1, num_negatives))
sampled_indices[sampled_indices >= feature_indices] += 1

# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[
batch_idx]] = mapped_masked_indices[sampled_indices]

# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length

return sampled_negative_indices
return sampled_negative_indices.reshape(batch_size, -1)


def _compute_contrastive_loss(quantized_features: torch.Tensor,
Expand Down Expand Up @@ -241,29 +247,39 @@ def forward(
unmasked_xs, masks.squeeze(1), gumbel_temperature)

sampled_negative_indices = _sample_negative_indices(
xs.size(), self.num_negatives, masks.squeeze(1))
xs.size()[:-1], self.num_negatives, masked_masks.device,
masked_masks)

loss_contrastive = _compute_contrastive_loss(
quantized_features, out, sampled_negative_indices, masked_masks,
self.contrastive_logits_temp, self.num_negatives)
loss = loss_contrastive

# scale by sample size
# make sure that diversity loss is multiplied by `sample_size`
# since contrastive_loss is `sum`-reduced instead of averaged
sample_size = masked_masks.sum()
# higher codevector_perplexity leads to lower diversity loss
loss_diversity: Optional[torch.Tensor] = None
if self.diversity_weight != 0.0:
loss_diversity = (
self.num_codevector_groups * self.num_codevectors_per_group -
codevector_perplexity) / (self.num_codevectors_per_group *
self.num_codevector_groups)
loss_diversity = loss_diversity * sample_size
loss = loss + self.diversity_weight * loss_diversity
loss = loss / sample_size

features_pen: Optional[torch.Tensor] = None
if self.features_regularization_weight != 0.0:
features_pen = xs.pow(2).mean()
loss = loss + self.features_regularization_weight * features_pen

return {
"code_ppl": codevector_perplexity.detach(),
"features_l2": features_pen,
"loss": loss,
"losss_constrastive": loss_contrastive / sample_size,
"loss_diversity": loss_diversity,
}

Expand Down