Skip to content

Commit

Permalink
Adding custom model size for Whisper.cpp and Faster Whisper (#820)
Browse files Browse the repository at this point in the history
  • Loading branch information
raivisdejus authored Jul 2, 2024
1 parent 4d06273 commit 2eeb03a
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 47 deletions.
110 changes: 84 additions & 26 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import whisper
import huggingface_hub

from buzz.locale import _

# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
# the DLL in some envs is fixed.
Expand All @@ -46,6 +48,7 @@ class WhisperModelSize(str, enum.Enum):
LARGE = "large"
LARGEV2 = "large-v2"
LARGEV3 = "large-v3"
CUSTOM = "custom"

def to_faster_whisper_model_size(self) -> str:
if self == WhisperModelSize.LARGE:
Expand Down Expand Up @@ -112,9 +115,15 @@ def is_manually_downloadable(self):

@dataclass()
class TranscriptionModel:
model_type: ModelType = ModelType.WHISPER
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
hugging_face_model_id: Optional[str] = "openai/whisper-tiny"
def __init__(
self,
model_type: ModelType = ModelType.WHISPER,
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY,
hugging_face_model_id: Optional[str] = ""
):
self.model_type = model_type
self.whisper_model_size = whisper_model_size
self.hugging_face_model_id = hugging_face_model_id

def __str__(self):
match self.model_type:
Expand All @@ -135,10 +144,16 @@ def is_deletable(self):
return (
self.model_type == ModelType.WHISPER
or self.model_type == ModelType.WHISPER_CPP
or self.model_type == ModelType.FASTER_WHISPER
) and self.get_local_model_path() is not None

def open_file_location(self):
model_path = self.get_local_model_path()

if (self.model_type == ModelType.HUGGING_FACE
or self.model_type == ModelType.FASTER_WHISPER):
model_path = os.path.dirname(model_path)

if model_path is None:
return
self.open_path(path=os.path.dirname(model_path))
Expand All @@ -160,6 +175,17 @@ def open_path(path: str):

def delete_local_file(self):
model_path = self.get_local_model_path()

if (self.model_type == ModelType.HUGGING_FACE
or self.model_type == ModelType.FASTER_WHISPER):
model_path = os.path.dirname(os.path.dirname(model_path))

logging.debug("Deleting model directory: %s", model_path)

shutil.rmtree(model_path, ignore_errors=True)
return

logging.debug("Deleting model file: %s", model_path)
os.remove(model_path)

def get_local_model_path(self) -> Optional[str]:
Expand All @@ -178,7 +204,7 @@ def get_local_model_path(self) -> Optional[str]:
if self.model_type == ModelType.FASTER_WHISPER:
try:
return download_faster_whisper_model(
size=self.whisper_model_size.value, local_files_only=True
model=self, local_files_only=True
)
except (ValueError, FileNotFoundError):
return None
Expand Down Expand Up @@ -208,6 +234,7 @@ def get_local_model_path(self) -> Optional[str]:
"large-v1": "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d",
"large-v2": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
"large-v3": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2",
"custom": None,
}


Expand All @@ -217,6 +244,10 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:

def get_whisper_file_path(size: WhisperModelSize) -> str:
root_dir = os.path.join(model_root_dir, "whisper")

if size == WhisperModelSize.CUSTOM:
return os.path.join(root_dir, "custom")

url = whisper._MODELS[size.value]
return os.path.join(root_dir, os.path.basename(url))

Expand Down Expand Up @@ -286,13 +317,17 @@ def download_from_huggingface(
allow_patterns: List[str],
progress: pyqtSignal(tuple),
):
progress.emit((1, 100))
progress.emit((0, 100))

model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
cache_dir=model_root_dir
)
try:
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
cache_dir=model_root_dir
)
except Exception as exc:
logging.exception(exc)
return ""

progress.emit((1, 100))

Expand All @@ -302,29 +337,40 @@ def download_from_huggingface(
model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size)
model_download_monitor.start_monitoring()

huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
cache_dir=model_root_dir
)
try:
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
cache_dir=model_root_dir
)
except Exception as exc:
logging.exception(exc)
model_download_monitor.stop_monitoring()
return ""

model_download_monitor.stop_monitoring()

return model_root


