Skip to content

Commit

Permalink
ruff fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Hyoung-Kyu Song committed Feb 21, 2024
1 parent 0c39038 commit 9ee4536
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 62 deletions.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)

path_inference_sample = "sample.tar.gz"
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
Expand Down
2 changes: 1 addition & 1 deletion config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

hparams: DictConfig = OmegaConf.load("config/nota_wav2lip.yaml")

hparams_gradio: DictConfig = OmegaConf.load("config/gradio.yaml")
hparams_gradio: DictConfig = OmegaConf.load("config/gradio.yaml")
9 changes: 5 additions & 4 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from nota_wav2lip.preprocess import get_cropped_face_from_lrs3_label


def parse_args():

parser = argparse.ArgumentParser(description="NotaWav2Lip: Get LRS3 video sample with the label text file")
Expand All @@ -13,21 +14,21 @@ def parse_args():
required=True,
help="Path of the label text file downloaded from https://mmai.io/datasets/lip_reading"
)

parser.add_argument(
'-o',
'--output-dir',
type=str,
default="sample_video_lrs3",
help="Output directory to save the result. Defaults: sample_video_lrs3"
)

parser.add_argument(
'--ignore-cache',
action='store_true',
help="Whether to force downloading and resampling video and overwrite pre-existing files"
)

args = parser.parse_args()

return args
Expand All @@ -40,4 +41,4 @@ def parse_args():
args.input_file,
video_root_dir=args.output_dir,
ignore_cache = args.ignore_cache
)
)
19 changes: 9 additions & 10 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
import os
import subprocess
from pathlib import Path
import argparse

from config import hparams as hp
from nota_wav2lip import Wav2LipModelComparisonDemo


LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)

Expand All @@ -26,47 +25,47 @@ def parse_args():
required=True,
help="Path of the audio file"
)

parser.add_argument(
'-v',
'--video-frame-input',
type=str,
required=True,
help="Input directory with face image sequence. We recommend to extract the face image sequence with `preprocess.py`."
)

parser.add_argument(
'-b',
'--bbox-input',
type=str,
help="Path of the file with bbox coordinates. We recommend to extract the json file with `preprocess.py`."
"If None, it pretends that the json file is located at the same directory with face images: {VIDEO_FRAME_INPUT}.with_suffix('.json')."
)

parser.add_argument(
'-m',
'--model',
choices=['wav2lip', 'nota_wav2lip'],
default='nota_wav2ilp',
help="Model for generating talking video. Defaults: wav2lip"
)

parser.add_argument(
'-o',
'--output-dir',
type=str,
default="result",
help="Output directory to save the result. Defaults: result"
)

parser.add_argument(
'-d',
'--device',
choices=['cpu', 'cuda'],
default='cpu',
help="Device setting for model inference. Defaults: cpu"
)

args = parser.parse_args()

return args
Expand All @@ -75,9 +74,9 @@ def parse_args():
args = parse_args()
bbox_input = args.bbox_input if args.bbox_input is not None \
else Path(args.video_frame_input).with_suffix('.json')

servicer = Wav2LipModelComparisonDemo(device=args.device, result_dir=args.output_dir, model_list=args.model)
servicer.update_audio(args.audio_input, name='a0')
servicer.update_video(args.video_frame_input, bbox_input, name='v0')

servicer.save_as_video('a0', 'v0', args.model)
servicer.save_as_video('a0', 'v0', args.model)
4 changes: 2 additions & 2 deletions nota_wav2lip/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from config import hparams as hp
from nota_wav2lip.inference import Wav2LipInferenceImpl
from nota_wav2lip.video import AudioSlicer, VideoSlicer
from nota_wav2lip.util import FFMPEG_LOGGING_MODE
from nota_wav2lip.video import AudioSlicer, VideoSlicer


class Wav2LipModelComparisonDemo:
Expand Down Expand Up @@ -88,4 +88,4 @@ def save_as_video(self, audio_name, video_name, model_type):
video_frames_num = len(self.audio_dict[audio_name])
inference_fps = video_frames_num / inference_time

return output_video_path, inference_time, inference_fps
return output_video_path, inference_time, inference_fps
4 changes: 2 additions & 2 deletions nota_wav2lip/models/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Type, Dict
from typing import Dict, Type

import torch

