Skip to content

Commit

Permalink
Add black formatting (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Aug 18, 2023
1 parent f5f77b3 commit c498e60
Show file tree
Hide file tree
Showing 66 changed files with 2,556 additions and 1,384 deletions.
2 changes: 1 addition & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def build(setup_kwargs):
subprocess.call(['make', 'buzz/whisper_cpp.py'])
subprocess.call(["make", "buzz/whisper_cpp.py"])


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion buzz/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@


class Action(QAction):
def setShortcut(self, shortcut: typing.Union['QKeySequence', 'QKeySequence.StandardKey', str, int]) -> None:
def setShortcut(
self,
shortcut: typing.Union["QKeySequence", "QKeySequence.StandardKey", str, int],
) -> None:
super().setShortcut(shortcut)
self.setToolTip(Action.get_tooltip(self))

Expand Down
4 changes: 2 additions & 2 deletions buzz/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@


def get_asset_path(path: str):
if getattr(sys, 'frozen', False):
if getattr(sys, "frozen", False):
return os.path.join(os.path.dirname(sys.executable), path)
return os.path.join(os.path.dirname(__file__), '..', path)
return os.path.join(os.path.dirname(__file__), "..", path)
25 changes: 15 additions & 10 deletions buzz/buzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from appdirs import user_log_dir

# Check for segfaults if not running in frozen mode
if getattr(sys, 'frozen', False) is False:
if getattr(sys, "frozen", False) is False:
faulthandler.enable()

# Sets stderr to no-op TextIO when None (run as Windows GUI).
Expand All @@ -19,30 +19,35 @@

# Adds the current directory to the PATH, so the ffmpeg binary get picked up:
# https://stackoverflow.com/a/44352931/9830227
app_dir = getattr(sys, '_MEIPASS', os.path.dirname(
os.path.abspath(__file__)))
app_dir = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
os.environ["PATH"] += os.pathsep + app_dir

# Add the app directory to the DLL list: https://stackoverflow.com/a/64303856
if platform.system() == 'Windows':
if platform.system() == "Windows":
os.add_dll_directory(app_dir)


def main():
if platform.system() == 'Linux':
multiprocessing.set_start_method('spawn')
if platform.system() == "Linux":
multiprocessing.set_start_method("spawn")

# Fixes opening new window when app has been frozen on Windows:
# https://stackoverflow.com/a/33979091
multiprocessing.freeze_support()

log_dir = user_log_dir(appname='Buzz')
log_dir = user_log_dir(appname="Buzz")
os.makedirs(log_dir, exist_ok=True)

log_format = "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.DEBUG, format=log_format)
log_format = (
"[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
)
logging.basicConfig(
filename=os.path.join(log_dir, "logs.txt"),
level=logging.DEBUG,
format=log_format,
)

if getattr(sys, 'frozen', False) is False:
if getattr(sys, "frozen", False) is False:
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(logging.Formatter(log_format))
Expand Down
18 changes: 11 additions & 7 deletions buzz/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@


class TasksCache:
def __init__(self, cache_dir=user_cache_dir('Buzz')):
def __init__(self, cache_dir=user_cache_dir("Buzz")):
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks')
self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json')
self.pickle_cache_file_path = os.path.join(cache_dir, "tasks")
self.tasks_list_file_path = os.path.join(cache_dir, "tasks.json")

def save(self, tasks: List[FileTranscriptionTask]):
self.save_json_tasks(tasks=tasks)
Expand All @@ -23,16 +23,20 @@ def load(self) -> List[FileTranscriptionTask]:
return self.load_json_tasks()

try:
with open(self.pickle_cache_file_path, 'rb') as file:
with open(self.pickle_cache_file_path, "rb") as file:
return pickle.load(file)
except FileNotFoundError:
return []
except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache
except (
pickle.UnpicklingError,
AttributeError,
ValueError,
): # delete corrupted cache
os.remove(self.pickle_cache_file_path)
return []

def load_json_tasks(self) -> List[FileTranscriptionTask]:
with open(self.tasks_list_file_path, 'r') as file:
with open(self.tasks_list_file_path, "r") as file:
task_ids = json.load(file)

tasks = []
Expand All @@ -57,7 +61,7 @@ def save_json_tasks(self, tasks: List[FileTranscriptionTask]):
file.write(json_str)

def get_task_path(self, task_id: int):
path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json')
path = os.path.join(self.cache_dir, "transcriptions", f"{task_id}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
return path

Expand Down
181 changes: 119 additions & 62 deletions buzz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from buzz.gui import Application
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel
from buzz.store.keyring_store import KeyringStore
from buzz.transcriber import Task, FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, LANGUAGES, \
OutputFormat
from buzz.transcriber import (
Task,
FileTranscriptionTask,
FileTranscriptionOptions,
TranscriptionOptions,
LANGUAGES,
OutputFormat,
)


class CommandLineError(Exception):
Expand All @@ -17,25 +23,25 @@ def __init__(self, message: str):


class CommandLineModelType(enum.Enum):
WHISPER = 'whisper'
WHISPER_CPP = 'whispercpp'
HUGGING_FACE = 'huggingface'
FASTER_WHISPER = 'fasterwhisper'
OPEN_AI_WHISPER_API = 'openaiapi'
WHISPER = "whisper"
WHISPER_CPP = "whispercpp"
HUGGING_FACE = "huggingface"
FASTER_WHISPER = "fasterwhisper"
OPEN_AI_WHISPER_API = "openaiapi"


def parse_command_line(app: Application):
parser = QCommandLineParser()
try:
parse(app, parser)
except CommandLineError as exc:
print(f'Error: {str(exc)}\n', file=sys.stderr)
print(f"Error: {str(exc)}\n", file=sys.stderr)
print(parser.helpText())
sys.exit(1)


def parse(app: Application, parser: QCommandLineParser):
parser.addPositionalArgument('<command>', 'One of the following commands:\n- add')
parser.addPositionalArgument("<command>", "One of the following commands:\n- add")
parser.parse(app.arguments())

args = parser.positionalArguments()
Expand All @@ -50,36 +56,63 @@ def parse(app: Application, parser: QCommandLineParser):
if command == "add":
parser.clearPositionalArguments()

parser.addPositionalArgument('files', 'Input file paths', '[file file file...]')

task_option = QCommandLineOption(['t', 'task'],
f'The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.',
'task',
Task.TRANSCRIBE.value)
model_type_option = QCommandLineOption(['m', 'model-type'],
f'Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.',
'model-type',
CommandLineModelType.WHISPER.value)
model_size_option = QCommandLineOption(['s', 'model-size'],
f'Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.',
'model-size', WhisperModelSize.TINY.value)
hugging_face_model_id_option = QCommandLineOption(['hfid'],
f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"',
'id')
language_option = QCommandLineOption(['l', 'language'],
f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.',
'code', '')
initial_prompt_option = QCommandLineOption(['p', 'prompt'], f'Initial prompt', 'prompt', '')
open_ai_access_token_option = QCommandLineOption('openai-token',
f'OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.',
'token')
srt_option = QCommandLineOption(['srt'], 'Output result in an SRT file.')
vtt_option = QCommandLineOption(['vtt'], 'Output result in a VTT file.')
txt_option = QCommandLineOption('txt', 'Output result in a TXT file.')
parser.addPositionalArgument("files", "Input file paths", "[file file file...]")

task_option = QCommandLineOption(
["t", "task"],
f"The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.",
"task",
Task.TRANSCRIBE.value,
)
model_type_option = QCommandLineOption(
["m", "model-type"],
f"Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.",
"model-type",
CommandLineModelType.WHISPER.value,
)
model_size_option = QCommandLineOption(
["s", "model-size"],
f"Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.",
"model-size",
WhisperModelSize.TINY.value,
)
hugging_face_model_id_option = QCommandLineOption(
["hfid"],
f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"',
"id",
)
language_option = QCommandLineOption(
["l", "language"],
f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.',
"code",
"",
)
initial_prompt_option = QCommandLineOption(
["p", "prompt"], f"Initial prompt", "prompt", ""
)
open_ai_access_token_option = QCommandLineOption(
"openai-token",
f"OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.",
"token",
)
srt_option = QCommandLineOption(["srt"], "Output result in an SRT file.")
vtt_option = QCommandLineOption(["vtt"], "Output result in a VTT file.")
txt_option = QCommandLineOption("txt", "Output result in a TXT file.")

parser.addOptions(
[task_option, model_type_option, model_size_option, hugging_face_model_id_option, language_option,
initial_prompt_option, open_ai_access_token_option, srt_option, vtt_option, txt_option])
[
task_option,
model_type_option,
model_size_option,
hugging_face_model_id_option,
language_option,
initial_prompt_option,
open_ai_access_token_option,
srt_option,
vtt_option,
txt_option,
]
)

parser.addHelpOption()
parser.addVersionOption()
Expand All @@ -89,7 +122,7 @@ def parse(app: Application, parser: QCommandLineParser):
# slice after first argument, the command
file_paths = parser.positionalArguments()[1:]
if len(file_paths) == 0:
raise CommandLineError('No input files')
raise CommandLineError("No input files")

task = parse_enum_option(task_option, parser, Task)

Expand All @@ -98,21 +131,29 @@ def parse(app: Application, parser: QCommandLineParser):

hugging_face_model_id = parser.value(hugging_face_model_id_option)

if hugging_face_model_id == '' and model_type == CommandLineModelType.HUGGING_FACE:
raise CommandLineError('--hfid is required when --model-type is huggingface')

model = TranscriptionModel(model_type=ModelType[model_type.name], whisper_model_size=model_size,
hugging_face_model_id=hugging_face_model_id)
if (
hugging_face_model_id == ""
and model_type == CommandLineModelType.HUGGING_FACE
):
raise CommandLineError(
"--hfid is required when --model-type is huggingface"
)

model = TranscriptionModel(
model_type=ModelType[model_type.name],
whisper_model_size=model_size,
hugging_face_model_id=hugging_face_model_id,
)
model_path = model.get_local_model_path()

if model_path is None:
raise CommandLineError('Model not found')
raise CommandLineError("Model not found")

language = parser.value(language_option)
if language == '':
if language == "":
language = None
elif LANGUAGES.get(language) is None:
raise CommandLineError('Invalid language option')
raise CommandLineError("Invalid language option")

initial_prompt = parser.value(initial_prompt_option)

Expand All @@ -125,33 +166,49 @@ def parse(app: Application, parser: QCommandLineParser):
output_formats.add(OutputFormat.TXT)

openai_access_token = parser.value(open_ai_access_token_option)
if model.model_type == ModelType.OPEN_AI_WHISPER_API and openai_access_token == '':
openai_access_token = KeyringStore().get_password(key=KeyringStore.Key.OPENAI_API_KEY)

if openai_access_token == '':
raise CommandLineError('No OpenAI access token found')

transcription_options = TranscriptionOptions(model=model, task=task, language=language,
initial_prompt=initial_prompt,
openai_access_token=openai_access_token)
file_transcription_options = FileTranscriptionOptions(file_paths=file_paths, output_formats=output_formats)
if (
model.model_type == ModelType.OPEN_AI_WHISPER_API
and openai_access_token == ""
):
openai_access_token = KeyringStore().get_password(
key=KeyringStore.Key.OPENAI_API_KEY
)

if openai_access_token == "":
raise CommandLineError("No OpenAI access token found")

transcription_options = TranscriptionOptions(
model=model,
task=task,
language=language,
initial_prompt=initial_prompt,
openai_access_token=openai_access_token,
)
file_transcription_options = FileTranscriptionOptions(
file_paths=file_paths, output_formats=output_formats
)

for file_path in file_paths:
transcription_task = FileTranscriptionTask(file_path=file_path, model_path=model_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options)
transcription_task = FileTranscriptionTask(
file_path=file_path,
model_path=model_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
)
app.add_task(transcription_task)


T = typing.TypeVar("T", bound=enum.Enum)


def parse_enum_option(option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]) -> T:
def parse_enum_option(
option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]
) -> T:
try:
return enum_class(parser.value(option))
except ValueError:
raise CommandLineError(f'Invalid value for --{option.names()[-1]} option.')
raise CommandLineError(f"Invalid value for --{option.names()[-1]} option.")


def join_values(enum_class: typing.Type[enum.Enum]) -> str:
return ', '.join([v.value for v in enum_class])
return ", ".join([v.value for v in enum_class])
Loading

0 comments on commit c498e60

Please sign in to comment.