diff --git a/CHANGELOG.md b/CHANGELOG.md
index cb7c1cb3b8..1fa497852c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594))
+- Added a `SpeechRecognition` task for speech to text using Wav2Vec ([#586](https://github.com/PyTorchLightning/lightning-flash/pull/586))
+
### Changed
- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst
index 79662fea87..706a364372 100644
--- a/docs/source/api/audio.rst
+++ b/docs/source/api/audio.rst
@@ -19,3 +19,25 @@ ______________
~classification.data.AudioClassificationData
~classification.data.AudioClassificationPreprocess
+
+Speech Recognition
+__________________
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ~speech_recognition.model.SpeechRecognition
+ ~speech_recognition.data.SpeechRecognitionData
+
+ speech_recognition.data.SpeechRecognitionPreprocess
+ speech_recognition.data.SpeechRecognitionBackboneState
+ speech_recognition.data.SpeechRecognitionPostprocess
+ speech_recognition.data.SpeechRecognitionCSVDataSource
+ speech_recognition.data.SpeechRecognitionJSONDataSource
+ speech_recognition.data.BaseSpeechRecognition
+ speech_recognition.data.SpeechRecognitionFileDataSource
+ speech_recognition.data.SpeechRecognitionPathsDataSource
+ speech_recognition.data.SpeechRecognitionDatasetDataSource
+ speech_recognition.data.SpeechRecognitionDeserializer
diff --git a/docs/source/index.rst b/docs/source/index.rst
index d12099d884..8f56b56214 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -45,6 +45,7 @@ Lightning Flash
:caption: Audio
reference/audio_classification
+ reference/speech_recognition
.. toctree::
:maxdepth: 1
diff --git a/docs/source/reference/speech_recognition.rst b/docs/source/reference/speech_recognition.rst
new file mode 100644
index 0000000000..63816cba49
--- /dev/null
+++ b/docs/source/reference/speech_recognition.rst
@@ -0,0 +1,59 @@
+.. _speech_recognition:
+
+##################
+Speech Recognition
+##################
+
+********
+The Task
+********
+
+Speech recognition is the task of classifying audio into a text transcription. We rely on `Wav2Vec `_ as our backbone, fine-tuned on labeled transcriptions for speech to text.
+
+-----
+
+*******
+Example
+*******
+
+Let's fine-tune the model onto our own labeled audio transcription data:
+
+Here's the structure our CSV file:
+
+.. code-block::
+
+ file,text
+ "/path/to/file_1.wav ... ","what was said in file 1."
+ "/path/to/file_2.wav ... ","what was said in file 2."
+ "/path/to/file_3.wav ... ","what was said in file 3."
+ ...
+
+Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`.
+We select a pre-trained Wav2Vec backbone to use for our :class:`~flash.audio.speech_recognition.model.SpeechRecognition` and finetune on a subset of the `TIMIT corpus `__.
+The backbone can be any Wav2Vec model from `HuggingFace transformers `__.
+Next, we use the trained :class:`~flash.audio.speech_recognition.model.SpeechRecognition` for inference and save the model.
+Here's the full example:
+
+.. literalinclude:: ../../../flash_examples/speech_recognition.py
+ :language: python
+ :lines: 14-
+
+------
+
+*******
+Serving
+*******
+
+The :class:`~flash.audio.speech_recognition.model.SpeechRecognition` is servable.
+This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`.
+Here's an example:
+
+.. literalinclude:: ../../../flash_examples/serve/speech_recognition/inference_server.py
+ :language: python
+ :lines: 14-
+
+You can now perform inference from your client like this:
+
+.. literalinclude:: ../../../flash_examples/serve/speech_recognition/client.py
+ :language: python
+ :lines: 14-
diff --git a/flash/assets/example.wav b/flash/assets/example.wav
new file mode 100644
index 0000000000..8a1d66a36b
Binary files /dev/null and b/flash/assets/example.wav differ
diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py
index 40eeaae124..b90bc6d06e 100644
--- a/flash/audio/__init__.py
+++ b/flash/audio/__init__.py
@@ -1 +1,2 @@
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
+from flash.audio.speech_recognition import SpeechRecognition, SpeechRecognitionData # noqa: F401
diff --git a/flash/audio/speech_recognition/__init__.py b/flash/audio/speech_recognition/__init__.py
new file mode 100644
index 0000000000..00f1b6fa0c
--- /dev/null
+++ b/flash/audio/speech_recognition/__init__.py
@@ -0,0 +1,15 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.audio.speech_recognition.data import SpeechRecognitionData # noqa: F401
+from flash.audio.speech_recognition.model import SpeechRecognition # noqa: F401
diff --git a/flash/audio/speech_recognition/backbone.py b/flash/audio/speech_recognition/backbone.py
new file mode 100644
index 0000000000..425ef2eb00
--- /dev/null
+++ b/flash/audio/speech_recognition/backbone.py
@@ -0,0 +1,30 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _AUDIO_AVAILABLE
+
+SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones")
+
+if _AUDIO_AVAILABLE:
+ from transformers import Wav2Vec2ForCTC
+
+ WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"]
+
+ for model_name in WAV2VEC_MODELS:
+ SPEECH_RECOGNITION_BACKBONES(
+ fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name),
+ name=model_name,
+ )
diff --git a/flash/audio/speech_recognition/collate.py b/flash/audio/speech_recognition/collate.py
new file mode 100644
index 0000000000..9ee53a4686
--- /dev/null
+++ b/flash/audio/speech_recognition/collate.py
@@ -0,0 +1,101 @@
+# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.utilities.imports import _AUDIO_AVAILABLE
+
+if _AUDIO_AVAILABLE:
+ from transformers import Wav2Vec2Processor
+else:
+ Wav2Vec2Processor = object
+
+
+@dataclass
+class DataCollatorCTCWithPadding:
+ """
+ Data collator that will dynamically pad the inputs received.
+ Args:
+ processor (:class:`~transformers.Wav2Vec2Processor`)
+ The processor used for proccessing the data.
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`,
+ `optional`, defaults to :obj:`True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
+ among:
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided.
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
+ different lengths).
+ max_length (:obj:`int`, `optional`):
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
+ max_length_labels (:obj:`int`, `optional`):
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
+ pad_to_multiple_of (:obj:`int`, `optional`):
+ If set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
+ 7.5 (Volta).
+ """
+
+ processor: Wav2Vec2Processor
+ padding: Union[bool, str] = True
+ max_length: Optional[int] = None
+ max_length_labels: Optional[int] = None
+ pad_to_multiple_of: Optional[int] = None
+ pad_to_multiple_of_labels: Optional[int] = None
+
+ def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ inputs = [sample[DefaultDataKeys.INPUT] for sample in samples]
+ sampling_rates = [sample["sampling_rate"] for sample in metadata]
+
+ assert (
+ len(set(sampling_rates)) == 1
+ ), f"Make sure all inputs have the same sampling rate of {self.processor.feature_extractor.sampling_rate}."
+
+ inputs = self.processor(inputs, sampling_rate=sampling_rates[0]).input_values
+
+ # split inputs and labels since they have to be of different lengths and need
+ # different padding methods
+ input_features = [{"input_values": input} for input in inputs]
+
+ batch = self.processor.pad(
+ input_features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors="pt",
+ )
+
+ labels = [sample.get(DefaultDataKeys.TARGET, None) for sample in samples]
+ # check to ensure labels exist to collate
+ if None not in labels:
+ with self.processor.as_target_processor():
+ label_features = self.processor(labels).input_ids
+ label_features = [{"input_ids": feature} for feature in label_features]
+ labels_batch = self.processor.pad(
+ label_features,
+ padding=self.padding,
+ max_length=self.max_length_labels,
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
+ return_tensors="pt",
+ )
+
+ # replace padding with -100 to ignore loss correctly
+ batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
+
+ return batch
diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py
new file mode 100644
index 0000000000..97dfde0f26
--- /dev/null
+++ b/flash/audio/speech_recognition/data.py
@@ -0,0 +1,225 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import base64
+import io
+import os.path
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
+
+import torch
+from torch.utils.data import Dataset
+
+import flash
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import (
+ DatasetDataSource,
+ DataSource,
+ DefaultDataKeys,
+ DefaultDataSources,
+ PathsDataSource,
+)
+from flash.core.data.process import Deserializer, Postprocess, Preprocess
+from flash.core.data.properties import ProcessState
+from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras
+
+if _AUDIO_AVAILABLE:
+ import soundfile as sf
+ from datasets import Dataset as HFDataset
+ from datasets import load_dataset
+ from transformers import Wav2Vec2CTCTokenizer
+else:
+ HFDataset = object
+
+
+class SpeechRecognitionDeserializer(Deserializer):
+
+ def deserialize(self, sample: Any) -> Dict:
+ encoded_with_padding = (sample + "===").encode("ascii")
+ audio = base64.b64decode(encoded_with_padding)
+ buffer = io.BytesIO(audio)
+ data, sampling_rate = sf.read(buffer)
+ return {
+ DefaultDataKeys.INPUT: data,
+ DefaultDataKeys.METADATA: {
+ "sampling_rate": sampling_rate
+ },
+ }
+
+ @property
+ def example_input(self) -> str:
+ with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
+ return base64.b64encode(f.read()).decode("UTF-8")
+
+
+class BaseSpeechRecognition:
+
+ def _load_sample(self, sample: Dict[str, Any]) -> Any:
+ path = sample[DefaultDataKeys.INPUT]
+ if not os.path.isabs(path) and DefaultDataKeys.METADATA in sample and "root" in sample[DefaultDataKeys.METADATA
+ ]:
+ path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path)
+ speech_array, sampling_rate = sf.read(path)
+ sample[DefaultDataKeys.INPUT] = speech_array
+ sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate}
+ return sample
+
+
+class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition):
+
+ def __init__(self, filetype: Optional[str] = None):
+ super().__init__()
+ self.filetype = filetype
+
+ def load_data(
+ self,
+ data: Tuple[str, Union[str, List[str]], Union[str, List[str]]],
+ dataset: Optional[Any] = None,
+ ) -> Union[Sequence[Mapping[str, Any]]]:
+ if self.filetype == 'json':
+ file, input_key, target_key, field = data
+ else:
+ file, input_key, target_key = data
+ stage = self.running_stage.value
+ if self.filetype == 'json' and field is not None:
+ dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}, field=field)
+ else:
+ dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)})
+
+ dataset = dataset_dict[stage]
+ meta = {"root": os.path.dirname(file)}
+ return [{
+ DefaultDataKeys.INPUT: input_file,
+ DefaultDataKeys.TARGET: target,
+ DefaultDataKeys.METADATA: meta,
+ } for input_file, target in zip(dataset[input_key], dataset[target_key])]
+
+ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
+ return self._load_sample(sample)
+
+
+class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource):
+
+ def __init__(self):
+ super().__init__(filetype='csv')
+
+
+class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource):
+
+ def __init__(self):
+ super().__init__(filetype='json')
+
+
+class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition):
+
+ def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]:
+ if isinstance(data, HFDataset):
+ data = list(zip(data["file"], data["text"]))
+ return super().load_data(data, dataset)
+
+
+class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition):
+
+ def __init__(self):
+ super().__init__(("wav", "ogg", "flac", "mat"))
+
+ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
+ return self._load_sample(sample)
+
+
+class SpeechRecognitionPreprocess(Preprocess):
+
+ @requires_extras("audio")
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ ):
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={
+ DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(),
+ DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(),
+ DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(),
+ DefaultDataSources.DATASET: SpeechRecognitionDatasetDataSource(),
+ },
+ default_data_source=DefaultDataSources.FILES,
+ deserializer=SpeechRecognitionDeserializer(),
+ )
+
+ def get_state_dict(self) -> Dict[str, Any]:
+ return self.transforms
+
+ @classmethod
+ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
+ return cls(**state_dict)
+
+
+@dataclass(unsafe_hash=True, frozen=True)
+class SpeechRecognitionBackboneState(ProcessState):
+ """The ``SpeechRecognitionBackboneState`` stores the backbone in use by the
+ :class:`~flash.audio.speech_recognition.data.SpeechRecognitionPostprocess`
+ """
+
+ backbone: str
+
+
+class SpeechRecognitionPostprocess(Postprocess):
+
+ @requires_extras("audio")
+ def __init__(self):
+ super().__init__()
+
+ self._backbone = None
+ self._tokenizer = None
+
+ @property
+ def backbone(self):
+ backbone_state = self.get_state(SpeechRecognitionBackboneState)
+ if backbone_state is not None:
+ return backbone_state.backbone
+
+ @property
+ def tokenizer(self):
+ if self.backbone is not None and self.backbone != self._backbone:
+ self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone)
+ self._backbone = self.backbone
+ return self._tokenizer
+
+ def per_batch_transform(self, batch: Any) -> Any:
+ # converts logits into greedy transcription
+ pred_ids = torch.argmax(batch.logits, dim=-1)
+ transcriptions = self.tokenizer.batch_decode(pred_ids)
+ return transcriptions
+
+ def __getstate__(self): # TODO: Find out why this is being pickled
+ state = self.__dict__.copy()
+ state.pop("_tokenizer")
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone)
+
+
+class SpeechRecognitionData(DataModule):
+ """Data Module for text classification tasks"""
+
+ preprocess_cls = SpeechRecognitionPreprocess
+ postprocess_cls = SpeechRecognitionPostprocess
diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py
new file mode 100644
index 0000000000..588f4f89b2
--- /dev/null
+++ b/flash/audio/speech_recognition/model.py
@@ -0,0 +1,78 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import warnings
+from typing import Any, Callable, Dict, Mapping, Optional, Type, Union
+
+import torch
+import torch.nn as nn
+
+from flash import Task
+from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
+from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
+from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState
+from flash.core.data.process import Serializer
+from flash.core.data.states import CollateFn
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _AUDIO_AVAILABLE
+
+if _AUDIO_AVAILABLE:
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
+
+
+class SpeechRecognition(Task):
+
+ backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES
+
+ required_extras = "audio"
+
+ def __init__(
+ self,
+ backbone: str = "facebook/wav2vec2-base-960h",
+ loss_fn: Optional[Callable] = None,
+ optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
+ learning_rate: float = 1e-5,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
+ ):
+ os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
+ # disable HF thousand warnings
+ warnings.simplefilter("ignore")
+ # set os environ variable for multiprocesses
+ os.environ["PYTHONWARNINGS"] = "ignore"
+
+ model = self.backbones.get(backbone
+ )() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone)
+ super().__init__(
+ model=model,
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ learning_rate=learning_rate,
+ serializer=serializer,
+ )
+
+ self.save_hyperparameters()
+
+ self.set_state(SpeechRecognitionBackboneState(backbone))
+ self.set_state(CollateFn(DataCollatorCTCWithPadding(Wav2Vec2Processor.from_pretrained(backbone))))
+
+ def forward(self, batch: Dict[str, torch.Tensor]):
+ return self.model(batch["input_values"])
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ return self(batch)
+
+ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
+ out = self.model(batch["input_values"], labels=batch["labels"])
+ out["logs"] = {'loss': out.loss}
+ return out
diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py
index 51d28d2a22..e7e9a30635 100644
--- a/flash/core/data/batch.py
+++ b/flash/core/data/batch.py
@@ -229,7 +229,10 @@ def forward(self, samples: Sequence[Any]) -> Any:
with self._collate_context:
samples, metadata = self._extract_metadata(samples)
- samples = self.collate_fn(samples)
+ try:
+ samples = self.collate_fn(samples, metadata)
+ except TypeError:
+ samples = self.collate_fn(samples)
if metadata and isinstance(samples, dict):
samples[DefaultDataKeys.METADATA] = metadata
self.callback.on_collate(samples, self.stage)
diff --git a/flash/core/data/process.py b/flash/core/data/process.py
index 7020e32d36..a1d6e56085 100644
--- a/flash/core/data/process.py
+++ b/flash/core/data/process.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import os
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence
@@ -24,7 +25,7 @@
import flash
from flash.core.data.batch import default_uncollate
from flash.core.data.callback import FlashCallback
-from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources
+from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext
@@ -360,18 +361,24 @@ def per_batch_transform(self, batch: Any) -> Any:
"""
return self.current_transform(batch)
- def collate(self, samples: Sequence) -> Any:
+ def collate(self, samples: Sequence, metadata=None) -> Any:
""" Transform to convert a sequence of samples to a collated batch. """
+ current_transform = self.current_transform
+ if current_transform is self._identity:
+ current_transform = self._default_collate
# the model can provide a custom ``collate_fn``.
collate_fn = self.get_state(CollateFn)
if collate_fn is not None:
- return collate_fn.collate_fn(samples)
-
- current_transform = self.current_transform
- if current_transform is self._identity:
- return self._default_collate(samples)
- return self.current_transform(samples)
+ collate_fn = collate_fn.collate_fn
+ else:
+ collate_fn = current_transform
+ # return collate_fn.collate_fn(samples)
+
+ parameters = inspect.signature(collate_fn).parameters
+ if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters:
+ return collate_fn(samples, metadata)
+ return collate_fn(samples)
def per_sample_transform_on_device(self, sample: Any) -> Any:
"""Transforms to apply to the data before the collation (per-sample basis).
diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py
index 80c6b6188c..d1ba3388b6 100644
--- a/flash/core/utilities/imports.py
+++ b/flash/core/utilities/imports.py
@@ -87,15 +87,24 @@ def _compare_version(package: str, op, version) -> bool:
_OPEN3D_AVAILABLE = _module_available("open3d")
_ASTEROID_AVAILABLE = _module_available("asteroid")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
+_SOUNDFILE_AVAILABLE = _module_available("soundfile")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")
+_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score")
+_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
+_DATASETS_AVAILABLE = _module_available("datasets")
if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
-_TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE
+_TEXT_AVAILABLE = all([
+ _TRANSFORMERS_AVAILABLE,
+ _ROUGE_SCORE_AVAILABLE,
+ _SENTENCEPIECE_AVAILABLE,
+ _DATASETS_AVAILABLE,
+])
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE
_VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE
_IMAGE_AVAILABLE = all([
@@ -108,10 +117,7 @@ def _compare_version(package: str, op, version) -> bool:
])
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE
-_AUDIO_AVAILABLE = all([
- _ASTEROID_AVAILABLE,
- _TORCHAUDIO_AVAILABLE,
-])
+_AUDIO_AVAILABLE = all([_ASTEROID_AVAILABLE, _TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE
_EXTRAS_AVAILABLE = {
diff --git a/flash_examples/serve/speech_recognition/client.py b/flash_examples/serve/speech_recognition/client.py
new file mode 100644
index 0000000000..c855a37204
--- /dev/null
+++ b/flash_examples/serve/speech_recognition/client.py
@@ -0,0 +1,27 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import base64
+from pathlib import Path
+
+import requests
+
+import flash
+
+with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
+ audio_str = base64.b64encode(f.read()).decode("UTF-8")
+
+body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
+resp = requests.post("http://127.0.0.1:8000/predict", json=body)
+
+print(resp.json())
diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/flash_examples/serve/speech_recognition/inference_server.py
new file mode 100644
index 0000000000..bbc4479624
--- /dev/null
+++ b/flash_examples/serve/speech_recognition/inference_server.py
@@ -0,0 +1,17 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.audio import SpeechRecognition
+
+model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt")
+model.serve()
diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py
new file mode 100644
index 0000000000..269148c60f
--- /dev/null
+++ b/flash_examples/speech_recognition.py
@@ -0,0 +1,40 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flash
+from flash.audio import SpeechRecognition, SpeechRecognitionData
+from flash.core.data.utils import download_data
+
+# # 1. Create the DataModule
+download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")
+
+datamodule = SpeechRecognitionData.from_json(
+ input_fields="file",
+ target_fields="text",
+ train_file="data/timit/train.json",
+ test_file="data/timit/test.json",
+)
+
+# 2. Build the task
+model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_test_batches=1)
+trainer.finetune(model, datamodule=datamodule, strategy='no_freeze')
+
+# 4. Predict on audio files!
+predictions = model.predict(["data/timit/example.wav"])
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("speech_recognition_model.pt")
diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt
index e608a13b78..570e7c89b8 100644
--- a/requirements/datatype_audio.txt
+++ b/requirements/datatype_audio.txt
@@ -1,2 +1,5 @@
asteroid>=0.5.1
torchaudio
+soundfile>=0.10.2
+transformers>=4.5
+datasets>=1.8
diff --git a/tests/audio/speech_recognition/__init__.py b/tests/audio/speech_recognition/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py
new file mode 100644
index 0000000000..2b87129210
--- /dev/null
+++ b/tests/audio/speech_recognition/test_data.py
@@ -0,0 +1,89 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+from pathlib import Path
+
+import pytest
+
+import flash
+from flash.audio import SpeechRecognitionData
+from flash.core.data.data_source import DefaultDataKeys
+from tests.helpers.utils import _AUDIO_TESTING
+
+path = str(Path(flash.ASSETS_ROOT) / "example.wav")
+sample = {'file': path, 'text': 'example input.'}
+
+TEST_CSV_DATA = f"""file,text
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+"""
+
+
+def csv_data(tmpdir):
+ path = Path(tmpdir) / "data.csv"
+ path.write_text(TEST_CSV_DATA)
+ return path
+
+
+def json_data(tmpdir, n_samples=5):
+ path = Path(tmpdir) / "data.json"
+ with path.open('w') as f:
+ f.write('\n'.join([json.dumps(sample) for x in range(n_samples)]))
+ return path
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.")
+def test_from_csv(tmpdir):
+ csv_path = csv_data(tmpdir)
+ dm = SpeechRecognitionData.from_csv("file", "text", train_file=csv_path, batch_size=1, num_workers=0)
+ batch = next(iter(dm.train_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.")
+def test_stage_test_and_valid(tmpdir):
+ csv_path = csv_data(tmpdir)
+ dm = SpeechRecognitionData.from_csv(
+ "file", "text", train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, num_workers=0
+ )
+ batch = next(iter(dm.val_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+ batch = next(iter(dm.test_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.")
+def test_from_json(tmpdir):
+ json_path = json_data(tmpdir)
+ dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0)
+ batch = next(iter(dm.train_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+
+@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.")
+def test_audio_module_not_found_error():
+ with pytest.raises(ModuleNotFoundError, match="[audio]"):
+ SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0)
diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py
new file mode 100644
index 0000000000..0c9773022d
--- /dev/null
+++ b/tests/audio/speech_recognition/test_data_model_integration.py
@@ -0,0 +1,83 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+from pathlib import Path
+
+import pytest
+from pytorch_lightning import Trainer
+
+import flash
+from flash.audio import SpeechRecognition, SpeechRecognitionData
+from tests.helpers.utils import _AUDIO_TESTING
+
+TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing
+
+path = str(Path(flash.ASSETS_ROOT) / "example.wav")
+sample = {'file': path, 'text': 'example input.'}
+
+TEST_CSV_DATA = f"""file,text
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+"""
+
+
+def csv_data(tmpdir):
+ path = Path(tmpdir) / "data.csv"
+ path.write_text(TEST_CSV_DATA)
+ return path
+
+
+def json_data(tmpdir, n_samples=5):
+ path = Path(tmpdir) / "data.json"
+ with path.open('w') as f:
+ f.write('\n'.join([json.dumps(sample) for x in range(n_samples)]))
+ return path
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_classification_csv(tmpdir):
+ csv_path = csv_data(tmpdir)
+
+ data = SpeechRecognitionData.from_csv(
+ "file",
+ "text",
+ train_file=csv_path,
+ num_workers=0,
+ batch_size=2,
+ )
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, datamodule=data)
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_classification_json(tmpdir):
+ json_path = json_data(tmpdir)
+
+ data = SpeechRecognitionData.from_json(
+ "file",
+ "text",
+ train_file=json_path,
+ num_workers=0,
+ batch_size=2,
+ )
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, datamodule=data)
diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py
new file mode 100644
index 0000000000..69cf6a7aa3
--- /dev/null
+++ b/tests/audio/speech_recognition/test_model.py
@@ -0,0 +1,94 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+from unittest import mock
+
+import numpy as np
+import pytest
+import torch
+
+from flash import Trainer
+from flash.audio import SpeechRecognition
+from flash.audio.speech_recognition.data import SpeechRecognitionPostprocess, SpeechRecognitionPreprocess
+from flash.core.data.data_source import DefaultDataKeys
+from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING
+
+# ======== Mock functions ========
+
+
+class DummyDataset(torch.utils.data.Dataset):
+
+ def __getitem__(self, index):
+ return {
+ DefaultDataKeys.INPUT: np.random.randn(86631),
+ DefaultDataKeys.TARGET: "some target text",
+ DefaultDataKeys.METADATA: {
+ "sampling_rate": 16000
+ },
+ }
+
+ def __len__(self) -> int:
+ return 100
+
+
+# ==============================
+
+TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_init_train(tmpdir):
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ train_dl = torch.utils.data.DataLoader(DummyDataset())
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, train_dl)
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_jit(tmpdir):
+ sample_input = {"input_values": torch.randn(size=torch.Size([1, 86631])).float()}
+ path = os.path.join(tmpdir, "test.pt")
+
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ model.eval()
+
+ # Huggingface model only supports `torch.jit.trace` with `strict=False`
+ model = torch.jit.trace(model, sample_input, strict=False)
+
+ torch.jit.save(model, path)
+ model = torch.jit.load(path)
+
+ out = model(sample_input)["logits"]
+ assert isinstance(out, torch.Tensor)
+ assert out.shape == torch.Size([1, 95, 12])
+
+
+@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+@mock.patch("flash._IS_TESTING", True)
+def test_serve():
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ # TODO: Currently only servable once a preprocess and postprocess have been attached
+ model._preprocess = SpeechRecognitionPreprocess()
+ model._postprocess = SpeechRecognitionPostprocess()
+ model.eval()
+ model.serve()
+
+
+@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.")
+def test_load_from_checkpoint_dependency_error():
+ with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[audio]'")):
+ SpeechRecognition.load_from_checkpoint("not_a_real_checkpoint.pt")
diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py
index 2b593cdd9e..b5ec52dec1 100644
--- a/tests/core/data/test_data_pipeline.py
+++ b/tests/core/data/test_data_pipeline.py
@@ -691,7 +691,7 @@ def test_step(self, batch, batch_idx):
assert len(batch) == 2
assert batch[0].shape == torch.Size([2, 1])
- def predict_step(self, batch, batch_idx, dataloader_idx):
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert batch[0][0] == 'a'
assert batch[0][1] == 'a'
assert batch[1][0] == 'b'
diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py
index 56b729e36e..bc3260b1a8 100644
--- a/tests/examples/test_scripts.py
+++ b/tests/examples/test_scripts.py
@@ -42,6 +42,10 @@
"audio_classification.py",
marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed")
),
+ pytest.param(
+ "speech_recognition.py",
+ marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed")
+ ),
pytest.param(
"image_classification.py",
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")