Skip to content

Commit

Permalink
fix: disable whisper, faster_whisper, and hugging_face transcriptions…
Browse files Browse the repository at this point in the history
… in linux build (#659)
  • Loading branch information
chidiwilliams authored Jan 5, 2024
1 parent 5456774 commit c7be2f1
Show file tree
Hide file tree
Showing 26 changed files with 512 additions and 342 deletions.
3 changes: 0 additions & 3 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@ omit =

[html]
directory = coverage/html

[report]
fail_under = 75
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ jobs:
include:
- os: macos-latest
- os: windows-latest
- os: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
Expand Down Expand Up @@ -298,7 +297,6 @@ jobs:
include:
- os: macos-latest
- os: windows-latest
- os: ubuntu-20.04
needs: [ build, test ]
if: startsWith(github.ref, 'refs/tags/')
steps:
Expand Down
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ clean:
rm -f buzz/whisper_cpp.py
rm -rf dist/* || true

COVERAGE_THRESHOLD := 75
ifeq ($(UNAME_S),Linux)
COVERAGE_THRESHOLD := 70
endif

test: buzz/whisper_cpp.py translation_mo
pytest -vv --cov=buzz --cov-report=xml --cov-report=html --benchmark-skip
pytest -vv --cov=buzz --cov-report=xml --cov-report=html --benchmark-skip --cov-fail-under=${COVERAGE_THRESHOLD}

benchmarks: buzz/whisper_cpp.py translation_mo
pytest -vv --benchmark-only --benchmark-json benchmarks.json
Expand Down
63 changes: 59 additions & 4 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,38 @@
import hashlib
import logging
import os
import shutil
import subprocess
import sys
import tempfile
import warnings
from dataclasses import dataclass
from typing import Optional
import shutil

import faster_whisper
import huggingface_hub
import requests
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
from platformdirs import user_cache_dir
from tqdm.auto import tqdm

whisper = None
faster_whisper = None
huggingface_hub = None
if sys.platform != "linux":
import faster_whisper
import whisper
import huggingface_hub

# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
# the DLL in some envs is fixed.
LOADED_WHISPER_DLL = False
try:
import buzz.whisper_cpp as whisper_cpp # noqa: F401

LOADED_WHISPER_DLL = True
except ImportError:
logging.exception("")


class WhisperModelSize(str, enum.Enum):
TINY = "tiny"
Expand All @@ -42,6 +58,38 @@ class ModelType(enum.Enum):
FASTER_WHISPER = "Faster Whisper"
OPEN_AI_WHISPER_API = "OpenAI Whisper API"

def supports_recording(self):
# Live transcription with OpenAI Whisper API not supported
return self != ModelType.OPEN_AI_WHISPER_API

def is_available(self):
if (
# Hide Whisper.cpp option if whisper.dll did not load correctly.
# See: https://github.com/chidiwilliams/buzz/issues/274,
# https://github.com/chidiwilliams/buzz/issues/197
(self == ModelType.WHISPER_CPP and not LOADED_WHISPER_DLL)
# Disable Whisper and Faster Whisper options
# on Linux due to execstack errors on Snap
or (
sys.platform == "linux"
and self
in (
ModelType.WHISPER,
ModelType.FASTER_WHISPER,
ModelType.HUGGING_FACE,
)
)
):
return False
return True

def is_manually_downloadable(self):
return self in (
ModelType.WHISPER,
ModelType.WHISPER_CPP,
ModelType.FASTER_WHISPER,
)


@dataclass()
class TranscriptionModel:
Expand Down Expand Up @@ -76,6 +124,13 @@ def open_file_location(self):
return
self.open_path(path=os.path.dirname(model_path))

@staticmethod
def default():
model_type = next(
model_type for model_type in ModelType if model_type.is_available()
)
return TranscriptionModel(model_type=model_type)

@staticmethod
def open_path(path: str):
if sys.platform == "win32":
Expand Down
20 changes: 11 additions & 9 deletions buzz/recording_transcriber.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
import datetime
import logging
import sys
import threading
from typing import Optional

import numpy as np
import sounddevice
import whisper
from PyQt6.QtCore import QObject, pyqtSignal
from sounddevice import PortAudioError

from buzz import transformers_whisper
from buzz import transformers_whisper, whisper_audio
from buzz.model_loader import ModelType
from buzz.transcriber import TranscriptionOptions, WhisperCpp, whisper_cpp_params
from buzz.transformers_whisper import TransformersWhisper

if sys.platform != "linux":
import whisper


class RecordingTranscriber(QObject):
transcription = pyqtSignal(str)
finished = pyqtSignal()
error = pyqtSignal(str)
is_running = False
SAMPLE_RATE = whisper_audio.SAMPLE_RATE
MAX_QUEUE_SIZE = 10

def __init__(
Expand Down Expand Up @@ -149,17 +153,15 @@ def get_device_sample_rate(device_id: Optional[int]) -> int:
provided by Whisper if the microphone supports it, or else it uses the device's default
sample rate.
"""
whisper_sample_rate = whisper.audio.SAMPLE_RATE
sample_rate = whisper_audio.SAMPLE_RATE
try:
sounddevice.check_input_settings(
device=device_id, samplerate=whisper_sample_rate
)
return whisper_sample_rate
sounddevice.check_input_settings(device=device_id, samplerate=sample_rate)
return sample_rate
except PortAudioError:
device_info = sounddevice.query_devices(device=device_id)
if isinstance(device_info, dict):
return int(device_info.get("default_samplerate", whisper_sample_rate))
return whisper_sample_rate
return int(device_info.get("default_samplerate", sample_rate))
return sample_rate

def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
# Try to enqueue the next block. If the queue is already full, drop the block.
Expand Down
2 changes: 1 addition & 1 deletion buzz/store/keyring_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_password(self, key: Key) -> str:
return ""
return password
except (KeyringLocked, KeyringError) as exc:
logging.error("Unable to read from keyring: %s", exc)
logging.warning("Unable to read from keyring: %s", exc)
return ""

def set_password(self, username: Key, password: str) -> None:
Expand Down
135 changes: 114 additions & 21 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,22 @@
from threading import Thread
from typing import Any, List, Optional, Tuple, Union, Set

import faster_whisper
import numpy as np
import openai
import stable_whisper
import tqdm
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from dataclasses_json import dataclass_json, config, Exclude
from whisper import tokenizer

from . import transformers_whisper
from buzz.model_loader import whisper_cpp
from . import transformers_whisper, whisper_audio
from .conn import pipe_stderr
from .locale import _
from .model_loader import TranscriptionModel, ModelType

# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
# the DLL in some envs is fixed.
LOADED_WHISPER_DLL = False
try:
import buzz.whisper_cpp as whisper_cpp

LOADED_WHISPER_DLL = True
except ImportError:
logging.exception("")
if sys.platform != "linux":
import faster_whisper
import whisper
import stable_whisper

DEFAULT_WHISPER_TEMPERATURE = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)

Expand All @@ -58,7 +49,108 @@ class Segment:
text: str


LANGUAGES = tokenizer.LANGUAGES
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}


@dataclass()
Expand Down Expand Up @@ -168,6 +260,7 @@ def run(self):
try:
segments = self.transcribe()
except Exception as exc:
logging.error(exc)
self.error.emit(exc)
return

Expand Down Expand Up @@ -230,17 +323,17 @@ def transcribe(self) -> List[Segment]:
model_path = self.model_path

logging.debug(
"Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, "
"word level timings = %s",
"Starting whisper_cpp file transcription, file path = %s, language = %s, "
"task = %s, model_path = %s, word level timings = %s",
self.file_path,
self.language,
self.task,
model_path,
self.word_level_timings,
)

audio = whisper.audio.load_audio(self.file_path)
self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE
audio = whisper_audio.load_audio(self.file_path)
self.duration_audio_ms = len(audio) * 1000 / whisper_audio.SAMPLE_RATE

whisper_params = whisper_cpp_params(
language=self.language if self.language is not None else "",
Expand Down Expand Up @@ -722,7 +815,7 @@ def __init__(self, model: str) -> None:

def transcribe(self, audio: Union[np.ndarray, str], params: Any):
if isinstance(audio, str):
audio = whisper.audio.load_audio(audio)
audio = whisper_audio.load_audio(audio)

logging.debug("Loaded audio with length = %s", len(audio))

Expand Down
Loading

0 comments on commit c7be2f1

Please sign in to comment.