from nota_wav2lip.models import Wav2LipBase, Wav2Lip, NotaWav2Lip
from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase

MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = {
'wav2lip': Wav2Lip,
Expand Down
2 changes: 1 addition & 1 deletion nota_wav2lip/preprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from nota_wav2lip.preprocess.core import get_preprocessed_data
from nota_wav2lip.preprocess.lrs3_download import get_cropped_face_from_lrs3_label
from nota_wav2lip.preprocess.core import get_preprocessed_data
10 changes: 5 additions & 5 deletions nota_wav2lip/preprocess/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import cv2
import numpy as np
from tqdm import tqdm
from loguru import logger
from tqdm import tqdm

import face_detection
from nota_wav2lip.util import FFMPEG_LOGGING_MODE
Expand Down Expand Up @@ -81,18 +81,18 @@ def save_bbox_file(video_path, bbox_dict, output_path=None):

def get_preprocessed_data(video_path: Path):
video_path = Path(video_path)

image_sequence_dir = video_path.with_suffix('')
audio_path = video_path.with_suffix('.wav')
face_bbox_json_path = video_path.with_suffix('.json')

logger.info(f"Save 25 FPS video frames as image files ... will be saved at {video_path}")
save_video_frame(video_path=video_path, output_dir=image_sequence_dir)

logger.info(f"Save the audio as wav file ... will be saved at {audio_path}")
save_audio_file(video_path=video_path, output_path=audio_path) # bonus

# Load images, extract bboxes and save the coords(to directly use as array indicies)
logger.info(f"Extract face boxes and save the coords with json format ... will be saved at {face_bbox_json_path}")
results = face_detect(sorted(image_sequence_dir.glob("*.jpg")), pads=PADDING)
save_bbox_file(video_path, results, output_path=face_bbox_json_path)
save_bbox_file(video_path, results, output_path=face_bbox_json_path)
65 changes: 33 additions & 32 deletions nota_wav2lip/preprocess/lrs3_download.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from pathlib import Path
from typing import TypedDict, Union, List, Tuple, Dict
import subprocess
import yt_dlp
import platform
import subprocess
from pathlib import Path
from typing import Dict, List, Tuple, TypedDict, Union

import cv2
from tqdm import tqdm
import numpy as np
import yt_dlp
from loguru import logger
from tqdm import tqdm

from nota_wav2lip.util import FFMPEG_LOGGING_MODE


class LabelInfo(TypedDict):
text: str
conf: int
Expand All @@ -19,22 +20,22 @@ class LabelInfo(TypedDict):

def frame_to_time(frame_id: int, fps=25) -> str:
seconds = frame_id / fps

hours = int(seconds // 3600)
seconds -= 3600 * hours

minutes = int(seconds // 60)
seconds -= 60 * minutes

seconds_int = int(seconds)
seconds_milli = int((seconds - int(seconds)) * 1e3)

return f"{hours:02d}:{minutes:02d}:{seconds_int:02d}.{seconds_milli:03d}" # HH:MM:SS.mmm

def save_audio_file(input_path, start_frame_id, to_frame_id, output_path=None):
input_path = Path(input_path)
output_path = output_path if output_path is not None else input_path.with_suffix('.wav')

ss = frame_to_time(start_frame_id)
to = frame_to_time(to_frame_id)
subprocess.call(
Expand All @@ -51,44 +52,44 @@ def merge_video_audio(video_path, audio_path, output_path):
def parse_lrs3_label(label_path) -> LabelInfo:
label_text = Path(label_path).read_text()
label_splitted = label_text.split('\n')

# Label validation
assert label_splitted[0].startswith("Text:")
assert label_splitted[1].startswith("Conf:")
assert label_splitted[2].startswith("Ref:")
assert label_splitted[4].startswith("FRAME")

label_info = LabelInfo(bbox_xywhn={})
label_info['text'] = label_splitted[0][len("Text: "):].strip()
label_info['conf'] = int(label_splitted[1][len("Conf: "):])
label_info['url'] = label_splitted[2][len("Ref: "):].strip()

for label_line in label_splitted[5:]:
bbox_splitted = [x.strip() for x in label_line.split('\t')]
if len(bbox_splitted) != 5:
continue
frame_index = int(bbox_splitted[0])
bbox_xywhn = tuple(map(float, bbox_splitted[1:]))
label_info['bbox_xywhn'][frame_index] = bbox_xywhn

return label_info

def _get_cropped_bbox(bbox_info_xywhn, original_width, original_height):

bbox_info = bbox_info_xywhn
x = bbox_info[0] * original_width
y = bbox_info[1] * original_height
w = bbox_info[2] * original_width
h = bbox_info[3] * original_height

x_min = max(0, int(x - 0.5 * w))
y_min = max(0, int(y))
x_max = min(original_width, int(x + 1.5 * w))
y_max = min(original_height, int(y + 1.5 * h))

cropped_width = x_max - x_min
cropped_height = y_max - y_min

if cropped_height > cropped_width:
offset = cropped_height - cropped_width
offset_low = min(x_min, offset // 2)
Expand All @@ -101,15 +102,15 @@ def _get_cropped_bbox(bbox_info_xywhn, original_width, original_height):
offset_high = min(offset - offset_low, original_width - y_max)
y_min -= offset_low
y_max += offset_high

return x_min, y_min, x_max, y_max

def _get_smoothened_boxes(bbox_dict, bbox_smoothen_window):
boxes = [np.array(bbox_dict[frame_id]) for frame_id in sorted(bbox_dict)]
for i in range(len(boxes)):
window = boxes[len(boxes) - bbox_smoothen_window:] if i + bbox_smoothen_window > len(boxes) else boxes[i:i + bbox_smoothen_window]
boxes[i] = np.mean(window, axis=0)

for idx, frame_id in enumerate(sorted(bbox_dict)):
bbox_dict[frame_id] = (np.rint(boxes[idx])).astype(int).tolist()
return bbox_dict
Expand All @@ -136,18 +137,18 @@ def _get_smoothen_xyxy_bbox(
original_height: int,
bbox_smoothen_window: int = 5
) -> Dict[int, Tuple[float, float, float, float]]:

label_bbox_xyxy: Dict[int, Tuple[float, float, float, float]] = {}
for frame_id in sorted(label_bbox_xywhn):
frame_bbox_xywhn = label_bbox_xywhn[frame_id]
bbox_xyxy = _get_cropped_bbox(frame_bbox_xywhn, original_width, original_height)
label_bbox_xyxy[frame_id] = bbox_xyxy

label_bbox_xyxy = _get_smoothened_boxes(label_bbox_xyxy, bbox_smoothen_window=bbox_smoothen_window)
return label_bbox_xyxy

def get_start_end_frame_id(
label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
) -> Tuple[int, int]:
frame_ids = list(label_bbox_xywhn.keys())
start_frame_id = min(frame_ids)
Expand All @@ -169,31 +170,31 @@ def crop_video_with_bbox(
def frame_generator(cap):
if not cap.isOpened():
raise IOError("Error: Could not open video.")

while True:
ret, frame = cap.read()
if not ret:
break
yield frame

cap.release()

cap = cv2.VideoCapture(str(input_path))
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
label_bbox_xyxy = _get_smoothen_xyxy_bbox(label_bbox_xywhn, original_width, original_height, bbox_smoothen_window=bbox_smoothen_window)

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))

for frame_id, frame in tqdm(enumerate(frame_generator(cap))):
if start_frame_id <= frame_id <= to_frame_id:
x_min, y_min, x_max, y_max = label_bbox_xyxy[frame_id]
x_min, y_min, x_max, y_max = label_bbox_xyxy[frame_id]

frame_cropped = frame[y_min:y_max, x_min:x_max]
frame_cropped = cv2.resize(frame_cropped, (frame_width, frame_height), interpolation=interpolation)
out.write(frame_cropped)

out.release()


Expand All @@ -220,7 +221,7 @@ def get_cropped_face_from_lrs3_label(
output_cropped_audio: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.wav")
output_cropped_video: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.mp4")
output_cropped_with_audio: Path = video_root_dir / output_video.with_name(f"{output_video.stem}-{label_text_path.stem}.mp4").name

if not output_video.exists() or ignore_cache:
youtube_ref = label_info['url']
logger.info(f"Download Youtube video(https://www.youtube.com/watch?v={youtube_ref}) ... will be saved at {output_video}")
Expand Down
Loading

0 comments on commit 9ee4536

Please sign in to comment.