@@ -13,19 +13,49 @@ def forward(self, x):
1313 x = self .conv (x )
1414 return torch .nn .functional .gelu (self .layer_norm (x .transpose (- 2 , - 1 )).transpose (- 2 , - 1 ))
1515
16+ class LayerGroupNormConv (nn .Module ):
17+ def __init__ (self , in_channels , out_channels , kernel_size , stride , bias = False , dtype = None , device = None , operations = None ):
18+ super ().__init__ ()
19+ self .conv = operations .Conv1d (in_channels , out_channels , kernel_size = kernel_size , stride = stride , bias = bias , device = device , dtype = dtype )
20+ self .layer_norm = operations .GroupNorm (num_groups = out_channels , num_channels = out_channels , affine = True , device = device , dtype = dtype )
21+
22+ def forward (self , x ):
23+ x = self .conv (x )
24+ return torch .nn .functional .gelu (self .layer_norm (x ))
25+
26+ class ConvNoNorm (nn .Module ):
27+ def __init__ (self , in_channels , out_channels , kernel_size , stride , bias = False , dtype = None , device = None , operations = None ):
28+ super ().__init__ ()
29+ self .conv = operations .Conv1d (in_channels , out_channels , kernel_size = kernel_size , stride = stride , bias = bias , device = device , dtype = dtype )
30+
31+ def forward (self , x ):
32+ x = self .conv (x )
33+ return torch .nn .functional .gelu (x )
34+
1635
1736class ConvFeatureEncoder (nn .Module ):
18- def __init__ (self , conv_dim , dtype = None , device = None , operations = None ):
37+ def __init__ (self , conv_dim , conv_bias = False , conv_norm = True , dtype = None , device = None , operations = None ):
1938 super ().__init__ ()
20- self .conv_layers = nn .ModuleList ([
21- LayerNormConv (1 , conv_dim , kernel_size = 10 , stride = 5 , bias = True , device = device , dtype = dtype , operations = operations ),
22- LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = True , device = device , dtype = dtype , operations = operations ),
23- LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = True , device = device , dtype = dtype , operations = operations ),
24- LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = True , device = device , dtype = dtype , operations = operations ),
25- LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = True , device = device , dtype = dtype , operations = operations ),
26- LayerNormConv (conv_dim , conv_dim , kernel_size = 2 , stride = 2 , bias = True , device = device , dtype = dtype , operations = operations ),
27- LayerNormConv (conv_dim , conv_dim , kernel_size = 2 , stride = 2 , bias = True , device = device , dtype = dtype , operations = operations ),
28- ])
39+ if conv_norm :
40+ self .conv_layers = nn .ModuleList ([
41+ LayerNormConv (1 , conv_dim , kernel_size = 10 , stride = 5 , bias = True , device = device , dtype = dtype , operations = operations ),
42+ LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
43+ LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
44+ LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
45+ LayerNormConv (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
46+ LayerNormConv (conv_dim , conv_dim , kernel_size = 2 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
47+ LayerNormConv (conv_dim , conv_dim , kernel_size = 2 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
48+ ])
49+ else :
50+ self .conv_layers = nn .ModuleList ([
51+ LayerGroupNormConv (1 , conv_dim , kernel_size = 10 , stride = 5 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
52+ ConvNoNorm (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
53+ ConvNoNorm (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
54+ ConvNoNorm (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
55+ ConvNoNorm (conv_dim , conv_dim , kernel_size = 3 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
56+ ConvNoNorm (conv_dim , conv_dim , kernel_size = 2 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
57+ ConvNoNorm (conv_dim , conv_dim , kernel_size = 2 , stride = 2 , bias = conv_bias , device = device , dtype = dtype , operations = operations ),
58+ ])
2959
3060 def forward (self , x ):
3161 x = x .unsqueeze (1 )
@@ -76,6 +106,7 @@ def __init__(
76106 num_heads = 12 ,
77107 num_layers = 12 ,
78108 mlp_ratio = 4.0 ,
109+ do_stable_layer_norm = True ,
79110 dtype = None , device = None , operations = None
80111 ):
81112 super ().__init__ ()
@@ -86,20 +117,25 @@ def __init__(
86117 embed_dim = embed_dim ,
87118 num_heads = num_heads ,
88119 mlp_ratio = mlp_ratio ,
120+ do_stable_layer_norm = do_stable_layer_norm ,
89121 device = device , dtype = dtype , operations = operations
90122 )
91123 for _ in range (num_layers )
92124 ])
93125
94126 self .layer_norm = operations .LayerNorm (embed_dim , eps = 1e-05 , device = device , dtype = dtype )
127+ self .do_stable_layer_norm = do_stable_layer_norm
95128
96129 def forward (self , x , mask = None ):
97130 x = x + self .pos_conv_embed (x )
98131 all_x = ()
132+ if not self .do_stable_layer_norm :
133+ x = self .layer_norm (x )
99134 for layer in self .layers :
100135 all_x += (x ,)
101136 x = layer (x , mask )
102- x = self .layer_norm (x )
137+ if self .do_stable_layer_norm :
138+ x = self .layer_norm (x )
103139 all_x += (x ,)
104140 return x , all_x
105141
@@ -145,6 +181,7 @@ def __init__(
145181 embed_dim = 768 ,
146182 num_heads = 12 ,
147183 mlp_ratio = 4.0 ,
184+ do_stable_layer_norm = True ,
148185 dtype = None , device = None , operations = None
149186 ):
150187 super ().__init__ ()
@@ -154,15 +191,19 @@ def __init__(
154191 self .layer_norm = operations .LayerNorm (embed_dim , device = device , dtype = dtype )
155192 self .feed_forward = FeedForward (embed_dim , mlp_ratio , device = device , dtype = dtype , operations = operations )
156193 self .final_layer_norm = operations .LayerNorm (embed_dim , device = device , dtype = dtype )
194+ self .do_stable_layer_norm = do_stable_layer_norm
157195
158196 def forward (self , x , mask = None ):
159197 residual = x
160- x = self .layer_norm (x )
198+ if self .do_stable_layer_norm :
199+ x = self .layer_norm (x )
161200 x = self .attention (x , mask = mask )
162201 x = residual + x
163-
164- x = x + self .feed_forward (self .final_layer_norm (x ))
165- return x
202+ if not self .do_stable_layer_norm :
203+ x = self .layer_norm (x )
204+ return self .final_layer_norm (x + self .feed_forward (x ))
205+ else :
206+ return x + self .feed_forward (self .final_layer_norm (x ))
166207
167208
168209class Wav2Vec2Model (nn .Module ):
@@ -174,34 +215,38 @@ def __init__(
174215 final_dim = 256 ,
175216 num_heads = 16 ,
176217 num_layers = 24 ,
218+ conv_norm = True ,
219+ conv_bias = True ,
220+ do_normalize = True ,
221+ do_stable_layer_norm = True ,
177222 dtype = None , device = None , operations = None
178223 ):
179224 super ().__init__ ()
180225
181226 conv_dim = 512
182- self .feature_extractor = ConvFeatureEncoder (conv_dim , device = device , dtype = dtype , operations = operations )
227+ self .feature_extractor = ConvFeatureEncoder (conv_dim , conv_norm = conv_norm , conv_bias = conv_bias , device = device , dtype = dtype , operations = operations )
183228 self .feature_projection = FeatureProjection (conv_dim , embed_dim , device = device , dtype = dtype , operations = operations )
184229
185230 self .masked_spec_embed = nn .Parameter (torch .empty (embed_dim , device = device , dtype = dtype ))
231+ self .do_normalize = do_normalize
186232
187233 self .encoder = TransformerEncoder (
188234 embed_dim = embed_dim ,
189235 num_heads = num_heads ,
190236 num_layers = num_layers ,
237+ do_stable_layer_norm = do_stable_layer_norm ,
191238 device = device , dtype = dtype , operations = operations
192239 )
193240
194- def forward (self , x , mask_time_indices = None , return_dict = False ):
195-
241+ def forward (self , x , sr = 16000 , mask_time_indices = None , return_dict = False ):
196242 x = torch .mean (x , dim = 1 )
197243
198- x = (x - x .mean ()) / torch .sqrt (x .var () + 1e-7 )
244+ if self .do_normalize :
245+ x = (x - x .mean ()) / torch .sqrt (x .var () + 1e-7 )
199246
200247 features = self .feature_extractor (x )
201248 features = self .feature_projection (features )
202-
203249 batch_size , seq_len , _ = features .shape
204250
205251 x , all_x = self .encoder (features )
206-
207252 return x , all_x
0 commit comments