Skip to content

Commit

Permalink
Merge branch 'check-preprocess-func' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Hyoung-Kyu Song committed Feb 21, 2024
2 parents a0f1e11 + 81be537 commit 20c4d96
Show file tree
Hide file tree
Showing 14 changed files with 467 additions and 123 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ __pycache__

results/
temp/
sample*

app.sh
44 changes: 15 additions & 29 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import gradio as gr

from config import hparams as hp
from config import hparams_gradio as hp_gradio
from nota_wav2lip import Wav2LipModelComparisonGradio

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
device = hp_gradio.device
print(f'Using {device} for inference.')

video_label_dict = hp_gradio.sample.video
audio_label_dict = hp_gradio.sample.audio

LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
Expand All @@ -21,46 +23,30 @@
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)

path_inference_sample = "inference-sample.tar.gz"
path_inference_sample = "sample.tar.gz"
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)


VIDEO_LABEL_DICT = {
'v1': "sample/2145_orig.mp4",
'v2': "sample/2942_orig.mp4",
'v3': "sample/4598_orig.mp4",
'v4': "sample/4653_orig.mp4",
'v5': "sample/13692_orig.mp4",
}

AUDIO_LABEL_DICT = {
'a1': "sample/1673_orig.wav",
'a2': "sample/9948_orig.wav",
'a3': "sample/11028_orig.wav",
'a4': "sample/12640_orig.wav",
'a5': "sample/5592_orig.wav",
}

if __name__ == "__main__":

servicer = Wav2LipModelComparisonGradio(
device=device,
video_label_dict=VIDEO_LABEL_DICT,
audio_label_list=AUDIO_LABEL_DICT,
video_label_dict=video_label_dict,
audio_label_list=audio_label_dict,
default_video='v1',
default_audio='a1'
)

for video_name in sorted(VIDEO_LABEL_DICT):
video_path = Path(VIDEO_LABEL_DICT[video_name])
servicer.update_video(video_path.with_suffix(''), video_path.with_suffix('.json'),
video_path=video_path,
for video_name in sorted(video_label_dict):
video_stem = Path(video_label_dict[video_name])
servicer.update_video(video_stem, video_stem.with_suffix('.json'),
video_path=video_stem.with_suffix('.mp4'),
name=video_name)

for audio_name in sorted(AUDIO_LABEL_DICT):
audio_path = Path(AUDIO_LABEL_DICT[audio_name])
for audio_name in sorted(audio_label_dict):
audio_path = Path(audio_label_dict[audio_name])
servicer.update_audio(audio_path, name=audio_name)

with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
Expand All @@ -75,9 +61,9 @@
sample_audio = gr.Audio(interactive=False, label="Input Audio")

# Define radio inputs
video_selection = gr.components.Radio(VIDEO_LABEL_DICT,
video_selection = gr.components.Radio(video_label_dict,
type='value', label="Select an input video:")
audio_selection = gr.components.Radio(AUDIO_LABEL_DICT,
audio_selection = gr.components.Radio(audio_label_dict,
type='value', label="Select an input audio:")
# Define button inputs
with gr.Row(equal_height=True):
Expand Down
2 changes: 2 additions & 0 deletions config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from omegaconf import DictConfig, OmegaConf

hparams: DictConfig = OmegaConf.load("config/nota_wav2lip.yaml")

hparams_gradio: DictConfig = OmegaConf.load("config/gradio.yaml")
14 changes: 14 additions & 0 deletions config/gradio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
device: cpu
sample:
video:
v1: "sample/2145_orig"
v2: "sample/2942_orig"
v3: "sample/4598_orig"
v4: "sample/4653_orig"
v5: "sample/13692_orig"
audio:
a1: "sample/1673_orig.wav"
a2: "sample/9948_orig.wav"
a3: "sample/11028_orig.wav"
a4: "sample/12640_orig.wav"
a5: "sample/5592_orig.wav"
43 changes: 43 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse

from nota_wav2lip.preprocess import get_cropped_face_from_lrs3_label

def parse_args():

parser = argparse.ArgumentParser(description="NotaWav2Lip: Get LRS3 video sample with the label text file")

parser.add_argument(
'-i',
'--input-file',
type=str,
required=True,
help="Path of the label text file downloaded from https://mmai.io/datasets/lip_reading"
)

parser.add_argument(
'-o',
'--output-dir',
type=str,
default="sample_video_lrs3",
help="Output directory to save the result. Defaults: sample_video_lrs3"
)

parser.add_argument(
'--ignore-cache',
action='store_true',
help="Whether to force downloading and resampling video and overwrite pre-existing files"
)

args = parser.parse_args()

return args


if __name__ == '__main__':
args = parse_args()

get_cropped_face_from_lrs3_label(
args.input_file,
video_root_dir=args.output_dir,
ignore_cache = args.ignore_cache
)
2 changes: 2 additions & 0 deletions download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python download.py\
-i 00003.txt
3 changes: 2 additions & 1 deletion nota_wav2lip/gradio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
from pathlib import Path

from nota_wav2lip.demo import Wav2LipModelComparisonDemo

Expand All @@ -18,7 +19,7 @@ def __init__(
if video_label_dict is None:
video_label_dict = {}
super().__init__(device, result_dir)
self._video_label_dict = video_label_dict
self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()}
self._audio_label_dict = audio_label_list
self._default_video = default_video
self._default_audio = default_audio
Expand Down
2 changes: 2 additions & 0 deletions nota_wav2lip/preprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from nota_wav2lip.preprocess.lrs3_download import get_cropped_face_from_lrs3_label
from nota_wav2lip.preprocess.core import get_preprocessed_data
98 changes: 98 additions & 0 deletions nota_wav2lip/preprocess/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
import platform
import subprocess
from pathlib import Path

import cv2
import numpy as np
from tqdm import tqdm
from loguru import logger

import face_detection
from nota_wav2lip.preprocess.ffmpeg import FFMPEG_LOGGING_MODE

detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cpu')
PADDING = [0, 10, 0, 0]


def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
window = boxes[len(boxes) - T:] if i + T > len(boxes) else boxes[i:i + T]
boxes[i] = np.mean(window, axis=0)
return boxes


def face_detect(images, pads, no_smooth=False, batch_size=1):

predictions = []
images_array = [cv2.imread(str(image)) for image in images]
for i in tqdm(range(0, len(images_array), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images_array[i:i + batch_size])))

results = []
pady1, pady2, padx1, padx2 = pads
for rect, image_array in zip(predictions, images_array):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image_array) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')

y1 = max(0, rect[1] - pady1)
y2 = min(image_array.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image_array.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])

