diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7c1cb3b8..1fa497852c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594)) +- Added a `SpeechRecognition` task for speech to text using Wav2Vec ([#586](https://github.com/PyTorchLightning/lightning-flash/pull/586)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index 79662fea87..706a364372 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -19,3 +19,25 @@ ______________ ~classification.data.AudioClassificationData ~classification.data.AudioClassificationPreprocess + +Speech Recognition +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~speech_recognition.model.SpeechRecognition + ~speech_recognition.data.SpeechRecognitionData + + speech_recognition.data.SpeechRecognitionPreprocess + speech_recognition.data.SpeechRecognitionBackboneState + speech_recognition.data.SpeechRecognitionPostprocess + speech_recognition.data.SpeechRecognitionCSVDataSource + speech_recognition.data.SpeechRecognitionJSONDataSource + speech_recognition.data.BaseSpeechRecognition + speech_recognition.data.SpeechRecognitionFileDataSource + speech_recognition.data.SpeechRecognitionPathsDataSource + speech_recognition.data.SpeechRecognitionDatasetDataSource + speech_recognition.data.SpeechRecognitionDeserializer diff --git a/docs/source/index.rst b/docs/source/index.rst index d12099d884..8f56b56214 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,6 +45,7 @@ Lightning Flash :caption: Audio reference/audio_classification + reference/speech_recognition .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/speech_recognition.rst b/docs/source/reference/speech_recognition.rst new file mode 100644 index 0000000000..63816cba49 --- /dev/null +++ b/docs/source/reference/speech_recognition.rst @@ -0,0 +1,59 @@ +.. _speech_recognition: + +################## +Speech Recognition +################## + +******** +The Task +******** + +Speech recognition is the task of classifying audio into a text transcription. We rely on `Wav2Vec `_ as our backbone, fine-tuned on labeled transcriptions for speech to text. + +----- + +******* +Example +******* + +Let's fine-tune the model onto our own labeled audio transcription data: + +Here's the structure our CSV file: + +.. code-block:: + + file,text + "/path/to/file_1.wav ... ","what was said in file 1." + "/path/to/file_2.wav ... ","what was said in file 2." + "/path/to/file_3.wav ... ","what was said in file 3." + ... + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`. +We select a pre-trained Wav2Vec backbone to use for our :class:`~flash.audio.speech_recognition.model.SpeechRecognition` and finetune on a subset of the `TIMIT corpus `__. +The backbone can be any Wav2Vec model from `HuggingFace transformers `__. +Next, we use the trained :class:`~flash.audio.speech_recognition.model.SpeechRecognition` for inference and save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/speech_recognition.py + :language: python + :lines: 14- + +------ + +******* +Serving +******* + +The :class:`~flash.audio.speech_recognition.model.SpeechRecognition` is servable. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. +Here's an example: + +.. literalinclude:: ../../../flash_examples/serve/speech_recognition/inference_server.py + :language: python + :lines: 14- + +You can now perform inference from your client like this: + +.. literalinclude:: ../../../flash_examples/serve/speech_recognition/client.py + :language: python + :lines: 14- diff --git a/flash/assets/example.wav b/flash/assets/example.wav new file mode 100644 index 0000000000..8a1d66a36b Binary files /dev/null and b/flash/assets/example.wav differ diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py index 40eeaae124..b90bc6d06e 100644 --- a/flash/audio/__init__.py +++ b/flash/audio/__init__.py @@ -1 +1,2 @@ from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 +from flash.audio.speech_recognition import SpeechRecognition, SpeechRecognitionData # noqa: F401 diff --git a/flash/audio/speech_recognition/__init__.py b/flash/audio/speech_recognition/__init__.py new file mode 100644 index 0000000000..00f1b6fa0c --- /dev/null +++ b/flash/audio/speech_recognition/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.audio.speech_recognition.data import SpeechRecognitionData # noqa: F401 +from flash.audio.speech_recognition.model import SpeechRecognition # noqa: F401 diff --git a/flash/audio/speech_recognition/backbone.py b/flash/audio/speech_recognition/backbone.py new file mode 100644 index 0000000000..425ef2eb00 --- /dev/null +++ b/flash/audio/speech_recognition/backbone.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _AUDIO_AVAILABLE + +SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") + +if _AUDIO_AVAILABLE: + from transformers import Wav2Vec2ForCTC + + WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"] + + for model_name in WAV2VEC_MODELS: + SPEECH_RECOGNITION_BACKBONES( + fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name), + name=model_name, + ) diff --git a/flash/audio/speech_recognition/collate.py b/flash/audio/speech_recognition/collate.py new file mode 100644 index 0000000000..9ee53a4686 --- /dev/null +++ b/flash/audio/speech_recognition/collate.py @@ -0,0 +1,101 @@ +# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _AUDIO_AVAILABLE + +if _AUDIO_AVAILABLE: + from transformers import Wav2Vec2Processor +else: + Wav2Vec2Processor = object + + +@dataclass +class DataCollatorCTCWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor (:class:`~transformers.Wav2Vec2Processor`) + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, + `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + max_length_labels (:obj:`int`, `optional`): + Maximum length of the ``labels`` returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Wav2Vec2Processor + padding: Union[bool, str] = True + max_length: Optional[int] = None + max_length_labels: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + pad_to_multiple_of_labels: Optional[int] = None + + def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + inputs = [sample[DefaultDataKeys.INPUT] for sample in samples] + sampling_rates = [sample["sampling_rate"] for sample in metadata] + + assert ( + len(set(sampling_rates)) == 1 + ), f"Make sure all inputs have the same sampling rate of {self.processor.feature_extractor.sampling_rate}." + + inputs = self.processor(inputs, sampling_rate=sampling_rates[0]).input_values + + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": input} for input in inputs] + + batch = self.processor.pad( + input_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + labels = [sample.get(DefaultDataKeys.TARGET, None) for sample in samples] + # check to ensure labels exist to collate + if None not in labels: + with self.processor.as_target_processor(): + label_features = self.processor(labels).input_ids + label_features = [{"input_ids": feature} for feature in label_features] + labels_batch = self.processor.pad( + label_features, + padding=self.padding, + max_length=self.max_length_labels, + pad_to_multiple_of=self.pad_to_multiple_of_labels, + return_tensors="pt", + ) + + # replace padding with -100 to ignore loss correctly + batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + return batch diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py new file mode 100644 index 0000000000..97dfde0f26 --- /dev/null +++ b/flash/audio/speech_recognition/data.py @@ -0,0 +1,225 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io +import os.path +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import torch +from torch.utils.data import Dataset + +import flash +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import ( + DatasetDataSource, + DataSource, + DefaultDataKeys, + DefaultDataSources, + PathsDataSource, +) +from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.properties import ProcessState +from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras + +if _AUDIO_AVAILABLE: + import soundfile as sf + from datasets import Dataset as HFDataset + from datasets import load_dataset + from transformers import Wav2Vec2CTCTokenizer +else: + HFDataset = object + + +class SpeechRecognitionDeserializer(Deserializer): + + def deserialize(self, sample: Any) -> Dict: + encoded_with_padding = (sample + "===").encode("ascii") + audio = base64.b64decode(encoded_with_padding) + buffer = io.BytesIO(audio) + data, sampling_rate = sf.read(buffer) + return { + DefaultDataKeys.INPUT: data, + DefaultDataKeys.METADATA: { + "sampling_rate": sampling_rate + }, + } + + @property + def example_input(self) -> str: + with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f: + return base64.b64encode(f.read()).decode("UTF-8") + + +class BaseSpeechRecognition: + + def _load_sample(self, sample: Dict[str, Any]) -> Any: + path = sample[DefaultDataKeys.INPUT] + if not os.path.isabs(path) and DefaultDataKeys.METADATA in sample and "root" in sample[DefaultDataKeys.METADATA + ]: + path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path) + speech_array, sampling_rate = sf.read(path) + sample[DefaultDataKeys.INPUT] = speech_array + sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate} + return sample + + +class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition): + + def __init__(self, filetype: Optional[str] = None): + super().__init__() + self.filetype = filetype + + def load_data( + self, + data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], + dataset: Optional[Any] = None, + ) -> Union[Sequence[Mapping[str, Any]]]: + if self.filetype == 'json': + file, input_key, target_key, field = data + else: + file, input_key, target_key = data + stage = self.running_stage.value + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}) + + dataset = dataset_dict[stage] + meta = {"root": os.path.dirname(file)} + return [{ + DefaultDataKeys.INPUT: input_file, + DefaultDataKeys.TARGET: target, + DefaultDataKeys.METADATA: meta, + } for input_file, target in zip(dataset[input_key], dataset[target_key])] + + def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: + return self._load_sample(sample) + + +class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource): + + def __init__(self): + super().__init__(filetype='csv') + + +class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource): + + def __init__(self): + super().__init__(filetype='json') + + +class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition): + + def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]: + if isinstance(data, HFDataset): + data = list(zip(data["file"], data["text"])) + return super().load_data(data, dataset) + + +class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition): + + def __init__(self): + super().__init__(("wav", "ogg", "flac", "mat")) + + def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: + return self._load_sample(sample) + + +class SpeechRecognitionPreprocess(Preprocess): + + @requires_extras("audio") + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(), + DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(), + DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(), + DefaultDataSources.DATASET: SpeechRecognitionDatasetDataSource(), + }, + default_data_source=DefaultDataSources.FILES, + deserializer=SpeechRecognitionDeserializer(), + ) + + def get_state_dict(self) -> Dict[str, Any]: + return self.transforms + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + +@dataclass(unsafe_hash=True, frozen=True) +class SpeechRecognitionBackboneState(ProcessState): + """The ``SpeechRecognitionBackboneState`` stores the backbone in use by the + :class:`~flash.audio.speech_recognition.data.SpeechRecognitionPostprocess` + """ + + backbone: str + + +class SpeechRecognitionPostprocess(Postprocess): + + @requires_extras("audio") + def __init__(self): + super().__init__() + + self._backbone = None + self._tokenizer = None + + @property + def backbone(self): + backbone_state = self.get_state(SpeechRecognitionBackboneState) + if backbone_state is not None: + return backbone_state.backbone + + @property + def tokenizer(self): + if self.backbone is not None and self.backbone != self._backbone: + self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) + self._backbone = self.backbone + return self._tokenizer + + def per_batch_transform(self, batch: Any) -> Any: + # converts logits into greedy transcription + pred_ids = torch.argmax(batch.logits, dim=-1) + transcriptions = self.tokenizer.batch_decode(pred_ids) + return transcriptions + + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("_tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) + + +class SpeechRecognitionData(DataModule): + """Data Module for text classification tasks""" + + preprocess_cls = SpeechRecognitionPreprocess + postprocess_cls = SpeechRecognitionPostprocess diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py new file mode 100644 index 0000000000..588f4f89b2 --- /dev/null +++ b/flash/audio/speech_recognition/model.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from typing import Any, Callable, Dict, Mapping, Optional, Type, Union + +import torch +import torch.nn as nn + +from flash import Task +from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES +from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding +from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState +from flash.core.data.process import Serializer +from flash.core.data.states import CollateFn +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _AUDIO_AVAILABLE + +if _AUDIO_AVAILABLE: + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +class SpeechRecognition(Task): + + backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES + + required_extras = "audio" + + def __init__( + self, + backbone: str = "facebook/wav2vec2-base-960h", + loss_fn: Optional[Callable] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + learning_rate: float = 1e-5, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ): + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" + # disable HF thousand warnings + warnings.simplefilter("ignore") + # set os environ variable for multiprocesses + os.environ["PYTHONWARNINGS"] = "ignore" + + model = self.backbones.get(backbone + )() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone) + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.save_hyperparameters() + + self.set_state(SpeechRecognitionBackboneState(backbone)) + self.set_state(CollateFn(DataCollatorCTCWithPadding(Wav2Vec2Processor.from_pretrained(backbone)))) + + def forward(self, batch: Dict[str, torch.Tensor]): + return self.model(batch["input_values"]) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) + + def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: + out = self.model(batch["input_values"], labels=batch["labels"]) + out["logs"] = {'loss': out.loss} + return out diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 51d28d2a22..e7e9a30635 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -229,7 +229,10 @@ def forward(self, samples: Sequence[Any]) -> Any: with self._collate_context: samples, metadata = self._extract_metadata(samples) - samples = self.collate_fn(samples) + try: + samples = self.collate_fn(samples, metadata) + except TypeError: + samples = self.collate_fn(samples) if metadata and isinstance(samples, dict): samples[DefaultDataKeys.METADATA] = metadata self.callback.on_collate(samples, self.stage) diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 7020e32d36..a1d6e56085 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from abc import ABC, abstractclassmethod, abstractmethod from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence @@ -24,7 +25,7 @@ import flash from flash.core.data.batch import default_uncollate from flash.core.data.callback import FlashCallback -from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources +from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.properties import Properties from flash.core.data.states import CollateFn from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext @@ -360,18 +361,24 @@ def per_batch_transform(self, batch: Any) -> Any: """ return self.current_transform(batch) - def collate(self, samples: Sequence) -> Any: + def collate(self, samples: Sequence, metadata=None) -> Any: """ Transform to convert a sequence of samples to a collated batch. """ + current_transform = self.current_transform + if current_transform is self._identity: + current_transform = self._default_collate # the model can provide a custom ``collate_fn``. collate_fn = self.get_state(CollateFn) if collate_fn is not None: - return collate_fn.collate_fn(samples) - - current_transform = self.current_transform - if current_transform is self._identity: - return self._default_collate(samples) - return self.current_transform(samples) + collate_fn = collate_fn.collate_fn + else: + collate_fn = current_transform + # return collate_fn.collate_fn(samples) + + parameters = inspect.signature(collate_fn).parameters + if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters: + return collate_fn(samples, metadata) + return collate_fn(samples) def per_sample_transform_on_device(self, sample: Any) -> Any: """Transforms to apply to the data before the collation (per-sample basis). diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 80c6b6188c..d1ba3388b6 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -87,15 +87,24 @@ def _compare_version(package: str, op, version) -> bool: _OPEN3D_AVAILABLE = _module_available("open3d") _ASTEROID_AVAILABLE = _module_available("asteroid") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") +_SOUNDFILE_AVAILABLE = _module_available("soundfile") _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") _TORCHAUDIO_AVAILABLE = _module_available("torchaudio") +_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score") +_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") +_DATASETS_AVAILABLE = _module_available("datasets") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") -_TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE +_TEXT_AVAILABLE = all([ + _TRANSFORMERS_AVAILABLE, + _ROUGE_SCORE_AVAILABLE, + _SENTENCEPIECE_AVAILABLE, + _DATASETS_AVAILABLE, +]) _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE _VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE _IMAGE_AVAILABLE = all([ @@ -108,10 +117,7 @@ def _compare_version(package: str, op, version) -> bool: ]) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE -_AUDIO_AVAILABLE = all([ - _ASTEROID_AVAILABLE, - _TORCHAUDIO_AVAILABLE, -]) +_AUDIO_AVAILABLE = all([_ASTEROID_AVAILABLE, _TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE _EXTRAS_AVAILABLE = { diff --git a/flash_examples/serve/speech_recognition/client.py b/flash_examples/serve/speech_recognition/client.py new file mode 100644 index 0000000000..c855a37204 --- /dev/null +++ b/flash_examples/serve/speech_recognition/client.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +from pathlib import Path + +import requests + +import flash + +with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f: + audio_str = base64.b64encode(f.read()).decode("UTF-8") + +body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}} +resp = requests.post("http://127.0.0.1:8000/predict", json=body) + +print(resp.json()) diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/flash_examples/serve/speech_recognition/inference_server.py new file mode 100644 index 0000000000..bbc4479624 --- /dev/null +++ b/flash_examples/serve/speech_recognition/inference_server.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.audio import SpeechRecognition + +model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt") +model.serve() diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py new file mode 100644 index 0000000000..269148c60f --- /dev/null +++ b/flash_examples/speech_recognition.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.audio import SpeechRecognition, SpeechRecognitionData +from flash.core.data.utils import download_data + +# # 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") + +datamodule = SpeechRecognitionData.from_json( + input_fields="file", + target_fields="text", + train_file="data/timit/train.json", + test_file="data/timit/test.json", +) + +# 2. Build the task +model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h") + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_test_batches=1) +trainer.finetune(model, datamodule=datamodule, strategy='no_freeze') + +# 4. Predict on audio files! +predictions = model.predict(["data/timit/example.wav"]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("speech_recognition_model.pt") diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index e608a13b78..570e7c89b8 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,2 +1,5 @@ asteroid>=0.5.1 torchaudio +soundfile>=0.10.2 +transformers>=4.5 +datasets>=1.8 diff --git a/tests/audio/speech_recognition/__init__.py b/tests/audio/speech_recognition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py new file mode 100644 index 0000000000..2b87129210 --- /dev/null +++ b/tests/audio/speech_recognition/test_data.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from pathlib import Path + +import pytest + +import flash +from flash.audio import SpeechRecognitionData +from flash.core.data.data_source import DefaultDataKeys +from tests.helpers.utils import _AUDIO_TESTING + +path = str(Path(flash.ASSETS_ROOT) / "example.wav") +sample = {'file': path, 'text': 'example input.'} + +TEST_CSV_DATA = f"""file,text +{path},example input. +{path},example input. +{path},example input. +{path},example input. +{path},example input. +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir, n_samples=5): + path = Path(tmpdir) / "data.json" + with path.open('w') as f: + f.write('\n'.join([json.dumps(sample) for x in range(n_samples)])) + return path + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +def test_from_csv(tmpdir): + csv_path = csv_data(tmpdir) + dm = SpeechRecognitionData.from_csv("file", "text", train_file=csv_path, batch_size=1, num_workers=0) + batch = next(iter(dm.train_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +def test_stage_test_and_valid(tmpdir): + csv_path = csv_data(tmpdir) + dm = SpeechRecognitionData.from_csv( + "file", "text", train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, num_workers=0 + ) + batch = next(iter(dm.val_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + batch = next(iter(dm.test_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +def test_from_json(tmpdir): + json_path = json_data(tmpdir) + dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0) + batch = next(iter(dm.train_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + +@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.") +def test_audio_module_not_found_error(): + with pytest.raises(ModuleNotFoundError, match="[audio]"): + SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0) diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py new file mode 100644 index 0000000000..0c9773022d --- /dev/null +++ b/tests/audio/speech_recognition/test_data_model_integration.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from pathlib import Path + +import pytest +from pytorch_lightning import Trainer + +import flash +from flash.audio import SpeechRecognition, SpeechRecognitionData +from tests.helpers.utils import _AUDIO_TESTING + +TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing + +path = str(Path(flash.ASSETS_ROOT) / "example.wav") +sample = {'file': path, 'text': 'example input.'} + +TEST_CSV_DATA = f"""file,text +{path},example input. +{path},example input. +{path},example input. +{path},example input. +{path},example input. +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir, n_samples=5): + path = Path(tmpdir) / "data.json" + with path.open('w') as f: + f.write('\n'.join([json.dumps(sample) for x in range(n_samples)])) + return path + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_classification_csv(tmpdir): + csv_path = csv_data(tmpdir) + + data = SpeechRecognitionData.from_csv( + "file", + "text", + train_file=csv_path, + num_workers=0, + batch_size=2, + ) + model = SpeechRecognition(backbone=TEST_BACKBONE) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, datamodule=data) + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_classification_json(tmpdir): + json_path = json_data(tmpdir) + + data = SpeechRecognitionData.from_json( + "file", + "text", + train_file=json_path, + num_workers=0, + batch_size=2, + ) + model = SpeechRecognition(backbone=TEST_BACKBONE) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, datamodule=data) diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py new file mode 100644 index 0000000000..69cf6a7aa3 --- /dev/null +++ b/tests/audio/speech_recognition/test_model.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from unittest import mock + +import numpy as np +import pytest +import torch + +from flash import Trainer +from flash.audio import SpeechRecognition +from flash.audio.speech_recognition.data import SpeechRecognitionPostprocess, SpeechRecognitionPreprocess +from flash.core.data.data_source import DefaultDataKeys +from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + + def __getitem__(self, index): + return { + DefaultDataKeys.INPUT: np.random.randn(86631), + DefaultDataKeys.TARGET: "some target text", + DefaultDataKeys.METADATA: { + "sampling_rate": 16000 + }, + } + + def __len__(self) -> int: + return 100 + + +# ============================== + +TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_init_train(tmpdir): + model = SpeechRecognition(backbone=TEST_BACKBONE) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_jit(tmpdir): + sample_input = {"input_values": torch.randn(size=torch.Size([1, 86631])).float()} + path = os.path.join(tmpdir, "test.pt") + + model = SpeechRecognition(backbone=TEST_BACKBONE) + model.eval() + + # Huggingface model only supports `torch.jit.trace` with `strict=False` + model = torch.jit.trace(model, sample_input, strict=False) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input)["logits"] + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 95, 12]) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = SpeechRecognition(backbone=TEST_BACKBONE) + # TODO: Currently only servable once a preprocess and postprocess have been attached + model._preprocess = SpeechRecognitionPreprocess() + model._postprocess = SpeechRecognitionPostprocess() + model.eval() + model.serve() + + +@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.") +def test_load_from_checkpoint_dependency_error(): + with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[audio]'")): + SpeechRecognition.load_from_checkpoint("not_a_real_checkpoint.pt") diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 2b593cdd9e..b5ec52dec1 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -691,7 +691,7 @@ def test_step(self, batch, batch_idx): assert len(batch) == 2 assert batch[0].shape == torch.Size([2, 1]) - def predict_step(self, batch, batch_idx, dataloader_idx): + def predict_step(self, batch, batch_idx, dataloader_idx=None): assert batch[0][0] == 'a' assert batch[0][1] == 'a' assert batch[1][0] == 'b' diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 56b729e36e..bc3260b1a8 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -42,6 +42,10 @@ "audio_classification.py", marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") ), + pytest.param( + "speech_recognition.py", + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") + ), pytest.param( "image_classification.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")