diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index d26a579edfd..168b91d7343 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -73,5 +73,6 @@ th { |`DyninOmniForConditionalGeneration` | Dynin-Omni | `snu-aidas/Dynin-Omni` | ✅︎ | | | | | `ErnieImagePipeline` | ERNIE-Image | `baidu/ERNIE-Image`, `baidu/ERNIE-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | |`HiDreamImagePipeline` | HiDream-I1-Full | `HiDream-ai/HiDream-I1-Full` | ✅︎ | ✅︎ | | | +|`LingbotWorldFastPipeline`| Lingbot-World-Fast | `robbyant/lingbot-world-fast`|✅︎ | | | | ✅︎ indicates the model is supported on that backend. Empty cells mean not listed as supported on that backend. diff --git a/docs/user_guide/examples/offline_inference/lingbot_world_fast.md b/docs/user_guide/examples/offline_inference/lingbot_world_fast.md new file mode 100644 index 00000000000..2669f6e6ecd --- /dev/null +++ b/docs/user_guide/examples/offline_inference/lingbot_world_fast.md @@ -0,0 +1,49 @@ +# Lingbot World Fast Offline Inference + +Lingbot World Fast is an autoregressive diffusion model that uses a reference image, a text prompt and a set of camera positions to generate a video. + +## Video Generation + +First, download the model weights using `examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py`. + +The simplest way to run offline generation is to use the script on `examples/offline_inference/lingbot_world_fast/end2end.py`. The core of this script is done by: + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", model_class_name="LingbotWorldFastPipeline") + outputs = omni.generate( + { + "prompt": "A journey along the Great Wall of China", + "multi_modal_data": { + "image": "input.png", + "camera": { + "poses": np.load("path/to/poses.npy") + "intrinsics": np.load("path/to/intrinsics.npy") + } + }, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_frames=num_frames, + frame_rate=fps, + ), + ) + export_to_video(outputs[0], "output.png") +``` + +## Generation Parameters + +| Parameter | Type | Default | Description | +| --------------------- | ----- | ------- | ----------------------------------- | +| `height` | int | None (computed from image) | Image height in pixels | +| `width` | int | None (computed from image) | Image width in pixels | +| `num_frames` | int | 81 | Number of frames to generate | +| `fps` | int | 16 | Frames per second | +| `seed` | int | 42 | Optional random seed | +| `prompt` | str | "" | Text prompt | +| `negative_prompt` | str | None | Negative prompt | +| `image` | str | Required | Path to reference image | +|`camera-path` | str | Required | Path to folder with `poses.npy` and `intrinsics.npy`| diff --git a/docs/user_guide/examples/online_serving/lingbot_world_fast.md b/docs/user_guide/examples/online_serving/lingbot_world_fast.md new file mode 100644 index 00000000000..106f76a6a05 --- /dev/null +++ b/docs/user_guide/examples/online_serving/lingbot_world_fast.md @@ -0,0 +1,51 @@ +# Lingbot World Fast Offline Inference + +Lingbot World Fast is an autoregressive diffusion model that uses a reference image, a text prompt and a set of camera positions to generate a video. The online serving model of this model adds a feature that is not implemented in the original model: video extension. + +## Quickstart + +The easiest way to launch a server running the Lingbot World Fast model is by using the script `examples/online_serving/lingbot_world_fast/run_server.sh`. + +Once the server is launched, the client can send requests to its websocket at `/v1/realtime/world/camera`. The easiest way to interact with the server is using the script `examples/online_serving/lingbot_world_fast/openai_client.py`. Its command line options are described below. + +| Parameter | Type | Default | Description | +| --------------------- | ----- | ------- | ----------------------------------- | +| `height` | int | None (computed from image) | Image height in pixels | +| `width` | int | None (computed from image) | Image width in pixels | +| `num_frames` | int | 81 | Number of frames to generate | +| `fps` | int | 16 | Frames per second | +| `seed` | int | 42 | Optional random seed | +| `prompt` | str | "" | Text prompt | +| `negative_prompt` | str | None | Negative prompt | +| `image` | str | Required | Path to reference image | +|`camera-path` | str | Required | Path to folder with `poses.npy` and `intrinsics.npy`| +| `num-calls` | int | 1 | Makes an additional `num-calls - 1` video extension calls with `num_frames` frames | +| `num-skip-frames` | int | 4 | Extension calls have artifacts on the first couple frames. Discard them. | +| `session-id` | str | None | Session id to control whether to trigger a video extension call | + +## Video Extension + +The idea of video extension is to allow the user to generate further frames for the same video efficiently. This is done by the vllm-omni implementation by storing the KV-cache of the generated video by default. This way, if the next request uses the same session-id, the pipeline will enter extension mode. So, the newly generated frames will use the previously generated frames as context. This is done by storing the KV-cache as mentioned above. No frame information, whether in latent space or RGB values, is kept in the server. + +This feature is limited by the fact that the model has not been trained to perform this task. So, the steering capacity of the user is limited. Namely, the reference image and changes to the text prompt are ignored. The best tool the user has is to provide camera positions. In the end, video extension is more of a demonstration of the power and features of VLLM-Omni than of Lingbot World in itself. + +## API + +The server uses a websocket endpoint located at `/v1/realtime/world/camera`. It makes available two tasks: `infer` and `reset` which can be controlled by the "endpoint" key of the request. + +By default, the server uses the `infer` task, which checks the `session-id` field and compares it to the one used on the last infer call. If they are the same, it triggers an extension call at the pipeline level. Note that only the KV-cache of the last request is stored to mitigate Out of Memory problems at the GPU level. Otherwise, it generates the video from scratch. Notice that when doing an extension task, no reference image should be provided (it would be ignored anyway). + +The `reset` endpoint does not immediately evict the KV cache in the GPU, but instead it forces a reset on the next `infer` call independently of the value of `session-id`. + +The endpoint sends the resulting frames in groups of 4 to mitigate package loss problems. It is the client's role to concatenate the different frames to obtain the final video. + +## Example materials + +??? abstract "run_server.sh" + ``````sh + --8<-- "examples/online_serving/lingbot_world_fast/run_server.sh" + `````` +??? abstract "openai_client.py" + ``````sh + --8<-- "examples/online_serving/lingbot_world_fast/openai_client.py" + `````` diff --git a/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py new file mode 100644 index 00000000000..255e1073815 --- /dev/null +++ b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py @@ -0,0 +1,79 @@ +import argparse +import json +import os +import tempfile +import time +from pathlib import Path + +from huggingface_hub import snapshot_download + +DEPENDENCY_REPO = "https://github.com/Robbyant/lingbot-world" +DEPENDENCY_BRANCH = "main" +CACHE_DIR = Path(tempfile.gettempdir()) / "vllm-omni-dependency" +LOCK_FILE = CACHE_DIR / ".install.lock" +DEPENDENCY_DIR = CACHE_DIR / "Lingbot-World" + + +def timed_download(repo_id: str, local_dir: str, allow_patterns: list | None = None): + """Download files from HF repo and log time + destination.""" + if os.path.exists(local_dir): + print(f"Directory {local_dir} already exists. Skipping download.") + return + print(f"Starting download from {repo_id} into {local_dir}") + start_time = time.time() + + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + allow_patterns=allow_patterns, + ) + + elapsed = time.time() - start_time + print(f"✅ Finished downloading {repo_id} in {elapsed:.2f} seconds. Files saved at: {local_dir}") + + +def main(output_dir: str): + lingbot_base_dir = os.path.join(output_dir, "lingbot-world-base-cam") + + # Base Model + timed_download( + repo_id="robbyant/lingbot-world-base-cam", + local_dir=lingbot_base_dir, + allow_patterns=["google/*", "models_t5_umt5-xxl-enc-bf16.pth", "Wan2.1_VAE.pth"], + ) + + lingbot_fast_dir = os.path.join(lingbot_base_dir, "Lingbot-World-Fast") + + timed_download(repo_id="robbyant/lingbot-world-fast", local_dir=lingbot_fast_dir) + + # Lingbot World does not come with config.json which is required by diffusers + config = { + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "lingbot_world_fast", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + } + + config_path = os.path.join(output_dir, "lingbot-world-base-cam", "Lingbot-World-Fast", "config.json") + + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + print(f"config.json created at {config_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download models from Hugging Face") + parser.add_argument( + "--output-dir", type=str, default="./lingbot_world", help="Base directory to save downloaded models" + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/examples/offline_inference/lingbot_world_fast/end2end.py b/examples/offline_inference/lingbot_world_fast/end2end.py new file mode 100644 index 00000000000..44d3df40c04 --- /dev/null +++ b/examples/offline_inference/lingbot_world_fast/end2end.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Image-Camera to Video generation example using Lingbot World Fast + +Usage example: + python end2end.py --model path/to/lingbot-fast --image path/to/image --camera-path path/to/camera + --width 832 --height 480 --prompt "Walk in the Great Wall of China" --output output.mp4 +""" + +import argparse +import os +import time +from pathlib import Path + +import numpy as np +import PIL.Image +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate a video from an image (Wan2.2, LTX2, HunyuanVideo-1.5).") + parser.add_argument( + "--model", + default="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", + help="Diffusers I2V model ID or local path (Wan2.2 or HunyuanVideo-1.5).", + ) + parser.add_argument("--image", required=True, help="Path to input image.") + parser.add_argument("--camera-path", default=None, help="Path to input camera positions") + parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.") + parser.add_argument("--negative-prompt", default="", help="Negative prompt.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument( + "--height", type=int, default=None, help="Video height (auto-calculated from image if not set)." + ) + parser.add_argument("--width", type=int, default=None, help="Video width (auto-calculated from image if not set).") + parser.add_argument("--num-frames", type=int, default=81, help="Number of frames.") + parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") + parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") + parser.add_argument( + "--enable-diffusion-pipeline-profiler", + action="store_true", + help="Enable diffusion pipeline profiler to display stage durations.", + ) + return parser.parse_args() + + +def calculate_dimensions( + image: PIL.Image.Image, + max_area: int = 480 * 832, + mod_value: int = 16, +) -> tuple[int, int]: + """Calculate output dimensions maintaining aspect ratio.""" + aspect_ratio = image.height / image.width + + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + return height, width + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + model_class_name = "LingbotWorldFastPipeline" + + # Load input image + image = PIL.Image.open(args.image).convert("RGB") + + num_inference_steps = 40 + + # Calculate dimensions if not provided + height = args.height + width = args.width + if height is None or width is None: + max_area = 480 * 832 + mod_value = 16 + calc_height, calc_width = calculate_dimensions(image, max_area=max_area, mod_value=mod_value) + height = height or calc_height + width = width or calc_width + + # Resize image to target dimensions + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + + omni = Omni( + model=args.model, + parallel_config=None, + model_class_name=model_class_name, + stage_init_timeout=6000, + init_timeout=6000, + ) + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {num_inference_steps}") + print(f" Frames: {args.num_frames}") + print(f" Video size: {width}x{height}") + print(f"{'=' * 60}\n") + + # omni.generate() returns Generator[OmniRequestOutput, None, None] + + multi_modal_data = {"image": image} + + if args.camera_path is not None: + poses = np.load(os.path.join(args.camera_path, "poses.npy")) + intrinsics = np.load(os.path.join(args.camera_path, "intrinsics.npy")) + + multi_modal_data["camera"] = {"poses": poses, "intrinsics": intrinsics} + + generation_start = time.perf_counter() + frames = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": multi_modal_data, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + num_frames=args.num_frames, + frame_rate=args.fps, + extra_args={"session_id": "offline_generation"}, + ), + ) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + if isinstance(frames, list): + frames = frames[0] if frames else None + + if isinstance(frames, OmniRequestOutput): + if frames.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation." + ) + if frames.is_pipeline_output and frames.request_output is not None: + inner_output = frames.request_output + if isinstance(inner_output, OmniRequestOutput): + frames = inner_output + if isinstance(frames, OmniRequestOutput): + if frames.images: + if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: + frames = frames.images[0] + elif len(frames.images) == 1 and isinstance(frames.images[0], dict): + frames = frames.images[0].get("frames") or frames.images[0].get("video") + else: + frames = frames.images + else: + raise ValueError("No video frames found in OmniRequestOutput.") + + if isinstance(frames, list) and frames: + first_item = frames[0] + if isinstance(first_item, tuple) and len(first_item) == 2: + frames = first_item + elif isinstance(first_item, dict): + frames = first_item.get("frames") or first_item.get("video") + elif isinstance(first_item, list): + frames = first_item + + if isinstance(frames, tuple) and len(frames) == 2: + frames = frames + elif isinstance(frames, dict): + frames = frames.get("frames") or frames.get("video") + + if frames is None: + raise ValueError("No video frames found in output.") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + try: + from diffusers.utils import export_to_video + except ImportError: + raise ImportError("diffusers is required for export_to_video.") + + def _normalize_frame(frame): + if isinstance(frame, torch.Tensor): + frame_tensor = frame.detach().cpu() + if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor[0] + if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): + frame_tensor = frame_tensor.permute(1, 2, 0) + if frame_tensor.is_floating_point(): + frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 + return frame_tensor.float().numpy() + if isinstance(frame, np.ndarray): + frame_array = frame + if frame_array.ndim == 4 and frame_array.shape[0] == 1: + frame_array = frame_array[0] + if np.issubdtype(frame_array.dtype, np.integer): + frame_array = frame_array.astype(np.float32) / 255.0 + return frame_array + try: + from PIL import Image + except ImportError: + Image = None + if Image is not None and isinstance(frame, Image.Image): + return np.asarray(frame).astype(np.float32) / 255.0 + return frame + + def _ensure_frame_list(video_array): + if isinstance(video_array, list): + if len(video_array) == 0: + return video_array + first_item = video_array[0] + if isinstance(first_item, np.ndarray): + if first_item.ndim == 5: + return list(first_item[0]) + if first_item.ndim == 4: + if len(video_array) == 1: + return list(first_item) + return list(first_item) + if first_item.ndim == 3: + return video_array + return video_array + if isinstance(video_array, np.ndarray): + if video_array.ndim == 5: + return list(video_array[0]) + if video_array.ndim == 4: + return list(video_array) + if video_array.ndim == 3: + return [video_array] + return video_array + + # frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images + # export_to_video expects a list of frames with values in [0, 1] + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + if video_tensor.dim() == 5: + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + elif isinstance(frames, np.ndarray): + video_array = frames + if video_array.ndim == 5: + video_array = video_array[0] + if np.issubdtype(video_array.dtype, np.integer): + video_array = video_array.astype(np.float32) / 255.0 + elif isinstance(frames, list): + if len(frames) == 0: + raise ValueError("No video frames found in output.") + video_array = [_normalize_frame(frame) for frame in frames] + else: + video_array = frames + + video_array = _ensure_frame_list(video_array) + + export_to_video(video_array, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py new file mode 100644 index 00000000000..7be1489f314 --- /dev/null +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Lingbot World Fast realtime camera client. + +Talks to the WebSocket endpoint ``/v1/realtime/world/camera`` exposed by +``vllm serve --omni`` when the loaded pipeline is ``LingbotWorldFastPipeline``. + +The endpoint speaks the OpenPI policy protocol on the wire: + 1. Connect -> server sends msgpack(CameraServerConfig) + 2. Client send msgpack(request) + 3. Server send msgpack(ndarray) # generated frames + +The ``request`` payload sent here contains: + - "image": numpy array, the input image + - "prompt": str, the text prompt describing the desired motion + - "camera": {"poses": ndarray, "intrinsics": ndarray} + +Usage: + python openai_chat_client.py \\ + --image path/to/image.png \\ + --camera-path path/to/camera_dir \\ + --prompt "Walk along the Great Wall of China" \\ + --output frames.npy +""" + +import argparse +from argparse import Namespace +from pathlib import Path + +import numpy as np +import PIL.Image +import websockets.sync.client as ws_sync +from diffusers.utils import export_to_video + +try: + from openpi_client import msgpack_numpy +except ImportError as exc: + raise SystemExit("This example requires `openpi-client`. Install it with `pip install openpi-client`.") from exc + + +def _pack(obj): + return msgpack_numpy.packb(obj) + + +def _unpack(data): + return msgpack_numpy.unpackb(data) + + +def _load_image(path: str | None) -> np.ndarray | None: + image = PIL.Image.open(path).convert("RGB") + return np.asarray(image) + + +def _load_camera(camera_dir: str) -> dict: + camera_path = Path(camera_dir) + poses = np.load(camera_path / "poses.npy") + intrinsics = np.load(camera_path / "intrinsics.npy") + return {"poses": poses, "intrinsics": intrinsics} + + +def generate_video(args: Namespace) -> list[np.ndarray]: + """Send inference requests and return the generated frames.""" + image = _load_image(args.image) + full_camera = _load_camera(args.camera_path) + + extra_body = { + "height": args.height, + "width": args.width, + "num_frames": args.num_frames, + "fps": args.fps, + "session_id": args.session_id, + "frames_per_chunk": args.frames_per_chunk, + "seed": args.seed, + } + + video = [] + starting_frame = 0 + + for i in range(args.num_calls): + camera = { + "poses": full_camera["poses"][starting_frame : starting_frame + args.num_frames], + "intrinsics": full_camera["intrinsics"][starting_frame : starting_frame + args.num_frames], + } + + request: dict = {"prompt": args.prompt, "camera": camera, "extra_body": extra_body} + if i == 0: + request["image"] = image + + request["session_id"] = args.session_id + + endpoint = f"{args.server.rstrip('/')}/v1/realtime/world/camera" + print(f"Connecting to {endpoint} ...") + + with ws_sync.connect(endpoint, max_size=None, ping_interval=None, ping_timeout=None) as ws: + # 1. Server sends CameraServerConfig on connect. + server_config: dict = _unpack(ws.recv()) + print("Server Configuration:") + for key, val in server_config.items(): + print(f"\t{key}: {val}") + + # 2. Send request. + print( + f"Sending request image= ({str(image.shape) if request.get('image', None) is not None else 'None'}, " + f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." + ) + ws.send(_pack(request)) + + # 3. Receive generated frames. + chunks: list[np.ndarray] = [] + total = None + while total is None or len(chunks) < total: + msg = _unpack(ws.recv()) + if isinstance(msg, dict) and msg.get("type") == "error": + raise RuntimeError(f"Server error: {msg.get('message')}") + if not isinstance(msg, dict) or msg.get("type") != "frame": + continue # ignore anything unexpected + total = msg["total"] + chunks.append(msg["video"]) + print(f" received chunk {msg['index'] + 1}/{total}") + + clip = np.concatenate(chunks, axis=0) + # The first chunk of frames returned was used to condition the video continuation but they are not useful + if i != 0: + clip = clip[args.num_skip_frames :] + for frame in clip: + video.append(frame) + + starting_frame += args.num_frames + + return video + + +def main(): + parser = argparse.ArgumentParser(description="Lingbot World Fast realtime camera client") + parser.add_argument("--image", "-i", required=True, help="Path to input image.") + parser.add_argument( + "--camera-path", + "-c", + required=True, + help="Directory containing poses.npy and intrinsics.npy.", + ) + parser.add_argument( + "--prompt", + "-p", + default="Walk along the Great Wall of China", + help="Text prompt describing the desired motion.", + ) + parser.add_argument( + "--server", + "-s", + default="ws://localhost:8091", + help="WebSocket server URL (ws:// or wss://).", + ) + parser.add_argument("--session-id", default=None, help="Optional session id.") + parser.add_argument( + "--output", + "-o", + default="lingbot-video.mp4", + help="Path to save the returned frames (npy).", + ) + parser.add_argument("--width", type=int, default=832) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--fps", type=int, default=16) + parser.add_argument("--num-frames", type=int, default=24) + parser.add_argument("--num-calls", type=int, default=1) + parser.add_argument("--num-skip-frames", type=int, default=4) + parser.add_argument( + "--frames-per-chunk", type=int, default=4, help="How many frames are sent in each package in the response" + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + args = parser.parse_args() + + frames = generate_video(args) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + export_to_video(frames, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot_world_fast/run_server.sh b/examples/online_serving/lingbot_world_fast/run_server.sh new file mode 100755 index 00000000000..5bf638e2a1c --- /dev/null +++ b/examples/online_serving/lingbot_world_fast/run_server.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Lingbot World Fast online serving startup script + +MODEL="${MODEL:-../../offline_inference/lingbot_world_fast/lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast}" +PORT="${PORT:-8091}" + +echo "Starting Lingbot World server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" \ + --model-class-name LingbotWorldFastPipeline \ + --stage-init-timeout 6000 \ + --init-timeout 6000 \ + --ws-max-size 268435456 \ + --ws wsproto diff --git a/tests/diffusion/models/lingbot_world_fast/__init__.py b/tests/diffusion/models/lingbot_world_fast/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/models/lingbot_world_fast/conftest.py b/tests/diffusion/models/lingbot_world_fast/conftest.py new file mode 100644 index 00000000000..3ed76acb758 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/conftest.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared stubs and dummy-input helpers for Lingbot World Fast L1 tests. + +The real pipeline pulls in T5-XXL, the Wan VAE and a 5B-parameter transformer +on construction. Tests exercise only the state container, msgpack protocol and +scheduler, so these stubs replace the heavy dependencies with the smallest +implementations that match the call sites in +``vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py``. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image +from torch import nn + +if TYPE_CHECKING: + from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + LingbotWorldFastPipeline, + ) + + +class StubT5Encoder: + """Minimal stand-in for ``T5EncoderModel``. + + The pipeline calls ``self.text_encoder([prompt], device)`` and expects a + list of token-embedding tensors, one per prompt. + """ + + def __init__(self, text_len: int = 512, dim: int = 32, dtype: torch.dtype = torch.float32) -> None: + self.text_len = text_len + self.dim = dim + self.dtype = dtype + + def __call__(self, prompts: list[str], device: torch.device) -> list[torch.Tensor]: + return [torch.zeros(self.text_len, self.dim, dtype=self.dtype, device=device) for _ in prompts] + + +class StubVAE: + """Stand-in for ``Wan2_1_VAE``. + + ``encode([pixels])`` returns a list with one latent tensor shaped + ``[16, F_lat, lat_h, lat_w]`` where ``F_lat = (F + 3) // 4`` so the + pipeline's masking / slicing math is exercised normally. + ``decode([latents])`` returns the latents unchanged (caller indexes [0]). + """ + + vae_stride = (4, 8, 8) + + def encode(self, video_list: list[torch.Tensor]) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for v in video_list: + # v: [C, F, H, W] + _, f, h, w = v.shape + lat_f = (f + self.vae_stride[0] - 1) // self.vae_stride[0] + lat_h = h // self.vae_stride[1] + lat_w = w // self.vae_stride[2] + out.append(torch.zeros(16, lat_f, lat_h, lat_w, dtype=v.dtype, device=v.device)) + return out + + def decode(self, latents_list: list[torch.Tensor]) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for latents in latents_list: + # latents: [16, F_lat, lat_h, lat_w]; produce pixels at the inverse stride. + _, f_lat, lat_h, lat_w = latents.shape + f = f_lat * self.vae_stride[0] + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + out.append(torch.zeros(3, f, h, w, dtype=latents.dtype, device=latents.device)) + return out + + +class StubWanModelFast(nn.Module): + """Stand-in for ``WanModelFast``. + + Returns zeros shaped like the input latent, and bumps the local/global + index tensors so chunk-boundary arithmetic is exercised by the pipeline. + """ + + def __init__(self, *, dim: int = 16, num_heads: int = 4, num_layers: int = 2) -> None: + super().__init__() + self.config = SimpleNamespace( + dim=dim, + num_heads=num_heads, + num_layers=num_layers, + local_attn_size=-1, + ) + + def forward(self, *, x, t, **kwargs): # noqa: ARG002 — matches pipeline call site + del t, kwargs + return [torch.zeros_like(x[0])] + + @classmethod + def from_pretrained(cls, *args, **kwargs): # noqa: ARG003 + return cls() + + +def make_dummy_camera_inputs(num_frames: int) -> dict[str, np.ndarray]: + """Camera payload matching the shape the pipeline expects.""" + intrinsics = np.eye(4, dtype=np.float32) + poses = np.tile(np.eye(4, dtype=np.float32), (num_frames, 1, 1)) + return {"intrinsics": intrinsics, "poses": poses} + + +def make_dummy_image(width: int = 64, height: int = 64) -> Image.Image: + return Image.new("RGB", (width, height), color=(128, 128, 128)) + + +def make_stubbed_pipeline( + *, + device: torch.device | None = None, + dim: int = 16, + num_heads: int = 4, + num_layers: int = 2, + target_dtype: torch.dtype = torch.float32, +) -> LingbotWorldFastPipeline: + """Build a ``LingbotWorldFastPipeline`` backed by the conftest stubs. + + Skips the real ``__init__`` (which loads umt5-xxl, Wan VAE and a 5B + transformer) via ``object.__new__`` and assigns the stubs directly, + mirroring ``_make_i2v_pipeline`` in ``tests/diffusion/models/wan2_2``. + The returned pipeline is suitable for driving ``.forward(req)`` end-to-end + against ``LingbotWorldFastState`` without touching real weights. + """ + from vllm_omni.diffusion.models.lingbot_world_fast.fm_solvers_unipc import ( + FlowUniPCMultistepScheduler, + ) + from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + LingbotWorldFastPipeline, + ) + from vllm_omni.diffusion.models.lingbot_world_fast.state_lingbot_world_fast import ( + LingbotWorldFastState, + ) + + if device is None: + device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu", 0) + + parallel_config = SimpleNamespace(world_size=1) + patch_size = [1, 2, 2] + vae_stride = [4, 8, 8] + + od_config = SimpleNamespace( + model="stub/Lingbot-World-Fast", + parallel_config=parallel_config, + dtype=target_dtype, + model_config={ + "latent_frames_per_chunk": 3, + "max_area": 64 * 64, + }, + ) + + pipeline = object.__new__(LingbotWorldFastPipeline) + nn.Module.__init__(pipeline) + pipeline.od_config = od_config + pipeline.parallel_config = parallel_config + pipeline.device = device + pipeline.target_dtype = target_dtype + pipeline.control_type = "cam" + pipeline.num_train_timesteps = CONFIG["num_train_timesteps"] + pipeline.sp_size = parallel_config.world_size + pipeline.state = LingbotWorldFastState() + pipeline.text_encoder = StubT5Encoder(dim=dim, dtype=target_dtype) + pipeline.vae = StubVAE() + pipeline.vae_stride = vae_stride + pipeline.patch_size = patch_size + pipeline.model = StubWanModelFast(dim=dim, num_heads=num_heads, num_layers=num_layers).to(device) + pipeline.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=CONFIG["num_train_timesteps"], + shift=1, + use_dynamic_shifting=False, + ) + return pipeline diff --git a/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py b/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py new file mode 100644 index 00000000000..1985c8a7b4a --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 protocol-validation tests for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +import asyncio +import contextlib +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +from tests.diffusion.models.lingbot_world_fast.conftest import make_dummy_camera_inputs +from vllm_omni.entrypoints.openai.realtime.world.camera_connection import ( + DEFAULT_FRAMES_PER_CHUNK, + WorldCameraRealtimeConnection, +) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ( + CameraServerConfig, + ServingRealtimeWorldCamera, +) + +# The endpoint's wire codec is provided by the optional openpi-client dep. +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Mock infrastructure +# --------------------------------------------------------------------------- + + +class MockWebSocket: + def __init__(self, incoming: Iterable[dict[str, Any]] | None = None) -> None: + self._incoming: list[dict[str, Any]] = list(incoming or []) + self._idx = 0 + self.sent_bytes: list[bytes] = [] + self.sent_text: list[str] = [] + self.accepted = False + self.closed = False + + async def accept(self) -> None: + self.accepted = True + + async def receive(self) -> dict[str, Any]: + if self._idx >= len(self._incoming): + return {"type": "websocket.disconnect"} + msg = self._incoming[self._idx] + self._idx += 1 + return msg + + async def send_bytes(self, data: bytes) -> None: + self.sent_bytes.append(data) + + async def send_text(self, data: str) -> None: + self.sent_text.append(data) + + async def close(self) -> None: + self.closed = True + + +def _bytes_frame(payload: Any) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": msgpack_numpy.packb(payload)} + + +def _raw_bytes_frame(data: bytes) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": data} + + +class _AsyncIter: + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class FakeResult: + def __init__(self, frames: np.ndarray) -> None: + self.images = [frames] + + +class FakeEngineClient: + """Stand-in for ``AsyncOmni`` engine client. + + Captures the ``generate(...)`` arguments and yields a single fake result + so the connection's framing logic can be exercised without a real engine. + """ + + def __init__(self, frames: np.ndarray | None = None) -> None: + if frames is None: + # Default: CHUNK_FRAMES*(1+1/2) RGB frames so we exercise the chunk split. + frames = np.zeros((DEFAULT_FRAMES_PER_CHUNK * 3 // 2, 16, 16, 3), dtype=np.uint8) + self._frames = frames + self.calls: list[dict[str, Any]] = [] + self.fail_with: Exception | None = None + # Attributes consulted by ``CameraServerConfig.from_model_config``. + self.model_config = {"pipeline": "lingbot_world_fast", "resolution": [480, 832], "fps": 16} + + def generate(self, *, prompt, request_id, sampling_params_list): + self.calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + if self.fail_with is not None: + raise self.fail_with + return _AsyncIter([FakeResult(self._frames)]) + + +def _make_serving(engine_client: FakeEngineClient | None = None) -> ServingRealtimeWorldCamera: + return ServingRealtimeWorldCamera(engine_client=engine_client or FakeEngineClient(), model_name="lingbot") + + +# --------------------------------------------------------------------------- +# CameraServerConfig +# --------------------------------------------------------------------------- + + +def test_camera_server_config_round_trip_through_dict() -> None: + cfg = CameraServerConfig.from_model_config({"pipeline": "lingbot", "fps": 16}) + out = cfg.to_dict() + assert isinstance(out, dict) + assert out["pipeline"] == "lingbot" + assert out["fps"] == 16 + + +# --------------------------------------------------------------------------- +# msgpack-numpy round-trip +# --------------------------------------------------------------------------- + + +def test_msgpack_camera_payload_round_trip() -> None: + camera = make_dummy_camera_inputs(num_frames=8) + payload = { + "image": np.random.randint(0, 255, size=(8, 8, 3), dtype=np.uint8), + "prompt": "walk forward", + "camera": camera, + "session_id": "sess-1", + "extra_body": {"height": 240, "width": 416, "num_frames": 25, "fps": 16}, + } + packed = msgpack_numpy.packb(payload) + decoded = msgpack_numpy.unpackb(packed) + + assert decoded["prompt"] == "walk forward" + assert decoded["session_id"] == "sess-1" + assert decoded["extra_body"] == payload["extra_body"] + + image_out = decoded["image"] + assert isinstance(image_out, np.ndarray) + assert image_out.shape == payload["image"].shape + assert image_out.dtype == payload["image"].dtype + np.testing.assert_array_equal(image_out, payload["image"]) + + for key in ("intrinsics", "poses"): + arr_in = camera[key] + arr_out = decoded["camera"][key] + assert arr_out.dtype == arr_in.dtype + assert arr_out.shape == arr_in.shape + np.testing.assert_array_equal(arr_out, arr_in) + + +# --------------------------------------------------------------------------- +# Connection-level framing +# --------------------------------------------------------------------------- + + +def test_handshake_sends_camera_server_config_on_connect() -> None: + serving = _make_serving() + ws = MockWebSocket(incoming=[]) # client disconnects immediately after handshake + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert ws.accepted is True + assert len(ws.sent_bytes) == 1 + handshake = msgpack_numpy.unpackb(ws.sent_bytes[0]) + assert isinstance(handshake, dict) + assert handshake["pipeline"] == "lingbot_world_fast" + + +def test_invalid_msgpack_returns_error_frame_and_keeps_connection_open() -> None: + serving = _make_serving() + ws = MockWebSocket( + incoming=[ + _raw_bytes_frame(b"\x99not-msgpack"), # malformed + _bytes_frame({"endpoint": "reset"}), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # First sent message is the handshake, next is the error frame, then the + # "reset successful" text reply — proving the connection stayed open. + assert len(ws.sent_bytes) >= 2 + error = msgpack_numpy.unpackb(ws.sent_bytes[1]) + assert error == {"type": "error", "message": "Invalid request payload"} + assert ws.sent_text == ["reset successful"] + + +def test_non_dict_payload_is_rejected_with_error_frame() -> None: + serving = _make_serving() + ws = MockWebSocket(incoming=[_bytes_frame([1, 2, 3])]) # list, not dict + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert len(ws.sent_bytes) >= 2 + error = msgpack_numpy.unpackb(ws.sent_bytes[1]) + assert error["type"] == "error" + + +def test_reset_endpoint_clears_session_and_returns_text_ack() -> None: + engine_client = FakeEngineClient() + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + # Pre-populate as if a prior session were active. + serving._current_session_id = "session-a" + + ws = MockWebSocket(incoming=[_bytes_frame({"endpoint": "reset"})]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert serving._current_session_id is None + assert ws.sent_text == ["reset successful"] + + +def test_infer_frames_are_chunked() -> None: + num_frames = DEFAULT_FRAMES_PER_CHUNK * 3 // 2 + + frames = np.arange(num_frames * 4 * 4 * 3, dtype=np.uint8).reshape(num_frames, 4, 4, 3) + engine_client = FakeEngineClient(frames=frames) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=6), + "session_id": "s1", + "extra_body": {"num_frames": 6, "height": 16, "width": 16, "fps": 16}, + "image": np.zeros((16, 16, 3), dtype=np.uint8), + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # Drop the handshake; the remaining sent_bytes are frame chunks. + chunks = [msgpack_numpy.unpackb(b) for b in ws.sent_bytes[1:]] + assert [c["type"] for c in chunks] == ["frame", "frame"] + assert [c["index"] for c in chunks] == [0, 1] + assert {c["total"] for c in chunks} == {2} + + assert len(chunks[0]["video"]) == DEFAULT_FRAMES_PER_CHUNK + assert len(chunks[1]["video"]) == num_frames - DEFAULT_FRAMES_PER_CHUNK + + for chunk in chunks: + assert chunk["video"][0].shape == (4, 4, 3) + + +def test_session_id_churn_flips_current_session_id() -> None: + engine_client = FakeEngineClient() + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + base_obs = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=4), + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + obs_a = {**base_obs, "session_id": "session-a"} + obs_b = {**base_obs, "session_id": "session-b"} + + ws = MockWebSocket(incoming=[_bytes_frame(obs_a), _bytes_frame(obs_b)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert serving._current_session_id == "session-b" + assert len(engine_client.calls) == 2 + # Each engine call observes the active session id via extra_args. + seen_session_ids = [call["sampling_params_list"][0].extra_args["session_id"] for call in engine_client.calls] + assert seen_session_ids == ["session-a", "session-b"] + + +def test_engine_failure_surfaces_as_error_frame_not_close() -> None: + engine_client = FakeEngineClient() + engine_client.fail_with = RuntimeError("kaboom") + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=4), + "session_id": "s1", + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + error = msgpack_numpy.unpackb(ws.sent_bytes[-1]) + assert error == {"type": "error", "message": "Internal inference error"} + + +# --------------------------------------------------------------------------- +# Required-field validation: a missing ``camera`` propagates to the pipeline +# layer's ValueError. At the serving layer we exercise this by giving the +# fake engine a side-effect that raises like the pipeline would, then assert +# the connection responds with an error frame and keeps running. +# --------------------------------------------------------------------------- + + +def test_missing_camera_surfaces_as_error_frame() -> None: + engine_client = FakeEngineClient() + # Pipeline's actual ValueError text — useful to keep this in sync. + engine_client.fail_with = ValueError("A path to camera positions must be passed to this model") + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "session_id": "s1", + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + error = msgpack_numpy.unpackb(ws.sent_bytes[-1]) + assert error["type"] == "error" + + +# --------------------------------------------------------------------------- +# Dtype/rank guard: the wire codec preserves bit patterns, so a malformed +# camera entry (float64 vs float32, wrong rank) passes through the connection +# unchanged. The pipeline-layer assertions are exercised in the L2 offline +# test; here we just confirm the codec doesn't silently coerce. +# --------------------------------------------------------------------------- + + +def test_msgpack_does_not_silently_coerce_camera_dtypes() -> None: + payload = { + "intrinsics": np.eye(3, dtype=np.float64), # wrong dtype on purpose + "poses": np.tile(np.eye(4, dtype=np.float32), (2, 1, 1))[None], # extra leading dim + } + decoded = msgpack_numpy.unpackb(msgpack_numpy.packb(payload)) + assert decoded["intrinsics"].dtype == np.float64 + assert decoded["poses"].shape == (1, 2, 4, 4) + + +# --------------------------------------------------------------------------- +# Suppress event-loop teardown noise in some environments +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _silence_runtime_warnings(recwarn): # noqa: PT004 + yield + with contextlib.suppress(Exception): + recwarn.clear() diff --git a/tests/diffusion/models/lingbot_world_fast/test_schedule.py b/tests/diffusion/models/lingbot_world_fast/test_schedule.py new file mode 100644 index 00000000000..0adc1589c91 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_schedule.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 tests for the Lingbot World Fast scheduler.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from vllm_omni.diffusion.models.lingbot_world_fast.fm_solvers_unipc import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + LingbotWorldFastPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def _make_scheduler() -> FlowUniPCMultistepScheduler: + # Same construction as ``LingbotWorldFastPipeline.__init__``. + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=CONFIG["num_train_timesteps"], + shift=1, + use_dynamic_shifting=False, + ) + scheduler.set_timesteps(CONFIG["num_train_timesteps"], shift=CONFIG["sample_shift"]) + return scheduler + + +def test_timesteps_index_selects_exactly_four_steps() -> None: + scheduler = _make_scheduler() + selected = scheduler.timesteps[CONFIG["timesteps_index"]] + + assert selected.shape == (4,) + # Monotonically decreasing — flow matching schedulers walk t from high to low. + diffs = selected[1:] - selected[:-1] + assert torch.all(diffs < 0), f"timesteps must be strictly decreasing, got {selected.tolist()}" + + +def test_timesteps_full_schedule_length_matches_num_train_timesteps() -> None: + scheduler = _make_scheduler() + assert scheduler.num_inference_steps == CONFIG["num_train_timesteps"] + assert scheduler.timesteps.shape == (CONFIG["num_train_timesteps"],) + + +def _convert(flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor, scheduler) -> torch.Tensor: + # Bind ``_convert_flow_pred_to_x0`` as an unbound method to avoid + # constructing the full pipeline (which loads T5 / VAE). + return LingbotWorldFastPipeline._convert_flow_pred_to_x0( + None, # type: ignore[arg-type] + flow_pred=flow_pred, + xt=xt, + timestep=timestep, + scheduler=scheduler, + ) + + +def test_convert_flow_pred_to_x0_passthrough_when_pred_is_zero() -> None: + scheduler = _make_scheduler() + xt = torch.randn(1, 4, 1, 4, 4, dtype=torch.float32) + timestep = scheduler.timesteps[0] + flow_pred = torch.zeros_like(xt) + + x0 = _convert(flow_pred, xt, timestep, scheduler) + assert torch.allclose(x0, xt, atol=1e-6) + + +def test_convert_flow_pred_to_x0_recovers_x0_from_synthesized_pair() -> None: + scheduler = _make_scheduler() + timestep = scheduler.timesteps[CONFIG["timesteps_index"][1]] + + sigmas = scheduler.sigmas + timesteps = scheduler.timesteps + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].item() + + x0 = torch.randn(1, 4, 1, 4, 4, dtype=torch.float32) + noise = torch.randn_like(x0) + flow_pred = noise - x0 + xt = (1.0 - sigma_t) * x0 + sigma_t * noise + + recovered = _convert(flow_pred, xt, timestep, scheduler) + + # The function does the math in float64 internally, casts back to the + # input dtype. float32 inputs ⇒ ~1e-5 absolute tolerance is plenty. + assert torch.allclose(recovered, x0, atol=1e-4) + + +def test_timesteps_index_is_within_schedule_bounds() -> None: + """Defensive guard: an out-of-range index would silently wrap.""" + assert isinstance(CONFIG["timesteps_index"], list) + assert len(CONFIG["timesteps_index"]) == 4 + for idx in CONFIG["timesteps_index"]: + assert 0 <= idx < CONFIG["num_train_timesteps"] + + +def test_sample_shift_constant_is_positive() -> None: + """``sample_shift`` controls the timestep curve; a non-positive value + would corrupt the flow-matching trajectory.""" + assert CONFIG["sample_shift"] > 0 + # Reasonable upper bound — Wan models use shift ~5–10. + assert math.isfinite(CONFIG["sample_shift"]) + assert CONFIG["sample_shift"] <= 100 diff --git a/tests/diffusion/models/lingbot_world_fast/test_session_state.py b/tests/diffusion/models/lingbot_world_fast/test_session_state.py new file mode 100644 index 00000000000..e10298332b4 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_session_state.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 unit tests for ``LingbotWorldFastState``. + +The state container is the load-bearing structure for chunk-streamed +generation: it owns the KV cache, the cross-attention cache, the +``current_lat_f`` cursor used to derive ``current_start`` RoPE offsets, +and the session-id that decides between fresh vs extension semantics. +""" + +from __future__ import annotations + +import pytest +import torch + +from vllm_omni.diffusion.models.lingbot_world_fast.state_lingbot_world_fast import ( + LingbotWorldFastState, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +BATCH_SIZE = 1 +NUM_LAYERS = 3 +NUM_HEADS = 4 +HEAD_DIM = 8 +KV_SIZE = 16 +DTYPE = torch.float32 +DEVICE = torch.device("cpu") + + +def _fresh_state_with_caches(kv_size: int = KV_SIZE) -> LingbotWorldFastState: + state = LingbotWorldFastState() + state.create_kv_caches( + batch_size=BATCH_SIZE, + dtype=DTYPE, + device=DEVICE, + kv_size=kv_size, + num_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + head_dim=HEAD_DIM, + ) + return state + + +def test_reset_initializes_all_fields() -> None: + state = LingbotWorldFastState() + assert state.kv_cache is None + assert state.crossattn_cache is None + assert state.current_start_frame == 0 + assert state.local_end_index is None + assert state.global_end_index is None + assert state.is_initialized is False + assert state.current_lat_f == 0 + assert state.session_id is None + assert state.batch_size is None + assert state.num_layers is None + assert state.num_heads is None + assert state.head_dim is None + assert state.h is None + assert state.w is None + assert state.lat_h is None + assert state.lat_w is None + assert state.frame_seqlen is None + assert state.last_decoded_latent is None + + +def test_create_kv_caches_allocates_expected_shapes() -> None: + state = _fresh_state_with_caches() + + assert state.is_initialized is True + assert state.batch_size == BATCH_SIZE + assert state.num_layers == NUM_LAYERS + assert state.num_heads == NUM_HEADS + assert state.head_dim == HEAD_DIM + + assert state.kv_cache is not None + assert len(state.kv_cache) == NUM_LAYERS + for layer in state.kv_cache: + assert layer.shape == (2, BATCH_SIZE, KV_SIZE, NUM_HEADS, HEAD_DIM) + assert layer.dtype == DTYPE + assert torch.all(layer == 0) + + assert state.local_end_index is not None and state.global_end_index is not None + for idx_list in (state.local_end_index, state.global_end_index): + assert len(idx_list) == NUM_LAYERS + for idx in idx_list: + assert idx.shape == (1,) + assert idx.dtype == torch.long + assert int(idx.item()) == 0 + + assert state.crossattn_cache is not None + assert len(state.crossattn_cache) == NUM_LAYERS + for entry in state.crossattn_cache: + assert entry == {"is_init": False, "k": None, "v": None} + + +def test_extend_kv_caches_grows_tensor_and_zeros_new_slots() -> None: + state = _fresh_state_with_caches() + extra = 7 + # Mark the existing slots so we can confirm they aren't disturbed. + for layer in state.kv_cache: + layer.fill_(1.0) + + state.extend_kv_caches(extra_kv_size=extra) + + for layer in state.kv_cache: + assert layer.shape == (2, BATCH_SIZE, KV_SIZE + extra, NUM_HEADS, HEAD_DIM) + assert torch.all(layer[:, :, :KV_SIZE] == 1.0) + # Newly grown trailing slice is fresh zeros. + assert torch.all(layer[:, :, KV_SIZE:] == 0.0) + + +def test_extend_kv_caches_requires_initialization() -> None: + state = LingbotWorldFastState() + with pytest.raises(AssertionError): + state.extend_kv_caches(extra_kv_size=4) + + +def test_get_accessors_require_initialization() -> None: + state = LingbotWorldFastState() + with pytest.raises(AssertionError): + state.get_kv_caches() + with pytest.raises(AssertionError): + state.get_crossattn_caches() + + +def test_get_kv_caches_returns_underlying_list() -> None: + state = _fresh_state_with_caches() + assert state.get_kv_caches() is state.kv_cache + + +def test_advance_moves_cursor_by_delta() -> None: + state = LingbotWorldFastState() + state.advance(3) + assert state.current_lat_f == 3 + state.advance(5) + assert state.current_lat_f == 8 + + +def test_reset_clears_all_session_state() -> None: + state = _fresh_state_with_caches() + state.session_id = "abc" + state.advance(4) + state.h, state.w, state.lat_h, state.lat_w, state.frame_seqlen = 480, 832, 60, 104, 1560 + state.last_decoded_latent = torch.zeros(16, 2, 60, 104) + + state.reset() + + assert state.kv_cache is None + assert state.crossattn_cache is None + assert state.local_end_index is None + assert state.global_end_index is None + assert state.is_initialized is False + assert state.current_lat_f == 0 + assert state.current_start_frame == 0 + assert state.session_id is None + assert state.h is None and state.w is None + assert state.lat_h is None and state.lat_w is None + assert state.frame_seqlen is None + assert state.last_decoded_latent is None + assert state.batch_size is None + assert state.num_layers is None + + +def test_reset_is_idempotent() -> None: + state = LingbotWorldFastState() + state.reset() + state.reset() + assert state.is_initialized is False + assert state.current_lat_f == 0 + + +# --------------------------------------------------------------------------- +# Reset is triggered only by session-id change, not prompt change. +# +# Mirrors the conditional in ``LingbotWorldFastPipeline.forward`` (pipeline +# file, around the ``if self.state.session_id is None or +# self.state.session_id != session_id`` block). We assert the contract on +# the state container so the test does not depend on instantiating the +# heavy pipeline. +# --------------------------------------------------------------------------- + + +def _should_reset(state: LingbotWorldFastState, incoming_session_id: str) -> bool: + """Replicates the pipeline's reset trigger.""" + return state.session_id is None or state.session_id != incoming_session_id + + +def test_first_call_with_any_session_id_triggers_reset() -> None: + state = LingbotWorldFastState() + assert _should_reset(state, "session-a") is True + + +def test_same_session_id_does_not_reset() -> None: + state = _fresh_state_with_caches() + state.session_id = "session-a" + state.advance(4) + + assert _should_reset(state, "session-a") is False + # ... and a prompt-only change must not trigger a reset either. + assert _should_reset(state, "session-a") is False + # Pipeline would proceed in extension mode → state still alive. + assert state.is_initialized is True + assert state.current_lat_f == 4 + + +def test_different_session_id_triggers_reset() -> None: + state = _fresh_state_with_caches() + state.session_id = "session-a" + state.advance(4) + + assert _should_reset(state, "session-b") is True + + state.reset() + assert state.session_id is None + assert state.current_lat_f == 0 + assert state.kv_cache is None diff --git a/tests/e2e/offline_inference/test_lingbot_world_fast.py b/tests/e2e/offline_inference/test_lingbot_world_fast.py new file mode 100644 index 00000000000..f82f3bae03b --- /dev/null +++ b/tests/e2e/offline_inference/test_lingbot_world_fast.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L2 offline smoke for the Lingbot World Fast pipeline.""" + +from __future__ import annotations + +import pytest +import torch + +from tests.diffusion.models.lingbot_world_fast.conftest import ( + make_dummy_camera_inputs, + make_dummy_image, + make_stubbed_pipeline, +) +from tests.helpers.mark import hardware_test +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + get_lingbot_world_fast_post_process_func, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +# ``torch.amp.autocast("cuda", …)`` inside the pipeline requires CUDA at import +# time on hosts where PyTorch is compiled without CUDA support. +if not torch.cuda.is_available(): + pytest.skip( + 'Lingbot World Fast pipeline requires CUDA (torch.amp.autocast("cuda", …))', + allow_module_level=True, + ) + +# Default ``num_frames`` argument; the pipeline floors to ``25`` internally on +# a fresh call (smallest length that maps to a non-empty latent). +_FRESH_NUM_FRAMES = 25 +_EXTENSION_NUM_FRAMES = 24 + +_DIM = 16 +_NUM_HEADS = 4 +_NUM_LAYERS = 2 +_HEAD_DIM = _DIM // _NUM_HEADS + + +def _build_request( + *, + image, + camera, + session_id: str, + num_frames: int, + prompt: str = "walk forward", +) -> OmniDiffusionRequest: + multi_modal_data: dict = {"camera": camera} + if image is not None: + multi_modal_data["image"] = image + return OmniDiffusionRequest( + prompts=[{"prompt": prompt, "multi_modal_data": multi_modal_data}], + sampling_params=OmniDiffusionSamplingParams( + height=None, + width=None, + num_frames=num_frames, + seed=42, + extra_args={"session_id": session_id}, + ), + request_ids=[f"req-{session_id}"], + ) + + +@pytest.fixture +def stubbed_pipeline(monkeypatch): + """Build a stub-backed pipeline and shrink CONFIG['max_area'] for speed.""" + pipeline = make_stubbed_pipeline( + dim=_DIM, + num_heads=_NUM_HEADS, + num_layers=_NUM_LAYERS, + target_dtype=torch.float32, + ) + yield pipeline + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_session_lifecycle_fresh_then_extension(stubbed_pipeline) -> None: + """Drive a fresh + extension pair through the pipeline and assert that + ``LingbotWorldFastState`` advances exactly as the chunk arithmetic prescribes.""" + pipeline = stubbed_pipeline + session_id = "session-l2-offline" + + # --- Fresh call --------------------------------------------------------- + camera_fresh = make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES) + image = make_dummy_image() + req_fresh = _build_request( + image=image, + camera=camera_fresh, + session_id=session_id, + num_frames=_FRESH_NUM_FRAMES, + ) + + out_fresh = pipeline.forward(req_fresh) + + assert isinstance(out_fresh, DiffusionOutput) + assert out_fresh.output is not None + assert torch.isfinite(out_fresh.output).all(), "Fresh-call video contains NaN/Inf." + + state = pipeline.state + assert state.is_initialized is True + assert state.session_id == session_id + assert state.current_lat_f > 0 + assert state.kv_cache is not None + assert state.crossattn_cache is not None + assert state.last_decoded_latent is not None + + fresh_lat_f = state.current_lat_f + fresh_kv_size = state.kv_cache[0].shape[2] + frame_seqlen = state.frame_seqlen + assert frame_seqlen == state.lat_h * state.lat_w // 4 + assert fresh_kv_size == frame_seqlen * fresh_lat_f + + # The spatial dims must come from the input image on the fresh call so the + # extension branch can later reuse them — make sure they were captured. + assert state.h is not None and state.w is not None + assert state.lat_h is not None and state.lat_w is not None + + # --- Extension call ----------------------------------------------------- + camera_ext = make_dummy_camera_inputs(num_frames=_EXTENSION_NUM_FRAMES) + req_ext = _build_request( + image=None, + camera=camera_ext, + session_id=session_id, + num_frames=_EXTENSION_NUM_FRAMES, + ) + + out_ext = pipeline.forward(req_ext) + + assert isinstance(out_ext, DiffusionOutput) + assert out_ext.output is not None + assert torch.isfinite(out_ext.output).all(), "Extension-call video contains NaN/Inf." + + assert state.session_id == session_id, "Same session_id must not trigger a reset." + assert state.is_initialized is True + assert state.current_lat_f > fresh_lat_f, "current_lat_f must advance on extension." + ext_lat_f = state.current_lat_f - fresh_lat_f + assert ext_lat_f > 0 + + # ``extend_kv_caches`` allocates a fresh tensor of size old + frame_seqlen * + # new_lat_f and concatenates; assert the trailing slice grew by exactly + # ``frame_seqlen * ext_lat_f`` for every layer. + for layer_idx, layer in enumerate(state.kv_cache): + assert layer.shape == ( + 2, + 1, + fresh_kv_size + frame_seqlen * ext_lat_f, + _NUM_HEADS, + _HEAD_DIM, + ), f"layer {layer_idx} KV cache did not grow by exactly frame_seqlen * ext_lat_f" + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_different_session_id_resets_state(stubbed_pipeline) -> None: + """A new ``session_id`` must hard-reset the cached spatial dims and KV cache""" + pipeline = stubbed_pipeline + image = make_dummy_image() + + req_a = _build_request( + image=image, + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-a", + num_frames=_FRESH_NUM_FRAMES, + ) + pipeline.forward(req_a) + + assert pipeline.state.session_id == "session-a" + lat_f_after_a = pipeline.state.current_lat_f + assert lat_f_after_a > 0 + + # A different session_id on the next call must drop the prior KV cache — + # ``current_lat_f`` resets to ``new_lat_f`` of the second call, not to + # ``lat_f_after_a + new_lat_f``. + req_b = _build_request( + image=make_dummy_image(), + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-b", + num_frames=_FRESH_NUM_FRAMES, + ) + out_b = pipeline.forward(req_b) + assert torch.isfinite(out_b.output).all() + + assert pipeline.state.session_id == "session-b" + assert pipeline.state.current_lat_f == lat_f_after_a, ( + "Stub fresh-call advances by the same fresh new_lat_f, so the reset must " + "have wiped the prior cumulative count rather than added to it." + ) + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_post_process_shapes_videos_for_external_output(stubbed_pipeline) -> None: + """The model-specific post-process flips ``[C, F, H, W]`` to ``[F, H, W, C]``; + that's what diffusion engine + serving code downstream expects.""" + pipeline = stubbed_pipeline + req = _build_request( + image=make_dummy_image(), + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-postprocess", + num_frames=_FRESH_NUM_FRAMES, + ) + + out = pipeline.forward(req) + post = get_lingbot_world_fast_post_process_func(pipeline.od_config) + framed = post(out.output) + + # [C, F, H, W] → [F, H, W, C] + assert framed.ndim == 4 + assert framed.shape[-1] == out.output.shape[0] + assert framed.shape[0] == out.output.shape[1] diff --git a/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py b/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py new file mode 100644 index 00000000000..98a932f5680 --- /dev/null +++ b/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L3 real-weight offline expansion for Lingbot World Fast.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import PIL.Image +import pytest +import torch + +from tests.helpers.lingbot_world_fast import ( + FPS, + GREAT_WALL_PROMPT, + HEIGHT, + LONG_NUM_FRAMES, + SEED, + SHORT_NUM_FRAMES, + SSIM_THRESHOLD, + WIDTH, + find_lingbot_world_fast_assets, + frame_ssim, + golden_frames_dir, + load_camera_trajectory, + normalize_to_uint8_rgb, +) +from tests.helpers.mark import hardware_test +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +pytestmark = [ + pytest.mark.advanced_model, + pytest.mark.core_model, + pytest.mark.diffusion, +] + + +def _extract_frames_from_output(output: Any) -> np.ndarray: + """Pull a ``[N, H, W, 3]`` numpy array out of an ``OmniRequestOutput``.""" + if isinstance(output, list) and output: + output = output[0] + if isinstance(output, OmniRequestOutput): + if output.is_pipeline_output and output.request_output is not None: + inner = output.request_output + if isinstance(inner, OmniRequestOutput): + output = inner + if isinstance(output, OmniRequestOutput) and output.images: + entry = output.images[0] + if isinstance(entry, tuple) and len(entry) >= 1: + output = entry[0] + elif isinstance(entry, dict): + output = entry.get("frames") or entry.get("video") + else: + output = entry + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + if not isinstance(output, np.ndarray): + raise AssertionError(f"Could not extract frames from output: {type(output)}") + return normalize_to_uint8_rgb(output) + + +@pytest.fixture(scope="module") +def lingbot_world_fast_assets(): + assets = find_lingbot_world_fast_assets() + if assets is None: + pytest.skip( + "Lingbot-World-Fast L3 assets not available. Set LINGBOT_WORLD_FAST_PATH " + "(model dir) + LINGBOT_WORLD_FAST_CAMERA_PATH (poses.npy/intrinsics.npy) " + "+ LINGBOT_WORLD_FAST_IMAGE (input image) to enable.", + ) + return assets + + +@pytest.fixture(scope="module") +def lingbot_world_fast_omni(lingbot_world_fast_assets): + omni = Omni( + model=str(lingbot_world_fast_assets.weights_path), + parallel_config=None, + model_class_name="LingbotWorldFastPipeline", + stage_init_timeout=6000, + init_timeout=6000, + ) + try: + yield omni + finally: + if hasattr(omni, "close"): + omni.close() + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +@pytest.mark.parametrize("num_frames, length", [(SHORT_NUM_FRAMES, "short"), (LONG_NUM_FRAMES, "long")]) +def test_lingbot_world_offline_video( + num_frames, + length, + lingbot_world_fast_assets, + lingbot_world_fast_omni, +): + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + poses = poses[:num_frames] + intrinsics = intrinsics[:num_frames] + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(SEED) + sampling = OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + generator=generator, + num_frames=num_frames, + frame_rate=FPS, + extra_args={"session_id": f"SESSION_ID-{length}"}, + ) + + multi_modal_data: dict = {"image": image, "camera": {"poses": poses, "intrinsics": intrinsics}} + + prompt = { + "prompt": GREAT_WALL_PROMPT, + "negative_prompt": "", + "multi_modal_data": multi_modal_data, + } + + output = lingbot_world_fast_omni.generate(prompt, sampling) + + video = _extract_frames_from_output(output) + + first_frame = video[0] + last_frame = video[-1] + + first_path = golden_frames_dir() / f"golden_frame_{length}_first.npy" + last_path = golden_frames_dir() / f"golden_frame_{length}_last.npy" + + first_golden = np.load(first_path) + last_golden = np.load(last_path) + + ssim_first = frame_ssim(first_frame, first_golden) + ssim_last = frame_ssim(last_frame, last_golden) + print( + f"[lingbot-world-fast L3] SSIM(first)={ssim_first:.4f} SSIM(last)={ssim_last:.4f} (threshold {SSIM_THRESHOLD})" + ) + assert ssim_first >= SSIM_THRESHOLD, ( + f"First-frame SSIM {ssim_first:.4f} below {SSIM_THRESHOLD}: regression in first-call path." + ) + assert ssim_last >= SSIM_THRESHOLD, ( + f"Last-frame SSIM {ssim_last:.4f} below {SSIM_THRESHOLD}: regression in last-call path." + ) diff --git a/tests/e2e/online_serving/test_lingbot_world_fast.py b/tests/e2e/online_serving/test_lingbot_world_fast.py new file mode 100644 index 00000000000..52ae08a003a --- /dev/null +++ b/tests/e2e/online_serving/test_lingbot_world_fast.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L2 online smoke for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +from vllm_omni.entrypoints.openai.realtime.world.camera_connection import ( + DEFAULT_FRAMES_PER_CHUNK, + WorldCameraRealtimeConnection, +) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera + +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Test plumbing +# --------------------------------------------------------------------------- + + +class _MockWebSocket: + """ASGI-shaped mock matching ``WorldCameraRealtimeConnection``'s call sites. + + ``receive`` is the lowest-level ASGI hook the connection uses (it pulls + ``{"type": "websocket.receive", "bytes": ...}`` dicts directly, not the + higher-level ``receive_bytes``). After the scripted messages are + exhausted, ``receive`` returns a disconnect frame so the connection's + main loop exits cleanly instead of timing out. + """ + + def __init__(self, incoming: Iterable[dict[str, Any]] | None = None) -> None: + self._incoming: list[dict[str, Any]] = list(incoming or []) + self._idx = 0 + self.sent_bytes: list[bytes] = [] + self.sent_text: list[str] = [] + self.accepted = False + self.closed = False + + async def accept(self) -> None: + self.accepted = True + + async def receive(self) -> dict[str, Any]: + if self._idx >= len(self._incoming): + return {"type": "websocket.disconnect"} + msg = self._incoming[self._idx] + self._idx += 1 + return msg + + async def send_bytes(self, data: bytes) -> None: + self.sent_bytes.append(data) + + async def send_text(self, data: str) -> None: + self.sent_text.append(data) + + async def close(self) -> None: + self.closed = True + + +class _FakeAsyncIter: + """Async iterable for the canned engine output.""" + + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class _FakeResult: + """Stand-in for ``OmniRequestOutput`` — only ``.images`` is consulted.""" + + def __init__(self, frames: np.ndarray) -> None: + self.images = [frames] + + +class _FakeEngineClient: + """Stand-in for ``AsyncOmni``: records calls, returns a per-call frame buffer. + + The connection's framing logic calls ``generate(...)`` once per ``infer`` + request. Tests pre-load ``self.queued_frames`` with one buffer per + expected call, in order. + """ + + def __init__(self, queued_frames: list[np.ndarray]) -> None: + self.queued_frames = list(queued_frames) + self.calls: list[dict[str, Any]] = [] + # Attributes consulted by ``CameraServerConfig.from_model_config``. + self.model_config = {"pipeline": "lingbot_world_fast", "resolution": [480, 832], "fps": 16} + + def generate(self, *, prompt, request_id, sampling_params_list): + self.calls.append( + { + "prompt": prompt, + "request_id": request_id, + "session_id": sampling_params_list[0].extra_args.get("session_id"), + } + ) + if not self.queued_frames: + raise AssertionError("FakeEngineClient ran out of queued frames") + frames = self.queued_frames.pop(0) + return _FakeAsyncIter([_FakeResult(frames)]) + + +def _pack_frame(payload: Any) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": msgpack_numpy.packb(payload)} + + +def _camera_payload(num_frames: int) -> dict[str, np.ndarray]: + return { + "intrinsics": np.eye(3, dtype=np.float32), + "poses": np.tile(np.eye(4, dtype=np.float32), (num_frames, 1, 1)), + } + + +def _infer_req(*, session_id: str, num_frames: int, include_image: bool) -> dict[str, Any]: + req: dict[str, Any] = { + "prompt": "walk along the Great Wall of China", + "camera": _camera_payload(num_frames), + "session_id": session_id, + "extra_body": {"num_frames": num_frames, "height": 480, "width": 832, "fps": 16}, + } + if include_image: + req["image"] = np.zeros((480, 832, 3), dtype=np.uint8) + return req + + +# --------------------------------------------------------------------------- +# Lifecycle smoke test +# --------------------------------------------------------------------------- + + +def test_camera_session_lifecycle_handshake_infer_reset_infer() -> None: + """End-to-end client session: handshake → infer → reset → infer (new + session_id). Mirrors what ``examples/online_serving/lingbot_world_fast/openai_client.py`` + does on the wire, minus the actual model.""" + # Distinct, non-divisible-by-CHUNK_FRAMES buffer sizes so both calls + # exercise the boundary case (final chunk shorter than CHUNK_FRAMES) and + # the fill-value lets us prove chunks aren't swapped between requests. + first_frames = np.full((DEFAULT_FRAMES_PER_CHUNK * 2 + 1, 8, 8, 3), 3, dtype=np.uint8) + second_frames = np.full((DEFAULT_FRAMES_PER_CHUNK + 1, 8, 8, 3), 7, dtype=np.uint8) + + engine_client = _FakeEngineClient(queued_frames=[first_frames, second_frames]) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + # First infer uses ``4N+1`` to mimic the openai_client's fresh-call shape; + # second infer uses ``4N`` for the extension shape (modelled on the + # client's branch at ``openai_client.py:84``). + fresh_req = _infer_req(session_id="session-1", num_frames=25, include_image=True) + ext_req = _infer_req(session_id="session-2", num_frames=24, include_image=False) + + ws = _MockWebSocket( + incoming=[ + _pack_frame(fresh_req), + _pack_frame({"endpoint": "reset"}), + _pack_frame(ext_req), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # ---------------- Handshake ------------------ + assert ws.accepted is True, "Connection must accept() before sending the handshake." + assert len(ws.sent_bytes) >= 1 + handshake = msgpack_numpy.unpackb(ws.sent_bytes[0]) + assert isinstance(handshake, dict) and handshake, "Handshake must be a non-empty msgpack dict." + assert handshake.get("pipeline") == "lingbot_world_fast" + + # ---------------- Frame chunks --------------- + decoded = [msgpack_numpy.unpackb(b) for b in ws.sent_bytes[1:]] + frame_chunks = [d for d in decoded if isinstance(d, dict) and d.get("type") == "frame"] + + first_total = (len(first_frames) + DEFAULT_FRAMES_PER_CHUNK - 1) // DEFAULT_FRAMES_PER_CHUNK + second_total = (len(second_frames) + DEFAULT_FRAMES_PER_CHUNK - 1) // DEFAULT_FRAMES_PER_CHUNK + + assert len(frame_chunks) == first_total + second_total, ( + f"Expected {first_total} + {second_total} frame chunks, got {len(frame_chunks)}." + ) + + for chunk in frame_chunks: + assert chunk.keys() >= {"type", "index", "total", "video"} + video = chunk["video"] + for frame in video: + assert frame.dtype == np.float32 + assert (frame >= 0).all() and (frame <= 1).all() + assert frame.ndim == 3 # [h, w, 3] + assert frame.shape[-1] == 3 + + # Send order is first-call chunks, then second-call chunks. Index runs + # 0..total-1 per request, and the per-chunk fill-value proves the chunker + # didn't leak state between requests. + for i, chunk in enumerate(frame_chunks[:first_total]): + assert chunk["index"] == i + assert chunk["total"] == first_total + for i, chunk in enumerate(frame_chunks[first_total:]): + assert chunk["index"] == i + assert chunk["total"] == second_total + + # ---------------- reset ---------------------- + assert "reset successful" in ws.sent_text, "Reset endpoint must reply with a text ack." + # ``ServingRealtimeWorldCamera.reset`` wipes ``_current_session_id``; + # the third request then sets it to ``session-2``. + assert serving._current_session_id == "session-2" + + # ---------------- Engine call accounting ----- + # Exactly two ``generate`` invocations; the reset request must not call + # the engine. + assert len(engine_client.calls) == 2 + assert [c["session_id"] for c in engine_client.calls] == ["session-1", "session-2"] + + +def test_camera_session_handshake_does_not_repeat_within_connection() -> None: + """The handshake is sent **once** at connect, even if many infer/reset + operations follow. Other diffusion clients depend on this invariant to + avoid double-initialising their config.""" + first_frames = np.full((DEFAULT_FRAMES_PER_CHUNK + 1, 4, 4, 3), 1, dtype=np.uint8) + second_frames = np.full((DEFAULT_FRAMES_PER_CHUNK + 2, 4, 4, 3), 2, dtype=np.uint8) + engine_client = _FakeEngineClient(queued_frames=[first_frames, second_frames]) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + ws = _MockWebSocket( + incoming=[ + _pack_frame(_infer_req(session_id="s", num_frames=25, include_image=True)), + _pack_frame({"endpoint": "reset"}), + _pack_frame(_infer_req(session_id="s", num_frames=24, include_image=False)), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # Exactly one msgpack-encoded non-frame, non-error dict in the entire + # outbound stream: the handshake. + handshakes = [] + for b in ws.sent_bytes: + decoded = msgpack_numpy.unpackb(b) + if isinstance(decoded, dict) and decoded.get("type") not in ("frame", "error"): + handshakes.append(decoded) + assert len(handshakes) == 1, f"Expected exactly one handshake, got {len(handshakes)}: {handshakes}" diff --git a/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py b/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py new file mode 100644 index 00000000000..3de37072f55 --- /dev/null +++ b/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L3 real-weight online expansion for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import PIL.Image +import pytest + +from tests.helpers.lingbot_world_fast import ( + FPS, + GREAT_WALL_PROMPT, + HEIGHT, + LONG_NUM_FRAMES, + SEED, + SHORT_NUM_FRAMES, + SSIM_THRESHOLD, + WIDTH, + find_lingbot_world_fast_assets, + frame_ssim, + golden_frames_dir, + load_camera_trajectory, + slice_camera_chunk, +) +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer + +# Optional protocol deps mirror what the connection itself imports lazily. +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") +ws_sync = pytest.importorskip("websockets.sync.client") + +pytestmark = [ + pytest.mark.advanced_model, + pytest.mark.core_model, + pytest.mark.diffusion, +] + +_CONNECT_KWARGS = {"max_size": None, "ping_interval": None, "ping_timeout": None} + + +# --------------------------------------------------------------------------- +# Asset / golden fixtures (module-scoped to amortize file IO) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def lingbot_world_fast_assets(): + assets = find_lingbot_world_fast_assets() + if assets is None: + pytest.skip( + "Lingbot-World-Fast L3 assets not available. Set LINGBOT_WORLD_FAST_PATH, " + "LINGBOT_WORLD_FAST_CAMERA_PATH and LINGBOT_WORLD_FAST_IMAGE.", + ) + return assets + + +_LINGBOT_SERVER_ARGS = [ + "--model-class-name", + "LingbotWorldFastPipeline", + "--ws-max-size", + "16777216", # 16 MiB — matches run_server.sh; large enough for a 480×832 image + "--ws", + "wsproto", + "--stage-init-timeout", + "6000", + "--init-timeout", + "6000", +] + + +@pytest.fixture(scope="module") +def lingbot_world_fast_server(lingbot_world_fast_assets): + """Module-scoped real-weight server; amortizes the multi-minute cold load + across the four protocol tests in this file.""" + with OmniServer( + str(lingbot_world_fast_assets.weights_path), + list(_LINGBOT_SERVER_ARGS), + use_omni=True, + ) as server: + yield server + + +def _ws_url(server: OmniServer) -> str: + return f"ws://{server.host}:{server.port}/v1/realtime/world/camera" + + +# --------------------------------------------------------------------------- +# WebSocket helpers +# --------------------------------------------------------------------------- + + +def _drain_handshake(ws) -> dict[str, Any]: + handshake = msgpack_numpy.unpackb(ws.recv()) + return handshake + + +def _send_request(ws, req: dict[str, Any]) -> None: + ws.send(msgpack_numpy.packb(req)) + + +def _drain_frames_or_error(ws) -> tuple[list[np.ndarray] | None, dict[str, Any] | None, str | None]: + """Return ``(frames, error, text)``. Exactly one of the three is non-None. + + * ``frames``: list of per-chunk uint8 arrays once ``total`` frames arrive. + * ``error``: parsed ``{"type": "error", "message": ...}`` payload. + * ``text``: text-frame reply (e.g. ``"reset successful"``). + """ + chunks: list[np.ndarray] = [] + total: int | None = None + while total is None or len(chunks) < total: + msg = ws.recv() + if isinstance(msg, str): + return None, None, msg + decoded = msgpack_numpy.unpackb(msg) + if isinstance(decoded, dict) and decoded.get("type") == "error": + return None, decoded, None + if not isinstance(decoded, dict) or decoded.get("type") != "frame": + continue # ignore unknown + total = decoded["total"] + chunks.append(np.asarray(decoded["video"])) + return chunks, None, None + + +def _build_request( + *, + session_id: str, + image: np.ndarray | None, + camera_chunk: dict[str, np.ndarray], + num_frames: int, +) -> dict[str, Any]: + req: dict[str, Any] = { + "prompt": GREAT_WALL_PROMPT, + "camera": camera_chunk, + "session_id": session_id, + "extra_body": { + "num_frames": num_frames, + "height": HEIGHT, + "width": WIDTH, + "fps": FPS, + "session_id": session_id, + "seed": SEED, + }, + } + if image is not None: + req["image"] = image + return req + + +# --------------------------------------------------------------------------- +# Test 1: Single session generation +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +@pytest.mark.parametrize("num_frames, length", [(SHORT_NUM_FRAMES, "short"), (LONG_NUM_FRAMES, "long")]) +def test_lingbot_world_online_video( + num_frames, + length, + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + poses = poses[:num_frames] + intrinsics = intrinsics[:num_frames] + + camera = {"poses": poses, "intrinsics": intrinsics} + + req = _build_request( + session_id=f"SESSION-ID-{length}", + image=image, + camera_chunk=camera, + num_frames=num_frames, + ) + + _send_request(ws, req) + + chunks, error, text = _drain_frames_or_error(ws) + assert error is None and text is None, f"Got unexpected control reply: error={error} text={text}" + assert chunks is not None and chunks, "Returned no frames" + + reassembled = np.concatenate(chunks, axis=0) + + assert reassembled.ndim == 4 and reassembled.shape[0] >= 2, ( + f"Reassembled video has too few frames: {reassembled.shape}" + ) + + first_frame = (reassembled[0] * 255.0).round().astype(np.uint8) + last_frame = (reassembled[-1] * 255.0).round().astype(np.uint8) + + first_path = golden_frames_dir() / f"golden_frame_{length}_first.npy" + last_path = golden_frames_dir() / f"golden_frame_{length}_last.npy" + + first_golden = np.load(first_path) + last_golden = np.load(last_path) + + ssim_first = frame_ssim(first_frame, first_golden) + ssim_last = frame_ssim(last_frame, last_golden) + print( + f"[lingbot-world-fast L3 online] SSIM(first)={ssim_first:.4f} " + f"SSIM(last)={ssim_last:.4f} (threshold {SSIM_THRESHOLD})" + ) + assert ssim_first >= SSIM_THRESHOLD, ( + f"First-frame SSIM {ssim_first:.4f} below {SSIM_THRESHOLD}: regression in fresh-call path." + ) + assert ssim_last >= SSIM_THRESHOLD, ( + f"Last-frame SSIM {ssim_last:.4f} below {SSIM_THRESHOLD}: regression in extension-call path." + ) + + +# --------------------------------------------------------------------------- +# Test 2: Session-id churn mid-stream +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +def test_websocket_session_id_churn_resets_state( + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + """A new ``session_id`` mid-stream resets pipeline state. The next ``infer`` + that omits the image (i.e. an "extension-style" payload) must error + because the new session is fresh.""" + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + camera_a = slice_camera_chunk(poses, intrinsics, call_index=0) + camera_b = slice_camera_chunk(poses, intrinsics, call_index=1) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + _send_request( + ws, + _build_request( + session_id="churn-session-a", + image=image, + camera_chunk=camera_a, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks, error, text = _drain_frames_or_error(ws) + assert chunks is not None and not error and not text, ( + f"First infer on session-a should succeed; got error={error} text={text}" + ) + + # Switch session_id mid-stream WITHOUT sending an image. The pipeline + # treats this as a fresh call (new session) and rejects. + _send_request( + ws, + _build_request( + session_id="churn-session-b", + image=None, + camera_chunk=camera_b, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks2, error2, text2 = _drain_frames_or_error(ws) + assert error2 is not None, ( + "Server must reject a fresh session that omits ``image``; got " + f"frames={None if chunks2 is None else len(chunks2)} text={text2}" + ) + assert error2.get("type") == "error" + + +# --------------------------------------------------------------------------- +# Test 3: Mid-session ``reset`` RPC re-initializes +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +def test_websocket_mid_session_reset_reinitializes( + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + """After a ``reset`` text ack, the next ``infer`` with the same ``session_id`` + is a brand-new fresh call. We verify this by asserting that the + follow-up ``infer`` *without* an image errors (same logic as the + session-id churn test).""" + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + camera_a = slice_camera_chunk(poses, intrinsics, call_index=0) + camera_b = slice_camera_chunk(poses, intrinsics, call_index=1) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + _send_request( + ws, + _build_request( + session_id="reset-session", + image=image, + camera_chunk=camera_a, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks, error, text = _drain_frames_or_error(ws) + assert chunks is not None and not error and not text, ( + f"Initial infer must succeed; got error={error} text={text}" + ) + + # Mid-session reset RPC. + ws.send(msgpack_numpy.packb({"endpoint": "reset"})) + _, _, reset_text = _drain_frames_or_error(ws) + assert reset_text == "reset successful", f"Expected 'reset successful' text frame, got {reset_text!r}" + + # Same session_id, no image → fresh-call branch → server error. + _send_request( + ws, + _build_request( + session_id="reset-session", + image=None, + camera_chunk=camera_b, + num_frames=SHORT_NUM_FRAMES, + ), + ) + _, post_reset_error, post_reset_text = _drain_frames_or_error(ws) + assert post_reset_error is not None, ( + "After mid-session reset the server must treat the next infer as a fresh call; " + f"missing-image payload should error. Got text={post_reset_text!r}" + ) diff --git a/tests/helpers/lingbot_world_fast.py b/tests/helpers/lingbot_world_fast.py new file mode 100644 index 00000000000..5d9afdabab3 --- /dev/null +++ b/tests/helpers/lingbot_world_fast.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared L3-fixture helpers for Lingbot World Fast expansion tests.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + +# --------------------------------------------------------------------------- +# Constants — single source of truth across the two expansion tests +# --------------------------------------------------------------------------- + +# Mirrors ``examples/offline_inference/lingbot_world_fast/end2end.py`` and +# ``examples/online_serving/lingbot_world_fast/openai_client.py``. +GREAT_WALL_PROMPT = "A sweeping cinematic journey along the Great Wall of China, winding through golden autumn hills under a brilliant blue sky — stone pathways stretch into the distance, watchtowers stand sentinel, and vibrant foliage blankets the mountainsides as the camera glides smoothly forward, capturing the grandeur and timeless majesty of this ancient wonder." +SEED = 42 +WIDTH = 832 +HEIGHT = 480 +FPS = 16 + +SHORT_NUM_FRAMES = 25 +LONG_NUM_FRAMES = 81 + +EXTENSION_WARMUP_DROP = 4 + +SSIM_THRESHOLD = 0.95 + + +@dataclass(frozen=True) +class LingbotWorldFastAssets: + """All external assets a real-weight Lingbot World Fast test needs.""" + + weights_path: Path + camera_dir: Path + image_path: Path + + +# --------------------------------------------------------------------------- +# Resolution helpers (no pytest imports here — callers decide whether to skip) +# --------------------------------------------------------------------------- + + +def _hf_cache_root() -> Path: + return Path(os.environ.get("HF_HOME", str(Path.home() / ".cache" / "huggingface"))) + + +def _hf_model_snapshot_dirs(repo_id: str) -> list[Path]: + snapshots = _hf_cache_root() / "hub" / f"models--{repo_id.replace('/', '--')}" / "snapshots" + if not snapshots.exists(): + return [] + return sorted( + (p for p in snapshots.iterdir() if p.is_dir()), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + + +def _repo_root() -> Path: + # tests/helpers/lingbot_world_fast.py → repo root + return Path(__file__).resolve().parents[2] + + +def _example_checkpoint_root() -> Path: + return ( + _repo_root() + / "examples" + / "offline_inference" + / "lingbot_world_fast" + / "lingbot_world" + / "lingbot-world-base-cam" + / "Lingbot-World-Fast" + ) + + +def _example_camera_root_candidates() -> list[Path]: + base = ( + _repo_root() + / "examples" + / "offline_inference" + / "lingbot_world_fast" + / "lingbot_world" + / "lingbot-world-base-cam" + ) + return [base / "examples" / "04", base / "04", base] + + +def find_lingbot_world_fast_weights() -> Path | None: + """Return a path to the Lingbot World Fast model directory, or ``None``. + + Resolution order: ``LINGBOT_WORLD_FAST_PATH`` env var → committed example + checkpoint → HF cache snapshot of ``robbyant/lingbot-world-base-cam``. + A *path is real* only when the required ``config.json`` plus at least + one ``model-*.safetensors`` shard are present, so a half-pulled snapshot + doesn't masquerade as a usable checkpoint. + """ + override = os.environ.get("LINGBOT_WORLD_FAST_PATH") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + + candidates.append(_example_checkpoint_root()) + + for snapshot in _hf_model_snapshot_dirs("robbyant/lingbot-world-base-cam"): + candidates.append(snapshot / "Lingbot-World-Fast") + + for path in candidates: + if not path.exists() or not path.is_dir(): + continue + if not (path / "config.json").exists(): + continue + if not any(path.glob("model-*.safetensors")): + continue + return path + return None + + +def find_lingbot_world_fast_camera_dir() -> Path | None: + """Locate a directory with ``poses.npy`` + ``intrinsics.npy``.""" + override = os.environ.get("LINGBOT_WORLD_FAST_CAMERA_PATH") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + candidates.extend(_example_camera_root_candidates()) + for path in candidates: + if path.exists() and (path / "poses.npy").exists() and (path / "intrinsics.npy").exists(): + return path + return None + + +def find_lingbot_world_fast_image() -> Path | None: + """Locate the example input image used by ``run_fast.sh`` case 2.""" + override = os.environ.get("LINGBOT_WORLD_FAST_IMAGE") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + for camera_root in _example_camera_root_candidates(): + for name in ("image.jpg", "image.png", "input.jpg", "input.png"): + candidates.append(camera_root / name) + for path in candidates: + if path.exists() and path.is_file(): + return path + return None + + +def find_lingbot_world_fast_assets() -> LingbotWorldFastAssets | None: + """Resolve weights + camera trajectory + image in one call, or return None + if any of the three is missing — callers can then ``pytest.skip`` with a + single specific reason.""" + weights = find_lingbot_world_fast_weights() + camera = find_lingbot_world_fast_camera_dir() + image = find_lingbot_world_fast_image() + if not (weights and camera and image): + return None + return LingbotWorldFastAssets(weights_path=weights, camera_dir=camera, image_path=image) + + +def golden_frames_dir() -> Path: + return _repo_root() / "tests" / "data" / "lingbot_world_fast" + + +# --------------------------------------------------------------------------- +# Per-call payload helpers +# --------------------------------------------------------------------------- + + +def load_camera_trajectory(camera_dir: Path) -> tuple[np.ndarray, np.ndarray]: + poses = np.load(camera_dir / "poses.npy") + intrinsics = np.load(camera_dir / "intrinsics.npy") + return poses, intrinsics + + +def slice_camera_chunk( + poses: np.ndarray, + intrinsics: np.ndarray, + *, + call_index: int, + chunk_stride: int = SHORT_NUM_FRAMES, +) -> dict[str, np.ndarray]: + """Mirrors the slicing in ``examples/online_serving/lingbot_world_fast/openai_client.py``. + + Each call consumes ``chunk_stride`` poses. The model will floor internally if the slice has fewer + poses than requested. + """ + start = call_index * chunk_stride + end = start + chunk_stride + poses_slice = poses[start:end] + intrinsics_slice = intrinsics[start:end] if intrinsics.ndim > 2 else intrinsics + return {"poses": poses_slice, "intrinsics": intrinsics_slice} + + +# --------------------------------------------------------------------------- +# Frame post-processing helpers +# --------------------------------------------------------------------------- + + +def reassemble_chunked_video( + per_call_frames: list[np.ndarray], + *, + drop_warmup: int = EXTENSION_WARMUP_DROP, +) -> np.ndarray: + """Concatenate per-call frame chunks, dropping ``drop_warmup`` leading + frames on every extension call (call index >= 1).""" + assembled: list[np.ndarray] = [] + for i, frames in enumerate(per_call_frames): + clip = frames[drop_warmup:] if i > 0 else frames + assembled.append(clip) + return np.concatenate(assembled, axis=0) + + +def normalize_to_uint8_rgb(frames: np.ndarray) -> np.ndarray: + """Coerce a generated frames tensor to ``[N, H, W, 3]`` ``uint8``. + + The diffusion engine emits either an unsigned-int chunk (pre-encoded by + ``_normalize_frames``) or a float tensor in ``[-1, 1]``. We accept both + so the SSIM helper sees a single canonical shape. + """ + arr = frames + if arr.dtype.kind == "f": + arr = np.clip((arr + 1.0) * 0.5, 0.0, 1.0) + arr = (arr * 255.0).round().astype(np.uint8) + if arr.ndim == 5 and arr.shape[0] == 1: + arr = arr[0] + return arr + + +def frame_ssim(prediction: np.ndarray, reference: np.ndarray) -> float: + """Per-frame SSIM with ``data_range=1``. Uses ``torchmetrics`` (already a + transitive dep) and accepts ``[H, W, 3]`` uint8 arrays. + """ + import torch + from torchmetrics.image import StructuralSimilarityIndexMeasure + + pred_t = (torch.from_numpy(prediction.astype(np.float32) / 255.0)).permute(2, 0, 1).unsqueeze(0) + ref_t = (torch.from_numpy(reference.astype(np.float32) / 255.0)).permute(2, 0, 1).unsqueeze(0) + metric = StructuralSimilarityIndexMeasure(data_range=1.0) + return float(metric(pred_t, ref_t).item()) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index b8eb899fa32..a11ac588a35 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -797,7 +797,7 @@ def mock_process(transfer_manager, pooling_output, request): class TestLocalPayloadCacheLifecycle(unittest.TestCase): - """Unit tests for the local payload cache API (RFC §2.4).""" + """Unit tests for the local payload cache API""" def _make_host(self) -> MixinHost: host = MixinHost() diff --git a/vllm_omni/deploy/lingbot_world_fast.yaml b/vllm_omni/deploy/lingbot_world_fast.yaml new file mode 100644 index 00000000000..5414c3c0797 --- /dev/null +++ b/vllm_omni/deploy/lingbot_world_fast.yaml @@ -0,0 +1,18 @@ +# Stage config for Lingbot World Fast online serving. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + process: true + devices: "0" + final_output: true + final_output_type: image + engine_args: + model_class_name: LingbotWorldFastPipeline + dtype: bfloat16 + model_config: + patch_size: [1, 2, 2] + vae_stride: [4, 8, 8] + latent_frames_per_chunk: 3 + max_area: 399360 # 480 * 832 diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index a6fe1e4e9c7..d1f18c1ea49 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -901,6 +901,9 @@ def enrich_config(self) -> None: self.model_class_name = "WanS2VPipeline" self.tf_model_config = TransformerConfig() self.update_multimodal_support() + elif self.model_class_name == "LingbotWorldFastPipeline": + self.tf_config_dict = get_hf_file_to_dict("config.json", self.model) + self.tf_model_config = TransformerConfig.from_dict(self.tf_config_dict) elif architectures and len(architectures) == 1: self.model_class_name = architectures[0] else: diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index c13bd3c0c37..3633ea4c541 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -71,6 +71,13 @@ def supports_multimodal_input(od_config: OmniDiffusionConfig) -> tuple[bool, boo return supports_image_input, supports_audio_input +def supports_camera_pos_input(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_camera_pos_input", False)) + + def image_color_format(model_class_name: str) -> str: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) return getattr(model_cls, "color_format", "RGB") @@ -690,6 +697,17 @@ def _dummy_run(self): audio_sr = 16000 dummy_audio = np.random.randn(audio_sr * 2).astype(np.float32) prompt.setdefault("multi_modal_data", {})["audio"] = dummy_audio + else: + dummy_audio = None + + if supports_camera_pos_input(self.od_config.model_class_name): + camera_pos_len = 64 + # Shape [N x 4] + intrinsics = np.random.rand(camera_pos_len, 4) + # Shape [N x 4 x 4] + poses = np.array([np.identity(4) for _ in range(camera_pos_len)]) + dummy_camera_pos = {"intrinsics": intrinsics, "poses": poses} + prompt.setdefault("multi_modal_data", {})["camera"] = dummy_camera_pos req = OmniDiffusionRequest( prompts=[prompt], diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index f0ded9cdb0a..21848031120 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -28,6 +28,11 @@ class SupportAudioInput(Protocol): support_audio_input: ClassVar[bool] = True +@runtime_checkable +class SupportCameraPosInput(Protocol): + support_camera_pos_input: ClassVar[bool] = True + + @runtime_checkable class SupportAudioOutput(Protocol): support_audio_output: ClassVar[bool] = True diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py b/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py new file mode 100644 index 00000000000..20513d34e60 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_lingbot_world_fast import LingbotWorldFastPipeline, get_lingbot_world_fast_post_process_func +from .wan_fast import WanModelFast + +__all__ = ["LingbotWorldFastPipeline", "get_lingbot_world_fast_post_process_func", "WanModelFast"] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py b/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py new file mode 100644 index 00000000000..cbc0da6889e --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py @@ -0,0 +1,153 @@ +# Adapted from Lingbot-World/wan/utils/cam_utils.py +import numpy as np +import torch +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation, Slerp + + +def interpolate_camera_poses( + src_indices: np.ndarray, + src_rot_mat: np.ndarray, + src_trans_vec: np.ndarray, + tgt_indices: np.ndarray, +) -> torch.Tensor: + # interpolate translation + interp_func_trans = interp1d( + src_indices, + src_trans_vec, + axis=0, + kind="linear", + bounds_error=False, + fill_value="extrapolate", + ) + interpolated_trans_vec = interp_func_trans(tgt_indices) + + # interpolate rotation + src_quat_vec = Rotation.from_matrix(src_rot_mat) + # ensure there is no sudden change in qw + quats = src_quat_vec.as_quat().copy() # [N, 4] + for i in range(1, len(quats)): + if np.dot(quats[i], quats[i - 1]) < 0: + quats[i] = -quats[i] + src_quat_vec = Rotation.from_quat(quats) + slerp_func_rot = Slerp(src_indices, src_quat_vec) + interpolated_rot_quat = slerp_func_rot(tgt_indices) + interpolated_rot_mat = interpolated_rot_quat.as_matrix() + + poses = np.zeros((len(tgt_indices), 4, 4)) + poses[:, :3, :3] = interpolated_rot_mat + poses[:, :3, 3] = interpolated_trans_vec + poses[:, 3, 3] = 1.0 + return torch.from_numpy(poses).float() + + +def SE3_inverse(t: torch.Tensor) -> torch.Tensor: + Rot = t[:, :3, :3] # [B,3,3] + trans = t[:, :3, 3:] # [B,3,1] + R_inv = Rot.transpose(-1, -2) + t_inv = -torch.bmm(R_inv, trans) + T_inv = torch.eye(4, device=t.device, dtype=t.dtype)[None, :, :].repeat(t.shape[0], 1, 1) + T_inv[:, :3, :3] = R_inv + T_inv[:, :3, 3:] = t_inv + return T_inv + + +def compute_relative_poses( + c2ws_mat: torch.Tensor, + framewise: bool = False, + normalize_trans: bool = True, +) -> torch.Tensor: + ref_w2cs = SE3_inverse(c2ws_mat[0:1]) + relative_poses = torch.matmul(ref_w2cs, c2ws_mat) + # ensure identity matrix for 1st frame + relative_poses[0] = torch.eye(4, device=c2ws_mat.device, dtype=c2ws_mat.dtype) + if framewise: + # compute pose between i and i+1 + relative_poses_framewise = torch.bmm(SE3_inverse(relative_poses[:-1]), relative_poses[1:]) + relative_poses[1:] = relative_poses_framewise + if normalize_trans: + # scale the coordinate inputs to roughly 1 standard deviation to simplify model learning (camctrl2). + translations = relative_poses[:, :3, 3] # [f, 3] + max_norm = torch.norm(translations, dim=-1).max() + # only normalize when moving + if max_norm > 0: + relative_poses[:, :3, 3] = translations / max_norm + return relative_poses + + +@torch.no_grad() +def create_meshgrid( + n_frames: int, height: int, width: int, bias: float = 0.5, device="cuda", dtype=torch.float32 +) -> torch.Tensor: + x_range = torch.arange(width, device=device, dtype=dtype) + y_range = torch.arange(height, device=device, dtype=dtype) + grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") + grid_xy = torch.stack([grid_x, grid_y], dim=-1).view([-1, 2]) + bias # [h*w, 2] + grid_xy = grid_xy[None, ...].repeat(n_frames, 1, 1) # [f, h*w, 2] + return grid_xy + + +def get_plucker_embeddings( + c2ws_mat: torch.Tensor, + k: torch.Tensor, + height: int, + width: int, + only_rays_d: bool = False, +): + n_frames = c2ws_mat.shape[0] + grid_xy = create_meshgrid(n_frames, height, width, device=c2ws_mat.device, dtype=c2ws_mat.dtype) # [f, h*w, 2] + fx, fy, cx, cy = k.chunk(4, dim=-1) # [f, 1] + + i = grid_xy[..., 0] # [f, h*w] + j = grid_xy[..., 1] # [f, h*w] + zs = torch.ones_like(i) # [f, h*w] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + + directions = torch.stack([xs, ys, zs], dim=-1) # [f, h*w, 3] + directions = directions / directions.norm(dim=-1, keepdim=True) # [f, h*w, 3] + + rays_d = directions @ c2ws_mat[:, :3, :3].transpose(-1, -2) # [f, h*w, 3] + if only_rays_d: + plucker_embeddings = rays_d # [f, h*w, 3] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 3]) # [f*h*w, 3] + else: + rays_o = c2ws_mat[:, :3, 3] # [f, 3] + rays_o = rays_o[:, None, :].expand_as(rays_d) # [f, h*w, 3] + plucker_embeddings = torch.cat([rays_o, rays_d], dim=-1) # [f, h*w, 6] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 6]) # [f*h*w, 6] + return plucker_embeddings + + +def get_Ks_transformed( + k: torch.Tensor, + height_org: int, + width_org: int, + height_resize: int, + width_resize: int, + height_final: int, + width_final: int, +): + fx, fy, cx, cy = k.chunk(4, dim=-1) # [f, 1] + + scale_x = width_resize / width_org + scale_y = height_resize / height_org + + fx_resize = fx * scale_x + fy_resize = fy * scale_y + cx_resize = cx * scale_x + cy_resize = cy * scale_y + + crop_offset_x = (width_resize - width_final) / 2 + crop_offset_y = (height_resize - height_final) / 2 + + cx_final = cx_resize - crop_offset_x + cy_final = cy_resize - crop_offset_y + + Ks_transformed = torch.zeros_like(k) + Ks_transformed[:, 0:1] = fx_resize + Ks_transformed[:, 1:2] = fy_resize + Ks_transformed[:, 2:3] = cx_final + Ks_transformed[:, 3:4] = cy_final + + return Ks_transformed diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py b/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py new file mode 100644 index 00000000000..d9a18899da7 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py @@ -0,0 +1,576 @@ +# Adapted from Lingbot-World/wan/utils/fm_solvers_unipc.py +# Originally derived from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Converted to flow matching. + + +import math + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats # noqa: F401 + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # set table values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None = None, + shift: float | None = None, + ): + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> SchedulerOutput | tuple: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + if isinstance(timesteps, list): + timesteps = [timestep.to(original_samples.device) for timestep in timesteps] + else: + timesteps = timesteps.to(original_samples.device) + + if self.begin_index is None: + if not isinstance(timesteps, list): + timesteps = [timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + step_indices = [self.step_index] * timesteps.shape[0] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py new file mode 100644 index 00000000000..3c92f4219eb --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -0,0 +1,480 @@ +import logging +import math +import os +import random +import sys +import time +from collections.abc import Iterable +from contextlib import contextmanager + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from einops import rearrange +from torch import nn +from tqdm import tqdm +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportCameraPosInput, SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .cam_utils import ( + compute_relative_poses, + get_Ks_transformed, + get_plucker_embeddings, + interpolate_camera_poses, +) +from .fm_solvers_unipc import FlowUniPCMultistepScheduler +from .state_lingbot_world_fast import LingbotWorldFastState +from .t5 import T5EncoderModel +from .vae2_1 import Wan2_1_VAE +from .wan_fast import WanModelFast + +logger = logging.getLogger(__name__) + +CONFIG = { + "num_train_timesteps": 1000, + "timesteps_index": [0, 179, 358, 679], + "sample_shift": 10.0, + "t5_checkpoint": "models_t5_umt5-xxl-enc-bf16.pth", + "t5_tokenizer": "google/umt5-xxl", + "vae_checkpoint": "Wan2.1_VAE.pth", + "fast_noise_checkpoint": "Lingbot-World-Fast", +} + + +def get_lingbot_world_fast_post_process_func( + od_config: OmniDiffusionConfig, +): + def post_process_func( + video: torch.Tensor, + ): + outputs = video.permute(1, 2, 3, 0) + return outputs + + return post_process_func + + +class LingbotWorldFastPipeline(nn.Module, SupportImageInput, SupportCameraPosInput, CFGParallelMixin): + def __init__(self, *, od_config: OmniDiffusionConfig): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + + self.device = get_local_device() + + self.target_dtype = od_config.dtype + + self.control_type = "cam" + self.num_train_timesteps = CONFIG["num_train_timesteps"] + + self.sp_size = od_config.parallel_config.world_size + + self.state = LingbotWorldFastState() + + checkpoint_path = os.path.dirname(self.od_config.model) + assert checkpoint_path is not None, "lingbot_dir is None" + + self.text_encoder = T5EncoderModel( + text_len=od_config.tf_model_config.get("text_len"), + dtype=self.target_dtype, + device=self.device, + checkpoint_path=os.path.join(checkpoint_path, CONFIG["t5_checkpoint"]), + tokenizer_path=os.path.join(checkpoint_path, CONFIG["t5_tokenizer"]), + ) + + self.vae_stride = od_config.model_config["vae_stride"] + self.patch_size = od_config.model_config["patch_size"] + self.vae = Wan2_1_VAE(vae_pth=os.path.join(checkpoint_path, CONFIG["vae_checkpoint"]), device=self.device) + + logger.info(f"Creating WanModelFast from {checkpoint_path}") + + self.model = WanModelFast( + in_dim=od_config.tf_model_config.get("in_dim"), + dim=od_config.tf_model_config.get("dim"), + ffn_dim=od_config.tf_model_config.get("ffn_dim"), + num_layers=od_config.tf_model_config.get("num_layers"), + num_heads=od_config.tf_model_config.get("num_heads"), + ) + + # Tell the loader where the transformer weights live. Without this, + # get_all_weights() yields nothing and load_weights() runs empty. + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=checkpoint_path, + subfolder=CONFIG["fast_noise_checkpoint"], + revision=None, + prefix="model.", + fall_back_to_pt=True, + ), + ] + # self.model = WanModelFast.from_pretrained( + # checkpoint_path, + # subfolder=CONFIG["fast_noise_checkpoint"], + # torch_dtype=torch.bfloat16, + # control_type=self.control_type, + # ).to(self.device) + + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + + def _configure_model(self, model): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + def _convert_flow_pred_to_x0( + self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor, scheduler + ) -> torch.Tensor: + """ + Convert flow matching's prediction to x0 prediction. + flow_pred: the prediction with shape [B, C, F, H, W] + xt: the input noisy data with shape [B, C, F, H, W] + timestep: the timestep with shape [B] + + pred = noise - x0 + x_t = (1-sigma_t) * x0 + sigma_t * noise + we have x0 = x_t - sigma_t * pred + """ + # use higher precision for calculations + original_dtype = flow_pred.dtype + flow_pred, xt, sigmas, timesteps = map( + lambda x: x.double().to(flow_pred.device), [flow_pred, xt, scheduler.sigmas, scheduler.timesteps] + ) + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + + return x0_pred.to(original_dtype) + + def forward( + self, + req: OmniDiffusionRequest, + ) -> DiffusionOutput: + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + prompt = req.prompts[0].get("prompt") + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) + + session_id = req.sampling_params.extra_args.get("session_id") + + if session_id is None: + # Create a unique id if none is specified without messing with RNG state + session_id = time.time() + + session_id = str(session_id) + + force_reset = req.sampling_params.extra_args.get("force_reset") or False + + extension = True + + if force_reset or self.state.session_id is None or self.state.session_id != session_id: + self.state.reset() + self.state.session_id = session_id + extension = False + else: + extension = True + + camera = multi_modal_data.get("camera", None) + if camera is None: + raise ValueError("Camera positions are required by this model.") + + if extension: + assert multi_modal_data.get("image") is None, ( + "image must not be provided on extension calls; it is only used on the first call of a session" + ) + assert self.model.config.local_attn_size == -1, ( + "video extension requires the model to be configured with local_attn_size == -1" + ) + + batch_size = 1 + num_frames = req.sampling_params.num_frames + # In order to generate something num_frames must be at least 5 since it expects 4*n + 1 as input + # 25 is the smallest length supported by the model. Smaller values generate tensors with dimension zero/negative + num_frames = max(25, num_frames) + + c2ws = camera.get("poses") + latent_frames_per_chunk = self.od_config.model_config["latent_frames_per_chunk"] + max_area = self.od_config.model_config["max_area"] + + # Fresh: 4N+1 pixel frames → N+1 latents, the first slot is the anchor. + # Extension: 4N pixel frames → N regular latents, no anchor. + if extension: + len_c2ws = (len(c2ws) // 4) * 4 + num_frames = (num_frames // 4) * 4 + num_frames = min(num_frames, len_c2ws) + new_lat_f = num_frames // 4 + else: + len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 + num_frames = ((num_frames - 1) // 4) * 4 + 1 + num_frames = min(num_frames, len_c2ws) + new_lat_f = (num_frames - 1) // 4 + 1 + c2ws = c2ws[:num_frames] + + # 1. Derive spatial shape: from the input image on fresh start, from cache on extension. + if not extension: + img = multi_modal_data.get("image") + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1] + ) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2] + ) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + else: + img = None + h, w, lat_h, lat_w = self.state.h, self.state.w, self.state.lat_h, self.state.lat_w + + new_lat_f = int(new_lat_f - (new_lat_f % latent_frames_per_chunk)) + new_lat_f = max(new_lat_f, 1) + max_seq_len = latent_frames_per_chunk * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + seed_g = req.sampling_params.generator + if seed_g is None: + seed = req.sampling_params.seed + if seed is None: + seed = random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn(16, new_lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + + # Fresh: msk[0] = 1 (anchor) and the rest = 0, replicated into 4 channels grouped + # by latent frame to give shape [4, new_lat_f, lat_h, lat_w]. + # Extension: no anchor, all zeros, already in the [4, new_lat_f, ...] layout. + if not extension: + F = (new_lat_f - 1) * 4 + 1 + msk = torch.zeros(1, F, lat_h, lat_w, device=self.device) + msk[:, 0] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + else: + msk = torch.zeros(4, new_lat_f, lat_h, lat_w, device=self.device) + + # 2. Prepare timesteps + self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) + timesteps = self.scheduler.timesteps[CONFIG["timesteps_index"]] + + context = self.text_encoder([prompt], self.device) + context = [t.to(self.device) for t in context] + + dit_cond_dict = None + Ks = torch.from_numpy(camera.get("intrinsics")) + + # Transform the provided intrinsics from the original 480p according to the new image size (h, w). + Ks = get_Ks_transformed( + Ks, height_org=480, width_org=832, height_resize=h, width_resize=w, height_final=h, width_final=w + ) + Ks = Ks[0] + + # One target pose per output latent — must match the f= in the rearrange below. + len_c2ws = len(c2ws) + len_c2ws_ = new_lat_f + c2ws_infer = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws - 1, len_c2ws), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=np.linspace(0, len_c2ws - 1, len_c2ws_), + ) + c2ws_infer = compute_relative_poses(c2ws_infer, framewise=True) + Ks = Ks.repeat(len(c2ws_infer), 1) + + c2ws_infer = c2ws_infer.to(self.device).to(torch.float32) + Ks = Ks.to(self.device).to(torch.float32) + only_rays_d = False + c2ws_plucker_emb = get_plucker_embeddings(c2ws_infer, Ks, h, w, only_rays_d=only_rays_d) + c2ws_plucker_emb = rearrange( + c2ws_plucker_emb, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=int(h // lat_h), + c2=int(w // lat_w), + ) + c2ws_plucker_emb = c2ws_plucker_emb[None, ...] # [b, f*h*w, c] + c2ws_plucker_emb = rearrange(c2ws_plucker_emb, "b (f h w) c -> b c f h w", f=new_lat_f, h=lat_h, w=lat_w).to( + self.target_dtype + ) + + # Fresh: pixels = [anchor_image, zeros...] of shape [3, 4N+1, h, w]. + # VAE produces N+1 latents; latent[0] is the anchor encoding. + # Extension: pixels = zeros [3, 4N+1, h, w]. VAE produces N+1 latents, + # of which latent[0] is the special "1-frame init" encoding + # (biased differently than the regular 4-frame-group latents). + # Slice it off so the N conditioning slots are all regular — + # this drops a CONDITIONING slot, not an output latent. + if not extension: + F = (new_lat_f - 1) * 4 + 1 + pixels = torch.concat( + [ + torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, F - 1, h, w), + ], + dim=1, + ).to(self.device) + y = self.vae.encode([pixels])[0] + else: + pixels = torch.zeros(3, 4 * new_lat_f + 1, h, w, device=self.device) + y = self.vae.encode([pixels])[0][:, 1:] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync_model = getattr(self.model, "no_sync", noop_no_sync) + + # Initialize (fresh) or grow (extension) the KV cache. Cross-attn cache is + # left untouched on extension so text-context k/v computed on the first call + # are reused via crossattn_cache[i]["is_init"] == True. + model_args = self.model.config + transformer_dtype = self.target_dtype + frame_seqlen = int(noise.shape[-2] * noise.shape[-1] // 4) + extra_kv_size = frame_seqlen * new_lat_f + head_dim = model_args.dim // model_args.num_heads + local_num_heads = model_args.num_heads // self.sp_size + + if not extension: + self.state.create_kv_caches( + batch_size, + transformer_dtype, + self.device, + extra_kv_size, + model_args.num_layers, + local_num_heads, + head_dim, + ) + else: + self.state.extend_kv_caches(extra_kv_size) + + # Total cache size after this call, used both as the per-query attention + # window and as the absolute-token offset base for the chunk loop. + prev_lat_f = self.state.current_lat_f + total_kv_size = frame_seqlen * (prev_lat_f + new_lat_f) + start_token_offset = prev_lat_f * frame_seqlen + + # evaluation mode + with ( + torch.amp.autocast("cuda", dtype=self.target_dtype), + torch.no_grad(), + no_sync_model(), + ): + # sample videos + latent = noise + latents_chunk = latent.split(latent_frames_per_chunk, dim=1) # [c, f, h, w] + condition_chunk = y.split(latent_frames_per_chunk, dim=1) + c2ws_plucker_emb_chunk = c2ws_plucker_emb.split(latent_frames_per_chunk, dim=2) + num_inference_chunk = len(latents_chunk) + pred_latent_chunks = [] + for chunk_id in tqdm(range(num_inference_chunk)): + current_latent = latents_chunk[chunk_id] + current_condition = condition_chunk[chunk_id] + current_c2ws_plucker_emb = c2ws_plucker_emb_chunk[chunk_id] + + dit_cond_dict = { + "c2ws_plucker_emb": current_c2ws_plucker_emb.chunk(1, dim=0), + } + + kwargs = { + "context": [context[0]], + "seq_len": max_seq_len, + "y": [current_condition], + "dit_cond_dict": dit_cond_dict, + "kv_cache": self.state.get_kv_cache(), + "local_end_index": self.state.local_end_index, + "global_end_index": self.state.global_end_index, + "crossattn_cache": self.state.get_crossattn_caches(), + "current_start": start_token_offset + chunk_id * latent_frames_per_chunk * frame_seqlen, + "max_attention_size": total_kv_size, + } + + for timestep_idx in range(len(timesteps)): + latent_model_input = [current_latent.to(self.device)] + current_timestep = [timesteps[timestep_idx]] + + timestep = torch.stack(current_timestep).to(self.device) + + noise_pred = self.model(x=latent_model_input, t=timestep, **kwargs)[0] + + x0 = self._convert_flow_pred_to_x0( + flow_pred=noise_pred, + xt=current_latent, + timestep=current_timestep[0], + scheduler=self.scheduler, + ) + + if timestep_idx < len(timesteps) - 1: + next_timestep = timesteps[timestep_idx + 1] + current_latent = self.scheduler.add_noise( + x0, torch.randn(x0.shape, generator=seed_g, device=x0.device, dtype=x0.dtype), next_timestep + ) + else: + # note return x0 + break + + pred_latent_chunks.append(x0) + + # Update kv cache + context_timestep = [timesteps[-1] * 0.0] + timestep = torch.stack(context_timestep).to(self.device) + self.model(x=[x0], t=timestep, **kwargs) + + pred_latent_chunks = torch.cat(pred_latent_chunks, dim=1) + + videos = None + if self.device.index == 0: + # Wan VAE decode() calls clear_cache() internally, so the very first latent always runs the i==0 path + # (no temporal upsample, single-frame output) and leaves feat_map polluted with thatbias. + # The decoder's stacked temporal-causal layers also need ~2 latents of streaming context before deeper + # feat_map slots match a true mid-stream decode. On extension, prepend the prior chunk's last 2 latents + # so warmup_0 absorbs the i==0 bias and warmup_1 fully primes the cache. Then discard the + # 4*K - 3 leading pixels (re-decodes of already-shown frames). + if extension and self.state.last_decoded_latent is not None: + warmup = self.state.last_decoded_latent.to(pred_latent_chunks.device, pred_latent_chunks.dtype) + k = warmup.shape[1] + drop = 4 * k - 3 + to_decode = torch.cat([warmup, pred_latent_chunks], dim=1) + videos = self.vae.decode([to_decode]) + videos = [v[:, drop:] for v in videos] + else: + videos = self.vae.decode([pred_latent_chunks]) + + self.state.last_decoded_latent = pred_latent_chunks[:, -2:].detach().clone() + + if dist.is_initialized(): + dist.barrier() + + if not extension: + self.state.h = h + self.state.w = w + self.state.lat_h = lat_h + self.state.lat_w = lat_w + self.state.frame_seqlen = frame_seqlen + self.state.advance(new_lat_f) + + return DiffusionOutput(output=videos[0]) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py new file mode 100644 index 00000000000..0a13a8ad99f --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Lingbot World Fast pipeline persistent state.""" + +from __future__ import annotations + +import logging +from enum import IntEnum + +import torch + +logger = logging.getLogger(__name__) + + +class CacheIndex(IntEnum): + K = 0 + V = 1 + + +class LingbotWorldFastState: + """Pipeline persistent state across forward() calls. + + Lifecycle: + - Created once in LingbotWorldFastPipeline.__init__() + - Mutated every forward() call (frame append, KV cache grow) + - reset() on new session / local_attn_size exceeded + """ + + def __init__(self) -> None: + self.is_initialized = False + self.reset() + + # ------------------------------------------------------------------ + # Reset / should_reset + # ------------------------------------------------------------------ + + def reset(self) -> None: + """Clear all state.""" + + if self.is_initialized: + for cache in self.kv_cache: + del cache + for cache in self.crossattn_cache: + if isinstance(cache["k"], torch.Tensor): + del cache["k"] + del cache["v"] + + self.kv_cache: list[torch.Tensor] | None = None + self.crossattn_cache: list[dict[str, bool | torch.Tensor | None]] | None = None + self.current_start_frame: int = 0 + self.local_end_index: list[torch.Tensor] | None = None + self.global_end_index: list[torch.Tensor] | None = None + + self.is_initialized: bool = False + self.current_lat_f: int = 0 + self.session_id: str | None = None + + self.batch_size: int | None = None + self.num_layers: int | None = None + self.num_heads: int | None = None + self.head_dim: int | None = None + + # Shape constants captured on the first call of a session and reused + # on extension calls, where multi_modal_data["image"] is absent. + self.h: int | None = None + self.w: int | None = None + self.lat_h: int | None = None + self.lat_w: int | None = None + self.frame_seqlen: int | None = None + + # Last few latents emitted by the diffusion loop on the previous call. + # Prepended to pred_latent_chunks on extension so the Wan VAE decoder's + # stacked temporal feat_maps are fully warmed before the first NEW + # latent is decoded. The decoder's temporal receptive field spans + # ~2 latents, so we cache the last 2. + self.last_decoded_latent: torch.Tensor | None = None + + # ------------------------------------------------------------------ + # KV cache management + # ------------------------------------------------------------------ + + def create_kv_caches( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + kv_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + ) -> None: + self.batch_size = batch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + + """Initialize empty KV caches and cross-attention caches.""" + self.kv_cache = [ + torch.zeros(2, batch_size, kv_size, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_layers) + ] + + self.local_end_index = [torch.tensor([0], dtype=torch.long, device=device) for _ in range(num_layers)] + self.global_end_index = [torch.tensor([0], dtype=torch.long, device=device) for _ in range(num_layers)] + + self.crossattn_cache = [{"is_init": False, "k": None, "v": None} for _ in range(num_layers)] + + self.is_initialized = True + + def extend_kv_caches(self, extra_kv_size: int): + assert self.is_initialized, "Cannot extend uninitialized kv cache" + + dtype = self.kv_cache[0].dtype + device = self.kv_cache[0].device + + self.kv_cache = [ + torch.cat( + [ + self.kv_cache[i], + torch.zeros( + 2, self.batch_size, extra_kv_size, self.num_heads, self.head_dim, dtype=dtype, device=device + ), + ], + dim=2, + ) + for i in range(self.num_layers) + ] + + def update_kv_cache(self, layer_index: int, updated_kv: torch.Tensor) -> None: + """Update a single layer's KV cache after prefill.""" + assert self.kv_cache is not None, "KV caches not initialized, call create_kv_caches first" + self.kv_cache[layer_index] = updated_kv.clone() + + def get_kv_cache(self) -> list[torch.Tensor]: + """Get KV caches for the specified branch.""" + assert self.kv_cache is not None, "KV caches not initialized" + return self.kv_cache + + def get_crossattn_caches(self) -> list[dict[str, bool | torch.Tensor | None]]: + """Get cross-attention caches for the specified branch.""" + assert self.crossattn_cache is not None, "Cross-attn caches not initialized" + return self.crossattn_cache + + def advance(self, delta: int): + self.current_lat_f += delta diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/t5.py b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py new file mode 100644 index 00000000000..0863d121b4b --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py @@ -0,0 +1,451 @@ +# Adapted from Lingbot-World/wan/modules/t5.py +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + def __init__(self, dim, dim_ffn, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) + + def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super().__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = ( + max_exact + + ( + torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) + ).long() + ) + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + def __init__( + self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout + ) + self.decoder = T5Decoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout + ) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5( + name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device="cpu", + **kwargs, +): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.accelerator.current_device_index(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = ( + umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False) + ) + logging.info(f"loading {checkpoint_path}") + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def __call__(self, texts, device): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py b/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py new file mode 100644 index 00000000000..e939b191c12 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py @@ -0,0 +1,78 @@ +# Adapted from Lingbot-World/wan/modules/tokenizers.py +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ["HuggingfaceTokenizer"] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + + # arguments + _kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py b/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py new file mode 100644 index 00000000000..6017abe3476 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py @@ -0,0 +1,610 @@ +# Adapted from Lingbot-World/wan/modules/vae2_1.py +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = ["Wan2_1_VAE"] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolution. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +class Encoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1) + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temporal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temporal_upsample = temporal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = "upsample3d" if temporal_upsample[i] else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + logging.info(f"loading {pretrained_path}") + model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class Wan2_1_VAE: + def __init__(self, z_dim=16, vae_pth="cache/vae_step_411000.pth", dtype=torch.float, device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py new file mode 100644 index 00000000000..92b9e1193f5 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py @@ -0,0 +1,652 @@ +"""Some of the functions are borrowed from SelfForcing (https://github.com/guandeh17/Self-Forcing).""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as torch_F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange + +from vllm_omni.diffusion.attention.layer import Attention + +from .state_lingbot_world_fast import CacheIndex +from .wan_model import WanLayerNorm, WanRMSNorm, WanSelfAttention, rope_params, sinusoidal_embedding_1d + + +def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).type_as(x) + + +class CausalWanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + kv_cache=None, + local_end_index=None, + global_end_index=None, + current_start=0, + max_attention_size=1_000_000, + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + block_mask (BlockMask) + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + frame_seqlen = math.prod(grid_sizes[0][1:]).item() + current_start_frame = current_start // frame_seqlen + roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) + roped_key = causal_rope_apply(k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) + current_end = current_start + roped_query.shape[1] + sink_tokens = self.sink_size * frame_seqlen + # If we are using local attention and the current KV cache size is larger than the local attention size, + # then we need to truncate the KV cache + kv_cache_size = kv_cache[CacheIndex.K].shape[1] + num_new_tokens = roped_query.shape[1] + if ( + self.local_attn_size != -1 + and (current_end > global_end_index.item()) + and (num_new_tokens + local_end_index.item() > kv_cache_size) + ): + # Calculate the number of new tokens added in this step + # Shift existing cache content left to discard oldest tokens + # Clone the source slice to avoid overlapping memory error + num_evicted_tokens = num_new_tokens + local_end_index.item() - kv_cache_size + num_rolled_tokens = local_end_index.item() - num_evicted_tokens - sink_tokens + kv_cache[CacheIndex.K][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.K][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + kv_cache[CacheIndex.V][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.V][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + # Insert the new keys/values at the end + new_local_end_index = local_end_index.item() + current_end - global_end_index.item() - num_evicted_tokens + local_start_index = new_local_end_index - num_new_tokens + kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key + kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v + else: + # Assign new keys/values directly up to current_end + new_local_end_index = local_end_index.item() + current_end - global_end_index.item() + local_start_index = new_local_end_index - num_new_tokens + kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key + kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v + + k_cache = kv_cache[CacheIndex.K][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + v_cache = kv_cache[CacheIndex.V][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + x = self.attn(roped_query, k_cache, v_cache) + + global_end_index.fill_(current_end) + local_end_index.fill_(new_local_end_index) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanCrossAttention(WanSelfAttention): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward(self, x, context, context_lens, crossattn_cache=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + + if crossattn_cache is not None: + if not crossattn_cache.get("is_init", False): + crossattn_cache["is_init"] = True + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + crossattn_cache[CacheIndex.K] = k + crossattn_cache[CacheIndex.V] = v + else: + k = crossattn_cache[CacheIndex.K] + v = crossattn_cache[CacheIndex.V] + else: + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class CausalWanAttentionBlock(nn.Module): + def __init__( + self, dim, ffn_dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, cross_attn_norm=False, eps=1e-6 + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.local_attn_size = local_attn_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = CausalWanSelfAttention( + dim=dim, num_heads=num_heads, local_attn_size=local_attn_size, sink_size=sink_size, qk_norm=qk_norm, eps=eps + ) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.cam_injector_layer1 = nn.Linear(dim, dim) + self.cam_injector_layer2 = nn.Linear(dim, dim) + self.cam_scale_layer = nn.Linear(dim, dim) + self.cam_shift_layer = nn.Linear(dim, dim) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + dit_cond_dict=None, + kv_cache=None, + local_end_index=None, + global_end_index=None, + crossattn_cache=None, + current_start=0, + max_attention_size=1_000_000, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, F, 6, C] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + assert e[0].dtype == torch.float32 + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), + seq_lens, + grid_sizes, + freqs, + kv_cache, + local_end_index, + global_end_index, + current_start, + max_attention_size, + ) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].squeeze(2) + + # cam injection (only if dit_cond_dict is provided and contains c2ws_plucker_emb) + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_hidden_states = self.cam_injector_layer2(torch_F.silu(self.cam_injector_layer1(c2ws_plucker_emb))) + c2ws_hidden_states = c2ws_hidden_states + c2ws_plucker_emb + cam_scale = self.cam_scale_layer(c2ws_hidden_states) + cam_shift = self.cam_shift_layer(c2ws_hidden_states) + x = (1.0 + cam_scale) * x + cam_shift + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None): + x = x + self.cross_attn(self.norm3(x), context, context_lens, crossattn_cache=crossattn_cache) + y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].squeeze(2) + return x + + x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache) + return x + + +class CausalHead(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)) + return x + + +class WanModelFast(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim"] + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + model_type="", + control_type="cam", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + local_attn_size=-1, + sink_size=0, + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to ''): + control_type (`str`, *optional*, defaults to 'cam'): + Type of conditioning control signal - 'cam' (6-dim camera Plucker + embeddings) or 'act' (7-dim action embeddings including WASD movement) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + local_attn_size (`int`, *optional*, defaults to -1): + Window size for temporal local attention (-1 indicates global attention) + sink_size (`int`, *optional*, defaults to 0): + Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + self.model_type = model_type + + self.task_type = "i2v" + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.local_attn_size = local_attn_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + if control_type == "cam": + control_dim = 6 + elif control_type == "act": + control_dim = 7 + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.patch_embedding_wancamctrl = nn.Linear( + control_dim * 64 * patch_size[0] * patch_size[1] * patch_size[2], dim + ) + self.c2ws_hidden_states_layer1 = nn.Linear(dim, dim) + self.c2ws_hidden_states_layer2 = nn.Linear(dim, dim) + + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList( + [ + CausalWanAttentionBlock( + dim, ffn_dim, num_heads, local_attn_size, sink_size, qk_norm, cross_attn_norm, eps + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = CausalHead(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) + + # initialize weights + self.init_weights() + + def forward( + self, + x, + t, + context, + seq_len, + y=None, + dit_cond_dict=None, + kv_cache=None, + local_end_index=None, + global_end_index=None, + crossattn_cache=None, + current_start=0, + max_attention_size=1_000_000, + ): + r""" + Run the diffusion model with kv caching. + See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. + This function will be run for num_frame times. + Process the latent frames one by one (1560 tokens each) + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + dit_cond_dict (`dict`, *optional*, defaults to None): + Dictionary of conditioning signals. May contain key ``c2ws_plucker_emb`` + with camera Plucker embeddings of shape [B, C, F, H, W] for camera control. + kv_cache (`list[dict]`, *optional*, defaults to None): + Per-layer self-attention KV cache. Each dict contains keys ``k``, ``v`` + (Tensor of shape [B, kv_size, num_heads, head_dim]), ``global_end_index``, + and ``local_end_index`` (scalar Tensors tracking cache position). + crossattn_cache (`list[dict]`, *optional*, defaults to None): + Per-layer cross-attention KV cache. Each dict contains keys ``k``, ``v`` + (Tensor of shape [B, text_len, num_heads, head_dim]) and ``is_init`` (bool). + current_start (`int`, *optional*, defaults to 0): + Token offset of the current chunk in the full sequence. Used to index + into the KV cache and compute positional embeddings correctly. + max_attention_size (`int`, *optional*, defaults to 1_000_000): + Maximum number of KV tokens each query can attend to. Limits the + effective context window of self-attention to control memory usage. + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + + if self.task_type == "i2v": + assert y is not None + + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat(x) + + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_lens) + with torch.amp.autocast("cuda", dtype=torch.float32): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_lens)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) + ) + + # cam + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_plucker_emb = [ + rearrange( + i, + "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", + c1=self.patch_size[0], + c2=self.patch_size[1], + c3=self.patch_size[2], + ) + for i in c2ws_plucker_emb + ] + c2ws_plucker_emb = torch.cat(c2ws_plucker_emb, dim=1) # [1, (L1+...+Ln), C] + + c2ws_plucker_emb = self.patch_embedding_wancamctrl(c2ws_plucker_emb) + c2ws_hidden_states = self.c2ws_hidden_states_layer2( + torch_F.silu(self.c2ws_hidden_states_layer1(c2ws_plucker_emb)) + ) + dit_cond_dict = dict(dit_cond_dict) + dit_cond_dict["c2ws_plucker_emb"] = c2ws_plucker_emb + c2ws_hidden_states + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dit_cond_dict=dit_cond_dict, + max_attention_size=max_attention_size, + ) + + for block_index, block in enumerate(self.blocks): + kwargs.update( + { + "kv_cache": kv_cache[block_index], + "crossattn_cache": crossattn_cache[block_index], + "local_end_index": local_end_index[block_index], + "global_end_index": global_end_index[block_index], + "current_start": current_start, + } + ) + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py new file mode 100644 index 00000000000..4a21edd4585 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py @@ -0,0 +1,95 @@ +# Adapted from Lingbot-World/wan/modules/model.py +# +# Only the building blocks used by wan_fast.py are kept here: norm layers, +# the self-attention __init__ shape (used as a base class for the local +# WanCrossAttention that overrides forward), and the rope / time-embedding +# helpers. The original `flash_attention`-based forward paths are not used by +# the Fast pipeline, so this file does not depend on wan.modules.attention. + +import torch +import torch.nn as nn + +__all__ = [ + "WanLayerNorm", + "WanRMSNorm", + "WanSelfAttention", + "rope_params", + "sinusoidal_embedding_1d", +] + + +def sinusoidal_embedding_1d(dim, position): + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + """Base class providing the Q/K/V/O linear layers and (optional) QK RMSNorm. + + `wan_fast.py` only consumes this class via inheritance and `super().__init__()` + — the attention forward path is always overridden — so this trimmed copy + intentionally omits the flash-attention-based forward. + """ + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index d8302c11501..063a2105cd4 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -261,6 +261,11 @@ "pipeline_diffusers_adapter", "DiffusersAdapterPipeline", ), + "LingbotWorldFastPipeline": ( + "lingbot_world_fast", + "pipeline_lingbot_world_fast", + "LingbotWorldFastPipeline", + ), "HiDreamImagePipeline": ( "hidream_image", "pipeline_hidream_image", @@ -487,6 +492,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "OmniVoicePipeline": "get_omnivoice_post_process_func", "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func", "SenseNovaU1Pipeline": "get_sensenova_u1_post_process_func", + "LingbotWorldFastPipeline": "get_lingbot_world_fast_post_process_func", "HiDreamImagePipeline": "get_hidream_image_post_process_func", } diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index cc6e9a4dabb..a91bf5345eb 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -651,6 +651,17 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Enable AR stage profiler to include AR stage timing in stage_durations.", ) + omni_config_group.add_argument( + "--ws-max-size", + type=int, + default=1_048_576, # 1MB + help="Change max size of a websocket payload that is accepted by the server", + ) + omni_config_group.add_argument( + "--ws", + default="auto", + help="Set the websocket Protocol type", + ) # Supplementary auxiliary text encoder parameters # (e.g., the meta llama/meta llama-3.1-8b-instrument used by hidream) omni_config_group.add_argument( diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 69a6dd603ed..3520e5f301c 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -115,6 +115,7 @@ VideoListResponse, VideoResponse, ) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat @@ -377,6 +378,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, if log_config is not None: uvicorn_kwargs["log_config"] = log_config + if args.ws_max_size is not None: + uvicorn_kwargs["ws_max_size"] = args.ws_max_size + if args.ws is not None: + uvicorn_kwargs["ws"] = args.ws + async with build_async_omni( args, client_config=client_config, @@ -631,6 +637,10 @@ async def omni_init_app_state( state.openai_streaming_speech = None state.openai_streaming_video = None + state.openai_serving_world_camera = ServingRealtimeWorldCamera.create_policy_server( + engine_client=engine_client, model_name=model_name + ) + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 logger.info("Pure diffusion API server initialized for model: %s", model_name) @@ -948,6 +958,10 @@ async def omni_init_app_state( stage_configs=state.stage_configs, ) + state.openai_serving_world_camera = ServingRealtimeWorldCamera.create_policy_server( + engine_client=engine_client, model_name=model_name + ) + state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -1406,6 +1420,21 @@ async def realtime_websocket(websocket: WebSocket): await connection.handle_connection() +@router.websocket("/v1/realtime/world/camera") +async def realtime_world_camera_openpi(websocket: WebSocket): + from vllm_omni.entrypoints.openai.realtime.world.camera_connection import WorldCameraRealtimeConnection + + serving = getattr(websocket.app.state, "openai_serving_world_camera", None) + + if serving is None: + await websocket.accept() + await websocket.send_json({"type": "error", "error": "World Model policy not available", "code": "unsupported"}) + await websocket.close() + return + connection = WorldCameraRealtimeConnection(websocket, serving) + await connection.handle_connection() + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/realtime/__init__.py b/vllm_omni/entrypoints/openai/realtime/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/entrypoints/openai/realtime/world/__init__.py b/vllm_omni/entrypoints/openai/realtime/world/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py new file mode 100644 index 00000000000..8ff864b3c5f --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""WebSocket connection for robot policy inference (OpenPI protocol). + +Protocol (compatible with DreamZero test_client_AR.py): + Connect -> server sends msgpack(PolicyServerConfig fields) + Infer -> client sends msgpack(req), server sends msgpack(ndarray) + Reset -> client sends msgpack({endpoint:reset}), server sends "reset successful" +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import torch +from fastapi import WebSocket +from starlette.websockets import WebSocketDisconnect +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera +from vllm_omni.entrypoints.openai.video_api_utils import _normalize_frames + +logger = init_logger(__name__) +_DEFAULT_IDLE_TIMEOUT = 300.0 +DEFAULT_FRAMES_PER_CHUNK = 4 + + +def _get_msgpack_numpy() -> Any: + try: + from openpi_client import msgpack_numpy + except ImportError as exc: + raise ImportError( + "The `/v1/realtime/world/camera` endpoint requires the optional " + "`openpi-client` dependency. Install it with `pip install openpi-client`." + ) from exc + + return msgpack_numpy + + +def _pack(obj: Any) -> bytes: + return _get_msgpack_numpy().packb(obj) + + +def _unpack(data: bytes) -> Any: + return _get_msgpack_numpy().unpackb(data) + + +class WorldCameraRealtimeConnection: + """WebSocket connection for world model inference.""" + + def __init__( + self, + websocket: WebSocket, + serving: ServingRealtimeWorldCamera, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + ) -> None: + self.websocket = websocket + self.serving = serving + self._idle_timeout = idle_timeout + + async def _send_error(self, message: str) -> None: + await self.websocket.send_bytes(_pack({"type": "error", "message": message})) + + def _unpack_request(self, data: bytes) -> dict[str, Any]: + req = _unpack(data) + if not isinstance(req, dict): + raise ValueError("Invalid request payload") + return req + + async def handle_connection(self) -> None: + """Main loop.""" + await self.websocket.accept() + + try: + # Send model-specific PolicyServerConfig resolved by serving from + # diffusion od_config.model_config. + metadata = self.serving.policy_server_config.to_dict() + await self.websocket.send_bytes(_pack(metadata)) + + while True: + try: + msg = await asyncio.wait_for( + self.websocket.receive(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + logger.info("World Model OpenPI connection idle timeout after %.1f seconds", self._idle_timeout) + try: + await self.websocket.close() + except Exception: + logger.debug("Failed to close idle World Model websocket", exc_info=True) + return + + if msg.get("type") == "websocket.disconnect": + break + + if "bytes" not in msg or not msg["bytes"]: + continue + + try: + req = self._unpack_request(msg["bytes"]) + except Exception: + logger.exception("Invalid world model OpenPI request payload") + try: + await self._send_error("Invalid request payload") + except Exception: + break + continue + + try: + endpoint = req.pop("endpoint", "infer") + + if endpoint == "reset": + self.serving.reset(req) + await self.websocket.send_text("reset successful") + else: + result = await self.serving.infer(req) + + extra_body: dict = req.get("extra_body", {}) + + frames_per_chunk = extra_body.get("frames_per_chunk", DEFAULT_FRAMES_PER_CHUNK) + + if ( + len(result.images) == 1 + and isinstance(result.images[0], tuple) + and len(result.images[0]) == 1 + ): + frames = result.images[0] + elif len(result.images) == 1 and isinstance(result.images[0], dict): + frames = result.images[0].get("frames") or result.images[0].get("video") + else: + frames = result.images + + if len(frames) == 1: + frames = frames[0] + + if isinstance(frames, torch.Tensor): + frames = frames.numpy(force=True) + + frames = _normalize_frames(frames) + + total = (len(frames) + frames_per_chunk - 1) // frames_per_chunk + for i in range(total): + chunk = frames[i * frames_per_chunk : (i + 1) * frames_per_chunk] + await self.websocket.send_bytes( + _pack( + { + "type": "frame", + "index": i, + "total": total, + "video": chunk, + } + ) + ) + + except Exception: + logger.exception("Error handling request") + try: + await self._send_error("Internal inference error") + except Exception: + break + + except WebSocketDisconnect: + pass + except Exception: + logger.exception("Connection error") diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py new file mode 100644 index 00000000000..4ac8a6c4a1f --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Protocol structs for the /v1/realtime/world/* family of endpoints. + +These are msgpack-serialised over the WebSocket wire via ``msgspec.msgpack``. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import numpy as np +from omegaconf import OmegaConf +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _to_builtin_container(value: Any) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, Mapping): + return {key: _to_builtin_container(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_builtin_container(item) for item in value] + return value + + +@dataclass(frozen=True) +class CameraServerConfig: + """Static server-side camera/pipeline parameters sent to a client on connect.""" + + values: dict[str, Any] + + @classmethod + def from_model_config(cls, model_config: Any) -> CameraServerConfig: + return cls(_to_builtin_container(model_config)) + + def to_dict(self) -> dict[str, Any]: + return _to_builtin_container(self.values) + + +class ServingRealtimeWorldCamera: + """World Model Camera serving layer for OpenPI protocol. + + Model-specific transform/state lives in the diffusion pipeline. + """ + + def __init__( + self, + engine_client: Any, + model_name: str | None = None, + ) -> None: + self.engine_client = engine_client + self.model_name = model_name + self._current_session_id: str | None = None + self._call_count = 0 + self.policy_server_config = self._get_policy_server_config(engine_client) + self._force_reset = False + + @classmethod + def create_policy_server( + cls, + engine_client: Any, + model_name: str | None = None, + ) -> ServingRealtimeWorldCamera | None: + try: + return cls(engine_client=engine_client, model_name=model_name) + except ValueError as exc: + if "policy_server_config" not in str(exc): + raise + logger.info("World Model OpenPI serving disabled for model %s", model_name) + return None + + @staticmethod + def _get_policy_server_config(engine_client: Any) -> CameraServerConfig: + model_config = None + get_od_config = getattr(engine_client, "get_diffusion_od_config", None) + if callable(get_od_config): + od_config = get_od_config() + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + for stage_config in getattr(engine_client, "stage_configs", []) or []: + if getattr(stage_config, "stage_type", None) != "diffusion": + continue + engine_args = getattr(stage_config, "engine_args", None) + model_config = getattr(engine_args, "model_config", None) + if model_config is not None: + break + + if model_config is None: + od_config = getattr(engine_client, "od_config", None) + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + model_config = getattr(engine_client, "model_config", None) + + return CameraServerConfig.from_model_config(model_config) + + def reset(self, req: dict) -> None: + """Reset serving state. + + Engine-side Lingbot state is reset on the next inference request via + `extra_args["reset"]`, not by an immediate websocket-side RPC. + """ + self._current_session_id = None + self._force_reset = True + + async def infer(self, req: dict) -> np.ndarray: + """raw req → engine → video.""" + # Session tracking + + session_id = req.get("session_id") + if session_id is not None and session_id != self._current_session_id: + if self._current_session_id is not None: + logger.info("Session changed %s → %s", self._current_session_id, session_id) + self.reset({}) + self._current_session_id = session_id + + self._call_count += 1 + + # Build request, run inference through AsyncOmni + request = self._build_request(req) + + # After an inference call we reset the _force_reset argument + self._force_reset = False + + result = None + # OpenPI policy serving is one request -> one action reply. AsyncOmni + # exposes an async iterator, so consume it to completion and use the + # final output, matching other non-streaming OpenAI serving paths. + async for output in self.engine_client.generate( + prompt=request.prompts[0], + request_id=request.request_ids[0], + sampling_params_list=[request.sampling_params], + ): + result = output + if result is None: + raise RuntimeError("World Model Camera OpenPI request produced no output.") + + return result + + def _build_request(self, req: dict) -> Any: + """Build engine request from raw robot req. + + Returns an `OmniDiffusionRequest` payload consumed by + `AsyncOmni.generate()` and routed to the diffusion stage. + """ + from vllm_omni.diffusion.request import OmniDiffusionRequest + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + extra_args = {"session_id": self._current_session_id or "default", "force_reset": self._force_reset} + + camera = req.get("camera", None) + + multi_modal_data = { + "image": req.get("image", None), + "camera": camera, + } + + prompt = req.get("prompt", "") + + extra_body = req.get("extra_body", {}) + + height = extra_body.get("height", None) + width = extra_body.get("width", None) + num_frames = extra_body.get("num_frames", None) + fps = extra_body.get("fps", None) + seed = extra_body.get("seed", None) + + sampling_params = OmniDiffusionSamplingParams( + height=height, width=width, num_frames=num_frames, frame_rate=fps, extra_args=extra_args, seed=seed + ) + return OmniDiffusionRequest( + prompts=[ + { + "prompt": prompt, + "multi_modal_data": multi_modal_data, + } + ], + sampling_params=sampling_params, + request_ids=[f"camera-{self._current_session_id or 'default'}"], + )