-
Notifications
You must be signed in to change notification settings - Fork 6
/
app.py
105 lines (87 loc) · 5.47 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import subprocess
from pathlib import Path
import gradio as gr
from config import hparams as hp
from config import hparams_gradio as hp_gradio
from nota_wav2lip import Wav2LipModelComparisonGradio
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = hp_gradio.device
print(f'Using {device} for inference.')
video_label_dict = hp_gradio.sample.video
audio_label_dict = hp_gradio.sample.audio
LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None)
if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)
path_inference_sample = "sample.tar.gz"
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)
if __name__ == "__main__":
servicer = Wav2LipModelComparisonGradio(
device=device,
video_label_dict=video_label_dict,
audio_label_list=audio_label_dict,
default_video='v1',
default_audio='a1'
)
for video_name in sorted(video_label_dict):
video_stem = Path(video_label_dict[video_name])
servicer.update_video(video_stem, video_stem.with_suffix('.json'),
name=video_name)
for audio_name in sorted(audio_label_dict):
audio_path = Path(audio_label_dict[audio_name])
servicer.update_audio(audio_path, name=audio_name)
with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
gr.Markdown(Path('docs/header.md').read_text())
gr.Markdown(Path('docs/description.md').read_text())
with gr.Row():
with gr.Column(variant='panel'):
gr.Markdown('## Select input video and audio', sanitize_html=False)
# Define samples
sample_video = gr.Video(interactive=False, label="Input Video")
sample_audio = gr.Audio(interactive=False, label="Input Audio")
# Define radio inputs
video_selection = gr.components.Radio(video_label_dict,
type='value', label="Select an input video:")
audio_selection = gr.components.Radio(audio_label_dict,
type='value', label="Select an input audio:")
# Define button inputs
with gr.Row(equal_height=True):
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
with gr.Column(variant='panel'):
# Define original model output components
gr.Markdown('## Original Wav2Lip')
original_model_output = gr.Video(label="Original Model", interactive=False)
with gr.Column():
with gr.Row(equal_height=True):
original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
original_model_fps = gr.Textbox(value="", label="FPS")
original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters")
with gr.Column(variant='panel'):
# Define compressed model output components
gr.Markdown('## Compressed Wav2Lip (Ours)')
compressed_model_output = gr.Video(label="Compressed Model", interactive=False)
with gr.Column():
with gr.Row(equal_height=True):
compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
compressed_model_fps = gr.Textbox(value="", label="FPS")
compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters")
# Switch video and audio samples when selecting the raido button
video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video)
audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio)
# Click the generate button for original model
generate_original_button.click(servicer.generate_original_model,
inputs=[video_selection, audio_selection],
outputs=[original_model_output, original_model_inference_time, original_model_fps])
# Click the generate button for compressed model
generate_compressed_button.click(servicer.generate_compressed_model,
inputs=[video_selection, audio_selection],
outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps])
gr.Markdown(Path('docs/footer.md').read_text())
demo.queue().launch()