Skip to content

Commit

Permalink
fixed some webui bugs and release new features
Browse files Browse the repository at this point in the history
  • Loading branch information
shadowcun committed Jun 12, 2023
1 parent a810cbe commit f141edb
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 99 deletions.
123 changes: 55 additions & 68 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,26 @@
except:
in_webui = False

# mimetypes.init()
# mimetypes.add_type('application/javascript', '.js')

# script_path = os.path.dirname(os.path.realpath(__file__))

# def webpath(fn):
# if fn.startswith(script_path):
# web_path = os.path.relpath(fn, script_path).replace('\\', '/')
# else:
# web_path = os.path.abspath(fn)

# return f'file={web_path}?{os.path.getmtime(fn)}'

# def javascript_html():
# # Ensure localization is in `window` before scripts
# # head = f'<script type="text/javascript">{localization.localization_js(opts.localization)}</script>\n'
# head = 'somehead'

# script_js = os.path.join(script_path, "assets", "script.js")
# head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'

# script_js = os.path.join(script_path, "assets", "aspectRatioOverlay.js")
# head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'

# return head

# def resize_from_to_html(width, height, scale_by):
# target_width = int(width * scale_by)
# target_height = int(height * scale_by)

# if not target_width or not target_height:
# return "no image selected"

# return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"

# def get_source_image(image):
# return image

# def reload_javascript():
# js = javascript_html()

# def template_response(*args, **kwargs):
# res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
# res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
# res.init_headers()
# return res

# gradio.routes.templates.TemplateResponse = template_response

# if not hasattr(shared, 'GradioTemplateResponseOriginal'):
# shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
def toggle_audio_file(choice):
if choice == False:
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)

def ref_video_fn(path_of_ref_video):
if path_of_ref_video is not None:
return gr.update(value=True)
else:
return gr.update(value=False)


def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):

sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)

with gr.Blocks(analytics_enabled=False) as sadtalker_interface:

gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
Expand All @@ -75,24 +37,43 @@ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warp
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.TabItem('Source image'):
with gr.Row():
source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)


with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload OR TTS'):
with gr.Column(variant='panel'):
with gr.TabItem('Driving Methods'):
gr.Markdown("Possible driving combinations: <br> 1. Audio only 2. Audio/IDLE Mode + Ref Video(pose, blink, pose+blink) 3. IDLE Mode only 4. Ref Video only (all) ")

with gr.Row():
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")

if sys.platform != 'win32' and not in_webui:
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
with gr.Column(variant='panel'):
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])

driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)

with gr.Column():
use_idle_mode = gr.Checkbox(label="Use Idle Animation")
length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.")
use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo

if sys.platform != 'win32' and not in_webui:
with gr.Accordion('Generate Audio From TTS', open=False):
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
with gr.Column(variant='panel'):
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])

with gr.Row():
ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref").style(width=512)

with gr.Column():
use_ref_video = gr.Checkbox(label="Use Reference Video")
ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")

ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo


with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
Expand All @@ -101,21 +82,21 @@ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warp
# width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
# height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
with gr.Row():
pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
blink_every = gr.Checkbox(label="use eye blink", value=True)

with gr.Row():
size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")

with gr.Row():
is_still_mode = gr.Checkbox(label="Still Mode (fewer hand motion, works with preprocess `full`)")
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=1)
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")

submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')


with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)

Expand All @@ -129,7 +110,13 @@ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warp
batch_size,
size_of_image,
pose_style,
exp_weight
exp_weight,
use_ref_video,
ref_video,
ref_info,
use_idle_mode,
length_of_audio,
blink_every
],
outputs=[gen_video]
)
Expand Down
1 change: 0 additions & 1 deletion req.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
dlib-bin
av
safetensors
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@ basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
dlib-bin
av
safetensors
3 changes: 1 addition & 2 deletions requirements3d.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.2.5
facexlib==0.3.0
trimesh==3.9.20
dlib-bin
gradio
gfpgan
safetensors
2 changes: 2 additions & 0 deletions src/facerender/modules/make_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def make_animation(source_image, source_semantics, target_semantics,
kp_source = keypoint_transformation(kp_canonical, he_source)

for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
# still check the dimension
# print(target_semantics.shape, source_semantics.shape)
target_semantics_frame = target_semantics[:, frame_idx]
he_driving = mapping(target_semantics_frame)
if yaw_c_seq is not None:
Expand Down
45 changes: 25 additions & 20 deletions src/generate_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,35 @@ def generate_blink_seq_randomly(num_frames):
break
return ratio

def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False):
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):

