Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add Speech Recognition Task (Wav2Vec) (#586)
Browse files Browse the repository at this point in the history
* Base files for wav2vec integration

* Format code with autopep8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Closer to working

* Format code with autopep8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactors

* Refactors

* Cleanups

* Refactor to allow files

* Get predictions working

* Add licence

* Fix loads

* Add check

* Fix imports

* Cleanups

* Add backbone API

* Cleanups

* Fix

* Add tests

* Docs, requirements

* topic thing

* Doc fix

* test

* Add serve

* Fix path

* Swap to audio available

* Small fix

* Some fixes

* Small fix

* Small fix

* Fix

* Updates

* Fix docs

* Remove duplicate

* Add check for audio

Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
4 people authored Jul 19, 2021
1 parent ea4604f commit b8b4ebc
Show file tree
Hide file tree
Showing 24 changed files with 922 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594))

- Added a `SpeechRecognition` task for speech to text using Wav2Vec ([#586](https://github.com/PyTorchLightning/lightning-flash/pull/586))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
22 changes: 22 additions & 0 deletions docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,25 @@ ______________

~classification.data.AudioClassificationData
~classification.data.AudioClassificationPreprocess

Speech Recognition
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~speech_recognition.model.SpeechRecognition
~speech_recognition.data.SpeechRecognitionData

speech_recognition.data.SpeechRecognitionPreprocess
speech_recognition.data.SpeechRecognitionBackboneState
speech_recognition.data.SpeechRecognitionPostprocess
speech_recognition.data.SpeechRecognitionCSVDataSource
speech_recognition.data.SpeechRecognitionJSONDataSource
speech_recognition.data.BaseSpeechRecognition
speech_recognition.data.SpeechRecognitionFileDataSource
speech_recognition.data.SpeechRecognitionPathsDataSource
speech_recognition.data.SpeechRecognitionDatasetDataSource
speech_recognition.data.SpeechRecognitionDeserializer
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Lightning Flash
:caption: Audio

reference/audio_classification
reference/speech_recognition

.. toctree::
:maxdepth: 1
Expand Down
59 changes: 59 additions & 0 deletions docs/source/reference/speech_recognition.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
.. _speech_recognition:

##################
Speech Recognition
##################

********
The Task
********

Speech recognition is the task of classifying audio into a text transcription. We rely on `Wav2Vec <https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/>`_ as our backbone, fine-tuned on labeled transcriptions for speech to text.

-----

*******
Example
*******

Let's fine-tune the model onto our own labeled audio transcription data:

Here's the structure our CSV file:

.. code-block::
file,text
"/path/to/file_1.wav ... ","what was said in file 1."
"/path/to/file_2.wav ... ","what was said in file 2."
"/path/to/file_3.wav ... ","what was said in file 3."
...
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`.
We select a pre-trained Wav2Vec backbone to use for our :class:`~flash.audio.speech_recognition.model.SpeechRecognition` and finetune on a subset of the `TIMIT corpus <https://catalog.ldc.upenn.edu/LDC93S1>`__.
The backbone can be any Wav2Vec model from `HuggingFace transformers <https://huggingface.co/models?search=wav2vec>`__.
Next, we use the trained :class:`~flash.audio.speech_recognition.model.SpeechRecognition` for inference and save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/speech_recognition.py
:language: python
:lines: 14-

------

*******
Serving
*******

The :class:`~flash.audio.speech_recognition.model.SpeechRecognition` is servable.
This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`.
Here's an example:

.. literalinclude:: ../../../flash_examples/serve/speech_recognition/inference_server.py
:language: python
:lines: 14-

You can now perform inference from your client like this:

.. literalinclude:: ../../../flash_examples/serve/speech_recognition/client.py
:language: python
:lines: 14-
Binary file added flash/assets/example.wav
Binary file not shown.
1 change: 1 addition & 0 deletions flash/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
from flash.audio.speech_recognition import SpeechRecognition, SpeechRecognitionData # noqa: F401
15 changes: 15 additions & 0 deletions flash/audio/speech_recognition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.audio.speech_recognition.data import SpeechRecognitionData # noqa: F401
from flash.audio.speech_recognition.model import SpeechRecognition # noqa: F401
30 changes: 30 additions & 0 deletions flash/audio/speech_recognition/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE

SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones")

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2ForCTC

WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"]

for model_name in WAV2VEC_MODELS:
SPEECH_RECOGNITION_BACKBONES(
fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name),
name=model_name,
)
101 changes: 101 additions & 0 deletions flash/audio/speech_recognition/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _AUDIO_AVAILABLE

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
else:
Wav2Vec2Processor = object


@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`,
`optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
max_length_labels (:obj:`int`, `optional`):
Maximum length of the ``labels`` returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""

processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None

def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
inputs = [sample[DefaultDataKeys.INPUT] for sample in samples]
sampling_rates = [sample["sampling_rate"] for sample in metadata]

assert (
len(set(sampling_rates)) == 1
), f"Make sure all inputs have the same sampling rate of {self.processor.feature_extractor.sampling_rate}."

inputs = self.processor(inputs, sampling_rate=sampling_rates[0]).input_values

# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": input} for input in inputs]

batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)

labels = [sample.get(DefaultDataKeys.TARGET, None) for sample in samples]
# check to ensure labels exist to collate
if None not in labels:
with self.processor.as_target_processor():
label_features = self.processor(labels).input_ids
label_features = [{"input_ids": feature} for feature in label_features]
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
max_length=self.max_length_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels,
return_tensors="pt",
)

# replace padding with -100 to ignore loss correctly
batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

return batch
Loading

0 comments on commit b8b4ebc

Please sign in to comment.