Skip to content

Commit

Permalink
Clean up recording transcriber (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Dec 30, 2022
1 parent 8e643b4 commit 6e89684
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 226 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ omit =
directory = coverage/html

[report]
fail_under = 78
fail_under = 70
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
path: |
~/Library/Caches/Buzz
~/.cache/whisper
~/AppData/Local/Buzz/Cache
key: whisper-models-${{ runner.os }}

- uses: FedericoCarboni/setup-ffmpeg@v1
Expand Down
7 changes: 1 addition & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@ bundle_windows: dist/Buzz
iscc //DAppVersion=${version} installer.iss
cd dist && tar -czf ${windows_zip_path} Buzz/ && cd -

bundle_mac: dist/Buzz.app
make codesign_all_mac
make zip_mac
make notarize_zip
make staple_app_mac
make dmg_mac
bundle_mac: dist/Buzz.app codesign_all_mac zip_mac notarize_zip staple_app_mac dmg_mac

UNAME_S := $(shell uname -s)

Expand Down
140 changes: 53 additions & 87 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from .__version__ import VERSION
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat,
RecordingTranscriber, Task,
Task,
WhisperCppFileTranscriber, WhisperFileTranscriber,
get_default_output_file_path, segments_to_text, write_output, TranscriptionOptions,
FileTranscriberQueueWorker, FileTranscriptionTask)
FileTranscriberQueueWorker, FileTranscriptionTask, RecordingTranscriber)

APP_NAME = 'Buzz'

Expand Down Expand Up @@ -171,6 +171,7 @@ def on_click_record(self):

self.status_changed.emit(current_status)

# TODO: control the text and status from the caller
def on_status_changed(self, status: Status):
self.current_status = status
if status == self.Status.RECORDING:
Expand Down Expand Up @@ -205,40 +206,6 @@ def set_fraction_completed(self, fraction_completed: float) -> None:
f'Downloading model ({fraction_completed :.0%}, {humanize.naturaldelta(time_left)} remaining)')


class RecordingTranscriberObject(QObject):
"""
TranscriberWithSignal exports the text callback from a Transcriber
as a QtSignal to allow updating the UI from a secondary thread.
"""

event_changed = pyqtSignal(RecordingTranscriber.Event)
download_model_progress = pyqtSignal(tuple)
transcriber: RecordingTranscriber

def __init__(self, model_path: str, use_whisper_cpp, language: Optional[str],
task: Task, input_device_index: Optional[int], temperature: Tuple[float, ...], initial_prompt: str,
parent: Optional[QWidget], *args) -> None:
super().__init__(parent, *args)
self.transcriber = RecordingTranscriber(
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
on_download_model_chunk=self.on_download_model_progress, language=language, temperature=temperature,
initial_prompt=initial_prompt,
event_callback=self.event_callback, task=task,
input_device_index=input_device_index)

def start_recording(self):
self.transcriber.start_recording()

def event_callback(self, event: RecordingTranscriber.Event):
self.event_changed.emit(event)

def on_download_model_progress(self, current: int, total: int):
self.download_model_progress.emit((current, total))

def stop_recording(self):
self.transcriber.stop_recording()


class TimerLabel(QLabel):
start_time: Optional[QDateTime]

Expand Down Expand Up @@ -463,9 +430,9 @@ class RecordingTranscriberWidget(QWidget):
transcription_options: TranscriptionOptions
selected_device_id: Optional[int]
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
transcriber: Optional[RecordingTranscriberObject] = None
transcriber: Optional[RecordingTranscriber] = None
model_loader: Optional[ModelLoader] = None
model_loader_thread: Optional[QThread] = None
transcription_thread: Optional[QThread] = None

def __init__(self, parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
super().__init__(parent, flags)
Expand Down Expand Up @@ -531,52 +498,35 @@ def on_status_changed(self, status: RecordButton.Status):
def start_recording(self):
self.record_button.setDisabled(True)

use_whisper_cpp = self.transcription_options.model.model_type == ModelType.WHISPER_CPP and \
self.transcription_options.language is not None

def start_recording_transcription(model_path: str):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.timer_label.start_timer()
self.record_button.setDisabled(False)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog = None

self.transcriber = RecordingTranscriberObject(
model_path=model_path, use_whisper_cpp=use_whisper_cpp,
language=self.transcription_options.language, task=self.transcription_options.task,
input_device_index=self.selected_device_id,
temperature=self.transcription_options.temperature,
initial_prompt=self.transcription_options.initial_prompt,
parent=self
)
self.transcriber.event_changed.connect(
self.on_transcriber_event_changed)
self.transcriber.download_model_progress.connect(
self.on_download_model_progress)

self.transcriber.start_recording()

self.model_loader_thread = QThread()
self.transcription_thread = QThread()

self.model_loader = ModelLoader(model=self.transcription_options.model)
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
transcription_options=self.transcription_options)

