forked from fishaudio/Bert-VITS2
-
Notifications
You must be signed in to change notification settings - Fork 110
/
Copy pathapp.py
64 lines (52 loc) · 2.14 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
from pathlib import Path
import gradio as gr
import torch
from config import get_path_config
from gradio_tabs.dataset import create_dataset_app
from gradio_tabs.inference import create_inference_app
from gradio_tabs.merge import create_merge_app
from gradio_tabs.style_vectors import create_style_vectors_app
from gradio_tabs.train import create_train_app
from style_bert_vits2.constants import GRADIO_THEME, VERSION
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker
from style_bert_vits2.nlp.japanese.user_dict import update_dict
from style_bert_vits2.tts_model import TTSModelHolder
# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化
pyopenjtalk_worker.initialize_worker()
# dict_data/ 以下の辞書データを pyopenjtalk に適用
update_dict()
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--no_autolaunch", action="store_true")
parser.add_argument("--share", action="store_true")
# parser.add_argument("--skip_default_models", action="store_true")
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
# if not args.skip_default_models:
# download_default_models()
path_config = get_path_config()
model_holder = TTSModelHolder(Path(path_config.assets_root), device)
with gr.Blocks(theme=GRADIO_THEME) as app:
gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {VERSION})")
with gr.Tabs():
with gr.Tab("音声合成"):
create_inference_app(model_holder=model_holder)
with gr.Tab("データセット作成"):
create_dataset_app()
with gr.Tab("学習"):
create_train_app()
with gr.Tab("スタイル作成"):
create_style_vectors_app()
with gr.Tab("マージ"):
create_merge_app(model_holder=model_holder)
app.launch(
server_name=args.host,
server_port=args.port,
inbrowser=not args.no_autolaunch,
share=args.share,
)