diff --git a/examples/pytorch/_tests_requirements.txt b/examples/pytorch/_tests_requirements.txt index ce2394d8968a..f5b4b5170d5d 100644 --- a/examples/pytorch/_tests_requirements.txt +++ b/examples/pytorch/_tests_requirements.txt @@ -3,6 +3,7 @@ scikit-learn seqeval psutil sacrebleu >= 1.4.12 +accelerate >= 0.5.0 rouge-score tensorflow_datasets matplotlib diff --git a/examples/pytorch/speech-pretraining/README.md b/examples/pytorch/speech-pretraining/README.md new file mode 100644 index 000000000000..0e6795a61a68 --- /dev/null +++ b/examples/pytorch/speech-pretraining/README.md @@ -0,0 +1,124 @@ + + +# Speech Recognition Pre-Training + + +## Wav2Vec2 Speech Pre-Training + +The script [`run_speech_wav2vec2_pretraining_no_trainer.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py) can be used to pre-train a [Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html?highlight=wav2vec2) model from scratch. + +In the script [`run_speech_wav2vec2_pretraining_no_trainer`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py), a Wav2Vec2 model is pre-trained on audio data alone using [Wav2Vec2's contrastive loss objective](https://arxiv.org/abs/2006.11477). + +The following examples show how to fine-tune a `"base"`-sized Wav2Vec2 model as well as a `"large"`-sized Wav2Vec2 model using [`accelerate`](https://github.com/huggingface/accelerate). + + +--- +**NOTE 1** + +Wav2Vec2's pre-training is known to be quite unstable. +It is advised to do a couple of test runs with a smaller dataset, +*i.e.* `--dataset_config_names clean clean`, `--dataset_split_names validation test` +to find good hyper-parameters for `learning_rate`, `batch_size`, `num_warmup_steps`, +and the optimizer. +A good metric to observe during training is the gradient norm which should ideally be between 0.5 and 2. + +--- + +--- +**NOTE 2** + +When training a model on large datasets it is recommended to run the data preprocessing +in a first run in a **non-distributed** mode via `--preprocessing_only` so that +when running the model in **distributed** mode in a second step the preprocessed data +can easily be loaded on each distributed device. + +--- + +### Demo + +In this demo run we pre-train a `"base-sized"` Wav2Vec2 model simply only on the validation +and test data of [librispeech_asr](https://huggingface.co/datasets/librispeech_asr). + +The demo is run on two Titan RTX (24 GB RAM each). In case you have less RAM available +per device, consider reducing `--batch_size` and/or the `--max_duration_in_seconds`. + + +```bash +accelerate launch run_wav2vec2_pretraining_no_trainer.py \ + --dataset_name="librispeech_asr" \ + --dataset_config_names clean clean \ + --dataset_split_names validation test \ + --model_name_or_path="patrickvonplaten/wav2vec2-base-v2" \ + --output_dir="./wav2vec2-pretrained-demo" \ + --max_train_steps="20000" \ + --num_warmup_steps="32000" \ + --gradient_accumulation_steps="8" \ + --learning_rate="0.005" \ + --weight_decay="0.01" \ + --max_duration_in_seconds="20.0" \ + --min_duration_in_seconds="2.0" \ + --logging_steps="1" \ + --saving_steps="10000" \ + --per_device_train_batch_size="8" \ + --per_device_eval_batch_size="8" \ + --adam_beta1="0.9" \ + --adam_beta2="0.98" \ + --adam_epsilon="1e-06" \ + --gradient_checkpointing \ +``` + +The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/wav2vec2-pretrained-demo/reports/Wav2Vec2-PreTraining-Demo-Run--VmlldzoxMDk3MjAw?accessToken=oa05s1y57lizo2ocxy3k01g6db1u4pt8m6ur2n8nl4cb0ug02ms2cw313kb8ruch). + +### Base + +TODO (currently running...) + + +### Large + +To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60), +on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run: + +```bash +accelerate launch run_pretrain_no_trainer.py \ + --dataset_name=librispeech_asr \ + --dataset_config_names clean clean other \ + --dataset_split_names train.100 train.360 train.500 \ + --output_dir=./test \ + --max_train_steps=200000 \ + --num_warmup_steps=32000 \ + --gradient_accumulation_steps=8 \ + --learning_rate=0.001 \ + --weight_decay=0.01 \ + --max_duration_in_seconds=20.0 \ + --min_duration_in_seconds=2.0 \ + --model_name_or_path=./ + --logging_steps=1 \ + --saving_steps=10000 \ + --per_device_train_batch_size=2 \ + --per_device_eval_batch_size=4 \ + --adam_beta1=0.9 \ + --adam_beta2=0.98 \ + --adam_epsilon=1e-06 \ + --gradient_checkpointing \ +``` + +The experiment was run on 8 GPU V100 (16 GB RAM each) for 7 days. +In case you have more than 8 GPUs available for a higher effective `batch_size`, +it is recommended to increase the `learning_rate` to `0.005` for faster convergence. + +The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/pretraining-wav2vec2/reports/Wav2Vec2-Large--VmlldzoxMTAwODM4?accessToken=wm3qzcnldrwsa31tkvf2pdmilw3f63d4twtffs86ou016xjbyilh55uoi3mo1qzc) and the checkpoint pretrained for 120,000 steps can be accessed [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-repro-960h-libri-120k-steps) diff --git a/examples/pytorch/speech-pretraining/requirements.txt b/examples/pytorch/speech-pretraining/requirements.txt new file mode 100644 index 000000000000..ea09a02c2d6e --- /dev/null +++ b/examples/pytorch/speech-pretraining/requirements.txt @@ -0,0 +1,4 @@ +datasets >= 1.12.0 +torch >= 1.5 +torchaudio +accelerate >= 0.5.0 diff --git a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py new file mode 100755 index 000000000000..657c6e844b35 --- /dev/null +++ b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py @@ -0,0 +1,700 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +""" Pre-Training a 🤗 Wav2Vec2 model on unlabeled audio data """ + +import argparse +import logging +import math +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Union + +import datasets +import torch +import torchaudio +from datasets import DatasetDict, concatenate_datasets, load_dataset +from torch.utils.data.dataloader import DataLoader +from tqdm.auto import tqdm + +import transformers +from accelerate import Accelerator +from huggingface_hub import Repository +from transformers import ( + AdamW, + SchedulerType, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2ForPreTraining, + get_scheduler, + is_wandb_available, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices + + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_names", + nargs="+", + type=str, + required=True, + help="The configuration names of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_split_names", + nargs="+", + type=str, + required=True, + help="The names of the training data set splits to use (via the datasets library).", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--preprocessing_only", + action="store_true", + help="Only run the preprocessing script to be cached for future use", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="Where do you want to store the pretrained models downloaded from huggingface.co", + ) + parser.add_argument( + "--validation_split_percentage", + type=int, + default=1, + help="Percentage of training data that should be used for validation if no validation is present in dataset.", + ) + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Number of steps between each logging", + ) + parser.add_argument( + "--saving_steps", + type=int, + default=500, + help="Number of steps between each logging", + ) + parser.add_argument( + "--audio_column_name", + type=str, + default="file", + help="Column in the dataset that contains speech file path. Defaults to 'file'", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="If True, use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") + parser.add_argument( + "--max_gumbel_temperature", + type=float, + default=2.0, + help="Maximum temperature for gumbel softmax.", + ) + parser.add_argument( + "--min_gumbel_temperature", + type=float, + default=0.5, + help="Minimum temperature for gumbel softmax.", + ) + parser.add_argument( + "--gumbel_temperature_decay", type=float, default=0.999995, help="Decay of gumbel temperature during training." + ) + parser.add_argument( + "--max_duration_in_seconds", + type=float, + default=5.0, + help="Filter out audio files that are longer than `max_duration_in_seconds` seconds", + ) + parser.add_argument( + "--min_duration_in_seconds", + type=float, + default=3.0, + help="Filter out audio files that are shorter than `min_duration_in_seconds` seconds", + ) + parser.add_argument( + "--pad_to_multiple_of", + type=int, + default=None, + help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).", + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="Beta1 for AdamW optimizer", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="Beta2 for AdamW optimizer", + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-8, + help="Epsilon for AdamW optimizer", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + args = parser.parse_args() + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + return args + + +@dataclass +class DataCollatorForWav2Vec2Pretraining: + """ + Data collator that will dynamically pad the inputs received and prepare masked indices + for self-supervised pretraining. + + Args: + model (:class:`~transformers.Wav2Vec2ForPreTraining`): + The Wav2Vec2 model used for pretraining. The data collator needs to have access + to config and ``_get_feat_extract_output_lengths`` function for correct padding. + feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + model: Wav2Vec2ForPreTraining + feature_extractor: Wav2Vec2FeatureExtractor + padding: Union[bool, str] = "longest" + pad_to_multiple_of: Optional[int] = None + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + # reformat list to dict and set to pytorch format + batch = self.feature_extractor.pad( + features, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + device = batch["input_values"].device + batch_size = batch["input_values"].shape[0] + + mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) + + # make sure that no loss is computed on padded inputs + if batch.get("attention_mask") is not None: + # compute real output lengths according to convolution formula + batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask( + mask_indices_seq_length, batch["attention_mask"] + ) + + features_shape = (batch_size, mask_indices_seq_length) + + # sample randomly masked indices + mask_time_indices = _compute_mask_indices( + features_shape, + self.model.config.mask_time_prob, + self.model.config.mask_time_length, + attention_mask=batch.get("sub_attention_mask"), + ) + + # sample negative indices + sampled_negative_indices = _sample_negative_indices( + features_shape, + self.model.config.num_negatives, + mask_time_indices=mask_time_indices, + ) + batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device) + batch["sampled_negative_indices"] = torch.tensor(sampled_negative_indices, dtype=torch.long, device=device) + + return batch + + +def multiply_grads(params, c): + """Multiplies grads by a constant *c*.""" + for p in params: + if p.grad is not None: + if torch.is_tensor(c): + c = c.to(p.grad.device) + p.grad.data.mul_(c) + + +def get_grad_norm(params, scale=1): + """Compute grad norm given a gradient scale.""" + total_norm = 0.0 + for p in params: + if p.grad is not None: + param_norm = (p.grad.detach().data / scale).norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + return total_norm + + +def main(): + # See all possible arguments in src/transformers/args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + args = parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator() + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + + # set up weights and biases if available + if is_wandb_available(): + import wandb + + wandb.init(project=args.output_dir.split("/")[-1]) + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub and not args.preprocessing_only: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # 1. Download and create train, validation dataset + # We load all dataset configuration and datset split pairs passed in + # ``args.dataset_config_names`` and ``args.dataset_split_names`` + datasets_splits = [] + for dataset_config_name, train_split_name in zip(args.dataset_config_names, args.dataset_split_names): + # load dataset + dataset_split = load_dataset( + args.dataset_name, dataset_config_name, split=train_split_name, cache_dir=args.cache_dir + ) + datasets_splits.append(dataset_split) + + # Next, we concatenate all configurations and splits into a single training dataset + raw_datasets = DatasetDict() + if len(datasets_splits) > 1: + raw_datasets["train"] = concatenate_datasets(datasets_splits).shuffle(seed=args.seed) + else: + raw_datasets["train"] = datasets_splits[0] + + # Take ``args.validation_split_percentage`` from the training dataset for the validation_split_percentage + num_validation_samples = raw_datasets["train"].num_rows * args.validation_split_percentage // 100 + + if num_validation_samples == 0: + raise ValueError( + "`args.validation_split_percentage` is less than a single sample " + f"for {len(raw_datasets['train'])} training samples. Increase " + "`args.num_validation_split_percentage`. " + ) + + raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples)) + raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows)) + + # 2. Preprocess audio: load, resample, normalize and truncate + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path) + + # only normalized-inputs-training is supported + if not feature_extractor.do_normalize: + raise ValueError( + "Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``" + ) + + # set max & min audio length in number of samples + max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate) + min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate) + + resampler = None + if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3": + # TODO(PVP) - remove hard-coded 48_000 after audio feature is merged + resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate) + + def prepare_dataset(batch): + speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name]) + speech_array = speech_array.squeeze() + + # if necessary resample audio + if resampler is not None: + # TODO(PVP) - remove hard-coded 48_000 after audio feature is merged + speech_array = resampler(speech_array) + sampling_rate = resampler.new_freq + + speech_array = speech_array.numpy() + inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True) + batch["input_values"] = inputs.input_values[0] + return batch + + # load audio files into numpy arrays + with accelerator.main_process_first(): + vectorized_datasets = raw_datasets.map( + prepare_dataset, + num_proc=args.preprocessing_num_workers, + remove_columns=raw_datasets["train"].column_names, + load_from_cache_file=not args.overwrite_cache, + ) + vectorized_datasets = vectorized_datasets.filter( + lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with ``args.preprocessing_only`` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step ``args.preprocessing_only`` can then be set to `False` to load the + # cached dataset + if args.preprocessing_only: + return + + # 3. Load model + config = Wav2Vec2Config.from_pretrained(args.model_name_or_path) + + # pretraining is only supported for "newer" stable layer norm architecture + # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 + if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": + raise ValueError( + "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" + ) + + # initialize random model + model = Wav2Vec2ForPreTraining(config) + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + model.gradient_checkpointing_enable() + + # 4. Define data collator, optimizer and scheduler + data_collator = DataCollatorForWav2Vec2Pretraining( + model=model, feature_extractor=feature_extractor, pad_to_multiple_of=args.pad_to_multiple_of + ) + train_dataloader = DataLoader( + vectorized_datasets["train"], + shuffle=True, + collate_fn=data_collator, + batch_size=args.per_device_train_batch_size, + ) + eval_dataloader = DataLoader( + vectorized_datasets["validation"], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + optimizer = AdamW( + list(model.parameters()), + lr=args.learning_rate, + betas=[args.adam_beta1, args.adam_beta2], + eps=args.adam_epsilon, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 5. Train + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(vectorized_datasets['train'])}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + completed_steps = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + for epoch in range(args.num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + # compute num of losses + num_losses = batch["mask_time_indices"].sum() + sub_attention_mask = batch.pop("sub_attention_mask", None) + sub_attention_mask = ( + sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"]) + ) + percent_masked = num_losses / sub_attention_mask.sum() + + # forward + outputs = model(**batch) + + # divide loss by gradient accumulation steps since gradients + # are accumulated for multiple backward passes in PyTorch + loss = outputs.loss / args.gradient_accumulation_steps + accelerator.backward(loss) + + # make sure that `num_losses` is summed for distributed training + # and average gradients over losses of all devices + if accelerator.state.num_processes > 1: + num_losses = accelerator.gather(num_losses).sum() + gradient_multiplier = accelerator.state.num_processes / num_losses + multiply_grads(model.module.parameters(), gradient_multiplier) + else: + multiply_grads(model.parameters(), 1 / num_losses) + + # update step + if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + + # compute grad norm for monitoring + scale = ( + accelerator.scaler._scale.item() + if hasattr(accelerator, "scaler") and accelerator.scaler is not None + else 1 + ) + if accelerator.state.num_processes > 1: + grad_norm = get_grad_norm(model.module.parameters(), scale) + else: + grad_norm = get_grad_norm(model.parameters(), scale) + + # update parameters + optimizer.step() + optimizer.zero_grad() + + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + elif accelerator.is_local_main_process: + progress_bar.write( + "Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..." + ) + + # update gumbel temperature + gumbel_temperature = max( + args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps, + args.min_gumbel_temperature, + ) + if hasattr(model, "module"): + model.module.set_gumbel_temperature(gumbel_temperature) + else: + model.set_gumbel_temperature(gumbel_temperature) + + progress_bar.update(1) + completed_steps += 1 + + # 6. Log all results + if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0: + loss.detach() + outputs.contrastive_loss.detach() + outputs.diversity_loss.detach() + + if accelerator.state.num_processes > 1: + loss = accelerator.gather(loss).sum() + outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum() + outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum() + percent_masked = accelerator.gather(percent_masked).sum() + + train_logs = { + "loss": (loss * args.gradient_accumulation_steps) / num_losses, + "constrast_loss": outputs.contrastive_loss / num_losses, + "div_loss": outputs.diversity_loss / num_losses, + "%_mask_idx": percent_masked / accelerator.num_processes, + "ppl": outputs.codevector_perplexity, + "lr": torch.tensor(optimizer.param_groups[0]["lr"]), + "temp": torch.tensor(gumbel_temperature), + "grad_norm": torch.tensor(grad_norm), + } + log_str = "" + for k, v in train_logs.items(): + log_str += "| {}: {:.3e}".format(k, v.item()) + + if accelerator.is_local_main_process: + progress_bar.write(log_str) + if is_wandb_available(): + wandb.log(train_logs) + + # save model every `args.saving_steps` steps + if (step + 1) % (args.gradient_accumulation_steps * args.saving_steps) == 0: + if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + + if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process: + repo.push_to_hub(commit_message=f"Training in progress step {completed_steps}", blocking=False) + + # if completed steps > `args.max_train_steps` stop + if completed_steps >= args.max_train_steps: + break + + # 7. Validate! + model.eval() + + # init logs + val_logs = { + "val_loss": 0, + "val_contrastive_loss": 0, + "val_diversity_loss": 0, + "val_num_losses": 0, + } + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch.pop("sub_attention_mask", None) + outputs = model(**batch) + + val_logs["val_loss"] += outputs.loss + val_logs["val_contrastive_loss"] += outputs.contrastive_loss + val_logs["val_diversity_loss"] += outputs.diversity_loss + val_logs["val_num_losses"] += batch["mask_time_indices"].sum() + + # sum over devices in multi-processing + if accelerator.num_processes > 1: + val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()} + + val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()} + + log_str = "" + for k, v in val_logs.items(): + log_str += "| {}: {:.3e}".format(k, v.item()) + + if accelerator.is_local_main_process: + progress_bar.write(log_str) + if is_wandb_available(): + wandb.log(val_logs) + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + if accelerator.is_main_process: + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training") + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 75c46dd8f441..0a9839915550 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -23,6 +23,7 @@ import torch +from transformers import Wav2Vec2ForPreTraining from transformers.file_utils import is_apex_available from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device @@ -41,6 +42,7 @@ "image-classification", "speech-recognition", "audio-classification", + "speech-pretraining", ] ] sys.path.extend(SRC_DIRS) @@ -59,6 +61,7 @@ import run_summarization import run_swag import run_translation + import run_wav2vec2_pretraining_no_trainer logging.basicConfig(level=logging.DEBUG) @@ -447,3 +450,32 @@ def test_run_audio_classification(self): run_audio_classification.main() result = get_results(tmp_dir) self.assertLess(result["eval_loss"], result["train_loss"]) + + def test_run_wav2vec2_pretraining(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_wav2vec2_pretraining_no_trainer.py + --output_dir {tmp_dir} + --model_name_or_path hf-internal-testing/tiny-random-wav2vec2 + --dataset_name patrickvonplaten/librispeech_asr_dummy + --dataset_config_names clean + --dataset_split_names validation + --learning_rate 1e-4 + --per_device_train_batch_size 2 + --per_device_eval_batch_size 2 + --preprocessing_num_workers 16 + --max_train_steps 5 + --validation_split_percentage 5 + --seed 42 + """.split() + + if is_cuda_and_apex_available(): + testargs.append("--fp16") + + with patch.object(sys, "argv", testargs): + run_wav2vec2_pretraining_no_trainer.main() + model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir) + self.assertIsNotNone(model) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 00cfb896006f..5bc1fc2345d8 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -48,13 +48,13 @@ def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, - device: torch.device, - attention_mask: Optional[torch.tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, -) -> torch.tensor: +) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for - ASR `__. + ASR `__. Note that this method is not optimized to run on TPU and should be run + on CPU as part of the preprocessing during training. Args: shape: the the shape for which to compute masks. @@ -64,7 +64,6 @@ def _compute_mask_indices( however due to overlaps, the actual number will be smaller (unless no_overlap is True) mask_length: size of the mask min_masks: minimum number of masked spans - """ batch_size, sequence_length = shape @@ -76,42 +75,64 @@ def _compute_mask_indices( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" ) - # compute number of masked spans in batch - num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item()) - num_masked_spans = max(num_masked_spans, min_masks) + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) - # make sure num masked indices <= sequence_length - if num_masked_spans * mask_length > sequence_length: - num_masked_spans = sequence_length // mask_length + # make sure num masked indices <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) # SpecAugment mask to fill - spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + dummy_mask_idx = spec_aug_mask_idx[0] - # uniform distribution to sample from, make sure that offset samples are < sequence_length - uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device) + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) - # get random indices to mask - spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans) + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans - spec_aug_mask_idxs = ( - spec_aug_mask_idxs.unsqueeze(dim=-1) - .expand((batch_size, num_masked_spans, mask_length)) - .reshape(batch_size, num_masked_spans * mask_length) + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) - offsets = ( - torch.arange(mask_length, device=device)[None, None, :] - .expand((batch_size, num_masked_spans, mask_length)) - .reshape(batch_size, num_masked_spans * mask_length) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # scatter indices to mask - spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True) - - if attention_mask is not None: - # make sure padded input ids cannot be masked - spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False) + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask @@ -257,6 +278,7 @@ def __init__(self, config): f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" ) self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False def _freeze_parameters(self): for param in self.parameters(): @@ -264,8 +286,26 @@ def _freeze_parameters(self): def forward(self, input_values): hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self.training: + hidden_states.requires_grad = True + for conv_layer in self.conv_layers: - hidden_states = conv_layer(hidden_states) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(conv_layer), + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) return hidden_states @@ -864,10 +904,10 @@ def _mask_hidden_states( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, - device=hidden_states.device, attention_mask=attention_mask, min_masks=2, ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: @@ -876,9 +916,11 @@ def _mask_hidden_states( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, - device=hidden_states.device, ) - hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[ + :, None + ].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 return hidden_states diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 6b4a2522823e..7a566654ffb5 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch Wav2Vec2 model. """ +import math import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -87,7 +88,7 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput): Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions. Args: - loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`): + loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`): Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official paper `__ . (classification) loss. projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`): @@ -107,6 +108,10 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + contrastive_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`): + The contrastive loss (L_m) as stated in the `official paper `__ . + diversity_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`): + The diversity loss (L_d) as stated in the `official paper `__ . """ loss: Optional[torch.FloatTensor] = None @@ -115,19 +120,21 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput): codevector_perplexity: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + contrastive_loss: Optional[torch.FloatTensor] = None + diversity_loss: Optional[torch.FloatTensor] = None def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, - device: torch.device, - attention_mask: Optional[torch.tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, -) -> torch.tensor: +) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for - ASR `__. + ASR `__. Note that this method is not optimized to run on TPU and should be run + on CPU as part of the preprocessing during training. Args: shape: the the shape for which to compute masks. @@ -137,7 +144,6 @@ def _compute_mask_indices( however due to overlaps, the actual number will be smaller (unless no_overlap is True) mask_length: size of the mask min_masks: minimum number of masked spans - """ batch_size, sequence_length = shape @@ -149,46 +155,104 @@ def _compute_mask_indices( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" ) - # compute number of masked spans in batch - num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item()) - num_masked_spans = max(num_masked_spans, min_masks) + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) - # make sure num masked indices <= sequence_length - if num_masked_spans * mask_length > sequence_length: - num_masked_spans = sequence_length // mask_length + # make sure num masked indices <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) # SpecAugment mask to fill - spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) - # uniform distribution to sample from, make sure that offset samples are < sequence_length - uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device) + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) - # get random indices to mask - spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans) + # pick first sampled index that will serve as a dummy index to pad vector + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans - spec_aug_mask_idxs = ( - spec_aug_mask_idxs.unsqueeze(dim=-1) - .expand((batch_size, num_masked_spans, mask_length)) - .reshape(batch_size, num_masked_spans * mask_length) + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) - offsets = ( - torch.arange(mask_length, device=device)[None, None, :] - .expand((batch_size, num_masked_spans, mask_length)) - .reshape(batch_size, num_masked_spans * mask_length) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # scatter indices to mask - spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True) - - if attention_mask is not None: - # make sure padded input ids cannot be masked - spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False) + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask +def _sample_negative_indices( + features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None +): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length = features_shape + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + sequence_length_range = np.arange(sequence_length) + + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) + + mask_time_indices = ( + mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool) + ) + + for batch_idx in range(batch_size): + high = mask_time_indices[batch_idx].sum() - 1 + mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] + + feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) + sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_indices[sampled_indices >= feature_indices] += 1 + + # remap to actual indices + sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] + + # correct for batch size + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + return sampled_negative_indices + + class Wav2Vec2NoLayerNormConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -326,6 +390,7 @@ def __init__(self, config): f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" ) self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False def _freeze_parameters(self): for param in self.parameters(): @@ -333,8 +398,26 @@ def _freeze_parameters(self): def forward(self, input_values): hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self.training: + hidden_states.requires_grad = True + for conv_layer in self.conv_layers: - hidden_states = conv_layer(hidden_states) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(conv_layer), + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) return hidden_states @@ -778,7 +861,7 @@ def __init__(self, config): self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) # can be decayed for training - self.temperature = 1 + self.temperature = 2 def set_temperature(self, temperature: int): self.temperature = temperature @@ -844,8 +927,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): config_class = Wav2Vec2Config base_model_prefix = "wav2vec2" - supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -854,22 +937,31 @@ def _init_weights(self, module): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() nn.init.uniform_(module.codevectors) + elif isinstance(module, Wav2Vec2PositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, Wav2Vec2FeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) - if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.data.zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)): - module.gradient_checkpointing = value + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ @@ -898,6 +990,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureExtractor)): + module.gradient_checkpointing = value + WAV_2_VEC_2_START_DOCSTRING = r""" Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations @@ -1001,10 +1097,10 @@ def _mask_hidden_states( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, - device=hidden_states.device, attention_mask=attention_mask, min_masks=2, ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: @@ -1013,9 +1109,11 @@ def _mask_hidden_states( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, - device=hidden_states.device, ) - hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[ + :, None + ].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 return hidden_states @@ -1101,11 +1199,13 @@ def __init__(self, config: Wav2Vec2Config): self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) - self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) - self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.init_weights() + # make sure that project_hid & project_q are initialized like normal linear layers + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + def set_gumbel_temperature(self, temperature: int): """ Set the Gumbel softmax temperature to a given value. Only necessary for training @@ -1119,61 +1219,12 @@ def freeze_feature_extractor(self): """ self.wav2vec2.feature_extractor._freeze_parameters() - @staticmethod - def _sample_negatives( - features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None - ): - """ - Sample `num_negatives` vectors from feature vectors. - """ - batch_size, sequence_length, hidden_size = features.shape - if sequence_length <= 1: - raise ValueError( - f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})." - ) - - features = features.view(-1, hidden_size) # BTC => (BxT)C - - with torch.no_grad(): - # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = [] - for batch_idx in range(batch_size): - high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1 - sampled_indices_slice = torch.randint( - 0, high, size=(num_negatives * sequence_length,), device=features.device - ) - sampled_negative_indices.append(sampled_indices_slice) - - sampled_negative_indices = torch.stack(sampled_negative_indices) - - # generate indices of the positive vectors themselves, repeat them `num_negatives` times - feature_indices = ( - torch.arange(sequence_length, device=features.device)[:, None] - .expand(sequence_length, num_negatives) - .flatten() - ) - - # avoid sampling the same positive vector, but keep the distribution uniform - sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 - - # correct for batch size - for batch_idx in range(1, batch_size): - sampled_negative_indices[batch_idx] += batch_idx * sequence_length - - # take negative vectors from sampled indices - sampled_negatives = features[sampled_negative_indices.view(-1)] - sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute( - 2, 0, 1, 3 - ) - - return sampled_negatives - @staticmethod def compute_contrastive_logits( target_features: torch.FloatTensor, negative_features: torch.FloatTensor, predicted_features: torch.FloatTensor, - temperature: int = 1, + temperature: int = 0.1, ): """ Compute logits for contrastive loss based using cosine similarity as the distance measure between @@ -1196,6 +1247,7 @@ def forward( input_values, attention_mask=None, mask_time_indices=None, + sampled_negative_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -1204,6 +1256,9 @@ def forward( mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict masked extracted features in `config.proj_codevector_dim` space. + sampled_negative_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length, num_negatives)`, `optional`): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. Returns: @@ -1270,21 +1325,30 @@ def forward( # 2. quantize all (unmasked) extracted features and project to final vq dim extract_features = self.dropout_features(outputs[1]) - quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices) + + if attention_mask is not None: + # compute reduced attention_mask correponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices=mask_time_indices + ) quantized_features = self.project_q(quantized_features) - loss = None - if self.training: + loss = contrastive_loss = diversity_loss = None + if sampled_negative_indices is not None: + batch_size, sequence_length, hidden_size = quantized_features.shape + # for training, we sample negatives # 3. sample K negatives (distractors) quantized states for contrastive loss # if attention_mask is passed, make sure that padded feature vectors cannot be sampled - if attention_mask is not None: - # compute reduced attention_mask correponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) - - negative_quantized_features = self._sample_negatives( - quantized_features, self.config.num_negatives, attention_mask=attention_mask - ) + # sample negative quantized vectors BTC => (BxT)C + negative_quantized_features = quantized_features.view(-1, hidden_size)[ + sampled_negative_indices.long().view(-1) + ] + negative_quantized_features = negative_quantized_features.view( + batch_size, sequence_length, -1, hidden_size + ).permute(2, 0, 1, 3) # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf @@ -1298,18 +1362,19 @@ def forward( # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), # its cosine similarity will be masked neg_is_pos = (quantized_features == negative_quantized_features).all(-1) + if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) - preds = logits.transpose(0, 2).reshape(-1, logits.size(0)) + logits = logits.transpose(0, 2).reshape(-1, logits.size(0)) target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() - contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="sum") + contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum") # 7. compute diversity loss: \mathbf{L}_d num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups - diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors + diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss @@ -1326,6 +1391,8 @@ def forward( codevector_perplexity=codevector_perplexity, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + contrastive_loss=contrastive_loss, + diversity_loss=diversity_loss, ) diff --git a/tests/test_modeling_hubert.py b/tests/test_modeling_hubert.py index 6faed43ed61c..38e47103c491 100644 --- a/tests/test_modeling_hubert.py +++ b/tests/test_modeling_hubert.py @@ -586,7 +586,8 @@ def test_compute_mask_indices(self): mask_prob = 0.5 mask_length = 1 - mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + mask = torch.from_numpy(mask).to(torch_device) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) @@ -596,7 +597,8 @@ def test_compute_mask_indices_overlap(self): mask_prob = 0.5 mask_length = 4 - mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + mask = torch.from_numpy(mask).to(torch_device) # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in mask.sum(axis=-1): diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 9c0dc7ee9e0c..2a62ea97f13b 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -40,7 +40,11 @@ Wav2Vec2Model, Wav2Vec2Processor, ) - from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices + from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2GumbelVectorQuantizer, + _compute_mask_indices, + _sample_negative_indices, + ) class Wav2Vec2ModelTester: @@ -405,6 +409,12 @@ def test_initialization(self): "masked_spec_embed", "codevectors", "quantizer.weight_proj.weight", + "project_hid.weight", + "project_hid.bias", + "project_q.weight", + "project_q.bias", + "feature_projection.projection.weight", + "feature_projection.projection.bias", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -605,6 +615,12 @@ def test_initialization(self): "masked_spec_embed", "codevectors", "quantizer.weight_proj.weight", + "project_hid.weight", + "project_hid.bias", + "project_q.weight", + "project_q.bias", + "feature_projection.projection.weight", + "feature_projection.projection.bias", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -640,28 +656,37 @@ def test_model_for_pretraining(self): features_shape = ( inputs_dict["input_values"].shape[0], - model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])), + model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]), ) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, - device=inputs_dict["input_values"].device, min_masks=2, - ).to(torch_device) + ) + sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices) + + mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device) + sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device) loss = model( inputs_dict["input_values"], attention_mask=inputs_dict["attention_mask"], mask_time_indices=mask_time_indices, + sampled_negative_indices=sampled_negative_indices, ).loss + # more losses mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True + + sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy()) + sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device) loss_more_masked = model( inputs_dict["input_values"], attention_mask=inputs_dict["attention_mask"], mask_time_indices=mask_time_indices, + sampled_negative_indices=sampled_negative_indices, ).loss # loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted @@ -727,7 +752,8 @@ def test_compute_mask_indices(self): mask_prob = 0.5 mask_length = 1 - mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + mask = torch.from_numpy(mask).to(torch_device) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) @@ -737,7 +763,8 @@ def test_compute_mask_indices_overlap(self): mask_prob = 0.5 mask_length = 4 - mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + mask = torch.from_numpy(mask).to(torch_device) # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal for batch_sum in mask.sum(axis=-1): @@ -753,8 +780,9 @@ def test_compute_mask_indices_attn_mask_overlap(self): attention_mask[:2, sequence_length // 2 :] = 0 mask = _compute_mask_indices( - (batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask + (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask ) + mask = torch.from_numpy(mask).to(torch_device) for batch_sum in mask.sum(axis=-1): self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) @@ -785,8 +813,11 @@ def test_sample_negatives(self): ) # each value in vector consits of same value features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous() - negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives) - + # sample negative indices + sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None) + sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device) + negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)] + negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3) self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size)) # make sure no negatively sampled vector is actually a positive one @@ -796,15 +827,15 @@ def test_sample_negatives(self): # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) - def test_sample_negatives_with_attn_mask(self): + def test_sample_negatives_with_mask(self): batch_size = 2 sequence_length = 10 hidden_size = 4 num_negatives = 3 # second half of last input tensor is padded - attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) - attention_mask[-1, sequence_length // 2 :] = 0 + mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device) + mask[-1, sequence_length // 2 :] = 0 features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view( sequence_length, hidden_size @@ -812,9 +843,15 @@ def test_sample_negatives_with_attn_mask(self): features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous() # replace masked feature vectors with -100 to test that those are not sampled - features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100) + features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100) - negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask) + # sample negative indices + sampled_negative_indices = _sample_negative_indices( + (batch_size, sequence_length), num_negatives, mask.cpu().numpy() + ) + sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device) + negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)] + negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3) self.assertTrue((negatives >= 0).all().item()) @@ -924,16 +961,11 @@ def test_inference_ctc_robust_batched(self): ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) - # Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works - # correctly + @unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU") def test_inference_integration(self): - return - model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base") model.to(torch_device) - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "facebook/wav2vec2-base", return_attention_mask=True - ) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") input_speech = self._load_datasamples(2) inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True) @@ -943,19 +975,18 @@ def test_inference_integration(self): model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])), ) - torch.manual_seed(0) + np.random.seed(4) mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, - device=inputs_dict["input_values"].device, min_masks=2, - ).to(torch_device) + ) + mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device) with torch.no_grad(): outputs = model( inputs_dict.input_values.to(torch_device), - attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, ) @@ -965,14 +996,16 @@ def test_inference_integration(self): # retrieve cosine sim of masked features cosine_sim_masked = cosine_sim[mask_time_indices] + # cosine similarity of model is all > 0.5 as model is + # pre-trained on contrastive loss # fmt: off - expected_cosine_sim_masked = torch.tensor( - [0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, - 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, - 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, - 0.6997], - device=torch_device, - ) + expected_cosine_sim_masked = torch.tensor([ + 0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565, + 0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258, + 0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774, + 0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526, + 0.7768, 0.4898, 0.5393, 0.8183 + ], device=torch_device) # fmt: on self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3)) @@ -997,9 +1030,9 @@ def test_inference_pretrained(self): features_shape, model.config.mask_time_prob, model.config.mask_time_length, - device=inputs_dict["input_values"].device, min_masks=2, - ).to(torch_device) + ) + mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device) with torch.no_grad(): outputs = model( @@ -1064,28 +1097,36 @@ def test_loss_pretraining(self): ) torch.manual_seed(0) + np.random.seed(0) + mask_time_indices = _compute_mask_indices( features_shape, model.config.mask_time_prob, model.config.mask_time_length, - device=inputs_dict["input_values"].device, min_masks=2, - ).to(torch_device) + ) + sampled_negative_indices = _sample_negative_indices( + mask_time_indices.shape, model.config.num_negatives, mask_time_indices + ) + + mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device) + sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device) with torch.no_grad(): outputs = model( inputs_dict.input_values.to(torch_device), attention_mask=inputs_dict.attention_mask.to(torch_device), mask_time_indices=mask_time_indices, + sampled_negative_indices=sampled_negative_indices, ) # check diversity loss num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors - self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3) + self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3) # check overall loss (contrastive loss + diversity loss) - expected_loss = 62.5170 + expected_loss = 116.7094 self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)