Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
023a033
Support BatchNorm in Hubert pos_conv_emb as in fairseq
gallilmaimon Oct 24, 2024
fb32b55
Correct the new defaults (#34377)
Cyrilvallez Oct 24, 2024
c4ab8a5
[auto. ping] Avoid sending empty info + add more team members (#34383)
ydshieh Oct 24, 2024
b2a7b11
Fix glm (#34388)
Cyrilvallez Oct 24, 2024
5289130
Use non nested images and batched text Idefics2/3 (#34222)
yonigozlan Oct 25, 2024
86468ad
Fix onnx non-expotable inplace aten op (#34376)
IlyasMoutawwakil Oct 25, 2024
cfe1e14
Fix right padding in LLaVA models (#34305)
zucchini-nlp Oct 25, 2024
337621a
no filter (#34391)
ydshieh Oct 25, 2024
3ae703b
SynthID: better example (#34372)
gante Oct 25, 2024
a0ccf20
Tests: upgrade `test_eager_matches_sdpa_generate` (#34386)
gante Oct 25, 2024
3c3e153
Fix bnb training test failure (#34414)
matthewdouglas Oct 25, 2024
2dded53
Avoid check expected exception when it is on CUDA (#34408)
ydshieh Oct 25, 2024
d8edfcb
Fix typos in agents_advanced.md (#34405)
rudydel Oct 25, 2024
3398913
[docs] Cache implementations (#34325)
stevhliu Oct 25, 2024
75f0689
[run-slow] hubert
gallilmaimon Oct 26, 2024
a3042a0
Support BatchNorm in Hubert pos_conv_emb as in fairseq
gallilmaimon Oct 26, 2024
a0a2731
Support BatchNorm in Hubert pos_conv_emb as in fairseq
gallilmaimon Oct 26, 2024
987d521
Merge branch 'huggingface:main' into add_hubert_conv_emb_batchnorm_su…
gallilmaimon Oct 26, 2024
ce40909
[run-slow] hubert
gallilmaimon Oct 26, 2024
391ea79
Support BatchNorm in Hubert pos_conv_emb as in fairseq
gallilmaimon Oct 24, 2024
7bbc7b4
[run-slow] hubert
gallilmaimon Oct 26, 2024
3e7f77e
Support BatchNorm in Hubert pos_conv_emb as in fairseq
gallilmaimon Oct 26, 2024
2ca473f
Support BatchNorm in Hubert pos_conv_emb as in fairseq
gallilmaimon Oct 26, 2024
9f167a2
[run-slow] hubert
gallilmaimon Oct 26, 2024
af1d65e
Merge branch 'main' into add_hubert_conv_emb_batchnorm_support
ylacombe Nov 26, 2024
eaed17f
Merge branch 'add_hubert_conv_emb_batchnorm_support' of https://githu…
gallilmaimon Nov 26, 2024
61d8ad0
[run-slow] hubert
gallilmaimon Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/hubert/configuration_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class HubertConfig(PretrainedConfig):
embeddings layer.
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
Number of groups of 1D convolutional positional embeddings layer.
conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we precise (for bf16 models) out of curiosity ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it then

Whether to use batch norm instead of weight norm in conv_pos
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
Expand Down Expand Up @@ -182,6 +184,7 @@ def __init__(
conv_bias=False,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
conv_pos_batch_norm=False,
do_stable_layer_norm=False,
apply_spec_augment=True,
mask_time_prob=0.05,
Expand Down Expand Up @@ -209,6 +212,7 @@ def __init__(
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.conv_pos_batch_norm = conv_pos_batch_norm
self.num_feat_extract_layers = len(self.conv_dim)
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"encoder.pos_conv.0": "encoder.pos_conv_embed.batch_norm",
"encoder.pos_conv.1": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
Expand Down Expand Up @@ -76,6 +77,12 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "running_mean":
hf_pointer.running_mean.data = value
elif weight_type == "running_var":
hf_pointer.running_var.data = value
elif weight_type == "num_batches_tracked":
hf_pointer.num_batches_tracked.data = value
else:
hf_pointer.data = value

Expand Down Expand Up @@ -116,6 +123,12 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
elif "running_mean" in name:
weight_type = "running_mean"
elif "running_var" in name:
weight_type = "running_var"
elif "num_batches_tracked" in name:
weight_type = "num_batches_tracked"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
Expand Down
40 changes: 22 additions & 18 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def forward(self, hidden_states):
return hidden_states


# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
class HubertPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
Expand All @@ -272,32 +271,37 @@ def __init__(self, config):
groups=config.num_conv_pos_embedding_groups,
)

weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
self.batch_norm = None
if config.conv_pos_batch_norm:
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
else:
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if is_deepspeed_zero3_enabled():
import deepspeed
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)

self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)

if self.batch_norm is not None:
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
Expand Down
37 changes: 37 additions & 0 deletions tests/models/hubert/test_modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,40 @@ def test_inference_distilhubert(self):
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)

def test_inference_hubert_25hz(self):
model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to open a PR to the original repo and use pr branch revision in the mean time!


sample = self._load_datasamples(1)
input_speech = torch.tensor(sample[0], dtype=torch.float, device=torch_device).unsqueeze(0)

with torch.no_grad():
outputs = model(input_speech, output_hidden_states=True).hidden_states[11]

# expected outputs taken from the original textlesslib implementation by:
# model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans',
# vocab_size=500, deduplicate=False, need_f0=False)
# model(wav)['dense']
expected_outputs_first = torch.tensor(
[
[0.0267, 0.1776, -0.1706, -0.4559],
[-0.2430, -0.2943, -0.1864, -0.1187],
[-0.1812, -0.4239, -0.1916, -0.0858],
[-0.1495, -0.4758, -0.4036, 0.0302],
],
device=torch_device,
)
expected_outputs_last = torch.tensor(
[
[0.3366, -0.2734, -0.1415, -0.3055],
[0.2329, -0.3580, -0.1421, -0.3197],
[0.1631, -0.4301, -0.1965, -0.2956],
[0.3342, -0.2185, -0.2253, -0.2363],
],
device=torch_device,
)
expected_output_sum = 1681.7603

self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)