diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 87cdccef19b..8cd65aafaff 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -52,6 +52,7 @@ th { |`Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | |`FishSpeechSlowARForConditionalGeneration` | Fish Speech S2 Pro | `fishaudio/s2-pro` | |`DreamIDOmniPipeline`| DreamID-Omni | `XuGuo699/DreamID-Omni` | +|`LingbotWorldPipeline` | LingBot-World Base (Cam) | `robbyant/lingbot-world-base-cam` | ## List of Supported Models for NPU diff --git a/examples/offline_inference/lingbot-world/README.md b/examples/offline_inference/lingbot-world/README.md new file mode 100644 index 00000000000..fc8afa4cc0a --- /dev/null +++ b/examples/offline_inference/lingbot-world/README.md @@ -0,0 +1,56 @@ +# LingBot-World + +Offline LingBot-World examples. + +## Download + +```bash +python download_lingbot_world.py \ + --model-id robbyant/lingbot-world-base-cam \ + --output-dir ./lingbot-world-base-cam +``` + +The prepared model directory looks like this: + +```text +lingbot-world-base-cam/ +├── configuration.json +├── google/ +├── high_noise_model/ +├── low_noise_model/ +├── models_t5_umt5-xxl-enc-bf16.pth +├── model_index.json +└── Wan2.1_VAE.pth +``` + +## Run With Control Signals + +```bash +PROMPT="$(cat /tmp/vllm-omni-dependency/lingbot-world/examples/00/prompt.txt)" + +python image_to_video.py \ + --model ./lingbot-world-base-cam \ + --image /tmp/vllm-omni-dependency/lingbot-world/examples/00/image.jpg \ + --action-path /tmp/vllm-omni-dependency/lingbot-world/examples/00 \ + --prompt "$PROMPT" \ + --output lingbot_world_base_cam_examples00.mp4 +``` + +## Run Without Control Signals + +```bash +PROMPT="$(cat /tmp/vllm-omni-dependency/lingbot-world/examples/00/prompt.txt)" + +python image_to_video.py \ + --model ./lingbot-world-base-cam \ + --image /tmp/vllm-omni-dependency/lingbot-world/examples/00/image.jpg \ + --prompt "$PROMPT" \ + --output lingbot_world_base_cam_no_control.mp4 +``` + +## Notes + +- `--action-path` is optional. +- For `LingBot-World-Base (Cam)`, control signals should contain `poses.npy` and `intrinsics.npy`. +- For `LingBot-World-Base (Act)`, `action.npy` is also required. +- `--enable-cpu-offload` is supported for offline inference. diff --git a/examples/offline_inference/lingbot-world/download_lingbot_world.py b/examples/offline_inference/lingbot-world/download_lingbot_world.py new file mode 100644 index 00000000000..da4bc10255b --- /dev/null +++ b/examples/offline_inference/lingbot-world/download_lingbot_world.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import fcntl +import json +import os +import site +import subprocess +import tempfile +import time +from pathlib import Path + +try: + from huggingface_hub import snapshot_download +except ImportError: + snapshot_download = None + + +DEFAULT_MODEL_ID = "robbyant/lingbot-world-base-cam" +DEFAULT_OUTPUT_DIR = "./lingbot-world-base-cam" +DEFAULT_CLASS_NAME = "LingbotWorldPipeline" +DEPENDENCY_REPO = "https://github.com/robbyant/lingbot-world.git" +DEPENDENCY_BRANCH = "main" +CACHE_DIR = Path(tempfile.gettempdir()) / "vllm-omni-dependency" +LOCK_FILE = CACHE_DIR / ".lingbot_world_install.lock" +DEPENDENCY_DIR = CACHE_DIR / "lingbot-world" +PTH_FILE_NAME = "vllm_omni_lingbot_world_dependency.pth" + +REQUIRED_FILES = ( + "configuration.json", + "models_t5_umt5-xxl-enc-bf16.pth", + "Wan2.1_VAE.pth", + "low_noise_model/config.json", + "high_noise_model/config.json", +) + + +def infer_control_type(model_ref: str) -> str: + model_ref = model_ref.lower() + if "act" in model_ref: + return "act" + return "cam" + + +def ensure_model_index( + output_dir: Path, + *, + class_name: str = DEFAULT_CLASS_NAME, + control_type: str | None = None, +) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + model_index_path = output_dir / "model_index.json" + payload = { + "_class_name": class_name, + "control_type": control_type or infer_control_type(str(output_dir)), + } + model_index_path.write_text(json.dumps(payload, indent=2) + "\n") + return model_index_path + + +def validate_model_directory(output_dir: Path) -> None: + missing = [rel_path for rel_path in REQUIRED_FILES if not (output_dir / rel_path).exists()] + if missing: + raise FileNotFoundError("LingBot-World download is incomplete. Missing files: " + ", ".join(sorted(missing))) + + +def download_dependency() -> Path: + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + with open(LOCK_FILE, "w") as lock_file: + fcntl.flock(lock_file, fcntl.LOCK_EX) + if not DEPENDENCY_DIR.exists(): + print(f"Downloading LingBot-World to {DEPENDENCY_DIR} ...") + subprocess.run( + [ + "git", + "clone", + "--depth", + "1", + DEPENDENCY_REPO, + "--branch", + DEPENDENCY_BRANCH, + str(DEPENDENCY_DIR), + ], + check=True, + ) + print("Download finished.") + fcntl.flock(lock_file, fcntl.LOCK_UN) + + site_packages = Path(site.getsitepackages()[0]) + pth_file = site_packages / PTH_FILE_NAME + pth_file.write_text(f"{DEPENDENCY_DIR}\n", encoding="utf-8") + print(f"Added {DEPENDENCY_DIR} to site-packages via {pth_file}") + return pth_file + + +def timed_download(repo_id: str, local_dir: str) -> None: + if os.path.exists(local_dir): + print(f"Directory {local_dir} already exists. Skipping download.") + return + if snapshot_download is None: + raise ImportError( + "huggingface_hub is required to download LingBot-World. Install it before running this script." + ) + print(f"Starting download from {repo_id} into {local_dir}") + start_time = time.time() + + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + + elapsed = time.time() - start_time + print(f"Finished downloading {repo_id} in {elapsed:.2f} seconds. Files saved at: {local_dir}") + + +def download_lingbot_world(model_id: str, output_dir: Path) -> Path: + timed_download(repo_id=model_id, local_dir=str(output_dir)) + ensure_model_index(output_dir, control_type=infer_control_type(model_id)) + validate_model_directory(output_dir) + return output_dir + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download LingBot-World from Hugging Face.") + parser.add_argument("--model-id", default=DEFAULT_MODEL_ID, help="Hugging Face model ID to download.") + parser.add_argument( + "--output-dir", + default=DEFAULT_OUTPUT_DIR, + help="Local directory for the prepared model.", + ) + return parser.parse_args() + + +def main(output_dir: str, model_id: str = DEFAULT_MODEL_ID) -> None: + model_dir = download_lingbot_world(model_id, Path(output_dir).expanduser().resolve()) + download_dependency() + print(f"Prepared LingBot-World model at: {model_dir}") + + +if __name__ == "__main__": + args = parse_args() + main(args.output_dir, args.model_id) diff --git a/examples/offline_inference/lingbot-world/image_to_video.py b/examples/offline_inference/lingbot-world/image_to_video.py new file mode 100644 index 00000000000..bd01eca3fce --- /dev/null +++ b/examples/offline_inference/lingbot-world/image_to_video.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import time +from pathlib import Path + +import numpy as np +import PIL.Image +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate a video with LingBot-World.") + parser.add_argument("--model", default="./lingbot-world-base-cam", help="Model path or Hugging Face ID.") + parser.add_argument("--image", required=True, help="Path to the input image.") + parser.add_argument("--prompt", required=True, help="Prompt text.") + parser.add_argument("--negative-prompt", default="", help="Negative prompt.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument("--height", type=int, default=480, help="Video height.") + parser.add_argument("--width", type=int, default=832, help="Video width.") + parser.add_argument("--num-frames", type=int, default=161, help="Number of frames.") + parser.add_argument("--num-inference-steps", type=int, default=20, help="Sampling steps.") + parser.add_argument("--guidance-scale", type=float, default=5.0, help="Low-noise CFG scale.") + parser.add_argument("--guidance-scale-high", type=float, default=5.0, help="High-noise CFG scale.") + parser.add_argument("--boundary-ratio", type=float, default=0.875, help="Boundary split ratio.") + parser.add_argument("--flow-shift", type=float, default=10.0, help="Scheduler flow shift.") + parser.add_argument("--action-path", default=None, help="Optional path to control signals.") + parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") + parser.add_argument("--output", default="lingbot_world_output.mp4", help="Output video path.") + parser.add_argument( + "--enable-cpu-offload", + action="store_true", + help="Enable CPU offload.", + ) + return parser.parse_args() + + +def extract_frames(output: object) -> tuple[object, object | None]: + frames = output + audio = None + + if isinstance(frames, list): + frames = frames[0] if frames else None + + if isinstance(frames, OmniRequestOutput): + if frames.final_output_type != "image": + raise ValueError(f"Unexpected output type '{frames.final_output_type}', expected 'image'.") + if frames.multimodal_output and "audio" in frames.multimodal_output: + audio = frames.multimodal_output["audio"] + if frames.is_pipeline_output and frames.request_output is not None: + inner_output = frames.request_output + if isinstance(inner_output, OmniRequestOutput): + if inner_output.multimodal_output and "audio" in inner_output.multimodal_output: + audio = inner_output.multimodal_output["audio"] + frames = inner_output + if isinstance(frames, OmniRequestOutput): + if frames.images: + if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: + frames, audio = frames.images[0] + elif len(frames.images) == 1 and isinstance(frames.images[0], dict): + audio = frames.images[0].get("audio") + frames = frames.images[0].get("frames") or frames.images[0].get("video") + else: + frames = frames.images + else: + raise ValueError("No video frames found in OmniRequestOutput.") + + if isinstance(frames, list) and frames: + first_item = frames[0] + if isinstance(first_item, tuple) and len(first_item) == 2: + frames, audio = first_item + elif isinstance(first_item, dict): + audio = first_item.get("audio") + frames = first_item.get("frames") or first_item.get("video") + elif isinstance(first_item, list): + frames = first_item + + if isinstance(frames, tuple) and len(frames) == 2: + frames, audio = frames + elif isinstance(frames, dict): + audio = frames.get("audio") + frames = frames.get("frames") or frames.get("video") + + if frames is None: + raise ValueError("No video frames found in output.") + + return frames, audio + + +def normalize_frame(frame: object) -> object: + if isinstance(frame, torch.Tensor): + frame_tensor = frame.detach().cpu() + if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor[0] + if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): + frame_tensor = frame_tensor.permute(1, 2, 0) + if frame_tensor.is_floating_point(): + frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 + return frame_tensor.float().numpy() + if isinstance(frame, np.ndarray): + frame_array = frame + if frame_array.ndim == 4 and frame_array.shape[0] == 1: + frame_array = frame_array[0] + if np.issubdtype(frame_array.dtype, np.integer): + frame_array = frame_array.astype(np.float32) / 255.0 + return frame_array + if isinstance(frame, PIL.Image.Image): + return np.asarray(frame).astype(np.float32) / 255.0 + return frame + + +def ensure_frame_list(video_array: object) -> object: + if isinstance(video_array, list): + if len(video_array) == 0: + return video_array + first_item = video_array[0] + if isinstance(first_item, np.ndarray): + if first_item.ndim == 5: + return list(first_item[0]) + if first_item.ndim == 4: + return list(first_item) + if first_item.ndim == 3: + return video_array + return video_array + if isinstance(video_array, np.ndarray): + if video_array.ndim == 5: + return list(video_array[0]) + if video_array.ndim == 4: + return list(video_array) + if video_array.ndim == 3: + return [video_array] + return video_array + + +def export_video(frames: object, output_path: Path, fps: int) -> None: + try: + from diffusers.utils import export_to_video + except ImportError as exc: + raise ImportError("diffusers is required for export_to_video.") from exc + + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + if video_tensor.dim() == 5: + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + frames = video_tensor.numpy() + + frames = ensure_frame_list(frames) + if not isinstance(frames, list): + raise ValueError("Expected frames to be a list after normalization.") + + normalized_frames = [normalize_frame(frame) for frame in frames] + export_to_video(normalized_frames, output_video_path=str(output_path), fps=fps) + + +def main() -> None: + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + image = PIL.Image.open(args.image).convert("RGB") + image = image.resize((args.width, args.height), PIL.Image.Resampling.LANCZOS) + + omni = Omni( + model=args.model, + boundary_ratio=args.boundary_ratio, + flow_shift=args.flow_shift, + enable_cpu_offload=args.enable_cpu_offload, + ) + + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Image: {args.image}") + print(f" Size: {args.width}x{args.height}") + print(f" Frames: {args.num_frames}") + print(f" Steps: {args.num_inference_steps}") + if args.action_path is not None: + print(f" Action path: {args.action_path}") + + start = time.perf_counter() + output = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": {"image": image}, + }, + OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + generator=generator, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_high, + num_inference_steps=args.num_inference_steps, + num_frames=args.num_frames, + frame_rate=float(args.fps), + extra_args={"action_path": args.action_path} if args.action_path is not None else {}, + ), + ) + elapsed = time.perf_counter() - start + print(f"Generation time: {elapsed:.2f}s") + + frames, _ = extract_frames(output) + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + export_video(frames, output_path, args.fps) + print(f"Saved video to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot-world/README.md b/examples/online_serving/lingbot-world/README.md new file mode 100644 index 00000000000..da973887d8d --- /dev/null +++ b/examples/online_serving/lingbot-world/README.md @@ -0,0 +1,43 @@ +# LingBot-World + +Online LingBot-World examples using the async `/v1/videos` API. + +## Download + +```bash +python ../../offline_inference/lingbot-world/download_lingbot_world.py \ + --model-id robbyant/lingbot-world-base-cam \ + --output-dir ./lingbot-world-base-cam +``` + +## Start Server + +```bash +MODEL=./lingbot-world-base-cam \ +bash run_server.sh +``` + +Or: + +```bash +vllm serve ./lingbot-world-base-cam --omni --port 8099 +``` + +## Submit a Job + +```bash +PROMPT="$(cat /tmp/vllm-omni-dependency/lingbot-world/examples/00/prompt.txt)" +INPUT_IMAGE=/tmp/vllm-omni-dependency/lingbot-world/examples/00/image.jpg \ +PROMPT="$PROMPT" \ +bash run_curl_image_to_video.sh +``` + +The script follows the generic `examples/online_serving/image_to_video` flow and uses the same 480P defaults as the upstream LingBot-World example. + +- `POST /v1/videos` +- `GET /v1/videos/{video_id}` +- `GET /v1/videos/{video_id}/content` + +## Current Limitation + +`action_path` is not exposed by the current online `/v1/videos` API. Control signals are offline-only for now. diff --git a/examples/online_serving/lingbot-world/run_curl_image_to_video.sh b/examples/online_serving/lingbot-world/run_curl_image_to_video.sh new file mode 100755 index 00000000000..617b598ef2b --- /dev/null +++ b/examples/online_serving/lingbot-world/run_curl_image_to_video.sh @@ -0,0 +1,75 @@ +#!/bin/bash +set -euo pipefail + +BASE_URL="${BASE_URL:-http://localhost:8099}" +EXAMPLE_DIR="${EXAMPLE_DIR:-/tmp/vllm-omni-dependency/lingbot-world/examples/00}" +INPUT_IMAGE="${INPUT_IMAGE:-${EXAMPLE_DIR}/image.jpg}" +PROMPT="${PROMPT:-$(cat "${EXAMPLE_DIR}/prompt.txt")}" +OUTPUT_PATH="${OUTPUT_PATH:-lingbot_world_output.mp4}" +NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}" +POLL_INTERVAL="${POLL_INTERVAL:-2}" + +if [ ! -f "$INPUT_IMAGE" ]; then + echo "Input image not found: $INPUT_IMAGE" + exit 1 +fi + +create_cmd=( + curl -sS -X POST "${BASE_URL}/v1/videos" + -H "Accept: application/json" + -F "prompt=${PROMPT}" + -F "input_reference=@${INPUT_IMAGE}" + -F "size=832x480" + -F "num_frames=161" + -F "fps=16" + -F "num_inference_steps=20" + -F "guidance_scale=5.0" + -F "guidance_scale_2=5.0" + -F "boundary_ratio=0.875" + -F "flow_shift=10.0" + -F "seed=42" +) + +if [ -n "${NEGATIVE_PROMPT}" ]; then + create_cmd+=(-F "negative_prompt=${NEGATIVE_PROMPT}") +fi + +create_response="$("${create_cmd[@]}")" +video_id="$(echo "${create_response}" | jq -r '.id')" +if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then + echo "Failed to create video job:" + echo "${create_response}" | jq . + exit 1 +fi + +echo "Created video job ${video_id}" +echo "${create_response}" | jq . + +while true; do + status_response="$(curl -sS "${BASE_URL}/v1/videos/${video_id}")" + status="$(echo "${status_response}" | jq -r '.status')" + + case "${status}" in + queued|in_progress) + echo "Video job ${video_id} status: ${status}" + sleep "${POLL_INTERVAL}" + ;; + completed) + echo "${status_response}" | jq . + break + ;; + failed) + echo "Video generation failed:" + echo "${status_response}" | jq . + exit 1 + ;; + *) + echo "Unexpected status response:" + echo "${status_response}" | jq . + exit 1 + ;; + esac +done + +curl -sS -L "${BASE_URL}/v1/videos/${video_id}/content" -o "${OUTPUT_PATH}" +echo "Saved video to ${OUTPUT_PATH}" diff --git a/examples/online_serving/lingbot-world/run_server.sh b/examples/online_serving/lingbot-world/run_server.sh new file mode 100755 index 00000000000..b59fdc193c7 --- /dev/null +++ b/examples/online_serving/lingbot-world/run_server.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Lingbot-World online serving startup script + +MODEL="${MODEL:-./lingbot-world-base-cam}" +PORT="${PORT:-8099}" + +echo "Starting Lingbot-World server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" diff --git a/vllm_omni/diffusion/models/lingbot_world/__init__.py b/vllm_omni/diffusion/models/lingbot_world/__init__.py new file mode 100644 index 00000000000..f42fa53137d --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .pipeline_lingbot_world import LingbotWorldPipeline + +__all__ = [ + "LingbotWorldPipeline", +] diff --git a/vllm_omni/diffusion/models/lingbot_world/pipeline_lingbot_world.py b/vllm_omni/diffusion/models/lingbot_world/pipeline_lingbot_world.py new file mode 100644 index 00000000000..8eb024b38d8 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world/pipeline_lingbot_world.py @@ -0,0 +1,471 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import os +import random +from dataclasses import dataclass +from typing import cast + +import numpy as np +import PIL.Image +import torch +import torchvision.transforms.functional as TF +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from torch import nn + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.platforms import current_omni_platform + +try: + from wan.modules.t5 import T5EncoderModel + from wan.modules.vae2_1 import Wan2_1_VAE + from wan.utils.cam_utils import ( + compute_relative_poses, + get_Ks_transformed, + get_plucker_embeddings, + interpolate_camera_poses, + ) + from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +except ImportError as exc: + raise ImportError( + "Failed to import from dependency 'lingbot-world'. " + "Run examples/offline_inference/image_to_video/download_lingbot_world.py " + "or install the LingBot-World repository first." + ) from exc + +from .wan_model import WanModel + + +@dataclass(frozen=True) +class LingbotWorldConfig: + text_len: int = 512 + t5_dtype: torch.dtype = torch.bfloat16 + param_dtype: torch.dtype = torch.bfloat16 + num_train_timesteps: int = 1000 + frame_num: int = 81 + sample_steps: int = 70 + sample_shift: float = 10.0 + sample_guide_scale: tuple[float, float] = (5.0, 5.0) + boundary_ratio: float = 0.947 + sample_neg_prompt: str = ( + "画面突变,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走,镜头晃动,画面闪烁,模糊,噪点,水印,签名,文字,变形,扭曲," + "液化,不合逻辑的结构,卡顿,PPT幻灯片感,过暗,欠曝,低对比度,霓虹灯光感,过度锐化," + "3D渲染感,人物,行人,游客,身体,皮肤,肢体,面部特征,汽车,电线" + ) + t5_checkpoint: str = "models_t5_umt5-xxl-enc-bf16.pth" + t5_tokenizer: str = "google/umt5-xxl" + vae_checkpoint: str = "Wan2.1_VAE.pth" + vae_stride: tuple[int, int, int] = (4, 8, 8) + patch_size: tuple[int, int, int] = (1, 2, 2) + low_noise_checkpoint: str = "low_noise_model" + high_noise_checkpoint: str = "high_noise_model" + max_area: int = 480 * 832 + + +DEFAULT_CONFIG = LingbotWorldConfig() + + +def _resolve_model_path(model: str) -> str: + if os.path.isdir(model): + return model + try: + from huggingface_hub import snapshot_download + except ImportError as exc: + raise ImportError( + "huggingface_hub is required to download remote LingBot-World checkpoints. " + "Install it or pass a prepared local model directory." + ) from exc + return snapshot_download(repo_id=model) + + +def _infer_control_type(model_ref: str) -> str: + model_ref = model_ref.lower() + if "act" in model_ref: + return "act" + return "cam" + + +class LingbotWorldPipeline(nn.Module, SupportImageInput): + support_image_input = True + color_format = "RGB" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__() + self.od_config = od_config + self.device = torch.device(current_omni_platform.get_torch_device()) + self.config = DEFAULT_CONFIG + self.model_path = _resolve_model_path(cast(str, od_config.model)) + self.control_type = _infer_control_type(self.model_path) + self.enable_cpu_offload = bool(self.od_config.enable_cpu_offload) + + self.text_model = T5EncoderModel( + text_len=self.config.text_len, + dtype=self.config.t5_dtype, + device=self.device, + checkpoint_path=os.path.join(self.model_path, self.config.t5_checkpoint), + tokenizer_path=os.path.join(self.model_path, self.config.t5_tokenizer), + ) + self.text_encoder = self.text_model.model + self.vae_model = Wan2_1_VAE( + vae_pth=os.path.join(self.model_path, self.config.vae_checkpoint), + device=self.device, + ) + self.vae = self.vae_model.model + self.low_noise_model = ( + WanModel.from_pretrained( + self.model_path, + subfolder=self.config.low_noise_checkpoint, + torch_dtype=self.config.param_dtype, + control_type=self.control_type, + ) + .eval() + .requires_grad_(False) + ) + self.transformer = self.low_noise_model + self.high_noise_model = ( + WanModel.from_pretrained( + self.model_path, + subfolder=self.config.high_noise_checkpoint, + torch_dtype=self.config.param_dtype, + control_type=self.control_type, + ) + .eval() + .requires_grad_(False) + ) + self.transformer_2 = self.high_noise_model + + if self.enable_cpu_offload: + self.text_encoder.to("cpu") + self.transformer.to("cpu") + self.transformer_2.to("cpu") + self.vae.to("cpu") + else: + self.text_encoder.to(self.device) + self.transformer.to(self.device) + self.transformer_2.to(self.device) + self.vae.to(self.device) + + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.config.num_train_timesteps, + shift=1.0, + use_dynamic_shifting=False, + ) + self.video_processor = VideoProcessor(vae_scale_factor=8) + + def load_weights(self, weights): + pass + + def _prepare_model_for_timestep(self, timestep: torch.Tensor, boundary: float) -> nn.Module: + if timestep.item() >= boundary: + required_model = self.high_noise_model + other_model = self.low_noise_model + else: + required_model = self.low_noise_model + other_model = self.high_noise_model + + if self.enable_cpu_offload: + if next(other_model.parameters()).device.type == self.device.type: + other_model.to("cpu") + if next(required_model.parameters()).device.type == "cpu": + required_model.to(self.device) + + return required_model + + def _encode_prompt(self, prompt: str, negative_prompt: str) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + context = self.text_model([prompt], self.device) + context_null = self.text_model([negative_prompt], self.device) + return [context[0]], [context_null[0]] + + def _get_default_size(self, image: PIL.Image.Image) -> tuple[int, int]: + aspect_ratio = image.height / image.width + mod_value = 16 + height = round(np.sqrt(self.config.max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(self.config.max_area / aspect_ratio)) // mod_value * mod_value + return height, width + + def _load_control_inputs(self, action_path: str | None, num_frames: int) -> tuple[dict | None, int]: + if not action_path: + return None, num_frames + + poses_path = os.path.join(action_path, "poses.npy") + intrinsics_path = os.path.join(action_path, "intrinsics.npy") + if not os.path.exists(poses_path): + raise FileNotFoundError(f"LingBot-World control bundle is missing poses.npy: {poses_path}") + if not os.path.exists(intrinsics_path): + raise FileNotFoundError(f"LingBot-World control bundle is missing intrinsics.npy: {intrinsics_path}") + + c2ws = np.load(poses_path) + len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 + num_frames = min(num_frames, len_c2ws) + c2ws = c2ws[:num_frames] + + wasd_action = None + if self.control_type == "act": + action_npy = os.path.join(action_path, "action.npy") + if not os.path.exists(action_npy): + raise FileNotFoundError(f"LingBot-World act control bundle is missing action.npy: {action_npy}") + wasd_action = np.load(action_npy)[:num_frames] + + return {"path": action_path, "c2ws": c2ws, "wasd_action": wasd_action}, num_frames + + def _build_dit_cond_dict( + self, + control_inputs: dict | None, + *, + lat_f: int, + lat_h: int, + lat_w: int, + resized_height: int, + resized_width: int, + ) -> dict | None: + if control_inputs is None: + return None + + action_path = cast(str, control_inputs["path"]) + c2ws = cast(np.ndarray, control_inputs["c2ws"]) + wasd_action = cast(np.ndarray | None, control_inputs["wasd_action"]) + + Ks = torch.from_numpy(np.load(os.path.join(action_path, "intrinsics.npy"))).float() + Ks = get_Ks_transformed( + Ks, + height_org=480, + width_org=832, + height_resize=resized_height, + width_resize=resized_width, + height_final=resized_height, + width_final=resized_width, + ) + Ks = Ks[0] + + len_c2ws = len(c2ws) + c2ws_infer = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws - 1, len_c2ws), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=np.linspace(0, len_c2ws - 1, int((len_c2ws - 1) // 4) + 1), + ) + c2ws_infer = compute_relative_poses(c2ws_infer, framewise=True) + Ks = Ks.repeat(len(c2ws_infer), 1) + + c2ws_infer = c2ws_infer.to(self.device) + Ks = Ks.to(self.device) + wasd_action_tensor = None + if wasd_action is not None: + wasd_action_tensor = torch.from_numpy(wasd_action[::4]).float().to(self.device) + + only_rays_d = wasd_action_tensor is not None + c2ws_plucker_emb = get_plucker_embeddings( + c2ws_infer, + Ks, + resized_height, + resized_width, + only_rays_d=only_rays_d, + ) + c2ws_plucker_emb = rearrange( + c2ws_plucker_emb, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=int(resized_height // lat_h), + c2=int(resized_width // lat_w), + ) + c2ws_plucker_emb = c2ws_plucker_emb[None, ...] + c2ws_plucker_emb = rearrange( + c2ws_plucker_emb, + "b (f h w) c -> b c f h w", + f=lat_f, + h=lat_h, + w=lat_w, + ).to(self.config.param_dtype) + + if wasd_action_tensor is not None: + wasd_action_tensor = wasd_action_tensor[:, None, None, :].repeat(1, resized_height, resized_width, 1) + wasd_action_tensor = rearrange( + wasd_action_tensor, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=int(resized_height // lat_h), + c2=int(resized_width // lat_w), + ) + wasd_action_tensor = wasd_action_tensor[None, ...] + wasd_action_tensor = rearrange( + wasd_action_tensor, + "b (f h w) c -> b c f h w", + f=lat_f, + h=lat_h, + w=lat_w, + ).to(self.config.param_dtype) + c2ws_plucker_emb = torch.cat([c2ws_plucker_emb, wasd_action_tensor], dim=1) + + return {"c2ws_plucker_emb": c2ws_plucker_emb.chunk(1, dim=0)} + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + if len(req.prompts) != 1: + raise ValueError("LingBot-World currently supports exactly one prompt per request.") + if req.sampling_params.num_outputs_per_prompt != 1: + raise ValueError("LingBot-World currently supports num_outputs_per_prompt=1 only.") + + prompt_data = req.prompts[0] + if isinstance(prompt_data, str): + prompt = prompt_data + negative_prompt = self.config.sample_neg_prompt + multi_modal_data = {} + else: + prompt = cast(str, prompt_data.get("prompt")) + negative_prompt = cast(str | None, prompt_data.get("negative_prompt")) or self.config.sample_neg_prompt + multi_modal_data = cast(dict, prompt_data.get("multi_modal_data") or {}) + + raw_image = multi_modal_data.get("image") + if raw_image is None: + raise ValueError("LingBot-World requires an input image.") + if isinstance(raw_image, list): + if len(raw_image) == 0: + raise ValueError("LingBot-World received an empty image list.") + raw_image = raw_image[0] + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image).convert("RGB") + else: + image = cast(PIL.Image.Image, raw_image) + + height = req.sampling_params.height + width = req.sampling_params.width + if height is None or width is None: + default_height, default_width = self._get_default_size(image) + height = default_height if height is None else height + width = default_width if width is None else width + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + if req.sampling_params.guidance_scale_provided: + guidance_low = req.sampling_params.guidance_scale + else: + guidance_low = self.config.sample_guide_scale[0] + guidance_high = req.sampling_params.guidance_scale_2 + if guidance_high is None: + guidance_high = self.config.sample_guide_scale[1] + + num_frames = req.sampling_params.num_frames or self.config.frame_num + action_path = cast(str | None, req.sampling_params.extra_args.get("action_path")) + control_inputs, num_frames = self._load_control_inputs(action_path, num_frames) + num_steps = req.sampling_params.num_inference_steps or self.config.sample_steps + boundary_ratio = req.sampling_params.boundary_ratio + if boundary_ratio is None: + boundary_ratio = self.config.boundary_ratio + flow_shift = cast(float | None, req.sampling_params.extra_args.get("flow_shift")) + if flow_shift is None: + flow_shift = self.config.sample_shift + + img = TF.to_tensor(image).sub_(0.5).div_(0.5).to(self.device) + + aspect_ratio = img.shape[1] / img.shape[2] + lat_h = round( + np.sqrt(self.config.max_area * aspect_ratio) + // self.config.vae_stride[1] + // self.config.patch_size[1] + * self.config.patch_size[1] + ) + lat_w = round( + np.sqrt(self.config.max_area / aspect_ratio) + // self.config.vae_stride[2] + // self.config.patch_size[2] + * self.config.patch_size[2] + ) + resized_height = lat_h * self.config.vae_stride[1] + resized_width = lat_w * self.config.vae_stride[2] + + if num_frames % self.config.vae_stride[0] != 1: + num_frames = (num_frames - 1) // self.config.vae_stride[0] * self.config.vae_stride[0] + 1 + lat_f = (num_frames - 1) // self.config.vae_stride[0] + 1 + max_seq_len = lat_f * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) + + generator = req.sampling_params.generator + if generator is None: + seed = req.sampling_params.seed if req.sampling_params.seed is not None else random.randint(0, 2**31 - 1) + generator = torch.Generator(device=self.device).manual_seed(seed) + + noise = torch.randn( + 16, + lat_f, + lat_h, + lat_w, + dtype=torch.float32, + generator=generator, + device=self.device, + ) + + msk = torch.ones(1, num_frames, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + context, context_null = self._encode_prompt(prompt, negative_prompt) + + vae_input = torch.concat( + [ + torch.nn.functional.interpolate( + img[None].cpu(), + size=(resized_height, resized_width), + mode="bicubic", + ).transpose(0, 1), + torch.zeros(3, num_frames - 1, resized_height, resized_width), + ], + dim=1, + ).to(self.device) + y = self.vae_model.encode([vae_input])[0] + y = torch.concat([msk, y]) + + dit_cond_dict = self._build_dit_cond_dict( + control_inputs, + lat_f=lat_f, + lat_h=lat_h, + lat_w=lat_w, + resized_height=resized_height, + resized_width=resized_width, + ) + + self.scheduler.set_timesteps(num_steps, device=self.device, shift=flow_shift) + timesteps = self.scheduler.timesteps + boundary = boundary_ratio * self.config.num_train_timesteps + + arg_c = {"context": context, "seq_len": max_seq_len, "y": [y], "dit_cond_dict": dit_cond_dict} + arg_null = {"context": context_null, "seq_len": max_seq_len, "y": [y], "dit_cond_dict": dit_cond_dict} + + autocast_enabled = not self.od_config.disable_autocast and self.device.type == "cuda" + autocast_context = ( + torch.amp.autocast(self.device.type, dtype=self.config.param_dtype) + if autocast_enabled + else torch.autocast("cpu", enabled=False) + ) + + latent = noise + with autocast_context: + with torch.no_grad(): + for timestep in timesteps: + latent_model_input = [latent.to(self.device)] + timestep_tensor = torch.stack([timestep]).to(self.device) + model = self._prepare_model_for_timestep(timestep, boundary) + current_guidance = guidance_high if timestep.item() >= boundary else guidance_low + + noise_pred_cond = model(latent_model_input, t=timestep_tensor, **arg_c)[0] + noise_pred_uncond = model(latent_model_input, t=timestep_tensor, **arg_null)[0] + noise_pred = noise_pred_uncond + current_guidance * (noise_pred_cond - noise_pred_uncond) + + latent = self.scheduler.step( + noise_pred.unsqueeze(0), + timestep, + latent.unsqueeze(0), + return_dict=False, + generator=generator, + )[0].squeeze(0) + + videos = self.vae_model.decode([latent]) + output = videos[0].unsqueeze(0) if videos[0].ndim == 4 else videos[0] + output_type = req.sampling_params.output_type or "np" + if output_type != "latent": + output = self.video_processor.postprocess_video(output, output_type=output_type) + return DiffusionOutput(output=output) diff --git a/vllm_omni/diffusion/models/lingbot_world/wan_model.py b/vllm_omni/diffusion/models/lingbot_world/wan_model.py new file mode 100644 index 00000000000..abd056b0fe4 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world/wan_model.py @@ -0,0 +1,576 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as torch_F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange + +from vllm_omni.diffusion.attention.layer import Attention + +__all__ = ["WanModel"] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@torch.amp.autocast("cuda", enabled=False) +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def qkv_fn(self, x): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + def forward(self, x, seq_lens, grid_sizes, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + q, k, v = self.qkv_fn(x) + x = self.attn( + rope_apply(q, grid_sizes, freqs), + rope_apply(k, grid_sizes, freqs), + v, + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanCrossAttention(WanSelfAttention): + def qkv_fn(self, x, context): + b, n, d = x.size(0), self.num_heads, self.head_dim + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + return q, k, v + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + q, k, v = self.qkv_fn(x, context) + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanAttentionBlock(nn.Module): + def __init__(self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.cam_injector_layer1 = nn.Linear(dim, dim) + self.cam_injector_layer2 = nn.Linear(dim, dim) + self.cam_scale_layer = nn.Linear(dim, dim) + self.cam_shift_layer = nn.Linear(dim, dim) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + dit_cond_dict=None, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, L1, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn(self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].squeeze(2) + + # cam injection (only if dit_cond_dict is provided and contains c2ws_plucker_emb) + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_hidden_states = self.cam_injector_layer2(torch_F.silu(self.cam_injector_layer1(c2ws_plucker_emb))) + c2ws_hidden_states = c2ws_hidden_states + c2ws_plucker_emb + cam_scale = self.cam_scale_layer(c2ws_hidden_states) + cam_shift = self.cam_shift_layer(c2ws_hidden_states) + x = (1.0 + cam_scale) * x + cam_shift + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].squeeze(2) + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)) + return x + + +class WanModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + model_type="t2v", + control_type="cam", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v", "ti2v", "s2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + if control_type == "cam": + control_dim = 6 + elif control_type == "act": + control_dim = 7 + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.patch_embedding_wancamctrl = nn.Linear( + control_dim * 64 * patch_size[0] * patch_size[1] * patch_size[2], dim + ) + self.c2ws_hidden_states_layer1 = nn.Linear(dim, dim) + self.c2ws_hidden_states_layer2 = nn.Linear(dim, dim) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList( + [ + WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) + + # initialize weights + self.init_weights() + + def forward( + self, + x, + t, + context, + seq_len, + y=None, + dit_cond_dict=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == "i2v": + assert y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) + + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_len) + with torch.amp.autocast("cuda", dtype=torch.float32): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) + ) + + # cam + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_plucker_emb = [ + rearrange( + i, + "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", + c1=self.patch_size[0], + c2=self.patch_size[1], + c3=self.patch_size[2], + ) + for i in c2ws_plucker_emb + ] + c2ws_plucker_emb = torch.cat(c2ws_plucker_emb, dim=1) # [1, (L1+...+Ln), C] + c2ws_plucker_emb = self.patch_embedding_wancamctrl(c2ws_plucker_emb) + c2ws_hidden_states = self.c2ws_hidden_states_layer2( + torch_F.silu(self.c2ws_hidden_states_layer1(c2ws_plucker_emb)) + ) + dit_cond_dict = dict(dit_cond_dict) + dit_cond_dict["c2ws_plucker_emb"] = c2ws_plucker_emb + c2ws_hidden_states + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dit_cond_dict=dit_cond_dict, + ) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) + + # init cam control layers + nn.init.xavier_uniform_(self.patch_embedding_wancamctrl.weight) + nn.init.zeros_(self.patch_embedding_wancamctrl.bias) + nn.init.xavier_uniform_(self.c2ws_hidden_states_layer1.weight) + nn.init.zeros_(self.c2ws_hidden_states_layer1.bias) + nn.init.xavier_uniform_(self.c2ws_hidden_states_layer2.weight) + nn.init.zeros_(self.c2ws_hidden_states_layer2.bias) + + # init cam injector layers in blocks + for block in self.blocks: + nn.init.xavier_uniform_(block.cam_injector_layer1.weight) + nn.init.zeros_(block.cam_injector_layer1.bias) + nn.init.xavier_uniform_(block.cam_injector_layer2.weight) + nn.init.zeros_(block.cam_injector_layer2.bias) + nn.init.xavier_uniform_(block.cam_scale_layer.weight) + nn.init.zeros_(block.cam_scale_layer.bias) + nn.init.xavier_uniform_(block.cam_shift_layer.weight) + nn.init.zeros_(block.cam_shift_layer.bias) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 77e7885c01d..ee9474e9335 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -142,6 +142,11 @@ "pipeline_dreamid_omni", "DreamIDOmniPipeline", ), + "LingbotWorldPipeline": ( + "lingbot_world", + "pipeline_lingbot_world", + "LingbotWorldPipeline", + ), } diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 52bd6031c64..ed34a92645f 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -104,8 +104,14 @@ def __init__( od_config.model_class_name = config_dict.get("_class_name", None) od_config.update_multimodal_support() - tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + if od_config.model_class_name in ("DreamIDOmniPipeline", "LingbotWorldPipeline"): + # Custom pipelines such as DreamID-Omni and Lingbot-World + # may expose a minimal model_index.json without a + # diffusers-style transformer/ subdirectory. + od_config.tf_model_config = TransformerConfig() + else: + tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) else: raise FileNotFoundError("model_index.json not found") except (AttributeError, OSError, ValueError, FileNotFoundError):