Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add Speech Recognition Task (Wav2Vec) #586

Merged
merged 48 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
dd92d79
Base files for wav2vec integration
Jul 14, 2021
2a43fe7
Format code with autopep8
deepsource-autofix[bot] Jul 14, 2021
6a39b34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
1b48bc1
Closer to working
Jul 14, 2021
c87dcc2
Format code with autopep8
deepsource-autofix[bot] Jul 14, 2021
091da56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
2690e9c
Refactors
Jul 15, 2021
1531560
Refactors
Jul 15, 2021
e8664d6
Cleanups
Jul 15, 2021
6d0f1c3
Refactor to allow files
Jul 15, 2021
a9735b2
Get predictions working
Jul 15, 2021
0901d12
Add licence
Jul 15, 2021
bce0e10
Merge branch 'master' into feat/speech_recognition
Jul 15, 2021
1f18f05
Fix loads
Jul 15, 2021
71cb06d
Add check
Jul 15, 2021
50642f5
Fix imports
Jul 15, 2021
d271951
Cleanups
Jul 16, 2021
956ac8e
Add backbone API
Jul 16, 2021
6b132f2
Cleanups
Jul 16, 2021
3db4dad
Fix
Jul 16, 2021
c54acf1
Add tests
Jul 16, 2021
62175ae
Docs, requirements
Jul 16, 2021
dc2e72c
topic thing
Jul 16, 2021
8eccdf9
Doc fix
Jul 16, 2021
dcfa913
test
Jul 16, 2021
e4f0a69
Add serve
Jul 16, 2021
541c1fb
Merge branch 'master' into feat/speech_recognition
Jul 16, 2021
14795f3
Fix path
Jul 18, 2021
1b8eb08
Swap to audio available
Jul 18, 2021
ab3a437
Small fix
ethanwharris Jul 19, 2021
13eb84f
Some fixes
ethanwharris Jul 19, 2021
af9e0c1
Small fix
ethanwharris Jul 19, 2021
4bbc31c
Small fix
ethanwharris Jul 19, 2021
4336f61
Fix
ethanwharris Jul 19, 2021
51c640a
Updates
ethanwharris Jul 19, 2021
801b752
Fix docs
ethanwharris Jul 19, 2021
683f671
Remove duplicate
Jul 19, 2021
8590052
Add check for audio
Jul 19, 2021
1c98625
Updates
ethanwharris Jul 19, 2021
a208e17
Update CHANGELOG.md
ethanwharris Jul 19, 2021
d9d1a0a
Updates
ethanwharris Jul 19, 2021
9259f44
Update docs
ethanwharris Jul 19, 2021
70607a2
Update docs
ethanwharris Jul 19, 2021
4e6bce7
Update docs
ethanwharris Jul 19, 2021
2d08f21
Add example to CI
ethanwharris Jul 19, 2021
0052f1f
Fix some tests
ethanwharris Jul 19, 2021
0c87f04
Fix some broken tests
ethanwharris Jul 19, 2021
bfe8ea6
Fixes
ethanwharris Jul 19, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594))