def download_faster_whisper_model(
size: str, local_files_only=False, progress: pyqtSignal(tuple) = None
model: TranscriptionModel, local_files_only=False, progress: pyqtSignal(tuple) = None
):
if size not in faster_whisper.utils._MODELS:
size = model.whisper_model_size.to_faster_whisper_model_size()
custom_repo_id = model.hugging_face_model_id

if size != WhisperModelSize.CUSTOM and size not in faster_whisper.utils._MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s"
% (size, ", ".join(faster_whisper.utils._MODELS))
)

logging.debug("Downloading Faster Whisper model: %s", size)
if size == WhisperModelSize.CUSTOM and custom_repo_id == "":
raise ValueError("Custom model id is not provided")

if size == WhisperModelSize.LARGEV3:
if size == WhisperModelSize.CUSTOM:
repo_id = custom_repo_id
elif size == WhisperModelSize.LARGEV3:
repo_id = "Systran/faster-whisper-large-v3"
else:
repo_id = "guillaumekln/faster-whisper-%s" % size
Expand Down Expand Up @@ -358,20 +404,28 @@ class Signals(QObject):
progress = pyqtSignal(tuple) # (current, total)
error = pyqtSignal(str)

def __init__(self, model: TranscriptionModel):
def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None):
super().__init__()

self.signals = self.Signals()
self.model = model
self.stopped = False
self.custom_model_url = custom_model_url

def run(self) -> None:
logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id)

if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.to_whisper_cpp_model_size()
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)

if self.custom_model_url:
url = self.custom_model_url
else:
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)

file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
return self.download_model_to_path(
Expand All @@ -388,9 +442,13 @@ def run(self) -> None:

if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
model=self.model,
progress=self.signals.progress,
)

if model_path == "":
self.signals.error.emit(_("Error"))

self.signals.finished.emit(model_path)
return

Expand All @@ -417,7 +475,7 @@ def download_model_to_path(
if downloaded:
self.signals.finished.emit(file_path)
except requests.RequestException:
self.signals.error.emit("A connection error occurred")
self.signals.error.emit(_("A connection error occurred"))
logging.exception("")
except Exception as exc:
self.signals.error.emit(str(exc))
Expand Down
32 changes: 32 additions & 0 deletions buzz/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class Key(enum.Enum):

DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name"
CUSTOM_OPENAI_BASE_URL = "transcriber/custom-openai-base-url"
CUSTOM_FASTER_WHISPER_ID = "transcriber/custom-faster-whisper-id"
HUGGINGFACE_MODEL_ID = "transcriber/huggingface-model-id"

SHORTCUTS = "shortcuts"

Expand All @@ -50,6 +52,36 @@ class Key(enum.Enum):
def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)

def save_custom_model_id(self, model) -> None:
from buzz.model_loader import ModelType
match model.model_type:
case ModelType.FASTER_WHISPER:
self.set_value(
Settings.Key.CUSTOM_FASTER_WHISPER_ID,
model.hugging_face_model_id,
)
case ModelType.HUGGING_FACE:
self.set_value(
Settings.Key.HUGGINGFACE_MODEL_ID,
model.hugging_face_model_id,
)

def load_custom_model_id(self, model) -> str:
from buzz.model_loader import ModelType
match model.model_type:
case ModelType.FASTER_WHISPER:
return self.value(
Settings.Key.CUSTOM_FASTER_WHISPER_ID,
"",
)
case ModelType.HUGGING_FACE:
return self.value(
Settings.Key.HUGGINGFACE_MODEL_ID,
"",
)

return ""

def value(
self,
key: Key,
Expand Down
9 changes: 7 additions & 2 deletions buzz/transcriber/whisper_file_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from PyQt6.QtCore import QObject

from buzz.conn import pipe_stderr
from buzz.model_loader import ModelType
from buzz.model_loader import ModelType, WhisperModelSize
from buzz.transformers_whisper import TransformersWhisper
from buzz.transcriber.file_transcriber import FileTranscriber
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment
Expand Down Expand Up @@ -131,8 +131,13 @@ def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:

@classmethod
def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
if task.transcription_options.model.whisper_model_size == WhisperModelSize.CUSTOM:
model_size_or_path = task.transcription_options.model.hugging_face_model_id
else:
model_size_or_path = task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()

model = faster_whisper.WhisperModel(
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
model_size_or_path=model_size_or_path
)
whisper_segments, info = model.transcribe(
audio=task.file_path,
Expand Down
Loading

0 comments on commit 2eeb03a

Please sign in to comment.