Skip to content

Commit a7c4783

Browse files
kijaigmaOCR
authored andcommitted
Support wav2vec base models (comfyanonymous#9637)
* Support wav2vec base models * trim trailing whitespace * Do interpolation after
1 parent d620110 commit a7c4783

File tree

2 files changed

+99
-24
lines changed

2 files changed

+99
-24
lines changed

comfy/audio_encoders/audio_encoders.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ def __init__(self, config):
1111
self.load_device = comfy.model_management.text_encoder_device()
1212
offload_device = comfy.model_management.text_encoder_offload_device()
1313
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
14-
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
14+
model_config = dict(config)
15+
model_config.update({
16+
"dtype": self.dtype,
17+
"device": offload_device,
18+
"operations": comfy.ops.manual_cast
19+
})
20+
self.model = Wav2Vec2Model(**model_config)
1521
self.model.eval()
1622
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
1723
self.model_sample_rate = 16000
@@ -25,16 +31,40 @@ def get_sd(self):
2531
def encode_audio(self, audio, sample_rate):
2632
comfy.model_management.load_model_gpu(self.patcher)
2733
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
28-
out, all_layers = self.model(audio.to(self.load_device))
34+
out, all_layers = self.model(audio.to(self.load_device), sr=self.model_sample_rate)
2935
outputs = {}
3036
outputs["encoded_audio"] = out
3137
outputs["encoded_audio_all_layers"] = all_layers
3238
return outputs
3339

3440

3541
def load_audio_encoder_from_sd(sd, prefix=""):
36-
audio_encoder = AudioEncoderModel(None)
3742
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
43+
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
44+
if embed_dim == 1024:# large
45+
config = {
46+
"embed_dim": 1024,
47+
"num_heads": 16,
48+
"num_layers": 24,
49+
"conv_norm": True,
50+
"conv_bias": True,
51+
"do_normalize": True,
52+
"do_stable_layer_norm": True
53+
}
54+
elif embed_dim == 768: # base
55+
config = {
56+
"embed_dim": 768,
57+
"num_heads": 12,
58+
"num_layers": 12,
59+
"conv_norm": False,
60+
"conv_bias": False,
61+
"do_normalize": False, # chinese-wav2vec2-base has this False
62+
"do_stable_layer_norm": False
63+
}
64+
else:
65+
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
66+
67+
audio_encoder = AudioEncoderModel(config)
3868
m, u = audio_encoder.load_sd(sd)
3969
if len(m) > 0:
4070
logging.warning("missing audio encoder: {}".format(m))

comfy/audio_encoders/wav2vec2.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,49 @@ def forward(self, x):
1313
x = self.conv(x)
1414
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
1515

16+
class LayerGroupNormConv(nn.Module):
17+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
18+
super().__init__()
19+
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
20+
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
21+
22+
def forward(self, x):
23+
x = self.conv(x)
24+
return torch.nn.functional.gelu(self.layer_norm(x))
25+
26+
class ConvNoNorm(nn.Module):
27+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
28+
super().__init__()
29+
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
30+
31+
def forward(self, x):
32+
x = self.conv(x)
33+
return torch.nn.functional.gelu(x)
34+
1635

1736
class ConvFeatureEncoder(nn.Module):
18-
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
37+
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
1938
super().__init__()
20-
self.conv_layers = nn.ModuleList([
21-
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
22-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
23-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
24-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
25-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
26-
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
27-
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
28-
])
39+
if conv_norm:
40+
self.conv_layers = nn.ModuleList([
41+
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
42+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
43+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
44+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
45+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
46+
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
47+
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
48+
])
49+
else:
50+
self.conv_layers = nn.ModuleList([
51+
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
52+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
53+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
54+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
55+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
56+
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
57+
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
58+
])
2959

3060
def forward(self, x):
3161
x = x.unsqueeze(1)
@@ -76,6 +106,7 @@ def __init__(
76106
num_heads=12,
77107
num_layers=12,
78108
mlp_ratio=4.0,
109+
do_stable_layer_norm=True,
79110
dtype=None, device=None, operations=None
80111
):
81112
super().__init__()
@@ -86,20 +117,25 @@ def __init__(
86117
embed_dim=embed_dim,
87118
num_heads=num_heads,
88119
mlp_ratio=mlp_ratio,
120+
do_stable_layer_norm=do_stable_layer_norm,
89121
device=device, dtype=dtype, operations=operations
90122
)
91123
for _ in range(num_layers)
92124
])
93125

94126
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
127+
self.do_stable_layer_norm = do_stable_layer_norm
95128

96129
def forward(self, x, mask=None):
97130
x = x + self.pos_conv_embed(x)
98131
all_x = ()
132+
if not self.do_stable_layer_norm:
133+
x = self.layer_norm(x)
99134
for layer in self.layers:
100135
all_x += (x,)
101136
x = layer(x, mask)
102-
x = self.layer_norm(x)
137+
if self.do_stable_layer_norm:
138+
x = self.layer_norm(x)
103139
all_x += (x,)
104140
return x, all_x
105141

@@ -145,6 +181,7 @@ def __init__(
145181
embed_dim=768,
146182
num_heads=12,
147183
mlp_ratio=4.0,
184+
do_stable_layer_norm=True,
148185
dtype=None, device=None, operations=None
149186
):
150187
super().__init__()
@@ -154,15 +191,19 @@ def __init__(
154191
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
155192
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
156193
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
194+
self.do_stable_layer_norm = do_stable_layer_norm
157195

158196
def forward(self, x, mask=None):
159197
residual = x
160-
x = self.layer_norm(x)
198+
if self.do_stable_layer_norm:
199+
x = self.layer_norm(x)
161200
x = self.attention(x, mask=mask)
162201
x = residual + x
163-
164-
x = x + self.feed_forward(self.final_layer_norm(x))
165-
return x
202+
if not self.do_stable_layer_norm:
203+
x = self.layer_norm(x)
204+
return self.final_layer_norm(x + self.feed_forward(x))
205+
else:
206+
return x + self.feed_forward(self.final_layer_norm(x))
166207

167208

168209
class Wav2Vec2Model(nn.Module):
@@ -174,34 +215,38 @@ def __init__(
174215
final_dim=256,
175216
num_heads=16,
176217
num_layers=24,
218+
conv_norm=True,
219+
conv_bias=True,
220+
do_normalize=True,
221+
do_stable_layer_norm=True,
177222
dtype=None, device=None, operations=None
178223
):
179224
super().__init__()
180225

181226
conv_dim = 512
182-
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
227+
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
183228
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
184229

185230
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
231+
self.do_normalize = do_normalize
186232

187233
self.encoder = TransformerEncoder(
188234
embed_dim=embed_dim,
189235
num_heads=num_heads,
190236
num_layers=num_layers,
237+
do_stable_layer_norm=do_stable_layer_norm,
191238
device=device, dtype=dtype, operations=operations
192239
)
193240

194-
def forward(self, x, mask_time_indices=None, return_dict=False):
195-
241+
def forward(self, x, sr=16000, mask_time_indices=None, return_dict=False):
196242
x = torch.mean(x, dim=1)
197243

198-
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
244+
if self.do_normalize:
245+
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
199246

200247
features = self.feature_extractor(x)
201248
features = self.feature_projection(features)
202-
203249
batch_size, seq_len, _ = features.shape
204250

205251
x, all_x = self.encoder(features)
206-
207252
return x, all_x

0 commit comments

Comments
 (0)