diff --git a/CHANGELOG.md b/CHANGELOG.md index f8cc8447d8..981654fc32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated `FlashFinetuning` callback to use separate hooks that lets users use the freezing logic provided out-of-the-box from flash, route FlashFinetuning through a registry. ([#830](https://github.com/PyTorchLightning/lightning-flash/pull/830)) +- Changed the `SpeechRecognition` task to use `AutoModelForCTC` rather than just `Wav2Vec2ForCTC` ([#874](https://github.com/PyTorchLightning/lightning-flash/pull/874)) + ### Deprecated - Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) diff --git a/flash/audio/speech_recognition/backbone.py b/flash/audio/speech_recognition/backbone.py index 013090e3a2..3c4298e1e5 100644 --- a/flash/audio/speech_recognition/backbone.py +++ b/flash/audio/speech_recognition/backbone.py @@ -20,7 +20,7 @@ SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") if _AUDIO_AVAILABLE: - from transformers import Wav2Vec2ForCTC + from transformers import AutoModelForCTC, Wav2Vec2ForCTC WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"] @@ -31,6 +31,6 @@ providers=[_HUGGINGFACE, _FAIRSEQ], ) - HUGGINGFACE_BACKBONES = ExternalRegistry(Wav2Vec2ForCTC.from_pretrained, "backbones", providers=_HUGGINGFACE) + HUGGINGFACE_BACKBONES = ExternalRegistry(AutoModelForCTC.from_pretrained, "backbones", providers=_HUGGINGFACE) SPEECH_RECOGNITION_BACKBONES += HUGGINGFACE_BACKBONES diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 2353d8d18f..bb6d7baba6 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,4 +1,4 @@ torchaudio librosa>=0.8.1 -transformers>=4.5 +transformers>=4.11.0 datasets>=1.8