Skip to content

Commit

Permalink
Will download HF models to Buzz cache folder (#775)
Browse files Browse the repository at this point in the history
  • Loading branch information
raivisdejus authored Jun 4, 2024
1 parent 905716c commit 045bd21
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
19 changes: 15 additions & 4 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
except ImportError:
logging.exception("")

model_root_dir = user_cache_dir("Buzz")
model_root_dir = os.path.join(model_root_dir, "models")
os.makedirs(model_root_dir, exist_ok=True)

logging.debug("Model root directory: %s", model_root_dir)

class WhisperModelSize(str, enum.Enum):
TINY = "tiny"
Expand Down Expand Up @@ -182,7 +187,8 @@ def get_local_model_path(self) -> Optional[str]:
return huggingface_hub.snapshot_download(
self.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
local_files_only=True
local_files_only=True,
cache_dir=model_root_dir
)
except (ValueError, FileNotFoundError):
return None
Expand All @@ -200,8 +206,7 @@ def get_local_model_path(self) -> Optional[str]:


def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
root_dir = user_cache_dir("Buzz")
return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin")
return os.path.join(model_root_dir, f"ggml-model-whisper-{size.value}.bin")


def get_whisper_file_path(size: WhisperModelSize) -> str:
Expand All @@ -223,8 +228,11 @@ def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size

@staticmethod
def get_tmp_download_root(model_root):

logging.debug(f"=============== model_root: {model_root}")

normalized_model_root = os.path.normpath(model_root)
normalized_hub_path = os.path.normpath("huggingface/hub/")
normalized_hub_path = os.path.normpath("/models/")
index = normalized_model_root.find(normalized_hub_path)
if index == -1:
raise ValueError(f"Invalid model_root, '{normalized_hub_path}' not found")
Expand Down Expand Up @@ -272,6 +280,7 @@ def download_from_huggingface(
model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
cache_dir=model_root_dir
)

progress.emit((1, 100))
Expand All @@ -285,6 +294,7 @@ def download_from_huggingface(
huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
cache_dir=model_root_dir
)

model_download_monitor.stop_monitoring()
Expand Down Expand Up @@ -315,6 +325,7 @@ def download_faster_whisper_model(
repo_id,
allow_patterns=allow_patterns,
local_files_only=True,
cache_dir=model_root_dir
)

return download_from_huggingface(
Expand Down
3 changes: 1 addition & 2 deletions docs/docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ sidebar_position: 5
1. **Where are the models stored?**

The Whisper models are stored in `~/.cache/whisper`. The Whisper.cpp models are stored in `~/Library/Caches/Buzz`
(Mac OS), `~/.cache/Buzz` (Unix), or `C:\Users\<username>\AppData\Local\Buzz\Buzz\Cache` (Windows). The Hugging Face
models are stored in `~/.cache/huggingface/hub`.
(Mac OS), `~/.cache/Buzz` (Unix), or `C:\Users\<username>\AppData\Local\Buzz\Buzz\Cache` (Windows).

2. **What can I try if the transcription runs too slowly?**

Expand Down

0 comments on commit 045bd21

Please sign in to comment.