syncnet_mel_step_size = 16
fps = 25

pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]

wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy() # nframes 80
indiv_mels = []

for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16

if idlemode:
num_frames = int(length_of_audio * 25)
indiv_mels = np.zeros((num_frames, 80, 16))
else:
wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy() # nframes 80
indiv_mels = []

for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16

ratio = generate_blink_seq_randomly(num_frames) # T
source_semantics_path = first_coeff_path
Expand All @@ -96,10 +101,10 @@ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, stil

indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16

if still:
ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.) # bs T
if use_blink:
ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T
else:
ratio = torch.FloatTensor(ratio).unsqueeze(0)
ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
# bs T
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70

Expand Down
6 changes: 4 additions & 2 deletions src/generate_facerender_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,22 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
data['source_image'] = source_image_ts

source_semantics_dict = scio.loadmat(first_coeff_path)
generated_dict = scio.loadmat(coeff_path)

if 'full' not in preprocess.lower():
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
generated_3dmm = generated_dict['coeff_3dmm'][:,:70]

else:
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
generated_3dmm = generated_dict['coeff_3dmm'][:,:70]

source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
source_semantics_ts = source_semantics_ts.repeat(batch_size, 1, 1)
data['source_semantics'] = source_semantics_ts

# target
generated_dict = scio.loadmat(coeff_path)
generated_3dmm = generated_dict['coeff_3dmm']
generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale

if 'full' in preprocess.lower():
Expand Down
64 changes: 59 additions & 5 deletions src/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy


def test(self, source_image, driven_audio, preprocess='crop',
still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style = 0, exp_scale=1.0, result_dir='./results/'):
still_mode=False, use_enhancer=False, batch_size=1, size=256,
pose_style = 0, exp_scale=1.0,
use_ref_video = False,
ref_video = None,
ref_info = None,
use_idle_mode = False,
length_of_audio = 0, use_blink=True,
result_dir='./results/'):

self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
print(self.sadtalker_paths)
Expand All @@ -54,7 +61,7 @@ def test(self, source_image, driven_audio, preprocess='crop',
pic_path = os.path.join(input_dir, os.path.basename(source_image))
shutil.move(source_image, input_dir)

if os.path.isfile(driven_audio):
if driven_audio is not None and os.path.isfile(driven_audio):
audio_path = os.path.join(input_dir, os.path.basename(driven_audio))

#### mp3 to wav
Expand All @@ -63,9 +70,23 @@ def test(self, source_image, driven_audio, preprocess='crop',
audio_path = audio_path.replace('.mp3', '.wav')
else:
shutil.move(driven_audio, input_dir)

elif use_idle_mode:
audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
from pydub import AudioSegment
one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
one_sec_segment.export(audio_path, format="wav")
else:
raise AttributeError("error audio")
print(use_ref_video, ref_info)
assert use_ref_video == True and ref_info == 'all'

if use_ref_video and ref_info == 'all': # full ref mode
ref_video_videoname = os.path.basename(ref_video)
audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
print('new audiopath:',audio_path)
# if ref_video contains audio, set the audio from ref_video.
cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
os.system(cmd)

os.makedirs(save_dir, exist_ok=True)

Expand All @@ -77,9 +98,42 @@ def test(self, source_image, driven_audio, preprocess='crop',
if first_coeff_path is None:
raise AttributeError("No face is detected")

if use_ref_video:
print('using ref video for genreation')
ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
os.makedirs(ref_video_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
else:
ref_video_coeff_path = None

if use_ref_video:
if ref_info == 'pose':
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = None
elif ref_info == 'blink':
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = ref_video_coeff_path
elif ref_info == 'pose+blink':
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = ref_video_coeff_path
elif ref_info == 'all':
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None
else:
raise('error in refinfo')
else:
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None

#audio2ceoff
batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=None, still=still_mode) # longer audio?
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
if use_ref_video and ref_info == 'all':
coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
else:
batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)

#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale)
return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
Expand Down

0 comments on commit f141edb

Please sign in to comment.