Skip to content

Commit 35acffc

Browse files
authored
[ssl/w2vbert] support part of w2vbert training (#2039)
* [ssl/w2vbert] support part of w2vbert training * [ssl/wa2vbert] fix typo * [ssl/wa2vbert] fix bias * [ssl/wa2vbert] add mlm weight scale * fix typo
1 parent 3790509 commit 35acffc

File tree

3 files changed

+335
-11
lines changed

3 files changed

+335
-11
lines changed

wenet/ssl/w2vbert/w2vbert_model.py

+320
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
import math
2+
from typing import Optional, Tuple, Union
3+
import torch
4+
5+
from wenet.ssl.bestrq.mask import compute_mask_indices_v2
6+
from wenet.ssl.wav2vec2.quantizer import Wav2vecGumbelVectorQuantizer
7+
from wenet.ssl.wav2vec2.wav2vec2_model import (_compute_contrastive_loss,
8+
_sample_negative_indices)
9+
from wenet.transformer.attention import RelPositionMultiHeadedAttention
10+
11+
from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder
12+
from wenet.transformer.encoder_layer import ConformerEncoderLayer
13+
from wenet.utils.mask import make_non_pad_mask
14+
15+
16+
class W2VBERTModel(torch.nn.Module):
17+
18+
def __init__(
19+
self,
20+
encoder: Union[ConformerEncoder, TransformerEncoder],
21+
embedding_dim: int = 256,
22+
num_embeddings: int = 320,
23+
num_codebooks: int = 1,
24+
mask_prob: float = 0.065,
25+
mask_length: int = 10,
26+
min_masks: int = 2,
27+
num_negatives: int = 100,
28+
features_regularization_weight: float = 0.01,
29+
max_gumbel_temperature: float = 2.0,
30+
min_gumbel_temperature: float = 0.1,
31+
gumbel_temperature_decay: float = 0.999995,
32+
contrastive_logits_temperature: float = 0.1,
33+
diversity_weight: float = 0.0,
34+
bias: bool = True,
35+
contrastive_blocks: int = 6,
36+
masked_blocks: int = 6,
37+
contrastive_weight: float = 1.0,
38+
mlm_weight: float = 1.0,
39+
warmup_steps: int = 25000,
40+
) -> None:
41+
""" Wrap encoder to train using W2V-BERT's style
42+
43+
Described in:
44+
https://arxiv.org/pdf/2108.06209v2.pdf
45+
46+
Args:
47+
encoder: wenet's encoder,
48+
only support conformer and transformer now
49+
embedding_dim: codebooks embedding dim
50+
num_embeddings: numbers of each codebook
51+
num_codebooks: numbers of codebooks i.e groups of codebook
52+
mask_prob: probs of mask
53+
mask_length: spans of masks
54+
min_masks: min masks for each audio
55+
num_negatives: numbers of negatives of each masks
56+
features_regularization_weight: l2 regularization weight
57+
max_gumbel_temperature: maximum temperature for gumbel softmax
58+
min_gumbel_temperature: minimum temperature for gumbel softmax
59+
gumbel_temperature_decay:
60+
decay of gumbel temperature during training
61+
contrastive_logits_temperature:
62+
the temperature in the contrastive loss.
63+
"""
64+
super().__init__()
65+
assert mask_prob > 0.0
66+
assert (contrastive_blocks > 0 and masked_blocks > 0 and
67+
contrastive_blocks + masked_blocks == len(encoder.encoders))
68+
self.contrastive_blocks = contrastive_blocks
69+
self.masked_blocks = masked_blocks
70+
71+
self.mask_prob = mask_prob
72+
self.mask_length = mask_length
73+
self.min_masks = min_masks
74+
self.num_negatives = num_negatives
75+
76+
self.features_regularization_weight = features_regularization_weight
77+
self.diversity_weight = diversity_weight
78+
79+
self.contrastive_weight = contrastive_weight
80+
self.mlm_weight = mlm_weight
81+
self.warmup_steps = warmup_steps
82+
# encoder
83+
self.encoder = encoder
84+
85+
# quantizer
86+
self.num_codebooks = num_codebooks
87+
self.quantizer = Wav2vecGumbelVectorQuantizer(
88+
self.encoder.output_size(),
89+
num_codebooks=num_codebooks,
90+
num_embeddings=num_embeddings,
91+
embedding_dim=embedding_dim,
92+
hard=False,
93+
)
94+
self.max_gumbel_temp = max_gumbel_temperature
95+
self.min_gumbel_temp = min_gumbel_temperature
96+
self.gumbel_temp_decay = gumbel_temperature_decay
97+
98+
self.num_codevectors_per_group = num_embeddings
99+
self.num_codevector_groups = num_codebooks
100+
101+
self.contrastive_logits_temp = contrastive_logits_temperature
102+
103+
# NOET(Mddct): mask_em is replaced by random value in Wav-BERT
104+
# self.mask_emb = torch.nn.parameter.Parameter(
105+
# torch.empty(self.encoder.output_size()).uniform_(),
106+
# requires_grad=True,
107+
# )
108+
# TODO(Mddct): support causal or lookahead mask or keep consistent with
109+
# wenet dynamic chunk training
110+
111+
# # n softmax
112+
self.encoder_top_n_out = torch.nn.parameter.Parameter(
113+
torch.empty(num_codebooks, self.encoder.output_size(),
114+
num_embeddings))
115+
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
116+
self.bias = bias
117+
if bias:
118+
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
119+
torch.empty(num_codebooks, num_embeddings))
120+
torch.nn.init.zeros_(self.encoder_top_n_out_bias)
121+
122+
# reset parameter
123+
self.reset_encoder_parameter()
124+
125+
def reset_encoder_parameter(self):
126+
127+
def _reset_parameter(module: torch.nn.Module):
128+
if isinstance(module, torch.nn.Linear):
129+
torch.nn.init.trunc_normal_(module.weight.data,
130+
mean=0.0,
131+
std=0.02)
132+
if module.bias is not None:
133+
module.bias.data.zero_()
134+
elif isinstance(module, torch.nn.Conv1d):
135+
torch.nn.init.kaiming_normal_(module.weight)
136+
if module.bias is not None:
137+
k = math.sqrt(module.groups /
138+
(module.in_channels * module.kernel_size[0]))
139+
torch.nn.init.uniform_(module.bias, a=-k, b=k)
140+
elif isinstance(module, torch.Tensor):
141+
torch.nn.init.trunc_normal_(module)
142+
else:
143+
raise NotImplementedError("other module not support now")
144+
145+
encoders = self.encoder.encoders
146+
for _, layer in enumerate(encoders):
147+
self_attn = layer.self_attn
148+
_reset_parameter(self_attn.linear_q)
149+
_reset_parameter(self_attn.linear_k)
150+
_reset_parameter(self_attn.linear_v)
151+
_reset_parameter(self_attn.linear_out)
152+
if isinstance(self_attn, RelPositionMultiHeadedAttention):
153+
_reset_parameter(self_attn.pos_bias_u)
154+
_reset_parameter(self_attn.pos_bias_v)
155+
if isinstance(layer, ConformerEncoderLayer):
156+
conv1, conv2 = (layer.conv_module.pointwise_conv1,
157+
layer.conv_module.depthwise_conv)
158+
_reset_parameter(conv1)
159+
_reset_parameter(conv2)
160+
161+
@torch.jit.ignore(drop=True)
162+
def forward(
163+
self,
164+
xs: torch.Tensor,
165+
xs_lens: torch.Tensor,
166+
text: Optional[torch.Tensor] = None,
167+
text_length: Optional[torch.Tensor] = None,
168+
steps: Optional[int] = None,
169+
):
170+
171+
assert xs.size(0) == xs_lens.size(0)
172+
assert steps is not None
173+
174+
# 1 forward subsampling
175+
# NOTE(Mddct): use subsampling as feature extraction
176+
xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens)
177+
unmasked_xs = xs
178+
# 2 mask features
179+
masked_xs, masked_masks = self._apply_mask(xs, masks.squeeze(1))
180+
# 3 forward encoder blocks
181+
contrastive_vec, mlm_vec, out_mask = self._forward_encoder_blocks(
182+
masked_xs, masks, pos_emb, masks)
183+
184+
# 4 constrastive branch
185+
gumbel_temperature = max(
186+
self.max_gumbel_temp * self.gumbel_temp_decay**steps,
187+
self.min_gumbel_temp)
188+
189+
quantized_features, codevector_perplexity, targets_ids = self.quantizer(
190+
unmasked_xs, masks.squeeze(1), gumbel_temperature)
191+
192+
sampled_negative_indices = _sample_negative_indices(
193+
xs.size()[:-1], self.num_negatives, masked_masks.device,
194+
masked_masks)
195+
196+
loss_contrastive = _compute_contrastive_loss(
197+
quantized_features, contrastive_vec, sampled_negative_indices,
198+
masked_masks, self.contrastive_logits_temp, self.num_negatives)
199+
loss = loss_contrastive
200+
201+
# scale by sample size
202+
# make sure that diversity loss is multiplied by `sample_size`
203+
# since contrastive_loss is `sum`-reduced instead of averaged
204+
sample_size = masked_masks.sum()
205+
# higher codevector_perplexity leads to lower diversity loss
206+
loss_diversity: Optional[torch.Tensor] = None
207+
if self.diversity_weight != 0.0:
208+
loss_diversity = (
209+
self.num_codevector_groups * self.num_codevectors_per_group -
210+
codevector_perplexity) / (self.num_codevectors_per_group *
211+
self.num_codevector_groups)
212+
loss_diversity = loss_diversity * sample_size
213+
loss = loss + self.diversity_weight * loss_diversity
214+
loss = loss / sample_size
215+
216+
features_pen: Optional[torch.Tensor] = None
217+
if self.features_regularization_weight != 0.0:
218+
features_pen = xs.pow(2).mean()
219+
loss = loss + self.features_regularization_weight * features_pen
220+
221+
# 5 maked lm branch
222+
out = mlm_vec.unsqueeze(1)
223+
top_n_out = self.encoder_top_n_out.unsqueeze(
224+
0) # [1, num_codebooks, dim, num_embeddings]
225+
out = torch.matmul(out,
226+
top_n_out) # [B, num_codebooks, T', num_embeddings]
227+
if self.bias:
228+
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)
229+
num_codes = masked_masks.sum() * self.num_codebooks
230+
loss_mlm = self._compute_mlm_loss(out,
231+
targets_ids,
232+
mask=out_mask.squeeze(1) *
233+
masked_masks)
234+
ids_corr = out.argmax(dim=-1,
235+
keepdim=False).transpose(1, 2) == targets_ids
236+
codes_acc = (ids_corr * masked_masks.unsqueeze(2)).sum() / num_codes
237+
# TODO(Mddct): support num codes used in batch, unique num codes
238+
# used in batch like bestrq
239+
240+
# 6 final loss
241+
mlm_weight = (self.mlm_weight if steps >= self.warmup_steps else 0.1 +
242+
0.9 * (steps / self.warmup_steps))
243+
loss = self.contrastive_weight * loss + mlm_weight * loss_mlm
244+
return {
245+
"code_ppl": codevector_perplexity.detach(),
246+
"features_l2": features_pen,
247+
"codes_acc": codes_acc.detach(),
248+
"loss": loss,
249+
"loss_contrastive": loss_contrastive / sample_size,
250+
"loss_diversity": loss_diversity,
251+
"loss_mlm": loss_mlm,
252+
}
253+
254+
def _apply_mask(
255+
self, xs: torch.Tensor,
256+
xs_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
257+
258+
masks = compute_mask_indices_v2(xs.size()[:-1],
259+
~xs_masks,
260+
self.mask_prob,
261+
self.mask_length,
262+
min_masks=self.min_masks,
263+
device=xs.device)
264+
masks_expand = masks.unsqueeze(-1) # [B, T, 1]
265+
266+
mask_emb = torch.normal(mean=0,
267+
std=0.1,
268+
size=xs.size(),
269+
device=xs.device)
270+
xs = torch.where(masks_expand, mask_emb, xs)
271+
272+
return xs, masks
273+
274+
def _compute_mlm_loss(self, input: torch.Tensor, target: torch.Tensor,
275+
mask: torch.Tensor) -> torch.Tensor:
276+
log_probs = torch.log_softmax(input, dim=-1).transpose(
277+
1, 2) # [B, T', num_codebooks, num_embeddings]
278+
279+
per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze(
280+
3) # [B, T', num_codebooks]
281+
282+
numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2))
283+
denominator = torch.sum(mask) + 1e-5
284+
loss = numerator / (denominator * self.num_codebooks)
285+
return loss
286+
287+
def _forward_subsampling(
288+
self, xs: torch.Tensor, xs_lens: torch.Tensor
289+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
290+
291+
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
292+
if self.encoder.global_cmvn is not None:
293+
xs = self.encoder.global_cmvn(xs)
294+
xs, pos_emb, masks = self.encoder.embed(xs, masks)
295+
return xs, pos_emb, masks
296+
297+
def _forward_encoder_blocks(
298+
self, xs: torch.Tensor, xs_masks: torch.Tensor, pos_emb: torch.Tensor,
299+
mask_pad: torch.Tensor
300+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
301+
302+
masks = xs_masks
303+
304+
xs: torch.Tensor
305+
# forward contrastive layers get context vector for Contrastive Loss
306+
for layer in self.encoder.encoders[:self.contrastive_blocks]:
307+
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
308+
contrastive_vec = xs
309+
310+
for layer in self.encoder.encoders[self.contrastive_blocks:]:
311+
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
312+
masked_vec = xs
313+
314+
if self.encoder.normalize_before:
315+
xs = self.encoder.after_norm(xs)
316+
masked_vec = xs
317+
# Here we assume the mask is not changed in encoder layers, so just
318+
# return the masks before encoder layers, and the masks will be used
319+
# for cross attention with decoder later
320+
return contrastive_vec, masked_vec, masks

wenet/ssl/wav2vec2/quantizer.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Tuple
12
import torch
23

34

@@ -66,10 +67,12 @@ def _compute_perplexity(probs, mask=None):
6667
marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
6768
return perplexity
6869

69-
def forward(self,
70-
input: torch.Tensor,
71-
input_mask: torch.Tensor,
72-
temperature: float = 1.):
70+
def forward(
71+
self,
72+
input: torch.Tensor,
73+
input_mask: torch.Tensor,
74+
temperature: float = 1.
75+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
7376

7477
b, t, _ = input.size()
7578

@@ -98,6 +101,7 @@ def forward(self,
98101
b * t, self.num_groups, -1)
99102
perplexity = self._compute_perplexity(codevector_probs, input_mask)
100103

104+
targets_idx = codevector_probs.argmax(-1).reshape(b, t, -1)
101105
codevector_probs = codevector_probs.reshape(b * t, -1)
102106
# use probs to retrieve codevectors
103107
codevectors_per_group = codevector_probs.unsqueeze(
@@ -106,4 +110,4 @@ def forward(self,
106110
b * t, self.num_groups, self.num_codevectors_per_group, -1)
107111

108112
codevectors = codevectors.sum(-2).reshape(b, t, -1)
109-
return codevectors, perplexity
113+
return codevectors, perplexity, targets_idx

wenet/ssl/wav2vec2/wav2vec2_model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ def _sample_negative_indices(features_shape: Tuple,
2121
"""
2222
batch_size, sequence_length = features_shape
2323

24-
sequence_length_range = torch.arange(sequence_length)
24+
sequence_length_range = torch.arange(sequence_length, device=device)
2525

2626
# get `num_negatives` random vector indices from the same utterance
2727
sampled_negative_indices = torch.zeros(
2828
(batch_size, sequence_length, num_negatives),
2929
dtype=sequence_length_range.dtype,
3030
device=device)
3131

32-
mask_time_indices = (mask_time_indices.bool() if mask_time_indices
33-
is not None else torch.ones(features_shape,
34-
dtype=torch.bool))
32+
mask_time_indices = (mask_time_indices.bool()
33+
if mask_time_indices is not None else torch.ones(
34+
features_shape, dtype=torch.bool, device=device))
3535

3636
for batch_idx in range(batch_size):
3737
high = mask_time_indices[batch_idx].sum() - 1
@@ -243,7 +243,7 @@ def forward(
243243
self.max_gumbel_temp * self.gumbel_temp_decay**steps,
244244
self.min_gumbel_temp)
245245

246-
quantized_features, codevector_perplexity = self.quantizer(
246+
quantized_features, codevector_perplexity, _ = self.quantizer(
247247
unmasked_xs, masks.squeeze(1), gumbel_temperature)
248248

249249
sampled_negative_indices = _sample_negative_indices(
@@ -279,7 +279,7 @@ def forward(
279279
"code_ppl": codevector_perplexity.detach(),
280280
"features_l2": features_pen,
281281
"loss": loss,
282-
"losss_constrastive": loss_contrastive / sample_size,
282+
"loss_contrastive": loss_contrastive / sample_size,
283283
"loss_diversity": loss_diversity,
284284
}
285285

0 commit comments

Comments
 (0)