- Added a `SpeechRecognition` task for speech to text using Wav2Vec ([#586](https://github.com/PyTorchLightning/lightning-flash/pull/586))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
22 changes: 22 additions & 0 deletions docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,25 @@ ______________

~classification.data.AudioClassificationData
~classification.data.AudioClassificationPreprocess

Speech Recognition
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~speech_recognition.model.SpeechRecognition
~speech_recognition.data.SpeechRecognitionData

speech_recognition.data.SpeechRecognitionPreprocess
speech_recognition.data.SpeechRecognitionBackboneState
speech_recognition.data.SpeechRecognitionPostprocess
speech_recognition.data.SpeechRecognitionCSVDataSource
speech_recognition.data.SpeechRecognitionJSONDataSource
speech_recognition.data.BaseSpeechRecognition
speech_recognition.data.SpeechRecognitionFileDataSource
speech_recognition.data.SpeechRecognitionPathsDataSource
speech_recognition.data.SpeechRecognitionDatasetDataSource
speech_recognition.data.SpeechRecognitionDeserializer
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Lightning Flash
:caption: Audio

reference/audio_classification
reference/speech_recognition

.. toctree::
:maxdepth: 1
Expand Down
59 changes: 59 additions & 0 deletions docs/source/reference/speech_recognition.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
.. _speech_recognition:

##################
Speech Recognition
##################

********
The Task
********

Speech recognition is the task of classifying audio into a text transcription. We rely on `Wav2Vec <https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/>`_ as our backbone, fine-tuned on labeled transcriptions for speech to text.

-----

*******
Example
*******

Let's fine-tune the model onto our own labeled audio transcription data:

Here's the structure our CSV file:

.. code-block::

file,text
"/path/to/file_1.wav ... ","what was said in file 1."
"/path/to/file_2.wav ... ","what was said in file 2."
"/path/to/file_3.wav ... ","what was said in file 3."
...

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`.
We select a pre-trained Wav2Vec backbone to use for our :class:`~flash.audio.speech_recognition.model.SpeechRecognition` and finetune on a subset of the `TIMIT corpus <https://catalog.ldc.upenn.edu/LDC93S1>`__.
The backbone can be any Wav2Vec model from `HuggingFace transformers <https://huggingface.co/models?search=wav2vec>`__.
Next, we use the trained :class:`~flash.audio.speech_recognition.model.SpeechRecognition` for inference and save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/speech_recognition.py
:language: python
:lines: 14-

------

*******
Serving
*******

The :class:`~flash.audio.speech_recognition.model.SpeechRecognition` is servable.
This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`.
Here's an example:

.. literalinclude:: ../../../flash_examples/serve/speech_recognition/inference_server.py
:language: python
:lines: 14-

You can now perform inference from your client like this:

.. literalinclude:: ../../../flash_examples/serve/speech_recognition/client.py
:language: python
:lines: 14-
Binary file added flash/assets/example.wav
Binary file not shown.
1 change: 1 addition & 0 deletions flash/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
from flash.audio.speech_recognition import SpeechRecognition, SpeechRecognitionData # noqa: F401
15 changes: 15 additions & 0 deletions flash/audio/speech_recognition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.audio.speech_recognition.data import SpeechRecognitionData # noqa: F401
from flash.audio.speech_recognition.model import SpeechRecognition # noqa: F401
30 changes: 30 additions & 0 deletions flash/audio/speech_recognition/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE

SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones")

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2ForCTC

WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"]

for model_name in WAV2VEC_MODELS:
SPEECH_RECOGNITION_BACKBONES(
fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name),
name=model_name,
)
101 changes: 101 additions & 0 deletions flash/audio/speech_recognition/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _AUDIO_AVAILABLE

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
else:
Wav2Vec2Processor = object


@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`,
`optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
max_length_labels (:obj:`int`, `optional`):
Maximum length of the ``labels`` returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""

processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None

def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
inputs = [sample[DefaultDataKeys.INPUT] for sample in samples]
sampling_rates = [sample["sampling_rate"] for sample in metadata]

assert (
len(set(sampling_rates)) == 1
), f"Make sure all inputs have the same sampling rate of {self.processor.feature_extractor.sampling_rate}."

inputs = self.processor(inputs, sampling_rate=sampling_rates[0]).input_values

# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": input} for input in inputs]

batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)

labels = [sample.get(DefaultDataKeys.TARGET, None) for sample in samples]
# check to ensure labels exist to collate
if None not in labels:
with self.processor.as_target_processor():
label_features = self.processor(labels).input_ids
label_features = [{"input_ids": feature} for feature in label_features]
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
max_length=self.max_length_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels,
return_tensors="pt",
)

# replace padding with -100 to ignore loss correctly
batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

return batch
Loading