Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move transcriptions to individual cache files #519

Merged
merged 3 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions buzz/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging
import json
import os
import pickle
from typing import List
Expand All @@ -11,22 +11,56 @@
class TasksCache:
def __init__(self, cache_dir=user_cache_dir('Buzz')):
os.makedirs(cache_dir, exist_ok=True)
self.file_path = os.path.join(cache_dir, 'tasks')
self.cache_dir = cache_dir
self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks')
self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json')

def save(self, tasks: List[FileTranscriptionTask]):
with open(self.file_path, 'wb') as file:
pickle.dump(tasks, file)
self.save_json_tasks(tasks=tasks)

def load(self) -> List[FileTranscriptionTask]:
if os.path.exists(self.tasks_list_file_path):
return self.load_json_tasks()

try:
with open(self.file_path, 'rb') as file:
with open(self.pickle_cache_file_path, 'rb') as file:
return pickle.load(file)
except FileNotFoundError:
return []
except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache
os.remove(self.file_path)
os.remove(self.pickle_cache_file_path)
return []

def load_json_tasks(self) -> List[FileTranscriptionTask]:
with open(self.tasks_list_file_path, 'r') as file:
task_ids = json.load(file)

tasks = []
for task_id in task_ids:
try:
with open(self.get_task_path(task_id=task_id)) as file:
tasks.append(FileTranscriptionTask.from_json(file.read()))
except FileNotFoundError:
pass

return tasks

def save_json_tasks(self, tasks: List[FileTranscriptionTask]):
json_str = json.dumps([task.id for task in tasks])
with open(self.tasks_list_file_path, "w") as file:
file.write(json_str)

for task in tasks:
file_path = self.get_task_path(task_id=task.id)
json_str = task.to_json()
with open(file_path, "w") as file:
file.write(json_str)

def get_task_path(self, task_id: int):
path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json')
os.makedirs(os.path.dirname(path), exist_ok=True)
return path

def clear(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
if os.path.exists(self.pickle_cache_file_path):
os.remove(self.pickle_cache_file_path)
8 changes: 6 additions & 2 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,12 @@ def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, Fil
file_path, transcription_options, file_transcription_options, model_path)
self.add_task(task)

def update_task_table_row(self, task: FileTranscriptionTask):
def load_task(self, task: FileTranscriptionTask):
self.table_widget.upsert_task(task)
self.tasks[task.id] = task

def update_task_table_row(self, task: FileTranscriptionTask):
self.load_task(task=task)
self.tasks_changed.emit()

@staticmethod
Expand Down Expand Up @@ -965,6 +968,7 @@ def open_transcription_viewer(self, task_id: int):

transcription_viewer_widget = TranscriptionViewerWidget(
transcription_task=task, parent=self, flags=Qt.WindowType.Window)
transcription_viewer_widget.task_changed.connect(self.on_tasks_changed)
transcription_viewer_widget.show()

def add_task(self, task: FileTranscriptionTask):
Expand All @@ -978,7 +982,7 @@ def load_tasks_from_cache(self):
task.status = None
self.transcriber_worker.add_task(task)
else:
self.update_task_table_row(task)
self.load_task(task=task)

def save_tasks_to_cache(self):
self.tasks_cache.save(list(self.tasks.values()))
Expand Down
8 changes: 4 additions & 4 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class TranscriptionModel:
}


def get_hugging_face_dataset_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'
def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}'


def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
Expand Down Expand Up @@ -132,8 +132,8 @@ def __init__(self, model: TranscriptionModel):
def run(self) -> None:
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
url = get_hugging_face_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = get_whisper_cpp_file_path(
size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
Expand Down
6 changes: 4 additions & 2 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tqdm
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from dataclasses_json import dataclass_json, config, Exclude
from whisper import tokenizer

from . import transformers_whisper
Expand Down Expand Up @@ -65,7 +66,7 @@ class TranscriptionOptions:
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
openai_access_token: str = ''
openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS))


@dataclass()
Expand All @@ -74,6 +75,7 @@ class FileTranscriptionOptions:
output_formats: Set['OutputFormat'] = field(default_factory=set)


@dataclass_json
@dataclass
class FileTranscriptionTask:
class Status(enum.Enum):
Expand All @@ -87,7 +89,7 @@ class Status(enum.Enum):
transcription_options: TranscriptionOptions
file_transcription_options: FileTranscriptionOptions
model_path: str
id: int = field(default_factory=lambda: randint(0, 1_000_000))
id: int = field(default_factory=lambda: randint(0, 100_000_000))
segments: List[Segment] = field(default_factory=list)
status: Optional[Status] = None
fraction_completed = 0.0
Expand Down
10 changes: 7 additions & 3 deletions buzz/widgets/transcription_viewer_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from PyQt6.QtCore import Qt
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence, QAction
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QMenu, QPushButton, QVBoxLayout, QFileDialog

Expand All @@ -16,16 +16,18 @@

class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask
task_changed = pyqtSignal()

class ChangeSegmentTextCommand(QUndoCommand):
def __init__(self, table_widget: TranscriptionSegmentsEditorWidget, segments: List[Segment],
segment_index: int, segment_text: str):
segment_index: int, segment_text: str, task_changed: pyqtSignal):
super().__init__()

self.table_widget = table_widget
self.segments = segments
self.segment_index = segment_index
self.segment_text = segment_text
self.task_changed = task_changed

self.previous_segment_text = self.segments[self.segment_index].text

Expand All @@ -41,6 +43,7 @@ def set_segment_text(self, text: str):
self.table_widget.set_segment_text(self.segment_index, text)
self.table_widget.blockSignals(False)
self.segments[self.segment_index].text = text
self.task_changed.emit()

def __init__(
self, transcription_task: FileTranscriptionTask,
Expand Down Expand Up @@ -102,7 +105,8 @@ def on_segment_text_changed(self, event: tuple):
segment_index, segment_text = event
self.undo_stack.push(
self.ChangeSegmentTextCommand(table_widget=self.table_widget, segments=self.transcription_task.segments,
segment_index=segment_index, segment_text=segment_text))
segment_index=segment_index, segment_text=segment_text,
task_changed=self.task_changed))

def on_menu_triggered(self, action: QAction):
output_format = OutputFormat[action.text()]
Expand Down
86 changes: 85 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ faster-whisper = "^0.4.1"
keyring = "^23.13.1"
openai-whisper = "v20230124"
platformdirs = "^3.5.3"
dataclasses-json = "^0.5.9"

[tool.poetry.group.dev.dependencies]
autopep8 = "^1.7.0"
Expand Down