Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WAVLM_BASE pipeline shows different hidden states when compared to the HuggingFace model. #3347

Closed
SevKod opened this issue May 19, 2023 · 7 comments
Labels

Comments

@SevKod
Copy link

SevKod commented May 19, 2023

🐛 Describe the bug

When comparing the WAVLM_BASE model from the Torchaudio pipeline with the one HuggingFace provides (here : https://huggingface.co/microsoft/wavlm-base), it appears that the returned hidden states show different outputs EXCEPT for the first layer.
I provided a Colab example below :

https://colab.research.google.com/drive/1gEapCvAXCTCtauUjF_DvxW5nkiznbdsH?usp=sharing

What I find weird is that both models return the same first hidden states, even though other layers are different...
I tested the HuBERT pipeline with the HuggingFace one and layers were identical, so it seems to be the case for some models only (such as WavLM).
I tried tweaking the configuration files of both models in order to achieve the same outputs for both models but without any success. Some parameters might differ between architectures even though after analyzing them, I did not a see major difference...

Versions

PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.2
Libc version: glibc-2.31

Python version: 3.10.11 (main, Apr 5 2023, 14:15:10) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.107+-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 2
On-line CPU(s) list: 0,1
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping: 0
CPU MHz: 2200.216
BogoMIPS: 4400.43
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32 KiB
L1i cache: 32 KiB
L2 cache: 256 KiB
L3 cache: 55 MiB
NUMA node0 CPU(s): 0,1
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Mmio stale data: Vulnerable
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==2.0.1+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchdata==0.6.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.15.2
[pip3] torchvision==0.15.2+cu118
[pip3] triton==2.0.0
[conda] Could not collect

@SevKod SevKod changed the title WAVLM_BASE pipeline shows different hidden states when compared to HuggingFace model. WAVLM_BASE pipeline shows different hidden states when compared to the HuggingFace model. May 19, 2023
@nateanl
Copy link
Member

nateanl commented May 19, 2023

Hi @SevKod, thanks for reporting the issue. The reason is in get_intermediate_outputs method of Transformer model, positional_bias is ignored, however it should be used as input to the next transformer layer.

x, _ = layer(x, attention_mask) # Ignore position_bias

To get identical output as HuggingFace, you can try this script:

#Load the torchaudio model and apply feature extraction
bundle = torchaudio.pipelines.WAVLM_BASE
model_torch = bundle.get_model().to(device)
features = []
x, _ = model_torch.feature_extractor(waveforms, None)
x, masks = model_torch.encoder._preprocess(x, None)
position_bias = None
x = model_torch.encoder.transformer._preprocess(x)
for layer in model_torch.encoder.transformer.layers:
    x, position_bias = layer(x, masks, position_bias=position_bias)
    features.append(x)

I will make a quick fix today. Thanks for looking into it.

@SevKod
Copy link
Author

SevKod commented May 20, 2023

Thanks a lot @nateanl for answering so quickly and for the fix !

I might have found another issue this time regarding the WavLM-LARGE model (still about feature extraction). Don't know if I have to re-open another issue though, but here it is :

When extracting features, the WavLM_Large feature extraction method returns the same weights on each tensor columns (past the the 5th layer). Here is an example :

https://colab.research.google.com/drive/1I_zvmY3DomzQyGxe402BOAAZ5CnN4ToQ?usp=sharing

As seen from the example above, computed weights appear to be the same on each column from the output tensor (can be mainly seen when printing the 5th layer and below...). This does not appear to be normal... (Or maybe there is a different way of using this model ?).
In any case, thank you in advance for your help !

@mthrok
Copy link
Collaborator

mthrok commented May 24, 2023

What happens with #3350 applied? i.e. if we try nightly build?

@SevKod
Copy link
Author

SevKod commented May 24, 2023

What happens with #3350 applied? i.e. if we try nightly build?

Just tested it and it still appears to be the case :

https://colab.research.google.com/drive/1H_HLBUcyQS7u4lHMWXOwYIYN81rS5tIm#scrollTo=_c68RCbaBcAX

@sgrigory
Copy link
Contributor

sgrigory commented Jun 25, 2023

The difference between WavLM hidden states in TorchAudio and HF Transformers seems to come from the model configuration. To reproduce HF result one needs to set encoder_layer_norm_first = True when creating WavLM model in TorchAudio (this parameter are False by default)

Indeed, in WavLMEncoderStableLayerNorm, used by HF by default, the norm is applied at the beginning at the layer, and in TorchAudio this depends on EncoderLayer.layer_norm_first

In this notebook (slightly modified version of @SevKod's) I patched bundle config of WAVLM_LARGE with encoder_layer_norm_first = True, and the hidden states produced by TorchAudio matched that of HF - except for the first and the last layer. This remained difference is due to how "hidden state" is defined: in TorchAudio the value is saved after a layer, and in HF it is saved before a layer

Note that WAVLM_BASE and WAVLM_BASE_PLUS additionally set _normalize_waveform to False by default, which would also create a difference with the HF output (if preprocessing waveform with transformers.Wav2Vec2FeatureExtractor)

@nateanl @mthrok, perhaps we could modify the default config for WavLM bundles? It seems that features generated with encoder_layer_norm_first = True make more sense, since they don't have repeated rows described above

@mthrok
Copy link
Collaborator

mthrok commented Jun 26, 2023

@sgrigory Thanks for the investigation. Just to be sure, your suggestion is modifying the configuration defined in torchaudio.pipelines, right? Then we can go ahead and do it.

mthrok pushed a commit that referenced this issue Oct 17, 2023
The `encoder_layer_norm_first` should be set to True for the Large model of WavLM.
Address #3347
@mthrok
Copy link
Collaborator

mthrok commented Oct 17, 2023

Addressed via #3660

@mthrok mthrok closed this as completed Oct 17, 2023
mthrok pushed a commit to mthrok/audio that referenced this issue Oct 19, 2023
The `encoder_layer_norm_first` should be set to True for the Large model of WavLM.
Address pytorch#3347
mthrok added a commit that referenced this issue Oct 20, 2023
The `encoder_layer_norm_first` should be set to True for the Large model of WavLM.
Address #3347
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants