From 472bdc046502037dd664d82a5451e670a8d552d2 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Tue, 4 Jul 2023 01:11:13 +0100 Subject: [PATCH] Move transcriptions to individual cache files --- buzz/cache.py | 50 ++++++++++-- buzz/gui.py | 8 +- buzz/transcriber.py | 6 +- buzz/widgets/transcription_viewer_widget.py | 10 ++- poetry.lock | 86 ++++++++++++++++++++- pyproject.toml | 1 + 6 files changed, 145 insertions(+), 16 deletions(-) diff --git a/buzz/cache.py b/buzz/cache.py index 468b08585..26ae35a9a 100644 --- a/buzz/cache.py +++ b/buzz/cache.py @@ -1,4 +1,4 @@ -import logging +import json import os import pickle from typing import List @@ -11,22 +11,56 @@ class TasksCache: def __init__(self, cache_dir=user_cache_dir('Buzz')): os.makedirs(cache_dir, exist_ok=True) - self.file_path = os.path.join(cache_dir, 'tasks') + self.cache_dir = cache_dir + self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks') + self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json') def save(self, tasks: List[FileTranscriptionTask]): - with open(self.file_path, 'wb') as file: - pickle.dump(tasks, file) + self.save_json_tasks(tasks=tasks) def load(self) -> List[FileTranscriptionTask]: + if os.path.exists(self.tasks_list_file_path): + return self.load_json_tasks() + try: - with open(self.file_path, 'rb') as file: + with open(self.pickle_cache_file_path, 'rb') as file: return pickle.load(file) except FileNotFoundError: return [] except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache - os.remove(self.file_path) + os.remove(self.pickle_cache_file_path) return [] + def load_json_tasks(self) -> List[FileTranscriptionTask]: + with open(self.tasks_list_file_path, 'r') as file: + task_ids = json.load(file) + + tasks = [] + for task_id in task_ids: + try: + with open(self.get_task_path(task_id=task_id)) as file: + tasks.append(FileTranscriptionTask.from_json(file.read())) + except FileNotFoundError: + pass + + return tasks + + def save_json_tasks(self, tasks: List[FileTranscriptionTask]): + json_str = json.dumps([task.id for task in tasks]) + with open(self.tasks_list_file_path, "w") as file: + file.write(json_str) + + for task in tasks: + file_path = self.get_task_path(task_id=task.id) + json_str = task.to_json() + with open(file_path, "w") as file: + file.write(json_str) + + def get_task_path(self, task_id: int): + path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json') + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + def clear(self): - if os.path.exists(self.file_path): - os.remove(self.file_path) + if os.path.exists(self.pickle_cache_file_path): + os.remove(self.pickle_cache_file_path) diff --git a/buzz/gui.py b/buzz/gui.py index 6c98497ff..5780af28a 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -866,9 +866,12 @@ def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, Fil file_path, transcription_options, file_transcription_options, model_path) self.add_task(task) - def update_task_table_row(self, task: FileTranscriptionTask): + def load_task(self, task: FileTranscriptionTask): self.table_widget.upsert_task(task) self.tasks[task.id] = task + + def update_task_table_row(self, task: FileTranscriptionTask): + self.load_task(task=task) self.tasks_changed.emit() @staticmethod @@ -965,6 +968,7 @@ def open_transcription_viewer(self, task_id: int): transcription_viewer_widget = TranscriptionViewerWidget( transcription_task=task, parent=self, flags=Qt.WindowType.Window) + transcription_viewer_widget.task_changed.connect(self.on_tasks_changed) transcription_viewer_widget.show() def add_task(self, task: FileTranscriptionTask): @@ -978,7 +982,7 @@ def load_tasks_from_cache(self): task.status = None self.transcriber_worker.add_task(task) else: - self.update_task_table_row(task) + self.load_task(task=task) def save_tasks_to_cache(self): self.tasks_cache.save(list(self.tasks.values())) diff --git a/buzz/transcriber.py b/buzz/transcriber.py index 7b82e4b3d..1b9628435 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -22,6 +22,7 @@ import tqdm import whisper from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot +from dataclasses_json import dataclass_json, config, Exclude from whisper import tokenizer from . import transformers_whisper @@ -65,7 +66,7 @@ class TranscriptionOptions: word_level_timings: bool = False temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE initial_prompt: str = '' - openai_access_token: str = '' + openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS)) @dataclass() @@ -74,6 +75,7 @@ class FileTranscriptionOptions: output_formats: Set['OutputFormat'] = field(default_factory=set) +@dataclass_json @dataclass class FileTranscriptionTask: class Status(enum.Enum): @@ -87,7 +89,7 @@ class Status(enum.Enum): transcription_options: TranscriptionOptions file_transcription_options: FileTranscriptionOptions model_path: str - id: int = field(default_factory=lambda: randint(0, 1_000_000)) + id: int = field(default_factory=lambda: randint(0, 100_000_000)) segments: List[Segment] = field(default_factory=list) status: Optional[Status] = None fraction_completed = 0.0 diff --git a/buzz/widgets/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer_widget.py index e0e9c8cef..ce430aa34 100644 --- a/buzz/widgets/transcription_viewer_widget.py +++ b/buzz/widgets/transcription_viewer_widget.py @@ -1,6 +1,6 @@ from typing import List, Optional -from PyQt6.QtCore import Qt +from PyQt6.QtCore import Qt, pyqtSignal from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence, QAction from PyQt6.QtWidgets import QWidget, QHBoxLayout, QMenu, QPushButton, QVBoxLayout, QFileDialog @@ -16,16 +16,18 @@ class TranscriptionViewerWidget(QWidget): transcription_task: FileTranscriptionTask + task_changed = pyqtSignal() class ChangeSegmentTextCommand(QUndoCommand): def __init__(self, table_widget: TranscriptionSegmentsEditorWidget, segments: List[Segment], - segment_index: int, segment_text: str): + segment_index: int, segment_text: str, task_changed: pyqtSignal): super().__init__() self.table_widget = table_widget self.segments = segments self.segment_index = segment_index self.segment_text = segment_text + self.task_changed = task_changed self.previous_segment_text = self.segments[self.segment_index].text @@ -41,6 +43,7 @@ def set_segment_text(self, text: str): self.table_widget.set_segment_text(self.segment_index, text) self.table_widget.blockSignals(False) self.segments[self.segment_index].text = text + self.task_changed.emit() def __init__( self, transcription_task: FileTranscriptionTask, @@ -102,7 +105,8 @@ def on_segment_text_changed(self, event: tuple): segment_index, segment_text = event self.undo_stack.push( self.ChangeSegmentTextCommand(table_widget=self.table_widget, segments=self.transcription_task.segments, - segment_index=segment_index, segment_text=segment_text)) + segment_index=segment_index, segment_text=segment_text, + task_changed=self.task_changed)) def on_menu_triggered(self, action: QAction): output_format = OutputFormat[action.text()] diff --git a/poetry.lock b/poetry.lock index 6e07d1833..dd7f0fb68 100644 --- a/poetry.lock +++ b/poetry.lock @@ -680,6 +680,26 @@ files = [ {file = "ctypesgen-1.1.1.tar.gz", hash = "sha256:deaa2d64a95d90196a2e8a689cf9b952be6f3366f81e835245354bf9dbac92f6"}, ] +[[package]] +name = "dataclasses-json" +version = "0.5.9" +description = "Easily serialize dataclasses to and from JSON" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "dataclasses-json-0.5.9.tar.gz", hash = "sha256:e9ac87b73edc0141aafbce02b44e93553c3123ad574958f0fe52a534b6707e8e"}, + {file = "dataclasses_json-0.5.9-py3-none-any.whl", hash = "sha256:1280542631df1c375b7bc92e5b86d39e06c44760d7e3571a537b3b8acabf2f0c"}, +] + +[package.dependencies] +marshmallow = ">=3.3.0,<4.0.0" +marshmallow-enum = ">=1.5.1,<2.0.0" +typing-inspect = ">=0.4.0" + +[package.extras] +dev = ["flake8", "hypothesis", "ipython", "mypy (>=0.710)", "portray", "pytest (>=7.2.0)", "setuptools", "simplejson", "twine", "types-dataclasses", "wheel"] + [[package]] name = "dill" version = "0.3.6" @@ -1181,6 +1201,42 @@ files = [ [package.dependencies] altgraph = ">=0.17" +[[package]] +name = "marshmallow" +version = "3.19.0" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "marshmallow-3.19.0-py3-none-any.whl", hash = "sha256:93f0958568da045b0021ec6aeb7ac37c81bfcccbb9a0e7ed8559885070b3a19b"}, + {file = "marshmallow-3.19.0.tar.gz", hash = "sha256:90032c0fd650ce94b6ec6dc8dfeb0e3ff50c144586462c389b81a07205bedb78"}, +] + +[package.dependencies] +packaging = ">=17.0" + +[package.extras] +dev = ["flake8 (==5.0.4)", "flake8-bugbear (==22.10.25)", "mypy (==0.990)", "pre-commit (>=2.4,<3.0)", "pytest", "pytz", "simplejson", "tox"] +docs = ["alabaster (==0.7.12)", "autodocsumm (==0.2.9)", "sphinx (==5.3.0)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"] +lint = ["flake8 (==5.0.4)", "flake8-bugbear (==22.10.25)", "mypy (==0.990)", "pre-commit (>=2.4,<3.0)"] +tests = ["pytest", "pytz", "simplejson"] + +[[package]] +name = "marshmallow-enum" +version = "1.5.1" +description = "Enum field for Marshmallow" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "marshmallow-enum-1.5.1.tar.gz", hash = "sha256:38e697e11f45a8e64b4a1e664000897c659b60aa57bfa18d44e226a9920b6e58"}, + {file = "marshmallow_enum-1.5.1-py2.py3-none-any.whl", hash = "sha256:57161ab3dbfde4f57adeb12090f39592e992b9c86d206d02f6bd03ebec60f072"}, +] + +[package.dependencies] +marshmallow = ">=2.0.0" + [[package]] name = "mccabe" version = "0.7.0" @@ -1307,6 +1363,18 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +category = "main" +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nodeenv" version = "1.8.0" @@ -2340,6 +2408,22 @@ files = [ {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "urllib3" version = "2.0.3" @@ -2585,4 +2669,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.9.13,<3.11" -content-hash = "487051d5612787780f1d23aa1125ea8e8a43cb09024967113bd9fda9b8c42b3d" +content-hash = "ceb6ce6c7083882f1499bd36f5e98f6aa1e0a872d8268ccbda91d67ee81fdd1e" diff --git a/pyproject.toml b/pyproject.toml index 11748d4cc..43c0cf40f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ faster-whisper = "^0.4.1" keyring = "^23.13.1" openai-whisper = "v20230124" platformdirs = "^3.5.3" +dataclasses-json = "^0.5.9" [tool.poetry.group.dev.dependencies] autopep8 = "^1.7.0"