Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl/w2vbert] weight copy from meta w2vbert-2.0 #2392

Merged
merged 7 commits into from
Mar 11, 2024

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Mar 7, 2024

#2305

meta开源的w2vbert2.0 的权重是在4Mh的数据上pretrain训练(支持流式), 模型是conformer,大小600M, 输入是fbank

diff 只有三处

其他实现均一致,

此次迁移,有潜在三大收益

  • whisper 可做对比 以及嫁接模型等 多了fintune选择
  • 目前speech+llm 大部分方案,采用了whisper encoder, 大小和该encoder 一致, 给未来speech+llm 提供其他选择
  • wenet ssl/本身支持w2vbert的pretrain, 此次迁移相当于wenet 也有了预训练模型, 有了更多的可能性(比如使用基于wenet 快速构建情绪emb的预训练模型、continue pretrain on 业务数据 etc)

https://twitter.com/reach_vb/status/1750225679898071232
截屏2024-03-08 01 21 28

该pr TODO:

  • convert (rename ckpt name)
  • unit test

之后prTODO:

  • asr fintune egs
  • codebooks 迁移, support continue pretrain

@Mddct Mddct force-pushed the Mddct-w2vbert-weights branch from a56659e to f3e9bf1 Compare March 7, 2024 17:06
@Mddct Mddct mentioned this pull request Mar 7, 2024
3 tasks
@Mddct Mddct force-pushed the Mddct-w2vbert-weights branch from 51ffe46 to 9d10ad1 Compare March 8, 2024 18:02
@Mddct Mddct force-pushed the Mddct-w2vbert-weights branch from bb5b9e2 to 9b0d9f1 Compare March 9, 2024 19:21
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 9, 2024

def compute_w2vbert_fbank(sample,
                          num_mel_bins=23,
                          frame_length=25,
                          frame_shift=10,
                          dither=0.0):
    """ Extract Pretrain w2vbert(4.5M hours) fbank
    """
    sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift,
                           dither)
    mat = sample['feat']
    std, mean = torch.std_mean(mat, dim=0)
    mat = mat.subtract(mean).divide(std)
    sample['feat'] = mat
    return sample

删除这里,是因为这个可以放到w2vbert-conformer-600的拼帧的subsampling里边,未来导出模型,runitme自动支持, 不需要专门

@kobenaxie
Copy link
Contributor

def compute_w2vbert_fbank(sample,
                          num_mel_bins=23,
                          frame_length=25,
                          frame_shift=10,
                          dither=0.0):
    """ Extract Pretrain w2vbert(4.5M hours) fbank
    """
    sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift,
                           dither)
    mat = sample['feat']
    std, mean = torch.std_mean(mat, dim=0)
    mat = mat.subtract(mean).divide(std)
    sample['feat'] = mat
    return sample

删除这里,是因为这个可以放到w2vbert-conformer-600的拼帧的subsampling里边,未来导出模型,runitme自动支持, 不需要专门

需要考虑batch中padding对求均值和方差的影响吗?

@Mddct Mddct changed the title [ssl/w2vbert] weight copy from meta w2vbert-2.0 [WIP][ssl/w2vbert] weight copy from meta w2vbert-2.0 Mar 10, 2024
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 11, 2024

这里确实是个问题,后边会做个实验看看 或者干脆把拼帧换成卷积

@Mddct
Copy link
Collaborator Author

Mddct commented Mar 11, 2024

fairseq2有版本问题, 把ut放到wenet/test下太麻烦了,现把ut放到这个comment中

pip install fairseq2
from pathlib import Path
import pytest
import torch
import torchaudio

from wenet.dataset import processor

import fairseq2  # noqa
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
from fairseq2.memory import MemoryBlock

from wenet.transformer.encoder import ConformerEncoder
from wenet.utils.checkpoint import load_checkpoint
from wenet.dataset.processor import compute_w2vbert_fbank
from wenet.utils.mask import make_non_pad_mask

fbank_convert = WaveformToFbankConverter(
    num_mel_bins=80,
    waveform_scale=2**15,
    channel_last=True,
    standardize=True,
)
wav_file = '/Users/mddct/Downloads/ckpt/w2vbert2/test.wav'
audio_decoder = AudioDecoder(dtype=torch.float32)
with Path(wav_file).open("rb") as fb:
    block = MemoryBlock(fb.read())
decode_audio = audio_decoder(block)
w2vbert_waveform = decode_audio['waveform']
w2vbert_mat = fbank_convert(decode_audio)['fbank']

