Skip to content

Commit

Permalink
[ssl/wav2vec2] add more info
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 8, 2023
1 parent a1b0a29 commit d616dfe
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions wenet/ssl/wav2vec2/wav2vec2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,22 +248,31 @@ def forward(
self.contrastive_logits_temp, self.num_negatives)
loss = loss_contrastive

# scale by sample size
# make sure that diversity loss is multiplied by `sample_size`
# since contrastive_loss is `sum`-reduced instead of averaged
sample_size = masked_masks.sum()
# higher codevector_perplexity leads to lower diversity loss
loss_diversity: Optional[torch.Tensor] = None
if self.diversity_weight != 0.0:
loss_diversity = (
self.num_codevector_groups * self.num_codevectors_per_group -
codevector_perplexity) / (self.num_codevectors_per_group *
self.num_codevector_groups)
loss_diversity = loss_diversity * sample_size
loss = loss + self.diversity_weight * loss_diversity
loss = loss / sample_size

features_pen: Optional[torch.Tensor] = None
if self.features_regularization_weight != 0.0:
features_pen = xs.pow(2).mean()
loss = loss + self.features_regularization_weight * features_pen

return {
"code_ppl": codevector_perplexity.detach(),
"features_l2": features_pen,
"loss": loss,
"losss_constrastive": loss_contrastive / sample_size,
"loss_diversity": loss_diversity,
}

Expand Down

0 comments on commit d616dfe

Please sign in to comment.