Skip to content

Commit

Permalink
Export transcription before preview (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Apr 9, 2023
1 parent 7474ad4 commit b807302
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 44 deletions.
20 changes: 17 additions & 3 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def __init__(self, file_paths: List[str], openai_access_token: Optional[str] = N
file_transcription_layout = QFormLayout()
file_transcription_layout.addRow('', self.word_level_timings_checkbox)

export_format_layout = QHBoxLayout()
for output_format in OutputFormat:
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}', parent=self)
export_format_checkbox.stateChanged.connect(self.get_on_checkbox_state_changed_callback(output_format))
export_format_layout.addWidget(export_format_checkbox)

file_transcription_layout.addRow('Export:', export_format_layout)

self.run_button = QPushButton(_('Run'), self)
self.run_button.setDefault(True)
self.run_button.clicked.connect(self.on_click_run)
Expand All @@ -260,6 +268,15 @@ def __init__(self, file_paths: List[str], openai_access_token: Optional[str] = N
self.setLayout(layout)
self.setFixedSize(self.sizeHint())

def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
def on_checkbox_state_changed(state: int):
if state == Qt.CheckState.Checked.value:
self.file_transcription_options.output_formats.add(output_format)
elif state == Qt.CheckState.Unchecked.value:
self.file_transcription_options.output_formats.remove(output_format)

return on_checkbox_state_changed

def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
Expand All @@ -286,9 +303,6 @@ def on_click_run(self):
self.model_loader.finished.connect(self.on_model_loaded)
self.model_loader.finished.connect(self.model_loader.deleteLater)

self.transcriber_thread.finished.connect(
self.transcriber_thread.deleteLater)

self.transcriber_thread.start()

def on_model_loaded(self, model_path: str):
Expand Down
95 changes: 54 additions & 41 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from multiprocessing.connection import Connection
from random import randint
from threading import Thread
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, Set
import openai

import ffmpeg
Expand Down Expand Up @@ -73,6 +73,7 @@ class TranscriptionOptions:
@dataclass()
class FileTranscriptionOptions:
file_paths: List[str]
output_formats: Set['OutputFormat'] = field(default_factory=set)


@dataclass
Expand Down Expand Up @@ -233,8 +234,27 @@ def __init__(self, task: FileTranscriptionTask,
super().__init__(parent)
self.transcription_task = task

@abstractmethod
@pyqtSlot()
def run(self):
try:
segments = self.transcribe()
except Exception as exc:
self.error.emit(str(exc))
logging.exception('')
return

self.completed.emit(segments)

for output_format in self.transcription_task.file_transcription_options.output_formats:
default_path = get_default_output_file_path(
task=self.transcription_task.transcription_options.task,
input_file_path=self.transcription_task.file_path,
output_format=output_format)

write_output(path=default_path, segments=segments, output_format=output_format)

@abstractmethod
def transcribe(self) -> List[Segment]:
...

@abstractmethod
Expand Down Expand Up @@ -262,8 +282,7 @@ def __init__(self, task: FileTranscriptionTask,
self.process.readyReadStandardError.connect(self.read_std_err)
self.process.readyReadStandardOutput.connect(self.read_std_out)

@pyqtSlot()
def run(self):
def transcribe(self) -> List[Segment]:
self.running = True
model_path = self.model_path

Expand Down Expand Up @@ -303,8 +322,8 @@ def run(self):
self.progress.emit(
(self.duration_audio_ms, self.duration_audio_ms))

self.completed.emit(self.segments)
self.running = False
return self.segments

def stop(self):
if self.running:
Expand Down Expand Up @@ -358,37 +377,32 @@ def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = No
self.file_path = task.file_path
self.task = task.transcription_options.task

@pyqtSlot()
def run(self):
try:
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
self.task)

wav_file = tempfile.mktemp() + '.wav'
(
ffmpeg.input(self.file_path)
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)

# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
audio_file = open(wav_file, "rb")
openai.api_key = self.transcription_task.transcription_options.openai_access_token
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
language=language)
else:
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
language=language)
def transcribe(self) -> List[Segment]:
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
self.task)

segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
transcript["segments"]]
self.completed.emit(segments)
except Exception as exc:
self.error.emit(str(exc))
logging.exception('')
wav_file = tempfile.mktemp() + '.wav'
(
ffmpeg.input(self.file_path)
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)

# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
audio_file = open(wav_file, "rb")
openai.api_key = self.transcription_task.transcription_options.openai_access_token
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
language=language)
else:
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
language=language)

segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
transcript["segments"]]
return segments

def stop(self):
pass
Expand All @@ -410,8 +424,7 @@ def __init__(self, task: FileTranscriptionTask,
self.started_process = False
self.stopped = False

@pyqtSlot()
def run(self):
def transcribe(self) -> List[Segment]:
time_started = datetime.datetime.now()
logging.debug(
'Starting whisper file transcription, task = %s', self.transcription_task)
Expand Down Expand Up @@ -439,10 +452,10 @@ def run(self):
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments))

if self.current_process.exitcode == 0:
self.completed.emit(self.segments)
else:
self.error.emit('Unknown error')
if self.current_process.exitcode != 0:
raise Exception('Unknown error')

return self.segments

def stop(self):
self.stopped = True
Expand Down

0 comments on commit b807302

Please sign in to comment.