wenet_waveform, _ = torchaudio.load(wav_file)
fbank_args = {
    "num_mel_bins": 80,
    "frame_length": 25,
    "frame_shift": 10,
    "dither": 0.0,
}
sample = {'sample_rate': 16000, "wav": wenet_waveform, 'key': wav_file}
wenet_mat = processor.compute_w2vbert_fbank(sample, **fbank_args)['feat']
assert torch.allclose(w2vbert_waveform.transpose(0, 1), wenet_waveform)
assert torch.allclose(w2vbert_mat, wenet_mat, atol=9e-5, rtol=9e-4)

from fairseq2.data import Collater
from pathlib import Path
from seamless_communication.models.conformer_shaw import load_conformer_shaw_model
from fairseq2.nn.padding import get_seqs_and_padding_mask

collater = Collater(pad_value=1)

model = load_conformer_shaw_model("conformer_shaw",
                                  device=torch.device('cpu'),
                                  dtype=torch.float32)
model.eval()
src = collater(fbank_convert(decode_audio))["fbank"]
seqs, padding_mask = get_seqs_and_padding_mask(src)

with torch.inference_mode():
    fairseq_input, fairseq_mask = model.encoder_frontend(seqs, padding_mask)
    fairseq_out, _ = model.encoder(fairseq_input, fairseq_mask)

configs = {}
configs['input_dim'] = 80
configs['output_dim'] = 1024

configs['encoder'] = 'conformer'
configs['encoder_conf'] = {}
configs['encoder_conf']['causal'] = True
configs['encoder_conf']['gradient_checkpointing'] = True
configs['encoder_conf']['input_layer'] = 'stack_n_frames'
configs['encoder_conf']['output_size'] = 1024
configs['encoder_conf']['attention_heads'] = 16
configs['encoder_conf']['linear_units'] = 4096
configs['encoder_conf']['num_blocks'] = 24
configs['encoder_conf']['dropout_rate'] = 0.1
configs['encoder_conf']['positional_dropout_rate'] = 0.0
configs['encoder_conf']['attention_dropout_rate'] = 0.0
configs['encoder_conf']['normalize_before'] = True
configs['encoder_conf']['use_dynamic_chunk'] = False
configs['encoder_conf']['use_dynamic_left_chunk'] = False
configs['encoder_conf']['pos_enc_layer_type'] = "no_pos"
configs['encoder_conf']['static_chunk_size'] = -1
configs['encoder_conf']['activation_type'] = "swish"
configs['encoder_conf']['conv_bias'] = False
configs['encoder_conf']['selfattention_layer_type'] = 'shaw_rel_selfattn'
configs['encoder_conf']['cnn_module_kernel'] = 31
configs['encoder_conf']['cnn_module_norm'] = 'layer_norm'


class AsrModel(torch.nn.Module):

    def __init__(self, encoder) -> None:
        super().__init__()
        self.encoder = encoder


encoder = ConformerEncoder(input_size=80, **configs['encoder_conf'])
model = AsrModel(encoder)
load_checkpoint(
    model,
    '/Users/mddct/Downloads/ckpt/w2vbert2/wenet_w2vbert/wenet_w2vbert_conformer_600m.pt'
)
model.eval()

mask = torch.ones(1, 1, fairseq_out.size(2))
wenet_out = fairseq_input
with torch.inference_mode():
    for layer in encoder.encoders:
        wenet_out, wenet_out_mask, _, _ = layer(wenet_out, mask,
                                                torch.zeros(0))
assert torch.allclose(fairseq_out, wenet_out, atol=9e-6)
截屏2024-03-11 16 22 51

@Mddct Mddct force-pushed the Mddct-w2vbert-weights branch from f6571d9 to 4fba433 Compare March 11, 2024 08:26
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 11, 2024

有四个注意点:

  • w2vbert2 conformer的subsampling选择了拼帧, 后边可以考虑像wenet whisper一样 直接换个conv2d4
  • 提取fbank的时候,他们选择了在dim=0的地方求均值和方差, 这个不利于流失(官方说模型是流式的 所以这里就有个问题), 后边fintune也需要换成wenet compute_fbank
  • 后边可以考虑把whisper decoder 拿过来
  • conformer 多个实现中 是没有最后一个after norm的

@Mddct Mddct requested review from robin1001 and xingchensong March 11, 2024 08:38
@Mddct Mddct changed the title [WIP][ssl/w2vbert] weight copy from meta w2vbert-2.0 [ssl/w2vbert] weight copy from meta w2vbert-2.0 Mar 11, 2024
@xingchensong xingchensong merged commit 3258fb6 into main Mar 11, 2024
5 of 6 checks passed
@xingchensong xingchensong deleted the Mddct-w2vbert-weights branch March 11, 2024 09:03
@thsxbw
Copy link

thsxbw commented Jun 21, 2024

Is anyone still working on this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants