diff --git a/examples/InfiniteTalk/opea_text2video/Dockerfile b/examples/InfiniteTalk/opea_text2video/Dockerfile new file mode 100755 index 0000000000..10929f72b2 --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/Dockerfile @@ -0,0 +1,29 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# HABANA environment +FROM vault.habana.ai/gaudi-docker/1.21.4/ubuntu22.04/habanalabs/pytorch-installer-2.6.0 AS hpu +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + mkdir -p /home/user/video && \ + chown -R user /home/user/ + +COPY src /home/user/text2video + +RUN apt update && apt install -y ffmpeg +RUN cd /home/user && git clone https://github.com/HabanaAI/optimum-habana-fork.git -b aice/v1.22.0 +RUN cd /home/user && git clone https://github.com/opea-project/GenAIComps.git +RUN chown -R user /home/user/text2video + +# Set environment variables +ENV LANG=en_US.UTF-8 +ENV PYTHONPATH=/home/user/text2video:/usr/lib/habanalabs/:/home/user/optimum-habana-fork/examples/InfiniteTalk/infinitetalk/:/home/user/GenAIComps/ + +ARG uvpip='uv pip install --system --no-cache-dir' +RUN pip install --no-cache-dir --upgrade pip setuptools uv && \ + $uvpip -r /home/user/text2video/requirements.txt && \ + $uvpip -r /home/user/text2video/requirements-infinitetalk.txt && \ + $uvpip 'git+https://github.com/HabanaAI/optimum-habana-fork.git@aice/v1.22.0' + +USER user +WORKDIR /home/user/text2video diff --git a/examples/InfiniteTalk/opea_text2video/README.md b/examples/InfiniteTalk/opea_text2video/README.md new file mode 100644 index 0000000000..8cac5da27d --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/README.md @@ -0,0 +1,362 @@ +# Text2Video 服务 + +OPEA Text-to-Video (文本到视频) 微服务,用于根据文本提示和音频输入生成视频。 + +## 概述 + +本项目提供 OPEA Text2Video 组件的独立部署方案。它通过 REST API 提供先进的视频生成能力,并针对英特尔 ® Habana® Gaudi® 加速器进行了优化。 + +## 主要特性 + +- **文生视频**: 支持文本提示和音频条件输入,生成动态视频。 +- **任务队列管理**: 高效管理并发请求,确保服务稳定性。 +- **HPU/Gaudi 优化**: 充分利用 Habana Gaudi 加速器的高性能计算能力。 +- **RESTful API**: 提供标准化 RESTful 接口及 OpenAPI 类似接口。 +- **容器化部署**: 支持 Docker 快速部署和环境隔离。 + +## 安装部署 + +### 1. 构建 Docker 镜像 + +在构建镜像前,请根据您的网络环境设置代理(如果需要)。 + +```bash +# 设置代理(可选) +export http_proxy="http://your-proxy-address:port" +export https_proxy="http://your-proxy-address:port" + +# 执行构建命令 +docker build -t text2video-gaudi:latest \ + --build-arg https_proxy=$https_proxy \ + --build-arg http_proxy=$http_proxy \ + -f Dockerfile . +``` + +### 2. 创建 Docker 容器实例 + +此命令将创建一个配置好 Gaudi 环境的容器实例。 + +```bash +# 环境变量配置 +NAME="video-gaudi-service" +IMG_NAME="text2video-gaudi:latest" +HTTP_PROXY="http://your-proxy-address:port" +HTTPS_PROXY="http://your-proxy-address:port" +HF_ENDPOINT="https://hf-mirror.com" # Hugging Face 模型下载镜像地址 + +# Gaudi 相关运行参数 +RUN_ARG="-e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add SYS_PTRACE --cap-add=sys_nice --cap-add=CAP_IPC_LOCK --ulimit memlock=-1:-1 --ipc=host --net=host --device=/dev:/dev -v /dev:/dev -v /sys/kernel/debug:/sys/kernel/debug" + +# 创建并启动容器 +echo "正在创建 Docker 实例: ${NAME}" +docker run -it --name ${NAME} \ + -p 9389:9389 \ + -v /mnt/disk2/HF_models:/hf \ + -e http_proxy=$HTTP_PROXY \ + -e https_proxy=$HTTPS_PROXY \ + -e HF_ENDPOINT=$HF_ENDPOINT \ + ${RUN_ARG} \ + --user root \ + --workdir=/home/user/text2video \ + ${IMG_NAME} /bin/bash +``` + +## 服务使用 + +### 1. 启动 Web 服务 + +在容器内部执行以下命令,启动 API 服务。 + +```bash +# 进入容器 +docker exec -it video-gaudi-service bash + +# 切换到工作目录并启动服务 +cd /home/user/text2video +python3 web_service.py > web.log 2>&1 & +``` + +### 2. 启动 Gaudi 作业服务 + +此服务负责处理视频生成任务。 + +下面的示例使用 4 卡来启动 Gaudi 作业服务。 + +```bash +PT_HPU_SYNC_LAUNCH=1 PT_HPU_GPU_MIGRATION=1 PT_HPU_LAZY_MODE=1 torchrun --nproc_per_node=4 --master-port 29502 --standalone job_service.py \ + --size infinitetalk-480 \ + --mode streaming \ + --motion_frame 9 \ + --offload_model False \ + --ulysses_size=4 > job.log 2>&1 & +``` + +--- + +## API 端点 + +> **内部网络使用说明:** 在公司内部网络调用此 API 时,请确保已正确设置 `no_proxy` 环境变量,以避免代理问题: +> +> ```bash +> export no_proxy="localhost,10.239.15.41,127.0.0.1,::1" +> ``` + +### 1. 创建视频 + +此端点基于文本提示、参考图像/视频和音频文件的组合来生成一个新视频。 + +- **端点:** `POST /v1/videos` +- **内容类型:** `multipart/form-data` + +#### 请求参数 + +| 参数 | 类型 | 必需 | 默认值 | 描述 | +| ------------------- | ------------- | :----: | ------- | --------------------------------------------------------------------- | +| `input_reference` | 文件 | **是** | N/A | 源参考图像或视频文件。 | +| `audio` / `audio[]` | 文件 / [文件] | **是** | `[]` | 用于生成的单个或多个音频文件。 | +| `prompt` | 字符串 | 否 | `None` | 用于指导视频生成的描述性文本提示。 | +| `audio_guide_scale` | 浮点数 | 否 | `5.0` | 控制音频对生成过程的影响程度。 | +| `audio_type` | 字符串 | 否 | `"add"` | 定义多个音频文件的处理方式。有效选项:`add` (叠加) 或 `para` (并行)。 | +| `fps` | 整数 | 否 | `24` | 生成视频的帧率(每秒帧数)。 | +| `shift` | 浮点数 | 否 | `5.0` | 一个特定的生成参数,用于控制视频动态。 | +| `steps` | 整数 | 否 | `50` | 推理步数。 | +| `seed` | 整数 | 否 | `42` | 用于可复现结果的随机种子。 | +| `guide_scale` | 浮点数 | 否 | `5.0` | 控制生成视频与提示的贴合程度。 | +| `logo_video` | 布尔值 | 否 | `False` | 如果为 `True`,将自动附加一个 Intel 标志视频。 | +| `seconds` | 整数 | 否 | `20` | **(注意)** 期望的视频长度(秒)。目前,实际长度由音频输入决定。 | + +#### 响应体 + +成功的请求会将作业加入队列,并返回一个具有以下结构的 JSON 对象: + +| 参数 | 类型 | 描述 | +| ---------------- | ------ | -------------------------------------------------------------------------- | +| `id` | 字符串 | 视频生成作业的唯一标识符。 | +| `object` | 字符串 | 对象类型,始终为 `"video"`。 | +| `model` | 字符串 | 用于生成的模型 (例如, `"InfiniteTalk"`)。 | +| `status` | 字符串 | 作业的当前状态 (`queued`, `processing`, `completed`, `deleted`, `error`)。 | +| `progress` | 整数 | 任务的大致完成百分比。 | +| `created_at` | 整数 | 作业创建时的 Unix 时间戳(秒)。 | +| `estimated_time` | 整数 | 预计完成时间(分钟)。 | +| `queue_length` | 整数 | 在此作业之前排队的作业数量。 | +| `duration` | 整数 | 生成视频所花费的时间(秒)。 | +| `seconds` | 整数 | 生成视频的最终时长(秒)。 | +| `error` | 字符串 | 解释失败原因的消息(如果有)。 | + +
+响应示例 + +```json +{ + "id": "video_1766454718_2556", + "object": "video", + "model": "InfinteTalk", + "status": "queued", + "progress": 0, + "created_at": 1766454718, + "estimated_time": 14, + "queue_length": 1, + "duration": 0, + "seconds": "3", + "error": "" +} +``` + +
+ +--- + +### 2. 获取视频状态 + +检索视频生成作业的当前状态和进度。 + +- **端点:** `GET /v1/videos/{video_id}` + +#### 响应体 + +返回与创建端点相同的 JSON 对象,但 `status` 和 `progress` 字段会更新。当 `status` 为 `"completed"` 时,表示视频已准备就绪。 + +
+完成状态响应示例 + +```json +{ + "id": "video_1766454718_2556", + "object": "video", + "model": "InfinteTalk", + "status": "completed", + "progress": 100, + "created_at": 1766454718, + "estimated_time": 0, + "queue_length": 0, + "duration": 430, + "seconds": "3", + "error": "" +} +``` + +
+ +--- + +### 3. 获取视频内容 + +下载生成的视频文件。 + +- **端点:** `GET /v1/videos/{video_id}/content` + +此端点返回原始视频数据 (MIME 类型 `video/mp4`),可以直接保存到文件中。 + +--- + +### 4. 删除视频 + +从服务器删除视频生成作业及其关联文件。 + +- **端点:** `DELETE /v1/videos/{video_id}` + +> **注意:** 状态为 `processing` (处理中) 的作业无法被删除。 + +#### 响应体 + +成功删除后,服务器会返回一个包含作业最终元数据和 `"deleted"` 状态的 JSON 对象。 + +
+删除状态响应示例 + +```json +{ + "id": "video_1721105333_1234", + "model": "InfinteTalk", + "status": "deleted", + "progress": 0, + "created_at": 1721105333, + "seconds": "15", + "duration": 14, + "estimated_time": 0, + "queue_length": 0, + "error": "" +} +``` + +
+ +--- + +## API 使用示例 + +### 示例 1: 提示 + 音频 + 图像 + +```bash +curl -X POST "http://10.239.15.41:9389/v1/videos" \ + -H "Content-Type: multipart/form-data" \ + -F "prompt=一个女人在录音棚里对着专业麦克风热情地唱歌..." \ + -F "input_reference=@examples/single/ref_image.png" \ + -F "audio=@examples/single/1.wav" +``` + +### 示例 2: 提示 + 音频 + 视频 + +```bash +curl -X POST "http://10.239.15.41:9389/v1/videos" \ + -H "Content-Type: multipart/form-data" \ + -F "prompt=一个男人在说话" \ + -F "input_reference=@examples/single/ref_video.mp4" \ + -F "audio=@examples/single/1.wav" +``` + +### 示例 3: 提示 + 多个音频 + 图像 (并行模式) + +```bash +curl -X POST "http://10.239.15.41:9389/v1/videos" \ + -H "Content-Type: multipart/form-data" \ + -F "prompt=在一个轻松、亲密的环境中,一个男人和一个女人正在进行一场真诚的对话..." \ + -F "input_reference=@examples/multi/ref_img.png" \ + -F "audio_type=para" \ + -F "audio[]=@examples/multi/1-man.WAV" \ + -F "audio[]=@examples/multi/1-woman.WAV" +``` + +### 示例 4: 检查状态并下载 + +```bash +# 1. 使用创建请求返回的 ID 检查作业状态 +curl http://10.239.15.41:9389/v1/videos/video_1765526104_4523 + +# 2. 当状态变为 "completed" 后,下载视频 +curl http://10.239.15.41:9389/v1/videos/video_1765526104_4523/content -o video.mp4 +``` + +### 示例 5: 删除视频 + +```bash +curl -X DELETE http://10.239.15.41:9389/v1/videos/video_1765526104_4523 +``` + +--- + +## 错误处理 + +API 返回标准的 HTTP 状态码和一致的 JSON 错误体,以帮助诊断问题。 + +### 通用错误格式 + +```json +{ + "error": { + "message": "错误的详细描述。", + "code": "HTTP 状态码字符串 (例如, '400')。" + } +} +``` + +### 常见错误 + +| 状态码 | 错误类型 | 常见触发原因 | +| :----- | -------------------------- | ------------------------------------ | +| `400` | **错误请求 (Bad Request)** | 缺少必需参数、参数值无效或文件损坏。 | +| `404` | **未找到 (Not Found)** | 请求的 `video_id` 不存在。 | +| `500` | **内部服务器错误** | 处理过程中发生意外的服务器端故障。 | + +
+查看错误响应示例 + +**400 错误请求示例:** + +```json +{ + "error": { + "message": "无效的参数类型:'seconds' 参数必须大于 0。", + "code": "400" + } +} +``` + +**404 未找到示例:** + +```json +{ + "error": { + "message": "ID 为 video_1721105333_1234 的视频未找到。", + "code": "404" + } +} +``` + +**500 内部服务器错误示例:** + +```json +{ + "error": { + "message": "内部服务器错误:组件加载器未初始化。", + "code": "500" + } +} +``` + +
+ +## 相关链接 + +- [OPEA 项目](https://github.com/opea-project/GenAIComps) +- [项目文档](https://opea-project.github.io/) diff --git a/examples/InfiniteTalk/opea_text2video/src/__init__.py b/examples/InfiniteTalk/opea_text2video/src/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/examples/InfiniteTalk/opea_text2video/src/component.py b/examples/InfiniteTalk/opea_text2video/src/component.py new file mode 100755 index 0000000000..6739bfd978 --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/src/component.py @@ -0,0 +1,198 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +import random +import json +import fcntl +import librosa + +from enum import Enum +from pydantic import BaseModel +from typing import Optional, List, Union +from fastapi import Form, File, UploadFile +from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry + +logger = CustomLogger("opea_Text2Video") + + +class ServiceType(Enum): + """The enum of a service type.""" + TEXT2VIDEO = 1 + + +class Text2VideoInput: + def __init__( + self, + prompt: str = Form(None), + input_reference: Optional[UploadFile] = File(None), + audio: Union[UploadFile, List[UploadFile]] = File(None), + audio_guide_scale: Optional[float] = Form(5.0), + audio_type: Optional[str] = Form("add"), + model: Optional[str] = Form(None), + seconds: Optional[int] = Form(4), + fps: Optional[int] = Form(25), + shift: Optional[float] = Form(5.0), + steps: Optional[int] = Form(40), + seed: Optional[int] = Form(42), + guide_scale: Optional[float] = Form(5.0), + size: Optional[str] = Form("720x1280"), + logo_video: Optional[bool] = Form("False") + ): + self.prompt = prompt + self.input_reference = input_reference + self.audio = audio + self.audio_guide_scale = audio_guide_scale + self.audio_type = audio_type + self.model = model + self.seconds = seconds + self.fps = fps + self.shift = shift + self.steps = steps + self.seed = seed + self.guide_scale = guide_scale + self.size = size + self.logo_video = logo_video + + +class Text2VideoOutput(BaseModel): + id: str + object: str = "video" + model: str = None + status: str + progress: int + created_at: int + estimated_time: int + queue_length: int + duration: int + seconds: str + error: str = "" + + +def get_audio_duration(file_path): + return librosa.get_duration(path=file_path) + + +@OpeaComponentRegistry.register("OPEA_TEXT2VIDEO") +class OpeaText2Video(OpeaComponent): + """A specialized Text2Video component for video generation.""" + + def __init__( + self, + name: str, + description: str, + config: dict = None, + video_dir: str = "/home/user/video" + ): + """ + Initializes the OpeaText2Video component. + + Args: + name (str): The name of the component. + description (str): A description of the component. + config (dict, optional): Configuration dictionary. Defaults to None. + """ + super().__init__(name, ServiceType.TEXT2VIDEO.name.lower(), description, config) + self.video_dir = video_dir + os.makedirs(self.video_dir, exist_ok=True) + if not self.check_health(): + logger.error("OpeaText2Video health check failed upon initialization.") + + async def invoke(self, input: Text2VideoInput) -> Text2VideoOutput: + """ + Generates a video based on the provided text prompt. + + Args: + input (Text2VideoInput): The input data containing the prompt and other parameters. + """ + created = time.time() + job_id = f"video_{int(created)}_{random.randint(1000, 9999)}" + job_dir = os.path.join(self.video_dir, job_id) + os.makedirs(job_dir, exist_ok=True) + input_json = os.path.join(job_dir, "input.json") + input_json_content = {} + + if input.prompt and len(input.prompt) > 0: + input_json_content["prompt"] = input.prompt + + if input.audio_type: + input_json_content["audio_type"] = input.audio_type + + if input.input_reference: + image_file = os.path.join(job_dir, input.input_reference.filename) + input_json_content["cond_video"] = image_file + contents = await input.input_reference.read() + with open(image_file, "wb") as img_f: + img_f.write(contents) + + audio_durations = [] + if input.audio and isinstance(input.audio, list): + audio = {} + for idx, audio_file in enumerate(input.audio): + audio_path = os.path.join(job_dir, audio_file.filename) + audio[f"person{idx+1}"] = audio_path + contents = await audio_file.read() + with open(audio_path, "wb") as audio_f: + audio_f.write(contents) + audio_durations.append(get_audio_duration(audio_path)) + + input_json_content["cond_audio"] = audio + + with open(input_json, "w") as f: + json.dump(input_json_content, f, indent=4) + + seconds = int(min(audio_durations)) if audio_durations else 20 + logger.info(f"set audio seconds to {seconds} and audio durations for job {job_id}: {audio_durations}") + if seconds <= 0: + raise ValueError("The provided audio files have non-positive durations.") + + status = "queued" + quality = "standard" + generate_duration = 0 + start_time = 0 + end_time = 0 + job = [ + job_id, + status, + int(created), + input.prompt, + seconds, + input.size, + quality, + input.fps, + input.shift, + input.steps, + input.guide_scale, + input.audio_guide_scale, + input.seed, + input.logo_video, + generate_duration, + start_time, + end_time, + "" + ] + + sep = os.getenv("SEP", "##$##") + line = sep.join(map(str, job)) + "\n" + job_file = os.path.join(self.video_dir, "job.txt") + with open(job_file, "a") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + f.write(line) + f.flush() + os.fsync(f.fileno()) + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + logger.info(f"Job {job_id} queued with prompt: {input.prompt}") + return job_id + + def check_health(self) -> bool: + """ + Checks if the model pipeline is initialized. + + Returns: + bool: True if the pipeline is ready, False otherwise. + """ + return True diff --git a/examples/InfiniteTalk/opea_text2video/src/job_service.py b/examples/InfiniteTalk/opea_text2video/src/job_service.py new file mode 100755 index 0000000000..f3e7f72e66 --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/src/job_service.py @@ -0,0 +1,673 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import re +import wan +import subprocess +import torch +import random +import argparse +import logging +import os +import sys +import json +import time +import warnings +import librosa +import numpy as np +import torch.distributed as dist +import soundfile as sf +import pyloudnorm as pyln +import fcntl +import imageio + +from tqdm import tqdm +from einops import rearrange +from kokoro import KPipeline +from src.audio_analysis.wav2vec2 import Wav2Vec2Model +from transformers import Wav2Vec2FeatureExtractor +from wan.utils.segvideo import shot_detect +from wan.utils.multitalk_utils import save_video_ffmpeg, cache_video +from wan.utils.utils import str2bool, is_video, split_wav_librosa +from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS + + +warnings.filterwarnings("ignore") + + +def save_video_with_logo(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False): + + def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer( + save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params + ) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + save_path_tmp = save_path + "-temp.mp4" + + if high_quality_save: + cache_video( + tensor=gen_video_samples.unsqueeze(0), + save_file=save_path_tmp, + fps=fps, + nrow=1, + normalize=True, + value_range=(-1, 1) + ) + else: + video_audio = (gen_video_samples+1)/2 # C T H W + video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() + video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255] + save_video(video_audio, save_path_tmp, fps=fps, quality=quality) + + # crop audio according to video length + C, T, H, W = gen_video_samples.shape + duration = T / fps + save_path_crop_audio = save_path + "-cropaudio.wav" + final_command = [ + "ffmpeg", + "-i", + vocal_audio_list[0], + "-t", + f'{duration}', + save_path_crop_audio, + ] + subprocess.run(final_command, check=True) + logo_w = 1280 + logo_h = 720 + if W / H > logo_w / logo_h: + ratio = H / logo_h + resized_logo_h = H + resized_logo_w = logo_w * ratio + pad_w = (W - resized_logo_w) / 2 + pad_h = 0 + else: + ratio = W / logo_w + resized_logo_w = W + resized_logo_h = logo_h * ratio + pad_h = (H - resized_logo_h) / 2 + pad_w = 0 + save_path = save_path + ".mp4" + if high_quality_save: + final_command = [ + "ffmpeg", + "-y", + "-i", save_path_tmp, + "-i", save_path_crop_audio, + "-c:v", "libx264", + "-crf", "0", + "-preset", "veryslow", + "-c:a", "aac", + "-shortest", + save_path, + ] + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + os.remove(save_path_crop_audio) + else: + final_command = [ + "ffmpeg", + "-y", + "-i", + save_path_tmp, + "-i", + save_path_crop_audio, + "-i", + "/home/user/video/intel_logo.mp4", + "-filter_complex", + f"[2:v]scale=w={int(resized_logo_w)}:h={int(resized_logo_h)}," + f"setdar=0x0,pad={W}:{H}:{int(pad_w)}:{int(pad_h)}:black[2v],[0:v][1:a][2v][2:a]concat=n=2:v=1:a=1[v][a]", + "-map", "[v]", + "-map", "[a]", + "-c:v", + "libx264", + "-c:a", + "aac", + "-shortest", + "-vsync", + "passthrough", + save_path, + ] + subprocess.run(final_command, check=True) + os.remove(save_path_tmp) + os.remove(save_path_crop_audio) + + +def find_max_matching_frame(max_value: int, default_value: int) -> int: + """ + Finds the largest integer less than or equal to max_value + that can be expressed in the form 4*n + 1. + + Args: + max_value: The upper bound for the search. + + Returns: + The largest number matching the pattern, or None if no such + number exists within the given limit (e.g., if max_value < 1). + """ + # The smallest number of the form 4*n + 1 (for n>=0) is 1. + if max_value < 1: + return default_value + + # Start from max_value and check downwards. + for number in range(max_value, 0, -1): + # A number is of the form 4*n + 1 if its remainder when divided by 4 is 1. + if number % 4 == 1: + return number + + return default_value # Should not be reached if max_value >= 1 + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, 99999999) + # Size check + assert args.size in SUPPORTED_SIZES[args.task], (f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}") + + +def _parse_args(): + parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") + parser.add_argument("--task", type=str, default="infinitetalk-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") + parser.add_argument("--size", type=str, default="infinitetalk-480", choices=list(SIZE_CONFIGS.keys()), help="The buckget size of the generated video. The aspect ratio of the output video will follow that of the input image.",) + parser.add_argument("--max_frame_num", type=int, default=100000, help="The max frame lenght of the generated video.") + parser.add_argument("--ckpt_dir", type=str, default="/hf/Wan2.1-I2V-14B-480P", help="The path to the Wan checkpoint directory.") + parser.add_argument("--infinitetalk_dir", type=str, default="/hf/InfiniteTalk/single/infinitetalk.safetensors", help="The path to the InfiniteTalk checkpoint directory.") + parser.add_argument("--wav2vec_dir", type=str, default="/hf/chinese-wav2vec2-base", help="The path to the wav2vec checkpoint directory.") + parser.add_argument("--quant_dir", type=str, default=None, help="The path to the Wan quant checkpoint directory.") + parser.add_argument("--dit_path", type=str, default=None, help="The path to the Wan checkpoint directory.") + parser.add_argument("--base_seed", type=int, default=42, help="The seed to use for generating the image or video.") + parser.add_argument("--lora_dir", type=str, nargs="+", default=None, help="The paths to the LoRA checkpoint files.") + parser.add_argument("--lora_scale", type=float, nargs="+", default=[1.2], help="Controls how much to influence the outputs with the LoRA parameters. Accepts multiple float values.") + parser.add_argument("--offload_model", type=str2bool, default=None, help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.") + parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") + parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") + parser.add_argument("--t5_fsdp", action="store_true", default=False, help="Whether to use FSDP for T5.") + parser.add_argument("--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.") + parser.add_argument("--dit_fsdp", action="store_true", default=False, help="Whether to use FSDP for DiT.") + parser.add_argument("--mode", type=str, default="clip", choices=["clip", "streaming"], help="clip: generate one video chunk, streaming: long video generation") + parser.add_argument("--audio_mode", type=str, default="localfile", choices=["localfile", "tts"], help="localfile: audio from local wav file, tts: audio from TTS") + parser.add_argument("--motion_frame", type=int, default=9, help="Driven frame length used in the mode of long video genration.") + parser.add_argument("--num_persistent_param_in_dit", type=int, default=None, required=False, help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required") + parser.add_argument("--use_teacache", action="store_true", default=False, help="Enable teacache for video generation.") + parser.add_argument("--teacache_thresh", type=float, default=0.2, help="Threshold for teacache.") + parser.add_argument("--use_apg", action="store_true", default=False, help="Enable adaptive projected guidance for video generation (APG).") + parser.add_argument("--apg_momentum", type=float, default=-0.75, help="Momentum used in adaptive projected guidance (APG).") + parser.add_argument("--apg_norm_threshold", type=float, default=55, help="Norm threshold used in adaptive projected guidance (APG).") + parser.add_argument("--color_correction_strength", type=float, default=1.0, help="strength for color correction [0.0 -- 1.0].") + parser.add_argument("--scene_seg", action="store_true", default=False, help="Enable scene segmentation for input video.") + parser.add_argument("--quant", type=str, default=None, help="Quantization type, must be 'int8' or 'fp8'.") + parser.add_argument("--video_dir", type=str, default="/home/user/video", help="Video output directory.") + parser.add_argument("--sep", type=str, default="$###$", help="Video output directory.") + + args = parser.parse_args() + _validate_args(args) + return args + + +def custom_init(device, wav2vec): + audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True, attn_implementation="eager").to(device) + audio_encoder.feature_extractor._freeze_parameters() + wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True) + return wav2vec_feature_extractor, audio_encoder + + +def loudness_norm(audio_array, sr=16000, lufs=-23): + meter = pyln.Meter(sr) + loudness = meter.integrated_loudness(audio_array) + if abs(loudness) > 100: + return audio_array + normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) + return normalized_audio + + +def audio_prepare_multi(left_path, right_path, audio_type, sample_rate=16000): + if not (left_path == "None" or right_path == "None"): + human_speech_array1 = audio_prepare_single(left_path) + human_speech_array2 = audio_prepare_single(right_path) + elif left_path == "None": + human_speech_array2 = audio_prepare_single(right_path) + human_speech_array1 = np.zeros(human_speech_array2.shape[0]) + elif right_path == "None": + human_speech_array1 = audio_prepare_single(left_path) + human_speech_array2 = np.zeros(human_speech_array1.shape[0]) + + if audio_type == "para": + new_human_speech1 = human_speech_array1 + new_human_speech2 = human_speech_array2 + elif audio_type == "add": + new_human_speech1 = np.concatenate( + [human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])] + ) + new_human_speech2 = np.concatenate( + [np.zeros(human_speech_array1.shape[0]), human_speech_array2[: human_speech_array2.shape[0]]] + ) + sum_human_speechs = new_human_speech1 + new_human_speech2 + return new_human_speech1, new_human_speech2, sum_human_speechs + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)], + ) + else: + logging.basicConfig(level=logging.ERROR) + + +def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device="cpu"): + audio_duration = len(speech_array) / sr + video_length = audio_duration * 25 # Assume the video fps is 25 + + # wav2vec_feature_extractor + audio_feature = np.squeeze(wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values) + audio_feature = torch.from_numpy(audio_feature).float().to(device=device) + audio_feature = audio_feature.unsqueeze(0) + + # audio encoder + with torch.no_grad(): + embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True) + + if len(embeddings) == 0: + print("Fail to extract audio embedding") + return None + + audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) + audio_emb = rearrange(audio_emb, "b s d -> s b d") + + audio_emb = audio_emb.cpu().detach() + return audio_emb + + +def extract_audio_from_video(filename, sample_rate): + raw_audio_path = filename.split("/")[-1].split(".")[0] + ".wav" + ffmpeg_command = [ + "ffmpeg", + "-y", + "-i", + str(filename), + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "2", + str(raw_audio_path), + ] + subprocess.run(ffmpeg_command, check=True) + human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate) + human_speech_array = loudness_norm(human_speech_array, sr) + os.remove(raw_audio_path) + + return human_speech_array + + +def audio_prepare_single(audio_path, sample_rate=16000): + ext = os.path.splitext(audio_path)[1].lower() + if ext in [".mp4", ".mov", ".avi", ".mkv"]: + human_speech_array = extract_audio_from_video(audio_path, sample_rate) + return human_speech_array + else: + human_speech_array, sr = librosa.load(audio_path, sr=sample_rate) + human_speech_array = loudness_norm(human_speech_array, sr) + return human_speech_array + + +def process_tts_single(text, save_dir, voice1): + s1_sentences = [] + + pipeline = KPipeline(lang_code="a", repo_id="weights/Kokoro-82M") + + voice_tensor = torch.load(voice1, weights_only=True) + generator = pipeline( + text, + voice=voice_tensor, # <= change voice here + speed=1, + split_pattern=r"\n+", + ) + audios = [] + for i, (gs, ps, audio) in enumerate(generator): + audios.append(audio) + audios = torch.concat(audios, dim=0) + s1_sentences.append(audios) + s1_sentences = torch.concat(s1_sentences, dim=0) + save_path1 = f"{save_dir}/s1.wav" + sf.write(save_path1, s1_sentences, 24000) # save each audio file + s1, _ = librosa.load(save_path1, sr=16000) + return s1, save_path1 + + +def process_tts_multi(text, save_dir, voice1, voice2): + pattern = r"\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)" + matches = re.findall(pattern, text, re.DOTALL) + + s1_sentences = [] + s2_sentences = [] + + pipeline = KPipeline(lang_code="a", repo_id="weights/Kokoro-82M") + for idx, (speaker, content) in enumerate(matches): + if speaker == "1": + voice_tensor = torch.load(voice1, weights_only=True) + generator = pipeline( + content, + voice=voice_tensor, # <= change voice here + speed=1, + split_pattern=r"\n+", + ) + audios = [] + for i, (gs, ps, audio) in enumerate(generator): + audios.append(audio) + audios = torch.concat(audios, dim=0) + s1_sentences.append(audios) + s2_sentences.append(torch.zeros_like(audios)) + elif speaker == "2": + voice_tensor = torch.load(voice2, weights_only=True) + generator = pipeline( + content, + voice=voice_tensor, # <= change voice here + speed=1, + split_pattern=r"\n+", + ) + audios = [] + for i, (gs, ps, audio) in enumerate(generator): + audios.append(audio) + audios = torch.concat(audios, dim=0) + s2_sentences.append(audios) + s1_sentences.append(torch.zeros_like(audios)) + + s1_sentences = torch.concat(s1_sentences, dim=0) + s2_sentences = torch.concat(s2_sentences, dim=0) + sum_sentences = s1_sentences + s2_sentences + save_path1 = f"{save_dir}/s1.wav" + save_path2 = f"{save_dir}/s2.wav" + save_path_sum = f"{save_dir}/sum.wav" + sf.write(save_path1, s1_sentences, 24000) # save each audio file + sf.write(save_path2, s2_sentences, 24000) + sf.write(save_path_sum, sum_sentences, 24000) + + s1, _ = librosa.load(save_path1, sr=16000) + s2, _ = librosa.load(save_path2, sr=16000) + # sum, _ = librosa.load(save_path_sum, sr=16000) + return s1, s2, save_path_sum + + +def update_job(job_processed, args): + # If a job was processed, rewrite the entire job file + job_file = os.path.join(args.video_dir, "job.txt") + sep = args.sep + if job_processed: + with open(job_file, "r+", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + # Re-read the file to get the latest content before writing + f.seek(0) + lines_before_write = [line.strip() for line in f if line.strip()] + + # Find the job by ID and update it + job_id_to_update = job_processed[0] + found = False + for i, line in enumerate(lines_before_write): + if line.startswith(job_id_to_update + sep): + lines_before_write[i] = sep.join(map(str, job_processed)) + found = True + break + + # If the job was somehow removed from the file, add the new status at the end + if not found: + lines_before_write.append(sep.join(map(str, job_processed))) + + # Write the updated content back to the file + f.seek(0) + f.truncate() + for line in lines_before_write: + f.write(line + "\n") + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info(f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) + else: + assert not (args.t5_fsdp or args.dit_fsdp), (f"t5_fsdp and dit_fsdp are not supported in non-distributed environments.") + assert not (args.ulysses_size > 1 or args.ring_size > 1), (f"context parallel are not supported in non-distributed environments.") + + if args.ulysses_size > 1 or args.ring_size > 1: + assert args.ulysses_size * args.ring_size == world_size, (f"The number of ulysses_size and ring_size should be equal to the world size.") + assert args.ulysses_size * args.ring_size <= 8, (f"Currently, sequence parallel degree should be no larger than 8.") # TODO: remove this limit in the future + from wan.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, + ) + + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=args.ring_size, + ulysses_degree=args.ulysses_size, + ) + + cfg = WAN_CONFIGS[args.task] + if args.ulysses_size > 1: + assert cfg.num_heads % args.ulysses_size == 0, ( + f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." + ) + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + assert args.task == "infinitetalk-14B", "You should choose infinitetalk in args.task." + + logging.info("Creating infinitetalk pipeline.") + wan_i2v = wan.InfiniteTalkPipeline( + config=cfg, + checkpoint_dir=args.ckpt_dir, + quant_dir=args.quant_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + lora_dir=args.lora_dir, + lora_scales=args.lora_scale, + quant=args.quant, + dit_path=args.dit_path, + infinitetalk_dir=args.infinitetalk_dir, + ) + + if args.num_persistent_param_in_dit is not None: + wan_i2v.vram_management = True + wan_i2v.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit) + + # Initialize models once before the loop to prevent race conditions + wav2vec_feature_extractor, audio_encoder = custom_init("cpu", args.wav2vec_dir) + + job_file = os.path.join(args.video_dir, "job.txt") + while True: + try: + time.sleep(10.0) + if not os.path.exists(job_file): + time.sleep(1.0) + continue + + job_to_process = None + if rank == 0: + with open(job_file, "r+", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + lines = [line.strip() for line in f if line.strip()] + updated_lines = [] + job_found = False + for line in lines: + parts = line.strip().split(args.sep) + if not job_found and len(parts) >= 17 and parts[1] == "queued": + job_found = True + parts[1] = "processing" # Mark as processing + parts[15] = str(int(time.time())) # Set start time + job_to_process = parts + updated_lines.append(args.sep.join(map(str, parts)) + "\n") + else: + updated_lines.append(line + "\n") + + if job_found: + f.seek(0) + f.truncate() + f.writelines(updated_lines) + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + if world_size > 1: + job_list = [job_to_process] if rank == 0 else [None] + dist.broadcast_object_list(job_list, src=0) + job_to_process = job_list[0] + + if job_to_process: + try: + id, status, created_str, prompt, seconds, size, quality, fps, shift, steps, guide_scale, audio_guide_scale, seed, logo_video, generate_duration, start_time, end_time, *error_msg_parts = job_to_process + generate_start_time = float(start_time) + + fps = 25 + user_frames = int(seconds) * fps + 1 + num_frames = 81 if user_frames >= 81 else find_max_matching_frame(user_frames, 5) + generated_list = [] + job_dir = os.path.join(args.video_dir, id) + os.makedirs(job_dir, exist_ok=True) + input_json = os.path.join(job_dir, "input.json") + audio_save_dir = os.path.join(job_dir, "audio") + save_file = os.path.join(job_dir, "output") + with open(input_json, "r", encoding="utf-8") as f: + input_data = json.load(f) + + audio_save_dir = os.path.join(audio_save_dir, input_data["cond_video"].split("/")[-1].split(".")[0]) + os.makedirs(audio_save_dir, exist_ok=True) + + conds_list = [] + + if args.scene_seg and is_video(input_data["cond_video"]): + time_list, cond_list = shot_detect(input_data["cond_video"], audio_save_dir) + if len(time_list) == 0: + conds_list.append([input_data["cond_video"]]) + conds_list.append([input_data["cond_audio"]["person1"]]) + if len(input_data["cond_audio"]) == 2: + conds_list.append([input_data["cond_audio"]["person2"]]) + else: + audio1_list = split_wav_librosa(input_data["cond_audio"]["person1"], time_list, audio_save_dir) + conds_list.append(cond_list) + conds_list.append(audio1_list) + if len(input_data["cond_audio"]) == 2: + audio2_list = split_wav_librosa(input_data["cond_audio"]["person2"], time_list, audio_save_dir) + conds_list.append(audio2_list) + else: + conds_list.append([input_data["cond_video"]]) + conds_list.append([input_data["cond_audio"]["person1"]]) + if len(input_data["cond_audio"]) == 2: + conds_list.append([input_data["cond_audio"]["person2"]]) + + if len(input_data["cond_audio"]) == 2: + new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(input_data["cond_audio"]["person1"], input_data["cond_audio"]["person2"], input_data["audio_type"]) + sum_audio = os.path.join(audio_save_dir, "sum_all.wav") + sf.write(sum_audio, sum_human_speechs, 16000) + input_data["video_audio"] = sum_audio + else: + human_speech = audio_prepare_single(input_data["cond_audio"]["person1"]) + sum_audio = os.path.join(audio_save_dir, "sum_all.wav") + sf.write(sum_audio, human_speech, 16000) + input_data["video_audio"] = sum_audio + logging.info("Generating video ...") + + for idx, items in enumerate(zip(*conds_list)): + input_clip = {} + input_clip["prompt"] = input_data.get("prompt", " ") + input_clip["cond_video"] = items[0] + + if "audio_type" in input_data: + input_clip["audio_type"] = input_data["audio_type"] + if "bbox" in input_data: + input_clip["bbox"] = input_data["bbox"] + cond_audio = {} + if args.audio_mode == "localfile": + if len(input_data["cond_audio"]) == 2: + new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(items[1], items[2], input_data["audio_type"]) + audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder) + audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder) + sum_audio = os.path.join(audio_save_dir, "sum.wav") + sf.write(sum_audio, sum_human_speechs, 16000) + cond_audio["person1"] = audio_embedding_1 + cond_audio["person2"] = audio_embedding_2 + input_clip["video_audio"] = sum_audio + elif len(input_data["cond_audio"]) == 1: + human_speech = audio_prepare_single(items[1]) + audio_embedding = get_embedding(human_speech, wav2vec_feature_extractor, audio_encoder) + sum_audio = os.path.join(audio_save_dir, "sum.wav") + sf.write(sum_audio, human_speech, 16000) + cond_audio["person1"] = audio_embedding + input_clip["video_audio"] = sum_audio + + input_clip["cond_audio"] = cond_audio + + video = wan_i2v.generate_infinitetalk( + input_clip, + size_buckget=args.size, + motion_frame=args.motion_frame, + frame_num=num_frames, + shift=float(shift), + sampling_steps=int(steps), + text_guide_scale=float(guide_scale), + audio_guide_scale=float(audio_guide_scale), + seed=int(seed), + offload_model=args.offload_model, + max_frames_num=args.max_frame_num, + color_correction_strength=args.color_correction_strength, + extra_args=args, + ) + + generated_list.append(video) + + if rank == 0: + sum_video = torch.cat(generated_list, dim=1) + if logo_video.lower() == "true": + save_video_with_logo(sum_video, save_file, [input_data["video_audio"]], high_quality_save=False, fps=fps) + else: + save_video_ffmpeg(sum_video, save_file, [input_data["video_audio"]], high_quality_save=False, fps=fps) + + generate_end_time = time.time() + job_processed = [id, "completed", created_str, prompt, seconds, size, quality, fps, shift, steps, guide_scale, audio_guide_scale, seed, logo_video, max(0, int(generate_end_time - generate_start_time)), int(generate_start_time), int(generate_end_time), ""] + if rank == 0: + update_job(job_processed, args) + except Exception as e: + logging.error(f"error: {e}") + generate_end_time = time.time() + job_processed = [id, "error", created_str, prompt, seconds, size, quality, fps, shift, steps, guide_scale, audio_guide_scale, seed, logo_video, max(0, int(generate_end_time - generate_start_time)), int(generate_start_time), int(generate_end_time), str(e)] + if rank == 0: + update_job(job_processed, args) + + except Exception as e: + logging.error(f"Job worker encountered an error: {e}") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/examples/InfiniteTalk/opea_text2video/src/requirements-infinitetalk.txt b/examples/InfiniteTalk/opea_text2video/src/requirements-infinitetalk.txt new file mode 100755 index 0000000000..ebda467ba1 --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/src/requirements-infinitetalk.txt @@ -0,0 +1,24 @@ +opencv-python>=4.9.0.80 +diffusers>=0.31.0 +transformers>=4.49.0 +tokenizers>=0.20.3 +accelerate>=1.1.1 +tqdm +imageio +easydict +ftfy +dashscope +imageio-ffmpeg +scikit-image +loguru +gradio>=5.0.0 +numpy>=1.23.5,<2 +pyloudnorm +optimum-quanto==0.2.6 +scenedetect +moviepy==1.0.3 +decord +misaki[en] +librosa +soundfile +peft==0.17.0 \ No newline at end of file diff --git a/examples/InfiniteTalk/opea_text2video/src/requirements.txt b/examples/InfiniteTalk/opea_text2video/src/requirements.txt new file mode 100755 index 0000000000..cef543251d --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/src/requirements.txt @@ -0,0 +1,17 @@ +accelerate +datasets +diffusers +docarray[full] +fastapi +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +prometheus-fastapi-instrumentator +pydantic==2.7.2 +pydub +shortuuid +torch +transformers +uvicorn +python-multipart +librosa \ No newline at end of file diff --git a/examples/InfiniteTalk/opea_text2video/src/web_service.py b/examples/InfiniteTalk/opea_text2video/src/web_service.py new file mode 100755 index 0000000000..45e8752162 --- /dev/null +++ b/examples/InfiniteTalk/opea_text2video/src/web_service.py @@ -0,0 +1,359 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +import fcntl +import shutil +import math + +from fastapi import Depends, Request, status +from fastapi.responses import FileResponse, JSONResponse + +from comps import ( + CustomLogger, + OpeaComponentLoader, + opea_microservices, + register_microservice, + register_statistics, + statistics_dict, +) +from component import Text2VideoInput, Text2VideoOutput, ServiceType, OpeaText2Video + + +# Initialize logger and component loader +logger = CustomLogger("text2video") +component_loader = None +LOGFLAG = os.getenv("LOGFLAG", "False").lower() in ("true", "1", "t") + + +def validate_form_parameters(form): + """Validate and convert form parameters to their expected types.""" + try: + audio = [] + if "audio[]" in form: + audio += form.getlist("audio[]") + elif "audio" in form: + audio += form.getlist("audio") + + params = { + "prompt": form.get("prompt"), + "input_reference": form.get("input_reference"), + "audio": audio, + "audio_guide_scale": float(form.get("audio_guide_scale", 5.0)), + "audio_type": form.get("audio_type", "add"), + "model": form.get("model"), + "seconds": int(form.get("seconds", 4)), + "fps": int(form.get("fps", 25)), + "shift": float(form.get("shift", 5.0)), + "steps": int(form.get("steps", 40)), + "seed": int(form.get("seed", 42)), + "guide_scale": float(form.get("guide_scale", 5.0)), + "size": form.get("size", "720x1280"), + "logo_video": form.get("logo_video", "False") + } + + if params["seconds"] <= 0: + raise ValueError("The 'seconds' parameter must be greater than 0.") + + # Validate size format + width, height = params["size"].split("x") + if not (width.isdigit() and height.isdigit()): + raise ValueError("Invalid size format. Expected 'widthxheight'.") + + if not params["input_reference"] or len(params["audio"]) == 0: + raise ValueError("'input_reference' and 'audio' must be provided.") + + return params, None + except (ValueError, TypeError) as e: + error_content = {"error": {"message": f"Invalid parameter type: {e}", "code": "400"}} + return None, JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=error_content) + + +async def resolve_request(request: Request): + form = await request.form() + validated_params, error_response = validate_form_parameters(form) + if error_response: + return error_response + return Text2VideoInput(**validated_params) + + +def calculate_progress(job_info): + estimated_time = estimate_queue_time(int(job_info[4]), int(job_info[9])) + start_time = int(job_info[15]) + elapsed_time = int(time.time()) - start_time + progress = int(min(int((elapsed_time / (estimated_time * 60)) * 100), 99)) + left_time = int(max(1, int(estimated_time - (elapsed_time / 60)))) + return progress, left_time + + +def estimate_queue_time(seconds, steps): + steps = max(steps, 1) + return math.ceil(seconds * 1.16 * steps / 20) if seconds <= 10 else math.ceil(int(seconds * steps / 20)) if seconds <= 15 else math.ceil(int(seconds * 0.83 * steps / 20)) + + +def generate_response(video_id) -> Text2VideoOutput: + job_file = os.path.join(os.getenv("VIDEO_DIR"), "job.txt") + if os.path.exists(job_file): + sep = os.getenv("SEP") + queue_estimated_time_in_minutes = 0 + queue_length = 0 + job_info = None + with open(job_file, "r") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + lines = f.readlines() + for line in lines: + job = line.strip().split(sep) + + if len(job) < 17: + continue + + if job[0] == video_id: + job_info = job + queue_estimated_time_in_minutes += estimate_queue_time(int(job[4]), int(job[9])) + break + + if job[1] == "queued": + queue_length += 1 + queue_estimated_time_in_minutes += estimate_queue_time(int(job[4]), int(job[9])) + + if job[1] == "processing": + progress, left_time = calculate_progress(job) + queue_length += 1 + queue_estimated_time_in_minutes += left_time + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + if job_info: + if job_info[1] == "processing": + progress, left_time = calculate_progress(job_info) + return Text2VideoOutput( + id=job_info[0], + model=os.getenv("MODEL"), + status=job_info[1], + progress=progress, + created_at=int(job_info[2]), + seconds=job_info[4], + duration=0, + estimated_time=left_time, + queue_length=0, + error=job_info[-1] if job_info[1] == "error" else "" + ) + else: + return Text2VideoOutput( + id=job_info[0], + model=os.getenv("MODEL"), + status=job_info[1], + progress=100 if job_info[1] == "completed" else 0, + created_at=int(job_info[2]), + seconds=job_info[4], + duration=job_info[14], + estimated_time=0 if job_info[1] == "completed" else int(queue_estimated_time_in_minutes), + queue_length=0 if job_info[1] == "completed" else queue_length, + error=job_info[-1] if job_info[1] == "error" else "" + ) + + content = { + "error": { + "message": f"Video with id {video_id} not found.", + "code": "404" + } + } + return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content) + + +@register_microservice( + name="opea_service@text2video", + service_type=ServiceType.TEXT2VIDEO, + endpoint="/v1/videos", + host="0.0.0.0", + port=9396, + input_datatype=Text2VideoInput, + output_datatype=Text2VideoOutput, +) +@register_statistics(names=["opea_service@text2video"]) +async def text2video(input_data: Text2VideoInput = Depends(resolve_request)) -> Text2VideoOutput: + """ + Process a text-to-video generation request. + + Args: + input_data (Text2VideoInput): The input data containing the prompt. + + Returns: + Text2VideoOutput: The result of the video generation. + """ + if isinstance(input_data, JSONResponse): + return input_data + start = time.time() + if component_loader: + try: + job_id = await component_loader.invoke(input_data) + results = generate_response(job_id) + except ValueError as ve: + error_content = {"error": {"message": str(ve), "code": "400"}} + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=error_content) + except Exception as e: + error_content = {"error": {"message": f"Internal server error: {e}", "code": "500"}} + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_content) + else: + raise RuntimeError("Component loader is not initialized.") + latency = time.time() - start + statistics_dict["opea_service@text2video"].append_latency(latency, None) + return results + + +@register_microservice( + name="opea_service@text2video", + service_type=ServiceType.TEXT2VIDEO, + endpoint="/v1/videos/{video_id}", + host="0.0.0.0", + port=9396, + methods=["GET"], +) +@register_statistics(names=["opea_service@text2video"]) +async def get_video(video_id: str): + try: + return generate_response(video_id) + except Exception as e: + error_content = {"error": {"message": f"Internal server error: {e}", "code": "500"}} + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_content) + + +@register_microservice( + name="opea_service@text2video", + service_type=ServiceType.TEXT2VIDEO, + endpoint="/v1/videos/{video_id}", + host="0.0.0.0", + port=9396, + methods=["DELETE"], +) +@register_statistics(names=["opea_service@text2video"]) +async def delete_video(video_id: str): + try: + job_file = os.path.join(os.getenv("VIDEO_DIR"), "job.txt") + if not os.path.exists(job_file): + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"error": {"message": f"Job queue is missing and video with id {video_id} not found.", "code": "404"}}, + ) + + sep = os.getenv("SEP") + deleted_job_info = None + updated_lines = [] + job_found = False + + with open(job_file, "r+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: + lines = f.readlines() + for line in lines: + job = line.strip().split(sep) + if job[0] == video_id: + job_found = True + if job[1] == "processing": + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"error": {"message": f"Video with id {video_id} is processing and cannot be deleted.", "code": "400"}}, + ) + deleted_job_info = job + else: + updated_lines.append(line) + + if not job_found: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"error": {"message": f"Video with id {video_id} not found.", "code": "404"}}, + ) + + # Rewrite the file without the deleted line + f.seek(0) + f.truncate() + f.writelines(updated_lines) + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + if deleted_job_info: + video_folder_path = os.path.join(os.getenv("VIDEO_DIR"), deleted_job_info[0]) + if os.path.isdir(video_folder_path): + shutil.rmtree(video_folder_path) + return Text2VideoOutput( + id=deleted_job_info[0], + model=os.getenv("MODEL"), + status="deleted", + progress=0, + created_at=int(deleted_job_info[2]), + seconds=deleted_job_info[4], + duration=int(deleted_job_info[14]), + estimated_time=0, + queue_length=0, + error="" + ) + + except Exception as e: + error_content = {"error": {"message": f"Internal server error: {e}", "code": "500"}} + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_content) + + +@register_microservice( + name="opea_service@text2video", + service_type=ServiceType.TEXT2VIDEO, + endpoint="/v1/videos/{video_id}/content", + host="0.0.0.0", + port=9396, + methods=["GET"], +) +@register_statistics(names=["opea_service@text2video"]) +async def get_video_content(video_id: str): + try: + res = generate_response(video_id) + if isinstance(res, JSONResponse): + return res + if res.status == "completed": + video_path = os.path.join(os.getenv("VIDEO_DIR"), video_id, "output.mp4") + if os.path.exists(video_path): + return FileResponse(video_path, media_type="video/mp4", filename=f"{video_id}.mp4") + else: + error_content = {"error": {"message": f"Video file for id {video_id} not found.", "code": "404"}} + return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=error_content) + else: + return res + except Exception as e: + error_content = {"error": {"message": f"Internal server error: {e}", "code": "500"}} + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_content) + + +def main(): + """ + Main function to set up and run the text-to-video microservice. + """ + global component_loader + + parser = argparse.ArgumentParser(description="Text-to-Video Microservice") + parser.add_argument("--model_name_or_path", type=str, default="InfinteTalk", help="Model name or path.") + parser.add_argument("--video_dir", type=str, default="/home/user/video", help="Video output directory.") + + args = parser.parse_args() + os.environ["MODEL"] = args.model_name_or_path + os.environ["VIDEO_DIR"] = args.video_dir + os.environ["SEP"] = "$###$" + text2video_component_name = os.getenv("TEXT2VIDEO_COMPONENT_NAME", "OPEA_TEXT2VIDEO") + + try: + component_loader = OpeaComponentLoader( + component_name=text2video_component_name, + description=f"OPEA IMAGES_GENERATIONS Component: {text2video_component_name}", + config=args.__dict__, + video_dir=args.video_dir, + ) + except Exception as e: + logger.error(f"Failed to initialize component loader: {e}") + exit(1) + + logger.info("Text-to-video server started.") + opea_microservices["opea_service@text2video"].start() + + +if __name__ == "__main__": + main()