self.model_loader.moveToThread(self.model_loader_thread)
self.model_loader.moveToThread(self.transcription_thread)
self.transcriber.moveToThread(self.transcription_thread)

self.model_loader_thread.started.connect(self.model_loader.run)
self.model_loader.finished.connect(self.model_loader_thread.quit)
self.transcription_thread.started.connect(self.model_loader.run)
self.transcription_thread.finished.connect(
self.transcription_thread.deleteLater)

self.model_loader.finished.connect(self.reset_recording_controls)
self.model_loader.finished.connect(self.transcriber.start)
self.model_loader.finished.connect(self.model_loader.deleteLater)
self.model_loader_thread.finished.connect(
self.model_loader_thread.deleteLater)

self.model_loader.progress.connect(
self.on_download_model_progress)

self.model_loader.finished.connect(start_recording_transcription)
self.model_loader.error.connect(self.on_download_model_error)

self.model_loader_thread.start()
self.transcriber.transcription.connect(self.on_next_transcription)

self.transcriber.finished.connect(self.on_transcriber_finished)
self.transcriber.finished.connect(self.transcription_thread.quit)
self.transcriber.finished.connect(self.transcriber.deleteLater)

self.transcription_thread.start()

def on_download_model_progress(self, progress: Tuple[float, float]):
(current_size, total_size) = progress
Expand All @@ -596,19 +546,25 @@ def on_download_model_error(self, error: str):
self.record_button.force_stop()
self.record_button.setDisabled(False)

def on_transcriber_event_changed(self, event: RecordingTranscriber.Event):
if isinstance(event, RecordingTranscriber.TranscribedNextChunkEvent):
text = event.text.strip()
if len(text) > 0:
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
self.text_box.insertPlainText(text + '\n\n')
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
def on_next_transcription(self, text: str):
text = text.strip()
if len(text) > 0:
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
if len(self.text_box.toPlainText()) > 0:
self.text_box.insertPlainText('\n\n')
self.text_box.insertPlainText(text)
self.text_box.moveCursor(QTextCursor.MoveOperation.End)

def stop_recording(self):
if self.transcriber is not None:
self.transcriber.stop_recording()
# Disable record button until the transcription is actually stopped in the background
self.record_button.setDisabled(True)
self.timer_label.stop_timer()

def on_transcriber_finished(self):
self.record_button.setEnabled(True)

def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.stop()
Expand All @@ -621,6 +577,15 @@ def reset_model_download(self):
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None

def reset_recording_controls(self):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.timer_label.start_timer()
self.record_button.setDisabled(False)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None


def get_asset_path(path: str):
if getattr(sys, 'frozen', False):
Expand Down Expand Up @@ -1091,7 +1056,8 @@ class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)

def __init__(self, default_transcription_options: TranscriptionOptions, parent: Optional[QWidget] = None):
def __init__(self, default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
parent: Optional[QWidget] = None):
super().__init__(title='', parent=parent)
self.transcription_options = default_transcription_options

Expand All @@ -1115,10 +1081,10 @@ def __init__(self, default_transcription_options: TranscriptionOptions, parent:
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)

model_type_combo_box = QComboBox(self)
model_type_combo_box.addItems([model_type.value for model_type in ModelType])
model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)
self.model_type_combo_box = QComboBox(self)
self.model_type_combo_box.addItems([model_type.value for model_type in ModelType])
self.model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
self.model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)

self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
Expand All @@ -1129,7 +1095,7 @@ def __init__(self, default_transcription_options: TranscriptionOptions, parent:

self.form_layout.addRow('Task:', self.tasks_combo_box)
self.form_layout.addRow('Language:', self.languages_combo_box)
self.form_layout.addRow('Model:', model_type_combo_box)
self.form_layout.addRow('Model:', self.model_type_combo_box)
self.form_layout.addRow('', self.whisper_model_size_combo_box)
self.form_layout.addRow('', self.hugging_face_search_line_edit)

Expand Down Expand Up @@ -1171,7 +1137,7 @@ def on_model_type_changed(self, text: str):
self.form_layout.setRowVisible(self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE)
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER) or (model_type == ModelType.WHISPER_CPP))
self.transcription_options.model_type = model_type
self.transcription_options.model.model_type = model_type
self.transcription_options_changed.emit(self.transcription_options)

def on_whisper_model_size_changed(self, text: str):
Expand All @@ -1180,7 +1146,7 @@ def on_whisper_model_size_changed(self, text: str):
self.transcription_options_changed.emit(self.transcription_options)

def on_hugging_face_model_changed(self, model: str):
self.transcription_options.hugging_face_model = model
self.transcription_options.model.hugging_face_model_id = model
self.transcription_options_changed.emit(self.transcription_options)


Expand Down
Loading

0 comments on commit 6e89684

Please sign in to comment.