boxes = np.array(results)
bbox_format = "(y1, y2, x1, x2)"
if not no_smooth:
boxes = get_smoothened_boxes(boxes, T=5)
outputs = {
'bbox': {str(image_path): tuple(map(int, (y1, y2, x1, x2))) for image_path, (x1, y1, x2, y2) in zip(images, boxes)},
'format': bbox_format
}
return outputs


def save_video_frame(video_path, output_dir=None):
video_path = Path(video_path)
output_dir = output_dir if output_dir is not None else video_path.with_suffix('')
output_dir.mkdir(exist_ok=True)
return subprocess.call(
f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -r 25 -f image2 {output_dir}/%05d.jpg",
shell=platform.system() != 'Windows'
)


def save_audio_file(video_path, output_path=None):
video_path = Path(video_path)
output_path = output_path if output_path is not None else video_path.with_suffix('.wav')
subprocess.call(
f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -vn -acodec pcm_s16le -ar 16000 -ac 1 {output_path}",
shell=platform.system() != 'Windows'
)


def save_bbox_file(video_path, bbox_dict, output_path=None):
video_path = Path(video_path)
output_path = output_path if output_path is not None else video_path.with_suffix('.json')

with open(output_path, 'w') as f:
json.dump(bbox_dict, f, indent=4)

def get_preprocessed_data(video_path: Path):
video_path = Path(video_path)

image_sequence_dir = video_path.with_suffix('')
audio_path = video_path.with_suffix('.wav')
face_bbox_json_path = video_path.with_suffix('.json')

logger.info(f"Save 25 FPS video frames as image files ... will be saved at {video_path}")
save_video_frame(video_path=video_path, output_dir=image_sequence_dir)

logger.info(f"Save the audio as wav file ... will be saved at {audio_path}")
save_audio_file(video_path=video_path, output_path=audio_path) # bonus

# Load images, extract bboxes and save the coords(to directly use as array indicies)
logger.info(f"Extract face boxes and save the coords with json format ... will be saved at {face_bbox_json_path}")
results = face_detect(sorted(image_sequence_dir.glob("*.jpg")), pads=PADDING)
save_bbox_file(video_path, results, output_path=face_bbox_json_path)
5 changes: 5 additions & 0 deletions nota_wav2lip/preprocess/ffmpeg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FFMPEG_LOGGING_MODE = {
'DEBUG': "",
'INFO': "-v quiet -stats",
'ERROR': "-hide_banner -loglevel error",
}
Loading

0 comments on commit 20c4d96

Please sign in to comment.