From 34c59f8c86562f553a070b7ad968e49e9cc6502e Mon Sep 17 00:00:00 2001 From: Miguel Vieira Pereira Date: Wed, 6 May 2026 15:07:51 +0000 Subject: [PATCH 01/53] Adapt the Lingbot World Fast Pipeline to work in vllm-omni Signed-off-by: Miguel Vieira Pereira Implement Lingbot World Transformer into vllm-omni Signed-off-by: Miguel Vieira Pereira Implement KV cache abstraction for Lingbot World Fast Add script to offline generation using Lingbot World Fast Implement online serving for Lingbot World and camera-based world models in general --- .../download_lingbot_world_fast.py | 106 +++ .../lingbot_world_fast/end2end.py | 294 ++++++++ .../lingbot_world_fast/openai_client.py | 152 ++++ .../lingbot_world_fast/run_server.sh | 17 + vllm_omni/diffusion/diffusion_engine.py | 28 + vllm_omni/diffusion/models/interface.py | 5 + .../models/lingbot_world_fast/__init__.py | 4 + .../pipeline_lingbot_world_fast.py | 374 ++++++++++ .../state_lingbot_world_fast.py | 107 +++ .../models/lingbot_world_fast/wan_fast.py | 653 ++++++++++++++++++ vllm_omni/diffusion/registry.py | 6 + vllm_omni/entrypoints/cli/serve.py | 13 + vllm_omni/entrypoints/openai/api_server.py | 27 + .../realtime/world/camera_connection.py | 165 +++++ .../openai/realtime/world/camera_serving.py | 186 +++++ 15 files changed, 2137 insertions(+) create mode 100644 examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py create mode 100644 examples/offline_inference/lingbot_world_fast/end2end.py create mode 100644 examples/online_serving/lingbot_world_fast/openai_client.py create mode 100755 examples/online_serving/lingbot_world_fast/run_server.sh create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/__init__.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py create mode 100644 vllm_omni/entrypoints/openai/realtime/world/camera_connection.py create mode 100644 vllm_omni/entrypoints/openai/realtime/world/camera_serving.py diff --git a/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py new file mode 100644 index 00000000000..2cd6b6f0538 --- /dev/null +++ b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py @@ -0,0 +1,106 @@ +import argparse +import fcntl +import json +import os +import site +import subprocess +import tempfile +import time +from pathlib import Path + +from huggingface_hub import snapshot_download + +DEPENDENCY_REPO = "https://github.com/Robbyant/lingbot-world" +DEPENDENCY_BRANCH = "main" +CACHE_DIR = Path(tempfile.gettempdir()) / "vllm-omni-dependency" +LOCK_FILE = CACHE_DIR / ".install.lock" +DEPENDENCY_DIR = CACHE_DIR / "Lingbot-World" + + +def download_dependency(): + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + with open(LOCK_FILE, "w") as f: + fcntl.flock(f, fcntl.LOCK_EX) + if not DEPENDENCY_DIR.exists(): + print(f"Downloading Lingbot World Fast to {DEPENDENCY_DIR} ...") + subprocess.run( + ["git", "clone", "--depth", "1", DEPENDENCY_REPO, "--branch", DEPENDENCY_BRANCH, str(DEPENDENCY_DIR)], + check=True, + ) + print("Download finished.") + fcntl.flock(f, fcntl.LOCK_UN) + + # write .pth to site-packages + site_packages = Path(site.getsitepackages()[0]) + pth_file = site_packages / "vllm_omni_dependency.pth" + pth_file.write_text(str(DEPENDENCY_DIR)) + print(f"Added {DEPENDENCY_DIR} to site-packages via {pth_file}") + + +def timed_download(repo_id: str, local_dir: str, allow_patterns: list | None = None): + """Download files from HF repo and log time + destination.""" + if os.path.exists(local_dir): + print(f"Directory {local_dir} already exists. Skipping download.") + return + print(f"Starting download from {repo_id} into {local_dir}") + start_time = time.time() + + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + local_dir_use_symlinks=False, + allow_patterns=allow_patterns, + ) + + elapsed = time.time() - start_time + print(f"✅ Finished downloading {repo_id} in {elapsed:.2f} seconds. Files saved at: {local_dir}") + + +def main(output_dir: str): + lingbot_base_dir = os.path.join(output_dir, "lingbot-world-base-cam") + + # Base Model + timed_download( + repo_id="robbyant/lingbot-world-base-cam", + local_dir=lingbot_base_dir, + allow_patterns=["google/*", "models_t5_umt5-xxl-enc-bf16.pth", "Wan2.1_VAE.pth"], + ) + + lingbot_fast_dir = os.path.join(lingbot_base_dir, "Lingbot-World-Fast") + + timed_download(repo_id="robbyant/lingbot-world-fast", local_dir=lingbot_fast_dir) + + # Lingbot World does not come with config.json which is required by diffusers + config = { + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + } + + with open( + os.path.join(output_dir, "lingbot-world-base-cam", "Lingbot-World-Fast", "config.json"), "w", encoding="utf-8" + ) as f: + json.dump(config, f, indent=2) + + print(f"model_index.json created at {os.path.join(output_dir, 'model_index.json')}") + + download_dependency() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download models from Hugging Face") + parser.add_argument( + "--output-dir", type=str, default="./lingbot_world", help="Base directory to save downloaded models" + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/examples/offline_inference/lingbot_world_fast/end2end.py b/examples/offline_inference/lingbot_world_fast/end2end.py new file mode 100644 index 00000000000..a5c44b4d58c --- /dev/null +++ b/examples/offline_inference/lingbot_world_fast/end2end.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Image-Camera to Video generation example using Lingbot World Fast + +Usage example: + python end2end.py --model path/to/lingbot-fast --image path/to/image --camera-path path/to/camera + --width 832 --height 480 --prompt "Walk in the Great Wall of China" --output output.mp4 +""" + +import argparse +import os +import time +from pathlib import Path + +import numpy as np +import PIL.Image +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate a video from an image (Wan2.2, LTX2, HunyuanVideo-1.5).") + parser.add_argument( + "--model", + default="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", + help="Diffusers I2V model ID or local path (Wan2.2 or HunyuanVideo-1.5).", + ) + parser.add_argument("--image", required=True, help="Path to input image.") + parser.add_argument("--camera-path", default=None, help="Path to input camera positions") + parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.") + parser.add_argument("--negative-prompt", default="", help="Negative prompt.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument( + "--height", type=int, default=None, help="Video height (auto-calculated from image if not set)." + ) + parser.add_argument("--width", type=int, default=None, help="Video width (auto-calculated from image if not set).") + parser.add_argument("--num-frames", type=int, default=81, help="Number of frames.") + parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") + parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") + parser.add_argument( + "--enable-diffusion-pipeline-profiler", + action="store_true", + help="Enable diffusion pipeline profiler to display stage durations.", + ) + return parser.parse_args() + + +def calculate_dimensions( + image: PIL.Image.Image, + max_area: int = 480 * 832, + mod_value: int = 16, +) -> tuple[int, int]: + """Calculate output dimensions maintaining aspect ratio.""" + aspect_ratio = image.height / image.width + + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + return height, width + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + model_class_name = "LingbotWorldFastPipeline" + + # Load input image + image = PIL.Image.open(args.image).convert("RGB") + + num_inference_steps = 40 + + # Calculate dimensions if not provided + height = args.height + width = args.width + if height is None or width is None: + max_area = 480 * 832 + mod_value = 16 + calc_height, calc_width = calculate_dimensions(image, max_area=max_area, mod_value=mod_value) + height = height or calc_height + width = width or calc_width + + # Resize image to target dimensions + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + + omni = Omni( + model=args.model, + parallel_config=None, + model_class_name=model_class_name, + stage_init_timeout=6000, + init_timeout=6000, + ) + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {num_inference_steps}") + print(f" Frames: {args.num_frames}") + print(f" Video size: {args.width}x{args.height}") + print(f"{'=' * 60}\n") + + generation_start = time.perf_counter() + # omni.generate() returns Generator[OmniRequestOutput, None, None] + + multi_modal_data = {"image": image} + + if args.camera_path is not None: + poses = np.load(os.path.join(args.camera_path, "poses.npy")) + intrinsics = np.load(os.path.join(args.camera_path, "intrinsics.npy")) + + multi_modal_data["camera"] = {"poses": poses, "intrinsics": intrinsics} + + frames = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": multi_modal_data, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + num_frames=args.num_frames, + frame_rate=args.fps, + ), + ) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + if isinstance(frames, list): + frames = frames[0] if frames else None + + if isinstance(frames, OmniRequestOutput): + if frames.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation." + ) + if frames.is_pipeline_output and frames.request_output is not None: + inner_output = frames.request_output + if isinstance(inner_output, OmniRequestOutput): + frames = inner_output + if isinstance(frames, OmniRequestOutput): + if frames.images: + if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: + frames = frames.images[0] + elif len(frames.images) == 1 and isinstance(frames.images[0], dict): + frames = frames.images[0].get("frames") or frames.images[0].get("video") + else: + frames = frames.images + else: + raise ValueError("No video frames found in OmniRequestOutput.") + + if isinstance(frames, list) and frames: + first_item = frames[0] + if isinstance(first_item, tuple) and len(first_item) == 2: + frames = first_item + elif isinstance(first_item, dict): + frames = first_item.get("frames") or first_item.get("video") + elif isinstance(first_item, list): + frames = first_item + + if isinstance(frames, tuple) and len(frames) == 2: + frames = frames + elif isinstance(frames, dict): + frames = frames.get("frames") or frames.get("video") + + if frames is None: + raise ValueError("No video frames found in output.") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + try: + from diffusers.utils import export_to_video + except ImportError: + raise ImportError("diffusers is required for export_to_video.") + + def _normalize_frame(frame): + if isinstance(frame, torch.Tensor): + frame_tensor = frame.detach().cpu() + if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor[0] + if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): + frame_tensor = frame_tensor.permute(1, 2, 0) + if frame_tensor.is_floating_point(): + frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 + return frame_tensor.float().numpy() + if isinstance(frame, np.ndarray): + frame_array = frame + if frame_array.ndim == 4 and frame_array.shape[0] == 1: + frame_array = frame_array[0] + if np.issubdtype(frame_array.dtype, np.integer): + frame_array = frame_array.astype(np.float32) / 255.0 + return frame_array + try: + from PIL import Image + except ImportError: + Image = None + if Image is not None and isinstance(frame, Image.Image): + return np.asarray(frame).astype(np.float32) / 255.0 + return frame + + def _ensure_frame_list(video_array): + if isinstance(video_array, list): + if len(video_array) == 0: + return video_array + first_item = video_array[0] + if isinstance(first_item, np.ndarray): + if first_item.ndim == 5: + return list(first_item[0]) + if first_item.ndim == 4: + if len(video_array) == 1: + return list(first_item) + return list(first_item) + if first_item.ndim == 3: + return video_array + return video_array + if isinstance(video_array, np.ndarray): + if video_array.ndim == 5: + return list(video_array[0]) + if video_array.ndim == 4: + return list(video_array) + if video_array.ndim == 3: + return [video_array] + return video_array + + # frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images + # export_to_video expects a list of frames with values in [0, 1] + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + if video_tensor.dim() == 5: + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + elif isinstance(frames, np.ndarray): + video_array = frames + if video_array.ndim == 5: + video_array = video_array[0] + if np.issubdtype(video_array.dtype, np.integer): + video_array = video_array.astype(np.float32) / 255.0 + elif isinstance(frames, list): + if len(frames) == 0: + raise ValueError("No video frames found in output.") + video_array = [_normalize_frame(frame) for frame in frames] + else: + video_array = frames + + video_array = _ensure_frame_list(video_array) + + export_to_video(video_array, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py new file mode 100644 index 00000000000..e7cc22c0a9c --- /dev/null +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Lingbot World Fast realtime camera client. + +Talks to the WebSocket endpoint ``/v1/realtime/world/camera`` exposed by +``vllm serve --omni`` when the loaded pipeline is ``LingbotWorldFastPipeline``. + +The endpoint speaks the OpenPI policy protocol on the wire: + 1. Connect -> server sends msgpack(PolicyServerConfig) + 2. Client send msgpack(obs) + 3. Server send msgpack(ndarray) # generated frames + +The ``obs`` payload sent here contains: + - "image": numpy array, the input image + - "prompt": str, the text prompt describing the desired motion + - "camera": {"poses": ndarray, "intrinsics": ndarray} + +Usage: + python openai_chat_client.py \\ + --image path/to/image.png \\ + --camera-path path/to/camera_dir \\ + --prompt "Walk along the Great Wall of China" \\ + --output frames.npy +""" + +import argparse +from argparse import Namespace +from pathlib import Path + +import numpy as np +import PIL.Image +import websockets.sync.client as ws_sync +from diffusers.utils import export_to_video + +try: + from openpi_client import msgpack_numpy +except ImportError as exc: + raise SystemExit("This example requires `openpi-client`. Install it with `pip install openpi-client`.") from exc + + +def _pack(obj): + return msgpack_numpy.packb(obj) + + +def _unpack(data): + return msgpack_numpy.unpackb(data) + + +def _load_image(path: str) -> np.ndarray: + image = PIL.Image.open(path).convert("RGB") + return np.asarray(image) + + +def _load_camera(camera_dir: str) -> dict: + camera_path = Path(camera_dir) + poses = np.load(camera_path / "poses.npy") + intrinsics = np.load(camera_path / "intrinsics.npy") + return {"poses": poses, "intrinsics": intrinsics} + + +def generate_video(args: Namespace) -> np.ndarray: + """Send a single inference request and return the generated frames.""" + image = _load_image(args.image) + camera = _load_camera(args.camera_path) + + extra_body = {"height": args.height, "width": args.width, "num_frames": args.num_frames, "fps": args.fps} + + obs: dict = {"prompt": args.prompt, "image": image, "camera": camera, "extra_body": extra_body} + + if args.session_id is not None: + obs["session_id"] = args.session_id + + endpoint = f"{args.server.rstrip('/')}/v1/realtime/world/camera" + print(f"Connecting to {endpoint} ...") + + with ws_sync.connect(endpoint, max_size=None, ping_interval=None, ping_timeout=None) as ws: + # 1. Server sends CameraServerConfig on connect. + server_config = _unpack(ws.recv()) + + # 2. Send obs. + print( + f"Sending obs (image={image.shape}, " + f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." + ) + ws.send(_pack(obs)) + + # 3. Receive generated frames. + chunks: list[np.ndarray] = [] + total = None + while total is None or len(chunks) < total: + msg = _unpack(ws.recv()) + if isinstance(msg, dict) and msg.get("type") == "error": + raise RuntimeError(f"Server error: {msg.get('message')}") + if not isinstance(msg, dict) or msg.get("type") != "frame": + continue # ignore anything unexpected + total = msg["total"] + chunks.append(msg["video"]) + print(f" received chunk {msg['index'] + 1}/{total}") + + video = np.concatenate(chunks, axis=0) + + return video + + +def main(): + parser = argparse.ArgumentParser(description="Lingbot World Fast realtime camera client") + parser.add_argument("--image", "-i", required=True, help="Path to input image.") + parser.add_argument( + "--camera-path", + "-c", + required=True, + help="Directory containing poses.npy and intrinsics.npy.", + ) + parser.add_argument( + "--prompt", + "-p", + default="Walk along the Great Wall of China", + help="Text prompt describing the desired motion.", + ) + parser.add_argument( + "--server", + "-s", + default="ws://localhost:8091", + help="WebSocket server URL (ws:// or wss://).", + ) + parser.add_argument("--session-id", default=None, help="Optional session id.") + parser.add_argument( + "--output", + "-o", + default="lingbot-video.mp4", + help="Path to save the returned frames (npy).", + ) + parser.add_argument("--width", type=int, default=832) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--fps", type=int, default=16) + parser.add_argument("--num-frames", type=int, default=81) + args = parser.parse_args() + + frames = generate_video(args) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + print(frames.__class__) + print(frames[0].__class__) + + export_to_video(frames, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot_world_fast/run_server.sh b/examples/online_serving/lingbot_world_fast/run_server.sh new file mode 100755 index 00000000000..ef3e2caab93 --- /dev/null +++ b/examples/online_serving/lingbot_world_fast/run_server.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Bagel online serving startup script + +MODEL="${MODEL:-../../offline_inference/lingbot_world_fast/lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast}" +PORT="${PORT:-8091}" + +echo "Starting Lingbot World server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" \ + --model-class-name LingbotWorldFastPipeline \ + --stage-init-timeout 6000 \ + --init-timeout 6000 \ + --ws-max-size 268435456 \ + --ws wsproto diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index c13bd3c0c37..0745930fd97 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -71,6 +71,13 @@ def supports_multimodal_input(od_config: OmniDiffusionConfig) -> tuple[bool, boo return supports_image_input, supports_audio_input +def supports_camera_pos_input(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_camera_pos_input", False)) + + def image_color_format(model_class_name: str) -> str: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) return getattr(model_cls, "color_format", "RGB") @@ -691,6 +698,27 @@ def _dummy_run(self): dummy_audio = np.random.randn(audio_sr * 2).astype(np.float32) prompt.setdefault("multi_modal_data", {})["audio"] = dummy_audio + audio_duration_sec = 4 + audio_array = np.random.randn(audio_sr * audio_duration_sec).astype(np.float32) + dummy_audio = audio_array[audio_sr * 1 : audio_sr * 3] + else: + dummy_audio = None + + if supports_camera_pos_input(self.od_config.model_class_name): + camera_pos_len = 64 + # Shape [N x 4] + intrinsics = np.random.rand(camera_pos_len, 4) + # Shape [N x 4 x 4] + poses = np.array([np.identity(4) for _ in range(camera_pos_len)]) + + dummy_camera_pos = {"intrinsics": intrinsics, "poses": poses} + else: + dummy_camera_pos = None + + prompt: OmniTextPrompt = { + "prompt": "dummy run", + "multi_modal_data": {"image": dummy_image, "audio": dummy_audio, "camera": dummy_camera_pos}, + } req = OmniDiffusionRequest( prompts=[prompt], request_ids=["dummy_req_id"], diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index f0ded9cdb0a..21848031120 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -28,6 +28,11 @@ class SupportAudioInput(Protocol): support_audio_input: ClassVar[bool] = True +@runtime_checkable +class SupportCameraPosInput(Protocol): + support_camera_pos_input: ClassVar[bool] = True + + @runtime_checkable class SupportAudioOutput(Protocol): support_audio_output: ClassVar[bool] = True diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py b/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py new file mode 100644 index 00000000000..20513d34e60 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_lingbot_world_fast import LingbotWorldFastPipeline, get_lingbot_world_fast_post_process_func +from .wan_fast import WanModelFast + +__all__ = ["LingbotWorldFastPipeline", "get_lingbot_world_fast_post_process_func", "WanModelFast"] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py new file mode 100644 index 00000000000..4a31f6a23ed --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -0,0 +1,374 @@ +import logging +import math +import os +import random +import sys +from contextlib import contextmanager + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from einops import rearrange +from torch import nn +from tqdm import tqdm + +# Load dependencies from Lingbot World source code +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae2_1 import Wan2_1_VAE +from wan.utils.cam_utils import ( + compute_relative_poses, + get_Ks_transformed, + get_plucker_embeddings, + interpolate_camera_poses, +) +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.models.interface import SupportCameraPosInput, SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .state_lingbot_world_fast import LingbotWorldFastState +from .wan_fast import WanModelFast + +logger = logging.getLogger(__name__) + +CONFIG = { + "text_len": 512, + "num_train_timesteps": 1000, + "vae_stride": (4, 8, 8), + "patch_size": (1, 2, 2), + "timesteps_index": [0, 179, 358, 679], + "sample_shift": 10.0, + "max_area": 480 * 832, + "max_sequence_length": 512, + "chunk_size": 3, + "t5_checkpoint": "models_t5_umt5-xxl-enc-bf16.pth", + "t5_tokenizer": "google/umt5-xxl", + "vae_checkpoint": "Wan2.1_VAE.pth", + "fast_noise_checkpoint": "Lingbot-World-Fast", + "negative_prompt_sample": ( + "画面突变,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走,镜头晃动,画面闪烁,模糊,噪点,水印,签名,文字,变形," + "扭曲,液化,不合逻辑的结构,卡顿,PPT幻灯片感,过暗,欠曝,低对比度,霓虹灯光感," + "过度锐化,3D渲染感,人物,行人,游客,身体,皮肤,肢体,面部特征,汽车,电线" + ), +} + + +def get_lingbot_world_fast_post_process_func( + od_config: OmniDiffusionConfig, +): + def post_process_func( + video: torch.Tensor, + ): + outputs = video.permute(1, 2, 3, 0) + return outputs + + return post_process_func + + +class LingbotWorldFastPipeline(nn.Module, SupportImageInput, SupportCameraPosInput, CFGParallelMixin): + def __init__(self, *, od_config: OmniDiffusionConfig): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + + self.device = get_local_device() + + self.target_dtype = od_config.dtype + + self.control_type = "cam" + self.num_train_timesteps = CONFIG["num_train_timesteps"] + + self.sp_size = od_config.parallel_config.world_size + + self.state = LingbotWorldFastState() + + checkpoint_path = os.path.dirname(self.od_config.model) + assert checkpoint_path is not None, "lingbot_dir is None" + + self.text_encoder = T5EncoderModel( + text_len=CONFIG["text_len"], + dtype=self.target_dtype, + device=torch.device("cpu"), + checkpoint_path=os.path.join(checkpoint_path, CONFIG["t5_checkpoint"]), + tokenizer_path=os.path.join(checkpoint_path, CONFIG["t5_tokenizer"]), + ) + + self.vae_stride = CONFIG["vae_stride"] + self.patch_size = CONFIG["patch_size"] + self.vae = Wan2_1_VAE(vae_pth=os.path.join(checkpoint_path, CONFIG["vae_checkpoint"]), device=self.device) + + logger.info(f"Creating WanModelFast from {checkpoint_path}") + self.model = WanModelFast.from_pretrained( + checkpoint_path, + subfolder=CONFIG["fast_noise_checkpoint"], + torch_dtype=torch.bfloat16, + control_type=self.control_type, + ).to(self.device) + + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + + self.sample_neg_prompt = CONFIG["negative_prompt_sample"] + + def _configure_model(self, model): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + def _convert_flow_pred_to_x0( + self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor, scheduler + ) -> torch.Tensor: + """ + Convert flow matching's prediction to x0 prediction. + flow_pred: the prediction with shape [B, C, F, H, W] + xt: the input noisy data with shape [B, C, F, H, W] + timestep: the timestep with shape [B] + + pred = noise - x0 + x_t = (1-sigma_t) * x0 + sigma_t * noise + we have x0 = x_t - sigma_t * pred + """ + # use higher precision for calculations + original_dtype = flow_pred.dtype + flow_pred, xt, sigmas, timesteps = map( + lambda x: x.double().to(flow_pred.device), [flow_pred, xt, scheduler.sigmas, scheduler.timesteps] + ) + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + + return x0_pred.to(original_dtype) + + def forward( + self, + req: OmniDiffusionRequest, + ) -> DiffusionOutput: + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + prompt = req.prompts[0].get("prompt") + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) + + # Always reset: Lingbot Fast does not support video continuation + self.state.reset() + + camera = multi_modal_data.get("camera", None) + if camera is None: + self.od_config.model + raise ValueError("A path to camera positions must be passed to this model through action_path.") + + batch_size = 1 + num_frames = req.sampling_params.num_frames + # In order to generate something num_frames must be at least 5 since it expects 4*n + 1 as input + # 25 is the smallest length supported by the model. Smaller values generate tensors with dimension zero/negative + num_frames = max(25, num_frames) + + c2ws = camera.get("poses") + + len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 + num_frames = ((num_frames - 1) // 4) * 4 + 1 + num_frames = min(num_frames, len_c2ws) + c2ws = c2ws[:num_frames] + + # preprocess + img = multi_modal_data.get("image") + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + + max_area = CONFIG["max_area"] + chunk_size = CONFIG["chunk_size"] + + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round(np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) + lat_w = round(np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + lat_f = (num_frames - 1) // self.vae_stride[0] + 1 + lat_f = int(lat_f - (lat_f % chunk_size)) + lat_f = max(lat_f, 1) + F = (lat_f - 1) * 4 + 1 + max_seq_len = chunk_size * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + seed = random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn(16, lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + # 2. Prepare timesteps + self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) + timesteps = self.scheduler.timesteps[CONFIG["timesteps_index"]] + + context = self.text_encoder([prompt], torch.device("cpu")) + context = [t.to(self.device) for t in context] + + dit_cond_dict = None + Ks = torch.from_numpy(camera.get("intrinsics")) + + # Transform the provided intrinsics from the original 480p according to the new image size (h, w). + Ks = get_Ks_transformed( + Ks, height_org=480, width_org=832, height_resize=h, width_resize=w, height_final=h, width_final=w + ) + Ks = Ks[0] + + len_c2ws = len(c2ws) + len_c2ws_ = int((len_c2ws - 1) // 4) + 1 + len_c2ws_ = int(len_c2ws_ - (len_c2ws_ % chunk_size)) + c2ws_infer = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws - 1, len_c2ws), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=np.linspace(0, len_c2ws - 1, len_c2ws_), + ) + c2ws_infer = compute_relative_poses(c2ws_infer, framewise=True) + Ks = Ks.repeat(len(c2ws_infer), 1) + + c2ws_infer = c2ws_infer.to(self.device).to(torch.float32) + Ks = Ks.to(self.device).to(torch.float32) + only_rays_d = False + c2ws_plucker_emb = get_plucker_embeddings(c2ws_infer, Ks, h, w, only_rays_d=only_rays_d) + c2ws_plucker_emb = rearrange( + c2ws_plucker_emb, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=int(h // lat_h), + c2=int(w // lat_w), + ) + c2ws_plucker_emb = c2ws_plucker_emb[None, ...] # [b, f*h*w, c] + c2ws_plucker_emb = rearrange(c2ws_plucker_emb, "b (f h w) c -> b c f h w", f=lat_f, h=lat_h, w=lat_w).to( + self.target_dtype + ) + + y = self.vae.encode( + [ + torch.concat( + [ + torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, F - 1, h, w), + ], + dim=1, + ).to(self.device) + ] + )[0] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync_model = getattr(self.model, "no_sync", noop_no_sync) + + # Initialize KV cache to all zeros + model_args = self.model.config + transformer_dtype = self.target_dtype + frame_seqlen = int(noise.shape[-2] * noise.shape[-1] // 4) + kv_size = frame_seqlen * lat_f + head_dim = model_args.dim // model_args.num_heads + local_num_heads = model_args.num_heads // self.sp_size + + self.state.create_kv_caches( + batch_size, transformer_dtype, self.device, kv_size, model_args.num_layers, local_num_heads, head_dim + ) + + # evaluation mode + with ( + torch.amp.autocast("cuda", dtype=self.target_dtype), + torch.no_grad(), + no_sync_model(), + ): + # sample videos + latent = noise + latents_chunk = latent.split(chunk_size, dim=1) # [c, f, h, w] + condition_chunk = y.split(chunk_size, dim=1) + c2ws_plucker_emb_chunk = c2ws_plucker_emb.split(chunk_size, dim=2) + num_inference_chunk = len(latents_chunk) + pred_latent_chunks = [] + for chunk_id in tqdm(range(num_inference_chunk)): + current_latent = latents_chunk[chunk_id] + current_condition = condition_chunk[chunk_id] + current_c2ws_plucker_emb = c2ws_plucker_emb_chunk[chunk_id] + + dit_cond_dict = { + "c2ws_plucker_emb": current_c2ws_plucker_emb.chunk(1, dim=0), + } + + kwargs = { + "context": [context[0]], + "seq_len": max_seq_len, + "y": [current_condition], + "dit_cond_dict": dit_cond_dict, + "kv_cache": self.state.get_kv_caches(), + "local_end_index": self.state.local_end_index, + "global_end_index": self.state.global_end_index, + "crossattn_cache": self.state.get_crossattn_caches(), + "current_start": chunk_id * chunk_size * frame_seqlen, + "max_attention_size": kv_size, + } + + for timestep_idx in range(len(timesteps)): + latent_model_input = [current_latent.to(self.device)] + current_timestep = [timesteps[timestep_idx]] + + timestep = torch.stack(current_timestep).to(self.device) + + noise_pred = self.model(x=latent_model_input, t=timestep, **kwargs)[0] + + x0 = self._convert_flow_pred_to_x0( + flow_pred=noise_pred, + xt=current_latent, + timestep=current_timestep[0], + scheduler=self.scheduler, + ) + + if timestep_idx < len(timesteps) - 1: + next_timestep = timesteps[timestep_idx + 1] + current_latent = self.scheduler.add_noise( + x0, torch.randn(x0.shape, generator=seed_g, device=x0.device, dtype=x0.dtype), next_timestep + ) + else: + # note return x0 + break + + pred_latent_chunks.append(x0) + + # Update kv cache + context_timestep = [timesteps[-1] * 0.0] + timestep = torch.stack(context_timestep).to(self.device) + self.model(x=[x0], t=timestep, **kwargs) + + pred_latent_chunks = torch.cat(pred_latent_chunks, dim=1) + + if self.device.index == 0: + videos = self.vae.decode([pred_latent_chunks]) + + if dist.is_initialized(): + dist.barrier() + + return DiffusionOutput(output=videos[0]) + + def load_weights(self, weights): + pass diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py new file mode 100644 index 00000000000..4b62b190ed5 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Lingbot World Fast pipeline persistent state.""" + +from __future__ import annotations + +import logging +from enum import IntEnum + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +class CacheIndex(IntEnum): + K = 0 + V = 1 + + +class LingbotWorldFastState: + """Pipeline persistent state across forward() calls. + + Lifecycle: + - Created once in LingbotWorldFastPipeline.__init__() + - Mutated every forward() call (frame append, KV cache grow) + - reset() on new session / local_attn_size exceeded + """ + + def __init__(self) -> None: + self.reset() + + # ------------------------------------------------------------------ + # Reset / should_reset + # ------------------------------------------------------------------ + + def reset(self) -> None: + """Clear all state.""" + self.kv_cache: list[torch.Tensor] | None = None + self.crossattn_cache: list[dict[str, bool | torch.Tensor | None]] | None = None + self.current_start_frame: int = 0 + self.local_end_index: list[torch.Tensor] | None = None + self.global_end_index: list[torch.Tensor] | None = None + + def should_reset(self, text_tokens: torch.Tensor | None, num_video_frames: int, local_attn_size: int) -> bool: + """Determine if state should be reset before this forward().""" + # NOTE: after accumulate_frames, num_video_frames is the accumulated T + # (1 for first call, 4 for subsequent). Only reset on true single-frame + # which happens when the stitched_buffer was cleared externally. + if num_video_frames == 1 and self.call_count > 1: + logger.info("single frame input after first call, resetting") + return True + + if local_attn_size != -1 and self.current_start_frame >= local_attn_size: + logger.info( + "current_start_frame %d >= local_attn_size %d, resetting", self.current_start_frame, local_attn_size + ) + return True + + return False + + # ------------------------------------------------------------------ + # KV cache management + # ------------------------------------------------------------------ + + def create_kv_caches( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + kv_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + ) -> None: + """Initialize empty KV caches and cross-attention caches.""" + self.kv_cache = [ + torch.zeros(2, batch_size, kv_size, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_layers) + ] + + self.local_end_index = [torch.tensor([0], dtype=torch.long, device=device) for _ in range(num_layers)] + self.global_end_index = [torch.tensor([0], dtype=torch.long, device=device) for _ in range(num_layers)] + + self.crossattn_cache = [{"is_init": False, "k": None, "v": None} for _ in range(num_layers)] + + def update_kv_cache( + self, + layer_index: int, + updated_kv: torch.Tensor, + is_negative: bool = False, + ) -> None: + """Update a single layer's KV cache after prefill.""" + cache = self.kv_cache_neg if is_negative else self.kv_cache + assert cache is not None, "KV caches not initialized, call create_kv_caches first" + cache[layer_index] = updated_kv.clone() + + def get_kv_caches(self) -> list[torch.Tensor]: + """Get KV caches for the specified branch.""" + assert self.kv_cache is not None, "KV caches not initialized" + return self.kv_cache + + def get_crossattn_caches(self, is_negative: bool = False) -> list[dict[str, bool | torch.Tensor | None]]: + """Get cross-attention caches for the specified branch.""" + assert self.crossattn_cache is not None, "Cross-attn caches not initialized" + return self.crossattn_cache diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py new file mode 100644 index 00000000000..dc7d7961553 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py @@ -0,0 +1,653 @@ +"""Some of the functions are borrowed from SelfForcing (https://github.com/guandeh17/Self-Forcing).""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as torch_F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange +from wan.modules.model import WanLayerNorm, WanRMSNorm, WanSelfAttention, rope_params, sinusoidal_embedding_1d + +from vllm_omni.diffusion.attention.layer import Attention + +from .state_lingbot_world_fast import CacheIndex + + +def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).type_as(x) + + +class CausalWanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + kv_cache=None, + local_end_index=None, + global_end_index=None, + current_start=0, + max_attention_size=1_000_000, + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + block_mask (BlockMask) + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + frame_seqlen = math.prod(grid_sizes[0][1:]).item() + current_start_frame = current_start // frame_seqlen + roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) + roped_key = causal_rope_apply(k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) + current_end = current_start + roped_query.shape[1] + sink_tokens = self.sink_size * frame_seqlen + # If we are using local attention and the current KV cache size is larger than the local attention size, + # then we need to truncate the KV cache + kv_cache_size = kv_cache[CacheIndex.K].shape[1] + num_new_tokens = roped_query.shape[1] + if ( + self.local_attn_size != -1 + and (current_end > global_end_index.item()) + and (num_new_tokens + local_end_index.item() > kv_cache_size) + ): + # Calculate the number of new tokens added in this step + # Shift existing cache content left to discard oldest tokens + # Clone the source slice to avoid overlapping memory error + num_evicted_tokens = num_new_tokens + local_end_index.item() - kv_cache_size + num_rolled_tokens = local_end_index.item() - num_evicted_tokens - sink_tokens + kv_cache[CacheIndex.K][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.K][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + kv_cache[CacheIndex.V][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.V][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + # Insert the new keys/values at the end + new_local_end_index = local_end_index.item() + current_end - global_end_index.item() - num_evicted_tokens + local_start_index = new_local_end_index - num_new_tokens + kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key + kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v + else: + # Assign new keys/values directly up to current_end + new_local_end_index = local_end_index.item() + current_end - global_end_index.item() + local_start_index = new_local_end_index - num_new_tokens + kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key + kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v + + k_cache = kv_cache[CacheIndex.K][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + v_cache = kv_cache[CacheIndex.V][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + x = self.attn(roped_query, k_cache, v_cache) + + global_end_index.fill_(current_end) + local_end_index.fill_(new_local_end_index) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanCrossAttention(WanSelfAttention): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward(self, x, context, context_lens, crossattn_cache=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + + if crossattn_cache is not None: + if not crossattn_cache.get("is_init", False): + crossattn_cache["is_init"] = True + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + crossattn_cache[CacheIndex.K] = k + crossattn_cache[CacheIndex.V] = v + else: + k = crossattn_cache[CacheIndex.K] + v = crossattn_cache[CacheIndex.V] + else: + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class CausalWanAttentionBlock(nn.Module): + def __init__( + self, dim, ffn_dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, cross_attn_norm=False, eps=1e-6 + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.local_attn_size = local_attn_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = CausalWanSelfAttention( + dim=dim, num_heads=num_heads, local_attn_size=local_attn_size, sink_size=sink_size, qk_norm=qk_norm, eps=eps + ) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.cam_injector_layer1 = nn.Linear(dim, dim) + self.cam_injector_layer2 = nn.Linear(dim, dim) + self.cam_scale_layer = nn.Linear(dim, dim) + self.cam_shift_layer = nn.Linear(dim, dim) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + dit_cond_dict=None, + kv_cache=None, + local_end_index=None, + global_end_index=None, + crossattn_cache=None, + current_start=0, + max_attention_size=1_000_000, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, F, 6, C] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + assert e[0].dtype == torch.float32 + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), + seq_lens, + grid_sizes, + freqs, + kv_cache, + local_end_index, + global_end_index, + current_start, + max_attention_size, + ) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].squeeze(2) + + # cam injection (only if dit_cond_dict is provided and contains c2ws_plucker_emb) + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_hidden_states = self.cam_injector_layer2(torch_F.silu(self.cam_injector_layer1(c2ws_plucker_emb))) + c2ws_hidden_states = c2ws_hidden_states + c2ws_plucker_emb + cam_scale = self.cam_scale_layer(c2ws_hidden_states) + cam_shift = self.cam_shift_layer(c2ws_hidden_states) + x = (1.0 + cam_scale) * x + cam_shift + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None): + x = x + self.cross_attn(self.norm3(x), context, context_lens, crossattn_cache=crossattn_cache) + y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].squeeze(2) + return x + + x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache) + return x + + +class CausalHead(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)) + return x + + +class WanModelFast(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim"] + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + model_type="t2v", + control_type="cam", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + local_attn_size=-1, + sink_size=0, + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + control_type (`str`, *optional*, defaults to 'cam'): + Type of conditioning control signal - 'cam' (6-dim camera Plucker + embeddings) or 'act' (7-dim action embeddings including WASD movement) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + local_attn_size (`int`, *optional*, defaults to -1): + Window size for temporal local attention (-1 indicates global attention) + sink_size (`int`, *optional*, defaults to 0): + Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.local_attn_size = local_attn_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + if control_type == "cam": + control_dim = 6 + elif control_type == "act": + control_dim = 7 + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.patch_embedding_wancamctrl = nn.Linear( + control_dim * 64 * patch_size[0] * patch_size[1] * patch_size[2], dim + ) + self.c2ws_hidden_states_layer1 = nn.Linear(dim, dim) + self.c2ws_hidden_states_layer2 = nn.Linear(dim, dim) + + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList( + [ + CausalWanAttentionBlock( + dim, ffn_dim, num_heads, local_attn_size, sink_size, qk_norm, cross_attn_norm, eps + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = CausalHead(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) + + # initialize weights + self.init_weights() + + def forward( + self, + x, + t, + context, + seq_len, + y=None, + dit_cond_dict=None, + kv_cache=None, + local_end_index=None, + global_end_index=None, + crossattn_cache=None, + current_start=0, + max_attention_size=1_000_000, + ): + r""" + Run the diffusion model with kv caching. + See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. + This function will be run for num_frame times. + Process the latent frames one by one (1560 tokens each) + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + dit_cond_dict (`dict`, *optional*, defaults to None): + Dictionary of conditioning signals. May contain key ``c2ws_plucker_emb`` + with camera Plucker embeddings of shape [B, C, F, H, W] for camera control. + kv_cache (`list[dict]`, *optional*, defaults to None): + Per-layer self-attention KV cache. Each dict contains keys ``k``, ``v`` + (Tensor of shape [B, kv_size, num_heads, head_dim]), ``global_end_index``, + and ``local_end_index`` (scalar Tensors tracking cache position). + crossattn_cache (`list[dict]`, *optional*, defaults to None): + Per-layer cross-attention KV cache. Each dict contains keys ``k``, ``v`` + (Tensor of shape [B, text_len, num_heads, head_dim]) and ``is_init`` (bool). + current_start (`int`, *optional*, defaults to 0): + Token offset of the current chunk in the full sequence. Used to index + into the KV cache and compute positional embeddings correctly. + max_attention_size (`int`, *optional*, defaults to 1_000_000): + Maximum number of KV tokens each query can attend to. Limits the + effective context window of self-attention to control memory usage. + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + + if self.model_type == "i2v": + assert y is not None + + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat(x) + + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_lens) + with torch.amp.autocast("cuda", dtype=torch.float32): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_lens)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) + ) + + # cam + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_plucker_emb = [ + rearrange( + i, + "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", + c1=self.patch_size[0], + c2=self.patch_size[1], + c3=self.patch_size[2], + ) + for i in c2ws_plucker_emb + ] + c2ws_plucker_emb = torch.cat(c2ws_plucker_emb, dim=1) # [1, (L1+...+Ln), C] + + c2ws_plucker_emb = self.patch_embedding_wancamctrl(c2ws_plucker_emb) + c2ws_hidden_states = self.c2ws_hidden_states_layer2( + torch_F.silu(self.c2ws_hidden_states_layer1(c2ws_plucker_emb)) + ) + dit_cond_dict = dict(dit_cond_dict) + dit_cond_dict["c2ws_plucker_emb"] = c2ws_plucker_emb + c2ws_hidden_states + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dit_cond_dict=dit_cond_dict, + max_attention_size=max_attention_size, + ) + + for block_index, block in enumerate(self.blocks): + kwargs.update( + { + "kv_cache": kv_cache[block_index], + "crossattn_cache": crossattn_cache[block_index], + "local_end_index": local_end_index[block_index], + "global_end_index": global_end_index[block_index], + "current_start": current_start, + } + ) + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index b6a5a3700e8..0337c37a47f 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -261,6 +261,11 @@ "pipeline_diffusers_adapter", "DiffusersAdapterPipeline", ), + "LingbotWorldFastPipeline": ( + "lingbot_world_fast", + "pipeline_lingbot_world_fast", + "LingbotWorldFastPipeline", + ), } @@ -482,6 +487,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "OmniVoicePipeline": "get_omnivoice_post_process_func", "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func", "SenseNovaU1Pipeline": "get_sensenova_u1_post_process_func", + "LingbotWorldFastPipeline": "get_lingbot_world_fast_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index f2689d4b50f..3e8e6ccc04f 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -544,6 +544,19 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu action="store_true", help="Enable AR stage profiler to include AR stage timing in stage_durations.", ) + + omni_config_group.add_argument( + "--ws-max-size", + type=int, + default=1_048_576, # 1MB + help="Change max size of a websocket payload that is accepted by the server", + ) + omni_config_group.add_argument( + "ws", + default="auto", + help="Set the websocket Protocol type", + ) + # Stash via type(self) so the docs hook (which execs this function in a # sandboxed globals dict via ``DummySelf``) doesn't fail on a NameError. type(self)._parser = serve_parser diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index f00e6448a16..16b3367a16b 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -117,6 +117,7 @@ ) from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler @@ -377,6 +378,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, if log_config is not None: uvicorn_kwargs["log_config"] = log_config + if args.ws_max_size is not None: + uvicorn_kwargs["ws_max_size"] = args.ws_max_size + if args.ws is not None: + uvicorn_kwargs["ws"] = args.ws + async with build_async_omni( args, client_config=client_config, @@ -630,6 +636,10 @@ async def omni_init_app_state( state.openai_streaming_speech = None state.openai_streaming_video = None + state.openai_serving_world_camera = ServingRealtimeWorldCamera.create_policy_server( + engine_client=engine_client, model_name=model_name + ) + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 logger.info("Pure diffusion API server initialized for model: %s", model_name) @@ -947,6 +957,10 @@ async def omni_init_app_state( stage_configs=state.stage_configs, ) + state.openai_serving_world_camera = ServingRealtimeWorldCamera.create_policy_server( + engine_client=engine_client, model_name=model_name + ) + state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 state.sleeping_stages = set() @@ -1405,6 +1419,19 @@ async def realtime_websocket(websocket: WebSocket): connection = RealtimeConnection(websocket, serving) await connection.handle_connection() +@router.websocket("/v1/realtime/world/camera") +async def realtime_world_camera_openpi(websocket: WebSocket): + from vllm_omni.entrypoints.openai.realtime.world.camera_connection import WorldCameraRealtimeConnection + + serving = getattr(websocket.app.state, "openai_serving_world_camera", None) + + if serving is None: + await websocket.accept() + await websocket.send_json({"type": "error", "error": "World Model policy not available", "code": "unsupported"}) + await websocket.close() + return + connection = WorldCameraRealtimeConnection(websocket, serving) + await connection.handle_connection() # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py new file mode 100644 index 00000000000..7d4a500399c --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""WebSocket connection for robot policy inference (OpenPI protocol). + +Protocol (compatible with DreamZero test_client_AR.py): + Connect -> server sends msgpack(PolicyServerConfig fields) + Infer -> client sends msgpack(obs), server sends msgpack(ndarray) + Reset -> client sends msgpack({endpoint:reset}), server sends "reset successful" +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import torch +from fastapi import WebSocket +from starlette.websockets import WebSocketDisconnect +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera +from vllm_omni.entrypoints.openai.video_api_utils import _normalize_frames + +logger = init_logger(__name__) +_DEFAULT_IDLE_TIMEOUT = 30.0 + + +def _get_msgpack_numpy() -> Any: + try: + from openpi_client import msgpack_numpy + except ImportError as exc: + raise ImportError( + "The `/v1/realtime/world/camera` endpoint requires the optional " + "`openpi-client` dependency. Install it with `pip install openpi-client`." + ) from exc + + return msgpack_numpy + + +def _pack(obj: Any) -> bytes: + return _get_msgpack_numpy().packb(obj) + + +def _unpack(data: bytes) -> Any: + return _get_msgpack_numpy().unpackb(data) + + +class WorldCameraRealtimeConnection: + """WebSocket connection for world model inference.""" + + def __init__( + self, + websocket: WebSocket, + serving: ServingRealtimeWorldCamera, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + ) -> None: + self.websocket = websocket + self.serving = serving + self._idle_timeout = idle_timeout + + async def _send_error(self, message: str) -> None: + await self.websocket.send_bytes(_pack({"type": "error", "message": message})) + + def _unpack_request(self, data: bytes) -> dict[str, Any]: + obs = _unpack(data) + if not isinstance(obs, dict): + raise ValueError("Invalid request payload") + return obs + + async def handle_connection(self) -> None: + """Main loop.""" + await self.websocket.accept() + + try: + # Send model-specific PolicyServerConfig resolved by serving from + # diffusion od_config.model_config. + metadata = self.serving.policy_server_config.to_dict() + await self.websocket.send_bytes(_pack(metadata)) + + while True: + try: + msg = await asyncio.wait_for( + self.websocket.receive(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + logger.info("World Model OpenPI connection idle timeout after %.1f seconds", self._idle_timeout) + try: + await self.websocket.close() + except Exception: + logger.debug("Failed to close idle World Model websocket", exc_info=True) + return + + if msg.get("type") == "websocket.disconnect": + break + + if "bytes" not in msg or not msg["bytes"]: + continue + + try: + obs = self._unpack_request(msg["bytes"]) + except Exception: + logger.exception("Invalid world model OpenPI request payload") + try: + await self._send_error("Invalid request payload") + except Exception: + break + continue + + try: + endpoint = obs.pop("endpoint", "infer") + + if endpoint == "reset": + self.serving.reset(obs) + await self.websocket.send_text("reset successful") + else: + result = await self.serving.infer(obs) + + if ( + len(result.images) == 1 + and isinstance(result.images[0], tuple) + and len(result.images[0]) == 1 + ): + frames = result.images[0] + elif len(result.images) == 1 and isinstance(result.images[0], dict): + frames = result.images[0].get("frames") or result.images[0].get("video") + else: + frames = result.images + + if len(frames) == 1: + frames = frames[0] + + if isinstance(frames, torch.Tensor): + frames = frames.numpy(force=True) + + frames = _normalize_frames(frames) + + CHUNK_FRAMES = 4 + + total = (len(frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES + for i in range(total): + chunk = frames[i * CHUNK_FRAMES : (i + 1) * CHUNK_FRAMES] + await self.websocket.send_bytes( + _pack( + { + "type": "frame", + "index": i, + "total": total, + "video": chunk, + } + ) + ) + + except Exception: + logger.exception("Error handling request") + try: + await self._send_error("Internal inference error") + except Exception: + break + + except WebSocketDisconnect: + pass + except Exception: + logger.exception("Connection error") diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py new file mode 100644 index 00000000000..9544d5cd094 --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Protocol structs for the /v1/realtime/world/* family of endpoints. + +These are msgpack-serialised over the WebSocket wire via ``msgspec.msgpack``. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import numpy as np +from omegaconf import OmegaConf +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _to_builtin_container(value: Any) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, Mapping): + return {key: _to_builtin_container(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_builtin_container(item) for item in value] + return value + + +@dataclass(frozen=True) +class CameraServerConfig: + """Static server-side camera/pipeline parameters sent to a client on connect. + + Fields are seeded from the Lingbot World Fast pipeline constants — the only + camera-capable pipeline today. When additional camera pipelines are added, + ``from_model_config`` should branch on the model identifier. + """ + + values: dict[str, Any] + + @classmethod + def from_model_config(cls, model_config: Any) -> CameraServerConfig: + return cls(_to_builtin_container(model_config)) + + def to_dict(self) -> dict[str, Any]: + return _to_builtin_container(self.values) + + +class ServingRealtimeWorldCamera: + """World Model Camera serving layer for OpenPI protocol. + + Model-specific transform/state lives in the diffusion pipeline. + """ + + def __init__( + self, + engine_client: Any, + model_name: str | None = None, + ) -> None: + self.engine_client = engine_client + self.model_name = model_name + self._current_session_id: str | None = None + self._call_count = 0 + self.policy_server_config = self._get_policy_server_config(engine_client) + + @classmethod + def create_policy_server( + cls, + engine_client: Any, + model_name: str | None = None, + ) -> ServingRealtimeWorldCamera | None: + try: + return cls(engine_client=engine_client, model_name=model_name) + except ValueError as exc: + if "policy_server_config" not in str(exc): + raise + logger.info("World Model OpenPI serving disabled for model %s", model_name) + return None + + @staticmethod + def _get_policy_server_config(engine_client: Any) -> CameraServerConfig: + model_config = None + get_od_config = getattr(engine_client, "get_diffusion_od_config", None) + if callable(get_od_config): + od_config = get_od_config() + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + for stage_config in getattr(engine_client, "stage_configs", []) or []: + if getattr(stage_config, "stage_type", None) != "diffusion": + continue + engine_args = getattr(stage_config, "engine_args", None) + model_config = getattr(engine_args, "model_config", None) + if model_config is not None: + break + + if model_config is None: + od_config = getattr(engine_client, "od_config", None) + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + model_config = getattr(engine_client, "model_config", None) + return CameraServerConfig.from_model_config(model_config) + + def reset(self, obs: dict) -> None: + """Reset serving state. + + Engine-side Lingbot state is reset on the next inference request via + `extra_args["reset"]`, not by an immediate websocket-side RPC. + """ + self._call_count = 0 + self._current_session_id = None + + async def infer(self, obs: dict) -> np.ndarray: + """raw obs → engine → video.""" + # Session tracking + + session_id = obs.get("session_id") + if session_id is not None and session_id != self._current_session_id: + if self._current_session_id is not None: + logger.info("Session changed %s → %s", self._current_session_id, session_id) + self.reset({}) + self._current_session_id = session_id + + self._call_count += 1 + + # Build request, run inference through AsyncOmni + request = self._build_request(obs) + result = None + # OpenPI policy serving is one request -> one action reply. AsyncOmni + # exposes an async iterator, so consume it to completion and use the + # final output, matching other non-streaming OpenAI serving paths. + async for output in self.engine_client.generate( + prompt=request.prompts[0], + request_id=request.request_ids[0], + sampling_params_list=[request.sampling_params], + ): + result = output + if result is None: + raise RuntimeError("World Model Camera OpenPI request produced no output.") + + return result + + def _build_request(self, obs: dict) -> Any: + """Build engine request from raw robot obs. + + Returns an `OmniDiffusionRequest` payload consumed by + `AsyncOmni.generate()` and routed to the diffusion stage. + """ + from vllm_omni.diffusion.request import OmniDiffusionRequest + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + extra_args = { + "session_id": self._current_session_id or "default", + } + + camera = obs.get("camera", None) + + multi_modal_data = { + "image": obs.get("image", None), + "camera": camera, + } + + prompt = obs.get("prompt", "") + + extra_body = obs.get("extra_body", {}) + + height = extra_body.get("height", None) + width = extra_body.get("width", None) + num_frames = extra_body.get("num_frames", None) + fps = extra_body.get("fps", None) + + sampling_params = OmniDiffusionSamplingParams( + height=height, width=width, num_frames=num_frames, frame_rate=fps, extra_args=extra_args, seed=42 + ) + return OmniDiffusionRequest( + prompts=[ + { + "prompt": prompt, + "multi_modal_data": multi_modal_data, + } + ], + sampling_params=sampling_params, + request_ids=[f"camera-{self._current_session_id or 'default'}"], + ) From fae96e2e7e19efdd3e076388eceff764450007dd Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:19:58 +0200 Subject: [PATCH 02/53] Add step execution and denoise_micro_step to Wan2.2 pipeline Implement SupportsStepExecution protocol on Wan22Pipeline, decomposing the monolithic forward() into prepare_encode, denoise_step,step_scheduler, and post_decode. Add denoise_micro_step for temporal PP. Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../models/wan2_2/pipeline_wan2_2.py | 418 +++++++++++++++++- 1 file changed, 409 insertions(+), 9 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index b4438c79520..012f3d04f1a 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -3,12 +3,13 @@ from __future__ import annotations +import copy import json import logging import os import time from collections.abc import Iterable -from typing import Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast import PIL.Image import torch @@ -38,6 +39,9 @@ from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import DiffusionRequestState + logger = logging.getLogger(__name__) DEBUG_PERF = False WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} @@ -302,6 +306,8 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: class Wan22Pipeline( nn.Module, PipelineParallelMixin, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin ): + supports_step_execution: ClassVar[bool] = True + def __init__( self, *, @@ -1028,15 +1034,409 @@ def check_inputs( if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") + # ── Step-execution protocol (SupportsStepExecution) ── + + def _extract_prompts( + self, + state: DiffusionRequestState, + ) -> tuple[str | None, str | None]: + """Extract prompt and negative prompt from *state*.""" + prompt: str | None = None + negative_prompt: str | None = None + if state.prompts: + p = state.prompts[0] + if isinstance(p, str): + prompt = p + else: + prompt = p.get("prompt") + negative_prompt = p.get("negative_prompt") + return prompt, negative_prompt + + def _resolve_generation_params( + self, + state: DiffusionRequestState, + ) -> dict[str, Any]: + """Extract and validate generation parameters from *state*. + + Returns a dict with resolved height, width, num_frames, guidance_low, + guidance_high, boundary_ratio, boundary_timestep, device, dtype, and + generator. + """ + sampling = state.sampling + + height = sampling.height or 480 + width = sampling.width or 832 + num_frames = sampling.num_frames or 81 + num_steps = sampling.num_inference_steps or 40 + + patch_size = self.transformer_config.patch_size + mod_value = self.vae_scale_factor_spatial * patch_size[1] + height = (height // mod_value) * mod_value + width = (width // mod_value) * mod_value + + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + guidance_scale = sampling.guidance_scale if sampling.guidance_scale_provided else 4.0 + guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] + guidance_high = ( + sampling.guidance_scale_2 + if sampling.guidance_scale_2 is not None + else ( + guidance_scale[1] + if isinstance(guidance_scale, (list, tuple)) and len(guidance_scale) > 1 + else guidance_low + ) + ) + + boundary_ratio = self.boundary_ratio if self.boundary_ratio is not None else sampling.boundary_ratio + if boundary_ratio is None: + boundary_ratio = 0.875 + + device = self.device + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.transformer_2 is not None: + dtype = self.transformer_2.dtype + else: + dtype = self.text_encoder.dtype + + generator = sampling.generator + if generator is None and sampling.seed is not None: + generator = torch.Generator(device=device).manual_seed(sampling.seed) + + return { + "height": height, + "width": width, + "num_frames": num_frames, + "num_steps": num_steps, + "guidance_low": guidance_low, + "guidance_high": guidance_high, + "boundary_ratio": boundary_ratio, + "boundary_timestep": boundary_ratio * 1000, # num_train_timesteps + "device": device, + "dtype": dtype, + "generator": generator, + } + + def _prepare_latent_input( + self, + state: DiffusionRequestState, + t: torch.Tensor, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare latent_model_input and timestep tensor for one denoise step. + + Handles both T2V (passthrough) and I2V (condition blending + per-patch + timestep expansion) modes. + + Returns: + (latent_model_input, timestep_tensor) + """ + latent_condition = state.extra.get("latent_condition") + first_frame_mask = state.extra.get("first_frame_mask") + expand_timesteps = state.extra.get("expand_timesteps", False) + + if expand_timesteps and latent_condition is not None: + # I2V mode: blend condition with latents, expand timesteps per patch + latent_model_input = ((1 - first_frame_mask) * latent_condition + first_frame_mask * state.latents).to( + dtype + ) + patch_size = self.transformer_config.patch_size + patch_height = state.latents.shape[3] // patch_size[1] + patch_width = state.latents.shape[4] // patch_size[2] + patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]] + patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] + temp_ts = (patch_mask[0][0] * t).flatten() + timestep_tensor = temp_ts.unsqueeze(0).expand(state.latents.shape[0], -1) + else: + # T2V mode + latent_model_input = state.latents.to(dtype) + timestep_tensor = t.expand(state.latents.shape[0]) if t.ndim == 0 else t + + return latent_model_input, timestep_tensor + + def _select_model_for_timestep( + self, + t: torch.Tensor, + boundary_timestep: float | None, + ) -> tuple[nn.Module, float]: + """Return (transformer, guidance_scale) for a given timestep.""" + guidance_low = self._guidance_scale or 4.0 + guidance_high = self._guidance_scale_2 or guidance_low + if boundary_timestep is not None and t < boundary_timestep: + model = self.transformer_2 if self.transformer_2 is not None else self.transformer + return model, guidance_high + model = self.transformer if self.transformer is not None else self.transformer_2 + return model, guidance_low + + def prepare_encode( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> DiffusionRequestState: + """One-time request setup: encode prompt, prepare latents, init scheduler.""" + # Extract prompts + prompt, negative_prompt = self._extract_prompts(state) + + params = self._resolve_generation_params(state) + height = params["height"] + width = params["width"] + num_frames = params["num_frames"] + device = params["device"] + dtype = params["dtype"] + generator = params["generator"] + guidance_low = params["guidance_low"] + guidance_high = params["guidance_high"] + + # Store guidance for properties and for denoise_step model selection + self._guidance_scale = guidance_low + self._guidance_scale_2 = guidance_high + + # Encode prompt + do_cfg = guidance_low > 1.0 or guidance_high > 1.0 + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_cfg, + num_videos_per_prompt=state.sampling.num_outputs_per_prompt or 1, + max_sequence_length=state.sampling.max_sequence_length or 512, + device=device, + dtype=dtype, + ) + + # Scheduler + self.scheduler.set_timesteps(params["num_steps"], device=device) + req_scheduler = copy.deepcopy(self.scheduler) + + # I2V conditioning + multi_modal_data = ( + state.prompts[0].get("multi_modal_data", {}) + if state.prompts and not isinstance(state.prompts[0], str) + else None + ) + raw_image = multi_modal_data.get("image", None) if multi_modal_data else None + if isinstance(raw_image, list): + raw_image = raw_image[0] + + latent_condition = None + first_frame_mask = None + + if self.expand_timesteps and raw_image is not None: + # I2V mode + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + if isinstance(raw_image, str): + image = PIL.Image.open(raw_image) + else: + image = cast(PIL.Image.Image | torch.Tensor, raw_image) + + if isinstance(image, PIL.Image.Image): + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + image_tensor = video_processor.preprocess(image, height=height, width=width) + else: + image_tensor = image + + num_channels_latents = self.transformer_config.out_channels + batch_size = prompt_embeds.shape[0] + + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=state.sampling.latents, + ) + + image_tensor = image_tensor.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latent_condition = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latent_condition.device, latent_condition.dtype + ) + latent_condition = ((latent_condition - latents_mean) * latents_std).to(torch.float32) + + num_latent_frames = latents.shape[2] + latent_height = latents.shape[3] + latent_width = latents.shape[4] + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=torch.float32, device=device + ) + first_frame_mask[:, :, 0] = 0 + else: + # T2V mode + num_channels_latents = self.transformer_config.in_channels + latents = self.prepare_latents( + batch_size=prompt_embeds.shape[0], + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=state.sampling.latents, + ) + + # Populate state + state.prompt_embeds = prompt_embeds + state.negative_prompt_embeds = negative_prompt_embeds + state.latents = latents + state.timesteps = req_scheduler.timesteps + state.step_index = 0 + state.scheduler = req_scheduler + state.do_true_cfg = do_cfg and negative_prompt_embeds is not None + state.guidance = torch.tensor([guidance_low], device=device) + + state.extra["guidance_low"] = guidance_low + state.extra["guidance_high"] = guidance_high + state.extra["boundary_timestep"] = params["boundary_timestep"] + state.extra["expand_timesteps"] = self.expand_timesteps + state.extra["latent_condition"] = latent_condition + state.extra["first_frame_mask"] = first_frame_mask + state.extra["height"] = height + state.extra["width"] = width + + return state + + def denoise_step( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> torch.Tensor | None: + """Run one denoising iteration.""" + t = state.current_timestep + self._current_timestep = t -# --------------------------------------------------------------------------- -# DMD2-distilled variant -# --------------------------------------------------------------------------- + boundary_timestep = state.extra.get("boundary_timestep") + current_model, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) + latent_model_input, timestep = self._prepare_latent_input(state, t, current_model.dtype) -class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): - """Wan 2.x T2V pipeline for FastGen DMD2-distilled models.""" + do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None - def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): - super().__init__(od_config=od_config, prefix=prefix) - self.__init_dmd2__() + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": state.prompt_embeds, + "attention_kwargs": {}, + "return_dict": False, + "current_model": current_model, + } + negative_kwargs = ( + { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": state.negative_prompt_embeds, + "attention_kwargs": {}, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg + else None + ) + + return self.predict_noise_maybe_with_pp_and_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + ) + + def step_scheduler( + self, + state: DiffusionRequestState, + noise_pred: torch.Tensor, + **kwargs: Any, + ) -> None: + """Run one scheduler step: update latents and advance step_index.""" + t = state.current_timestep + boundary_timestep = state.extra.get("boundary_timestep") + _, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) + do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None + + state.latents = self.scheduler_step_maybe_with_pp_and_cfg( + noise_pred, t, state.latents, do_true_cfg, per_request_scheduler=state.scheduler + ) + state.step_index += 1 + + def post_decode( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> DiffusionOutput: + """Decode final latents after denoising completes.""" + self.sync_pp_send() + self._current_timestep = None + + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + + # I2V: blend final latents with condition + latent_condition = state.extra.get("latent_condition") + first_frame_mask = state.extra.get("first_frame_mask") + if state.extra.get("expand_timesteps") and latent_condition is not None: + state.latents = (1 - first_frame_mask) * latent_condition + first_frame_mask * state.latents + + latents = state.latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + output = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput( + output=output, + stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, + ) + + # ── Temporal PP: local-compute-only forward ── + + def denoise_micro_step( + self, + *, + state: DiffusionRequestState, + timestep: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: + """Local-compute-only forward for one PP stage at one timestep. + + Unlike ``predict_noise_maybe_with_pp_and_cfg``, this does NOT handle + communication (no isend/irecv) or CFG (single branch only). The model + runner owns send/recv timing; this method just runs the transformer's + local blocks and returns the raw result. + + Returns: + ``IntermediateTensors`` on non-last PP ranks, ``noise_pred`` tensor + on the last PP rank. + """ + boundary_timestep = state.extra.get("boundary_timestep") + current_model, _ = self._select_model_for_timestep(timestep, boundary_timestep) + + latent_model_input, timestep_tensor = self._prepare_latent_input(state, timestep, current_model.dtype) + + return self.predict_noise( + current_model=current_model, + intermediate_tensors=intermediate_tensors, + hidden_states=latent_model_input, + timestep=timestep_tensor, + encoder_hidden_states=state.prompt_embeds, + return_dict=False, + ) From 94141ec21c6ae145db5b4335f7a2b31930acf1da Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:19:58 +0200 Subject: [PATCH 03/53] remove denoise_micro_step Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../models/wan2_2/pipeline_wan2_2.py | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 012f3d04f1a..95a3426c9dd 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1406,37 +1406,3 @@ def post_decode( output=output, stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, ) - - # ── Temporal PP: local-compute-only forward ── - - def denoise_micro_step( - self, - *, - state: DiffusionRequestState, - timestep: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - ) -> torch.Tensor | IntermediateTensors: - """Local-compute-only forward for one PP stage at one timestep. - - Unlike ``predict_noise_maybe_with_pp_and_cfg``, this does NOT handle - communication (no isend/irecv) or CFG (single branch only). The model - runner owns send/recv timing; this method just runs the transformer's - local blocks and returns the raw result. - - Returns: - ``IntermediateTensors`` on non-last PP ranks, ``noise_pred`` tensor - on the last PP rank. - """ - boundary_timestep = state.extra.get("boundary_timestep") - current_model, _ = self._select_model_for_timestep(timestep, boundary_timestep) - - latent_model_input, timestep_tensor = self._prepare_latent_input(state, timestep, current_model.dtype) - - return self.predict_noise( - current_model=current_model, - intermediate_tensors=intermediate_tensors, - hidden_states=latent_model_input, - timestep=timestep_tensor, - encoder_hidden_states=state.prompt_embeds, - return_dict=False, - ) From e144462f782263c8cbdc374bd859125db51e7c23 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:19:58 +0200 Subject: [PATCH 04/53] add unit tests for wan22 step exec Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../wan2_2/test_wan22_step_execution.py | 505 ++++++++++++++++++ 1 file changed, 505 insertions(+) create mode 100644 tests/diffusion/models/wan2_2/test_wan22_step_execution.py diff --git a/tests/diffusion/models/wan2_2/test_wan22_step_execution.py b/tests/diffusion/models/wan2_2/test_wan22_step_execution.py new file mode 100644 index 00000000000..66409c982a6 --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_wan22_step_execution.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for Wan2.2 SupportsStepExecution protocol implementation. + +Tests use lightweight mocks (no real model weights) to verify: +- Protocol compliance (class flag, method presence) +- Helper correctness (_resolve_generation_params, _select_model_for_timestep) +- Step execution decomposition matches monolithic forward() +- I2V mode latent input preparation +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import torch + +from vllm_omni.diffusion.worker.utils import DiffusionRequestState + +# --------------------------------------------------------------------------- +# Shared test utilities +# --------------------------------------------------------------------------- + + +def _make_sampling(**overrides): + """Create a mock sampling params object.""" + sampling = MagicMock() + sampling.height = overrides.get("height", 480) + sampling.width = overrides.get("width", 832) + sampling.num_frames = overrides.get("num_frames", 81) + sampling.num_inference_steps = overrides.get("num_inference_steps", 4) + sampling.guidance_scale = overrides.get("guidance_scale", 1.0) + sampling.guidance_scale_provided = overrides.get("guidance_scale_provided", True) + sampling.guidance_scale_2 = overrides.get("guidance_scale_2", None) + sampling.boundary_ratio = overrides.get("boundary_ratio", None) + sampling.num_outputs_per_prompt = overrides.get("num_outputs_per_prompt", 1) + sampling.max_sequence_length = overrides.get("max_sequence_length", 512) + sampling.seed = overrides.get("seed", 42) + sampling.generator = None + sampling.latents = None + return sampling + + +def _make_state(**overrides): + """Create a DiffusionRequestState with mock sampling.""" + return DiffusionRequestState( + req_id="test", + sampling=_make_sampling(**overrides), + prompts=overrides.get("prompts", ["test prompt"]), + ) + + +def _make_pipeline_stub(): + """Create a minimal Wan22Pipeline without __init__ (no model weights).""" + from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline + + pipeline = object.__new__(Wan22Pipeline) + torch.nn.Module.__init__(pipeline) + pipeline.vae_scale_factor_spatial = 8 + pipeline.vae_scale_factor_temporal = 4 + pipeline.boundary_ratio = 0.875 + pipeline.expand_timesteps = False + pipeline._guidance_scale = None + pipeline._guidance_scale_2 = None + pipeline._current_timestep = None + pipeline._pp_send_work_list = [] + + config = MagicMock() + config.patch_size = (1, 2, 2) + config.in_channels = 16 + config.out_channels = 16 + pipeline.transformer_config = config + + pipeline.device = torch.device("cpu") + + return pipeline + + +# --------------------------------------------------------------------------- +# 1. Protocol compliance +# --------------------------------------------------------------------------- + + +class TestWan22SupportsStepExecution: + """Verify the class-level protocol flag and method signatures.""" + + def test_class_var_is_true(self): + from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline + + assert hasattr(Wan22Pipeline, "supports_step_execution") + assert Wan22Pipeline.supports_step_execution is True + + def test_has_required_methods(self): + from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline + + for method_name in ("prepare_encode", "denoise_step", "step_scheduler", "post_decode"): + assert hasattr(Wan22Pipeline, method_name), f"Missing method: {method_name}" + +# --------------------------------------------------------------------------- +# 2. _resolve_generation_params helper +# --------------------------------------------------------------------------- + + +class TestResolveGenerationParams: + """Verify parameter resolution and alignment logic.""" + + def test_dimensions_aligned_to_mod_value(self): + pipeline = _make_pipeline_stub() + pipeline.transformer = MagicMock(dtype=torch.bfloat16) + pipeline.transformer_2 = None + + for height in range(400, 600, 50): + for width in range(700, 900, 50): + state = _make_state(height=height, width=width) + params = pipeline._resolve_generation_params(state) + # mod_value = 8 * 2 = 16 + assert params["height"] % 16 == 0 + assert params["width"] % 16 == 0 + + def test_num_frames_aligned_to_vae_temporal(self): + pipeline = _make_pipeline_stub() + pipeline.transformer = MagicMock(dtype=torch.bfloat16) + pipeline.transformer_2 = None + + for num_frames in range(1, 200): + state = _make_state(num_frames=num_frames) + params = pipeline._resolve_generation_params(state) + assert params["num_frames"] % pipeline.vae_scale_factor_temporal == 1 or params["num_frames"] == 1 + + + +# --------------------------------------------------------------------------- +# 3. _select_model_for_timestep helper +# --------------------------------------------------------------------------- + + +class TestSelectModelForTimestep: + def _make(self): + pipeline = _make_pipeline_stub() + pipeline.transformer = MagicMock(name="transformer") + pipeline.transformer_2 = MagicMock(name="transformer_2") + pipeline._guidance_scale = 4.0 + pipeline._guidance_scale_2 = 7.0 + return pipeline + + def test_high_noise_uses_transformer(self): + pipeline = self._make() + model, scale = pipeline._select_model_for_timestep(torch.tensor(800.0), boundary_timestep=500.0) + assert model is pipeline.transformer + assert scale == 4.0 + + def test_low_noise_uses_transformer_2(self): + pipeline = self._make() + model, scale = pipeline._select_model_for_timestep(torch.tensor(300.0), boundary_timestep=500.0) + assert model is pipeline.transformer_2 + assert scale == 7.0 + + def test_no_boundary_uses_transformer(self): + pipeline = self._make() + model, _ = pipeline._select_model_for_timestep(torch.tensor(300.0), boundary_timestep=None) + assert model is pipeline.transformer + + def test_fallback_when_transformer_none(self): + pipeline = self._make() + pipeline.transformer = None + model, _ = pipeline._select_model_for_timestep(torch.tensor(800.0), boundary_timestep=500.0) + assert model is pipeline.transformer_2 + + +# --------------------------------------------------------------------------- +# 4.1 Step execution decomposition matches forward() +# --------------------------------------------------------------------------- + + +class _FakeScheduler: + """Minimal scheduler that applies a deterministic update: latents -= 0.1 * noise_pred.""" + + def __init__(self, timesteps: torch.Tensor): + self.timesteps = timesteps + self._step_index = 0 + self.config = MagicMock() + self.config.num_train_timesteps = 1000 + + def set_timesteps(self, _num_steps, device=None): + pass # timesteps already set + + def step(self, noise_pred, _t, latents, return_dict=False): + self._step_index += 1 + return (latents - 0.1 * noise_pred,) + + +class _FakeTransformer(torch.nn.Module): + """Deterministic transformer: output = input * 0.5 (applied to hidden_states).""" + + def __init__(self): + super().__init__() + self._dummy = torch.nn.Parameter(torch.tensor(1.0)) + + @property + def dtype(self): + return torch.float32 + + def forward(self, hidden_states, timestep, encoder_hidden_states, intermediate_tensors=None, **kwargs): + # Simulate noise prediction: scale down hidden_states + noise_pred = hidden_states * 0.5 + return (noise_pred,) + + +def _patch_parallel_state(): + """Context manager that patches PP and CFG parallel state to single-GPU (world_size=1).""" + from contextlib import ExitStack + + stack = ExitStack() + stack.enter_context( + patch("vllm_omni.diffusion.distributed.pp_parallel.get_pipeline_parallel_world_size", return_value=1) + ) + stack.enter_context( + patch("vllm_omni.diffusion.distributed.cfg_parallel.get_classifier_free_guidance_world_size", return_value=1) + ) + return stack + + +class TestDenoiseStepCorrectness: + """Verify prepare_encode → denoise_step x N → step_scheduler x N → post_decode + produces the same latent trajectory as running the equivalent loop manually.""" + + def _make_pipeline(self): + pipeline = _make_pipeline_stub() + pipeline.transformer = _FakeTransformer() + pipeline.transformer_2 = None + pipeline.expand_timesteps = False + + timesteps = torch.tensor([900.0, 700.0, 500.0, 300.0]) + pipeline.scheduler = _FakeScheduler(timesteps) + + # Mock encode_prompt to return fixed embeddings + prompt_embeds = torch.randn(1, 10, 64) + pipeline.encode_prompt = MagicMock(return_value=(prompt_embeds, None)) + + # Mock VAE decode (identity) + vae = MagicMock() + vae.dtype = torch.float32 + vae.config.latents_mean = [0.0] * 16 + vae.config.latents_std = [1.0] * 16 + vae.config.z_dim = 16 + vae.decode = MagicMock(side_effect=lambda x, **kw: (x,)) + pipeline.vae = vae + + # Mock prepare_latents to return seeded noise + torch.manual_seed(123) + fixed_latents = torch.randn(1, 16, 21, 30, 52) + pipeline.prepare_latents = MagicMock(return_value=fixed_latents.clone()) + + return pipeline, fixed_latents.clone(), prompt_embeds + + def test_latent_trajectory_matches(self): + """Step-by-step execution produces the same final latents as a manual loop.""" + pipeline, initial_latents, prompt_embeds = self._make_pipeline() + timesteps = pipeline.scheduler.timesteps + + # ── Manual baseline loop ── + latents = initial_latents.clone() + for t in timesteps: + latent_input = latents.to(torch.float32) + noise_pred = pipeline.transformer( + hidden_states=latent_input, + timestep=t.expand(1), + encoder_hidden_states=prompt_embeds, + )[0] + latents = latents - 0.1 * noise_pred + baseline_latents = latents + + # ── Step execution path ── + state = _make_state(num_inference_steps=4) + state = pipeline.prepare_encode(state) + + with _patch_parallel_state(): + while not state.denoise_completed: + noise_pred = pipeline.denoise_step(state) + pipeline.step_scheduler(state, noise_pred) + + assert state.step_index == len(timesteps) + torch.testing.assert_close( + state.latents, + baseline_latents, + rtol=1e-5, + atol=1e-5, + msg="Step execution latents diverged from manual baseline", + ) + + def test_post_decode_calls_vae(self): + """post_decode invokes VAE decode and returns DiffusionOutput.""" + from vllm_omni.diffusion.data import DiffusionOutput + + pipeline, _, _ = self._make_pipeline() + + state = _make_state(num_inference_steps=4) + state = pipeline.prepare_encode(state) + with _patch_parallel_state(): + while not state.denoise_completed: + noise_pred = pipeline.denoise_step(state) + pipeline.step_scheduler(state, noise_pred) + + mock_platform = MagicMock() + mock_platform.is_available.return_value = False + with ( + patch.object(type(pipeline), "sync_pp_send"), + patch("vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2.current_omni_platform", mock_platform), + ): + result = pipeline.post_decode(state) + + assert isinstance(result, DiffusionOutput) + assert result.output is not None + pipeline.vae.decode.assert_called_once() + + def test_step_count_matches_timesteps(self): + """Exactly len(timesteps) steps are executed.""" + pipeline, _, _ = self._make_pipeline() + state = _make_state(num_inference_steps=4) + state = pipeline.prepare_encode(state) + + step_count = 0 + with _patch_parallel_state(): + while not state.denoise_completed: + noise_pred = pipeline.denoise_step(state) + pipeline.step_scheduler(state, noise_pred) + step_count += 1 + + assert step_count == 4 + + def test_scheduler_is_deepcopied(self): + """Each request gets its own scheduler copy, not a shared reference.""" + pipeline, _, _ = self._make_pipeline() + original_scheduler = pipeline.scheduler + state = _make_state(num_inference_steps=4) + state = pipeline.prepare_encode(state) + assert state.scheduler is not original_scheduler + + +# --------------------------------------------------------------------------- +# 4.3 _prepare_latent_input I2V mode +# --------------------------------------------------------------------------- + + +class TestPrepareLatentInputI2V: + """Verify I2V mode latent blending and timestep expansion.""" + + def _make_pipeline_and_state(self): + pipeline = _make_pipeline_stub() + pipeline.transformer = MagicMock(dtype=torch.float32) + pipeline.transformer_2 = None + + # Latents: [B=1, C=16, T=5, H=8, W=10] + latents = torch.randn(1, 16, 5, 8, 10) + # Condition: same shape, different values + latent_condition = torch.randn(1, 16, 5, 8, 10) + # Mask: 0 for first frame, 1 for rest + first_frame_mask = torch.ones(1, 1, 5, 8, 10) + first_frame_mask[:, :, 0] = 0 + + state = _make_state() + state.latents = latents + state.extra["expand_timesteps"] = True + state.extra["latent_condition"] = latent_condition + state.extra["first_frame_mask"] = first_frame_mask + + return pipeline, state, latents, latent_condition, first_frame_mask + + def test_i2v_blends_condition_with_latents(self): + """First frame uses condition, remaining frames use latents.""" + pipeline, state, latents, condition, mask = self._make_pipeline_and_state() + t = torch.tensor(500.0) + + latent_input, _ = pipeline._prepare_latent_input(state, t, torch.float32) + + # First frame (mask=0): should be condition + expected_first = condition[:, :, 0] + torch.testing.assert_close(latent_input[:, :, 0], expected_first, rtol=1e-5, atol=1e-5) + + # Remaining frames (mask=1): should be latents + expected_rest = latents[:, :, 1:] + torch.testing.assert_close(latent_input[:, :, 1:], expected_rest, rtol=1e-5, atol=1e-5) + + def test_i2v_timestep_expansion(self): + """Timestep is expanded per-patch: 0 for condition patches, t for noise patches.""" + pipeline, state, _, _, mask = self._make_pipeline_and_state() + t = torch.tensor(500.0) + + _, timestep_tensor = pipeline._prepare_latent_input(state, t, torch.float32) + + # patch_size = (1, 2, 2) → patch dims: T=5, H=4, W=5 + # Sequence length = 5 * 4 * 5 = 100 + assert timestep_tensor.shape[0] == 1 # batch + assert timestep_tensor.shape[1] == 5 * 4 * 5 # flattened patch sequence + + # First frame patches (first 4*5=20) should have timestep 0 + first_frame_patches = timestep_tensor[0, : 4 * 5] + assert (first_frame_patches == 0).all(), "First frame patches should have timestep 0" + + # Remaining patches should have timestep = 500 + rest_patches = timestep_tensor[0, 4 * 5 :] + assert (rest_patches == 500.0).all(), "Non-first-frame patches should have timestep t" + + def test_t2v_mode_passthrough(self): + """T2V mode: latents pass through unchanged, timestep is broadcast.""" + pipeline = _make_pipeline_stub() + pipeline.transformer = MagicMock(dtype=torch.float32) + pipeline.transformer_2 = None + + latents = torch.randn(2, 16, 5, 8, 10) + state = _make_state() + state.latents = latents + # No I2V extras → T2V mode + + t = torch.tensor(500.0) + latent_input, timestep_tensor = pipeline._prepare_latent_input(state, t, torch.float32) + + torch.testing.assert_close(latent_input, latents) + assert timestep_tensor.shape == (2,) + assert (timestep_tensor == 500.0).all() + + +# --------------------------------------------------------------------------- +# 5. denoise_step vs Wan22Pipeline.forward() +# --------------------------------------------------------------------------- + + +def _make_req(num_inference_steps: int = 4): + """Create a minimal T2V OmniDiffusionRequest for forward() comparison tests.""" + from vllm_omni.diffusion.request import OmniDiffusionRequest + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + sampling = OmniDiffusionSamplingParams( + height=480, + width=832, + num_frames=81, + num_inference_steps=num_inference_steps, + guidance_scale=1.0, + max_sequence_length=512, + num_outputs_per_prompt=1, + seed=42, + ) + return OmniDiffusionRequest(prompts=["test prompt"], sampling_params=sampling) + + +class TestDenoiseStepMatchesForward: + """Verify denoise_step matches Wan22Pipeline.forward().""" + + def _make_pipeline(self): + pipeline = _make_pipeline_stub() + pipeline.transformer = _FakeTransformer() + pipeline.transformer_2 = None + pipeline.expand_timesteps = False + + timesteps = torch.tensor([900.0, 700.0, 500.0, 300.0]) + pipeline.scheduler = _FakeScheduler(timesteps) + + torch.manual_seed(17) + prompt_embeds = torch.randn(1, 10, 64) + pipeline.encode_prompt = MagicMock(return_value=(prompt_embeds, None)) + + vae = MagicMock() + vae.dtype = torch.float32 + vae.config.latents_mean = [0.0] * 16 + vae.config.latents_std = [1.0] * 16 + vae.config.z_dim = 16 + pipeline.vae = vae + + torch.manual_seed(13) + fixed_latents = torch.randn(1, 16, 21, 30, 52) + pipeline.prepare_latents = MagicMock(return_value=fixed_latents.clone()) + + return pipeline + + def _run_forward(self, pipeline): + """Run pipeline.forward() in T2V mode and return the final latents.""" + req = _make_req(num_inference_steps=4) + mock_platform = MagicMock() + mock_platform.is_available.return_value = False + with ( + patch("vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2.current_omni_platform", mock_platform), + _patch_parallel_state(), + ): + result = pipeline.forward(req, output_type="latent") + return result.output + + def test_denoise_step_matches_forward(self): + """Full step-execution loop (prepare_encode → denoise_step x N → step_scheduler x N) + produces the same final latents as Wan22Pipeline.forward().""" + pipeline = self._make_pipeline() + + # Reference: monolithic forward() + fwd_latents = self._run_forward(pipeline) + + # Step-execution path using the same pipeline (same latent / prompt_embed mocks) + state = _make_state(num_inference_steps=4) + state = pipeline.prepare_encode(state) + with _patch_parallel_state(): + while not state.denoise_completed: + noise_pred = pipeline.denoise_step(state) + pipeline.step_scheduler(state, noise_pred) + + torch.testing.assert_close(state.latents, fwd_latents, rtol=1e-5, atol=1e-5) + From a542a6c2754761012b13ac7ab38ca516aae72721 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:37:40 +0200 Subject: [PATCH 05/53] add RankTask and track per-rank tasks in DiffusionSchedulerOutput Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/sched/interface.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index 8eef211ffa7..17a01482410 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -106,6 +106,20 @@ def make_empty(cls) -> CachedRequestData: return cls(sched_req_ids=[]) +@dataclass +class RankTask: + """One ``(request, chunk, denoising_step)`` triple assigned to a single PP rank for one micro-step. + + Used by temporal-PP schedulers (e.g. ``StreamBatchScheduler``) to tell each rank + which work to perform in the current micro-step. Spatial-PP schedulers leave + ``DiffusionSchedulerOutput.per_rank_assignment`` as ``None``. + """ + + sched_req_id: str + chunk_idx: int + denoising_step: int + + @dataclass class DiffusionSchedulerOutput: """Output of a single scheduling cycle.""" @@ -116,6 +130,9 @@ class DiffusionSchedulerOutput: finished_req_ids: set[str] num_running_reqs: int num_waiting_reqs: int + # Temporal-PP per-rank assignment table. Index = PP rank id. ``None`` entries + # mark idle ranks (warmup / cooldown). + per_rank_assignment: list[RankTask | None] | None = None @cached_property def scheduled_req_ids(self) -> list[str]: From a493fb31aabddf1bf673f1988501207427cd6b4a Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Fri, 17 Apr 2026 13:46:52 +0200 Subject: [PATCH 06/53] add num_chunks sampling param Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/inputs/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 1b80f4b1b77..6cdf08fa3c4 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -220,6 +220,9 @@ class OmniDiffusionSamplingParams: width_latents: list[int] | int | None = None num_frames: int = 1 # Default for image models num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus + # Number of output chunks the request produces. Read by ``StreamBatchScheduler`` + # (temporal PP) to know how many chunks to admit through the pipeline. + num_chunks: int = 1 # Original dimensions (before VAE scaling) height: int | None = None From 5491d3838b4aad572f537a4b7cd033f85979ff80 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:32:58 +0200 Subject: [PATCH 07/53] Add chunk related data structures to track its state. >> Different ranks work on different chunks. A context manager that views the req state of a rank as a chunk state allows benefitting from existing functionalities. Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/worker/utils.py | 70 ++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index 1e98a7784ab..0c8ac1ca703 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -5,6 +5,8 @@ from __future__ import annotations from abc import ABC, abstractmethod +import contextlib +from collections.abc import Iterator from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -109,6 +111,54 @@ def new_request(self) -> bool: # A real "new request" signal should eventually come from scheduler/runner state transitions. return self.step_index == 0 or self.timesteps is None + @contextlib.contextmanager + def use_chunk(self, chunk: ChunkState) -> Iterator[None]: + """Temporarily alias per-chunk fields on ``self`` to a ``ChunkState``'s view. + + Swapped fields: ``latents``, ``step_index``, ``scheduler``. + + Lets ``prepare_encode`` / ``denoise_step`` / ``step_scheduler`` operate + per-chunk without any pipeline-side changes. Updates made inside the + context are written back to the chunk on exit; the request-level fields + are restored. + """ + saved_latents = self.latents + saved_step_index = self.step_index + saved_scheduler = self.scheduler + self.latents = chunk.latents + self.step_index = chunk.step_index + self.scheduler = chunk.scheduler + try: + yield + finally: + chunk.latents = self.latents + chunk.step_index = self.step_index + chunk.scheduler = self.scheduler + self.latents = saved_latents + self.step_index = saved_step_index + self.scheduler = saved_scheduler + + +@dataclass +class ChunkState: + """Per-chunk state for one in-flight chunk of a streaming request. + + Lives inside ``DiffusionRequestState.extra["chunks"]`` (keyed by + ``chunk_idx``). The runner swaps a chunk into the request state via + ``state.use_chunk(chunk)`` for the duration of one micro-step's + ``denoise_step + step_scheduler`` calls. + + Each chunk owns its own ``scheduler`` instance (deepcopied from the + pipeline's scheduler by ``prepare_encode``) because multi-step ODE solvers + (e.g. ``FlowUniPCMultistepScheduler``) are stateful — they track per-step + ``model_outputs`` that must not leak between chunks. + """ + + idx: int + latents: torch.Tensor | None = None + step_index: int = 0 + scheduler: Any | None = None + class BaseRunnerOutput(ABC): @abstractmethod @@ -117,12 +167,16 @@ def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: @dataclass -class RunnerOutput(BaseRunnerOutput): - """Output of a single denoising step for a request. +class RunnerOutput: + """Output of a single execution step for a request. - NOTE: `latents` may be None when returned through IPC to avoid - serialization overhead. The actual latents are kept in Worker's - _request_state_cache. + Each scheduler reads the fields it needs: + + - ``StepScheduler`` reads ``step_index`` / ``finished``. + - ``StreamBatchScheduler`` reads ``chunk_idx`` / ``step_index`` / + ``chunk_completed`` / ``finished``. + + Fields not relevant to an execution path are left as ``None`` / ``False``. """ req_id: str @@ -130,6 +184,10 @@ class RunnerOutput(BaseRunnerOutput): finished: bool = False result: DiffusionOutput | None = None + # ── Temporal-PP micro-step fields (set by execute_micro_step) ── + chunk_idx: int | None = None + chunk_completed: bool = False + def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: return self if self.req_id == sched_req_id else None @@ -159,4 +217,4 @@ def __len__(self) -> int: @classmethod def from_list(cls, runner_output_list: list[RunnerOutput]) -> BatchRunnerOutput: - return cls(runner_outputs=runner_output_list) + return cls(runner_outputs=runner_output_list) \ No newline at end of file From d9cb79825ab16c7b68da223ca9fa9d59c8075e3c Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:36:54 +0200 Subject: [PATCH 08/53] Add execute_micro_step execution path that works on a per-rank assigned task Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/data.py | 4 + vllm_omni/diffusion/diffusion_engine.py | 7 +- vllm_omni/diffusion/executor/abstract.py | 5 + .../diffusion/executor/multiproc_executor.py | 26 ++++ vllm_omni/diffusion/sched/interface.py | 8 +- .../worker/diffusion_model_runner.py | 134 +++++++++++++++++- .../diffusion/worker/diffusion_worker.py | 19 +++ 7 files changed, 196 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index c45ee63e7f1..0fa53763d0a 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -589,6 +589,10 @@ class OmniDiffusionConfig: # sleep mode enable_sleep_mode: bool = False + # Temporal pipeline parallelism (StreamBatchScheduler-driven streaming chunks). + # When True, the engine uses ``StreamBatchScheduler`` and routes execution + # through ``executor.execute_micro_step``. + stream_batch: bool = False # Maximum number of sequences to generate in a batch max_num_seqs: int = 1 diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index c13bd3c0c37..be693aaf9c8 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -162,7 +162,12 @@ def __init__( self._shutdown_complete = False self.abort_queue: queue.Queue[str] = queue.Queue() self._rpc_queue: queue.Queue[_RpcTask] = queue.Queue() - self.execute_fn = self.executor.execute_step if self.step_execution else self.executor.execute_request + if self.stream_batch: + self.execute_fn = self.executor.execute_micro_step + elif self.step_execution: + self.execute_fn = self.executor.execute_step + else: + self.execute_fn = self.executor.execute_request try: self._dummy_run() diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py index c5abaf59cc7..53715e2f7b1 100644 --- a/vllm_omni/diffusion/executor/abstract.py +++ b/vllm_omni/diffusion/executor/abstract.py @@ -84,6 +84,11 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner """Execute step-mode work from a scheduler output.""" pass + @abstractmethod + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one temporal-PP micro-step from a scheduler output.""" + pass + @abstractmethod def collective_rpc( self, diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index 6727122d462..21b6fe5f9ce 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -344,6 +344,32 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner else: raise RuntimeError(f"Unexpected response type for execute_step: {type(result)!r}") + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Forward a temporal-PP micro-step to worker ``execute_micro_step`` RPC. + + The reply is collected from the last PP rank, which owns rank N-1's ODE + results and any chunk-finished decodes (carried in + ``RunnerOutput.chunk_events``). Other ranks' replies are discarded. + + Assumes worker rank == PP rank (true for PP-only layouts; revisit when + introducing TP/DP combinations). + """ + from vllm_omni.diffusion.worker.utils import RunnerOutput + + self._ensure_open() + last_pp_rank = self.od_config.parallel_config.pipeline_parallel_size - 1 + result = self.collective_rpc( + "execute_micro_step", + args=(scheduler_output,), + unique_reply_rank=last_pp_rank, + exec_all_ranks=True, + ) + if not isinstance(result, RunnerOutput): + raise RuntimeError( + f"Unexpected response type for execute_micro_step: {type(result)!r}" + ) + return result + def collective_rpc( self, method: str, diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index 17a01482410..c94cd657ae8 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -108,16 +108,14 @@ def make_empty(cls) -> CachedRequestData: @dataclass class RankTask: - """One ``(request, chunk, denoising_step)`` triple assigned to a single PP rank for one micro-step. + """One ``(request, chunk, step_index)`` triple assigned to a single PP rank for one micro-step. - Used by temporal-PP schedulers (e.g. ``StreamBatchScheduler``) to tell each rank - which work to perform in the current micro-step. Spatial-PP schedulers leave - ``DiffusionSchedulerOutput.per_rank_assignment`` as ``None``. + Used by ``StreamBatchScheduler`` to tell each rank which work to perform in the current micro-step. """ sched_req_id: str chunk_idx: int - denoising_step: int + step_index: int @dataclass diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index ab6dbb79afd..6a717c42d5a 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -32,9 +32,10 @@ from vllm_omni.diffusion.offloader import get_offload_backend from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.distributed.parallel_state import get_pp_group from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput from vllm_omni.diffusion.worker.input_batch import InputBatch, scatter_latents -from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, DiffusionRequestState, RunnerOutput +from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, DiffusionRequestState, RunnerOutput, ChunkState from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.platforms import current_omni_platform from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin @@ -486,3 +487,134 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BatchR self._update_states_after(states, input_batch, pipeline_interrupted) return BatchRunnerOutput.from_list(runner_output_list) + # ------------------------------------------------------------------ + # Temporal-PP micro-step execution + # ------------------------------------------------------------------ + + @staticmethod + def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: + """Merge K completed chunk outputs into a single ``DiffusionOutput``. + + Concatenates video tensors ``[B, C, T, H, W]`` along the temporal + dimension (dim 2). + """ + if len(chunks) == 1: + return chunks[0] + try: + merged = torch.cat([c.output for c in chunks], dim=2) + except Exception as e: + return DiffusionOutput(error=f"Failed to merge {len(chunks)} chunk outputs: {e}") + return DiffusionOutput(output=merged) + + @staticmethod + def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ChunkState, bool]: + chunks: dict[int, ChunkState] = state.extra.setdefault("chunks", {}) + chunk = chunks.get(chunk_idx) + if chunk is not None: + return chunk, False + chunk = ChunkState(idx=chunk_idx) + chunks[chunk_idx] = chunk + return chunk, True + + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one temporal-PP micro-step. + + Each rank reads its own slot from ``scheduler_output.per_rank_assignment`` + and performs at most one local compute + send + recv via the existing + ``denoise_step`` / ``step_scheduler`` pipeline methods (with chunk-scoped + ``state.latents`` and ``state.step_index`` swapped in via + ``state.use_chunk``). Only the last PP rank emits ``chunk_events``. + """ + assert self.pipeline is not None, "Model not loaded. Call load_model() first." + if not self.supports_step_mode(): + raise ValueError("Current pipeline does not support step execution.") + if self.od_config.cache_backend not in (None, "none"): + raise ValueError("Stream-batch mode does not support cache_backend yet.") + + assignment = scheduler_output.per_rank_assignment + if assignment is None: + raise ValueError("execute_micro_step requires per_rank_assignment in scheduler_output.") + + use_hsdp = self.od_config.parallel_config.use_hsdp + grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() + + with grad_context: + state, is_new_request = self._update_states(scheduler_output) + + # prepare_encode must run on ALL ranks when the request is first + # seen — even if this rank is idle this micro-step — so that shared + # state (prompt embeds, timesteps, scheduler) is available when the + # rank later receives its first chunk assignment. + if is_new_request: + if state.sampling.generator is None and state.sampling.seed is not None: + gen_device = state.sampling.generator_device or ( + "cpu" if self.device.type == "cpu" else self.device + ) + state.sampling.generator = torch.Generator(device=gen_device).manual_seed(state.sampling.seed) + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + self.pipeline.prepare_encode(state) + + pp_group = get_pp_group() + pp_rank = pp_group.rank_in_group + task = assignment[pp_rank] + + # Idle rank (warmup / cooldown / no active request). Still drain prior sends. + if task is None: + if hasattr(self.pipeline, "sync_pp_send"): + self.pipeline.sync_pp_send() + return RunnerOutput(req_id=state.req_id) + + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + + chunk, is_new_chunk = self._get_or_create_chunk(state, task.chunk_idx) + if is_new_chunk: + # First chunk reuses the noise sampled by prepare_encode; + # subsequent chunks draw fresh noise of the same shape from + # the request's generator (deterministic, advances state). + # Each chunk gets its own scheduler deepcopy so multi-step + # ODE solver state doesn't leak between chunks. + chunk.latents = ( + state.latents + if task.chunk_idx == 0 + else torch.randn_like(state.latents, generator=state.sampling.generator) + ) + chunk.scheduler = copy.deepcopy(state.scheduler) + + # Sanity guard: scheduler's per-rank assignment must match what this + # rank currently believes about the chunk's progress. + assert chunk.step_index == task.step_index, ( + f"Stale chunk state on rank {pp_rank}: chunk_idx={task.chunk_idx} " + f"local step_index={chunk.step_index}, scheduler said {task.step_index}" + ) + + with state.use_chunk(chunk): + noise_pred = self.pipeline.denoise_step(state) + if noise_pred is None and getattr(self.pipeline, "interrupt", False): + return RunnerOutput( + req_id=task.sched_req_id, + result=DiffusionOutput(error="micro-step denoise interrupted"), + ) + self.pipeline.step_scheduler(state, noise_pred) + chunk_done = state.denoise_completed + + output = RunnerOutput( + req_id=task.sched_req_id, + step_index=chunk.step_index, + chunk_idx=task.chunk_idx, + ) + + # Only rank N-1 runs post_decode and tracks chunk completion. + if pp_group.is_last_rank and chunk_done: + output.chunk_completed = True + with state.use_chunk(chunk): + chunk_output = self.pipeline.post_decode(state) + state.extra["chunks"].pop(task.chunk_idx, None) + + completed = state.extra.setdefault("completed_chunk_outputs", []) + completed.append(chunk_output) + if len(completed) >= state.sampling.num_chunks: + output.finished = True + output.result = self._merge_chunk_outputs(completed) + self._update_states_after(state, finished=True) + + return output diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index e2c4b97c101..da80798376e 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -382,6 +382,21 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRu profiler.step() return output + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one temporal-PP micro-step by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None: + self.lora_manager.set_active_adapter(None) + if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): + raise ValueError("Stream-batch mode does not support LoRA yet.") + profiler = self._get_profiler() + ctx = profiler.annotate_context_manager("diffusion_micro_step") if profiler else nullcontext() + with ctx: + output = self.model_runner.execute_micro_step(scheduler_output) + if profiler: + profiler.step() + return output + def load_weights(self, weights) -> set[str]: """Load weights by delegating to the model runner.""" assert self.model_runner is not None, "Model runner not initialized" @@ -951,6 +966,10 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRu """Execute one diffusion step.""" return self.worker.execute_stepwise(scheduler_output) + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one temporal-PP micro-step.""" + return self.worker.execute_micro_step(scheduler_output) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """ Load model weights. From 3414a6eac72d075d36cb5802c6fedef8bc31a10d Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:45:41 +0200 Subject: [PATCH 09/53] Add StreamBatchScheduler that controls the flow of chunks through the pipeline (B and T hardcoded for now) Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/diffusion_engine.py | 22 +- vllm_omni/diffusion/sched/__init__.py | 2 + .../diffusion/sched/stream_batch_scheduler.py | 226 ++++++++++++++++++ 3 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 vllm_omni/diffusion/sched/stream_batch_scheduler.py diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index be693aaf9c8..21a2ef16c2b 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -31,7 +31,12 @@ get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface, StepScheduler +from vllm_omni.diffusion.sched import ( + RequestScheduler, + SchedulerInterface, + StepScheduler, + StreamBatchScheduler, +) from vllm_omni.diffusion.sched.interface import DiffusionRequestStatus from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, RunnerOutput from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt @@ -138,9 +143,18 @@ def __init__( executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) self.step_execution = bool(getattr(od_config, "step_execution", False)) - self.scheduler: SchedulerInterface = scheduler or ( - StepScheduler() if self.step_execution else RequestScheduler() - ) + self.stream_batch = bool(getattr(od_config, "stream_batch", False)) + if self.stream_batch and not self.step_execution: + raise ValueError("stream_batch=True requires step_execution=True.") + + if scheduler is not None: + self.scheduler: SchedulerInterface = scheduler + elif self.stream_batch: + self.scheduler = StreamBatchScheduler() + elif self.step_execution: + self.scheduler = StepScheduler() + else: + self.scheduler = RequestScheduler() self.scheduler.initialize(od_config) if self.scheduler.max_num_running_reqs > 1 and not self.step_execution: max_num_seqs = self.scheduler.max_num_running_reqs diff --git a/vllm_omni/diffusion/sched/__init__.py b/vllm_omni/diffusion/sched/__init__.py index e0263733847..bbea1b417d3 100644 --- a/vllm_omni/diffusion/sched/__init__.py +++ b/vllm_omni/diffusion/sched/__init__.py @@ -11,6 +11,7 @@ ) from vllm_omni.diffusion.sched.request_scheduler import RequestScheduler from vllm_omni.diffusion.sched.step_scheduler import StepScheduler +from vllm_omni.diffusion.sched.stream_batch_scheduler import StreamBatchScheduler Scheduler = RequestScheduler @@ -23,5 +24,6 @@ "SchedulerInterface", "RequestScheduler", "StepScheduler", + "StreamBatchScheduler", "Scheduler", ] diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py new file mode 100644 index 00000000000..3893a095178 --- /dev/null +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -0,0 +1,226 @@ +"""Temporal-pipeline-parallel scheduler for streaming chunked diffusion. + +Each ``schedule()`` call corresponds to +one micro-step. At any micro-step, each PP rank processes a different +``(chunk, step_index)`` pair drawn from the active requests' in-flight +chunks. Chunks are admitted to rank 0 in order, propagate through ranks under +NCCL FIFO ordering, and exit at rank N-1 in the same order. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.base_scheduler import _BaseScheduler +from vllm_omni.diffusion.sched.interface import ( + DiffusionRequestStatus, + DiffusionSchedulerOutput, + RankTask, +) + +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import RunnerOutput + +logger = init_logger(__name__) + + +@dataclass +class _InFlightChunk: + """One chunk of an active request, tracked through the temporal pipeline.""" + + chunk_idx: int + step_index: int = 0 + in_pipeline: bool = True # currently flowing through ranks (True between admission and exit) + entered_rank0_at: int = -1 # micro-step at which the chunk last entered rank 0 + + +@dataclass +class _ChunkProgress: + """Per-request chunk-level scheduling state.""" + sched_req_id: str + num_chunks: int # total chunks to produce for this request + num_steps: int # denoising steps per chunk + chunks_admitted: int = 0 + chunks_completed: int = 0 + in_flight: list[_InFlightChunk] = field(default_factory=list) + + +class StreamBatchScheduler(_BaseScheduler): + """Temporal-PP scheduler driving chunked-streaming diffusion requests. + + Per micro-step: + 1. Promote waiting requests up to ``max_num_running_reqs`` (handled by the base class). + 2. Re-admit at most one returning chunk to rank 0 (FIFO across all active requests). + 3. If rank 0 is still free and admission budget remains, admit a new chunk. + 4. Build the per-rank assignment table from in-pipeline chunks' positions. + + A chunk that entered rank 0 at micro-step ``m₀`` is at rank + ``r = current_micro_step - m₀`` while ``0 ≤ r < pp_size``. After ``r == + pp_size - 1``, rank N-1's ``step_scheduler`` runs the ODE; the chunk's + latents are sent back to rank 0; the chunk leaves the pipeline and may be + re-admitted on the next micro-step (until it has run all + ``num_steps`` denoising steps). + """ + + def __init__(self) -> None: + super().__init__() + self.pp_size: int = 1 # set in initialize() + self.B: int = 1 # intra-rank batch + self._global_micro_step: int = 0 + self._chunk_progress: dict[str, _ChunkProgress] = {} + + # ── Lifecycle ────────────────────────────────────────────────────────── + + def initialize(self, od_config: OmniDiffusionConfig) -> None: + super().initialize(od_config) + self.pp_size = od_config.parallel_config.pipeline_parallel_size + + def _reset_scheduler_state(self) -> None: + self._global_micro_step = 0 + self._chunk_progress.clear() + + def _pop_extra_request_state(self, sched_req_id: str) -> None: + self._chunk_progress.pop(sched_req_id, None) + + # ── Request admission ────────────────────────────────────────────────── + + def add_request(self, request: OmniDiffusionRequest) -> str: + num_chunks = request.sampling_params.num_chunks + num_steps = request.sampling_params.num_inference_steps + if num_chunks is None or num_chunks <= 0: + raise ValueError(f"num_chunks must be a positive int, got {num_chunks!r}") + if num_steps is None or num_steps <= 0: + raise ValueError( + f"num_inference_steps must be a positive int, got {num_steps!r}" + ) + return super().add_request(request) + + # ── Scheduling ───────────────────────────────────────────────────────── + + def schedule(self) -> DiffusionSchedulerOutput: + # Base class promotes waiting → running and fills scheduled_new_reqs / step_id. + base_output = super().schedule() + + # Initialize chunk-progress state for any newly promoted requests. + for new_req in base_output.scheduled_new_reqs: + self._init_chunk_progress(new_req.sched_req_id, new_req.req) + + # Re-admit a returning chunk; otherwise admit a new chunk if rank 0 is free. + self._advance_chunk_pipeline() + + # Build the per-rank assignment from current in-pipeline chunks. + if self._chunk_progress: + base_output.per_rank_assignment = self._build_assignment() + # else: no active request → executor sees per_rank_assignment=None and idles. + + self._global_micro_step += 1 + return base_output + + def _init_chunk_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: + num_chunks = req.sampling_params.num_chunks + num_steps = req.sampling_params.num_inference_steps + assert num_chunks is not None and num_steps is not None # validated in add_request() + self._chunk_progress[sched_req_id] = _ChunkProgress( + sched_req_id=sched_req_id, + num_chunks=num_chunks, + num_steps=num_steps, + ) + logger.debug( + "StreamBatchScheduler initialized chunk progress for %s " + "(num_chunks=%d, num_steps=%d, pp_size=%d)", + sched_req_id, num_chunks, num_steps, self.pp_size, + ) + + def _advance_chunk_pipeline(self) -> None: + """Admit at most one chunk to rank 0 this micro-step. + + Re-admission of a returning chunk takes priority over admitting a new + chunk so that FIFO order is preserved (an admitted chunk's latents + always re-enter rank 0 before any later-admitted chunk's first entry). + Admission order across requests follows ``_chunk_progress`` insertion + order, which matches the order the base scheduler promoted them. + """ + if not self._chunk_progress: + return + + # 1. Try to re-admit a returning chunk (FIFO oldest-first across requests). + for progress in self._chunk_progress.values(): + for chunk in progress.in_flight: + if not chunk.in_pipeline and chunk.step_index < progress.num_steps: + chunk.in_pipeline = True + chunk.entered_rank0_at = self._global_micro_step + return # rank 0 is now taken + + # 2. Otherwise admit a new chunk from the first request with budget. + for progress in self._chunk_progress.values(): + if progress.chunks_admitted < progress.num_chunks: + new_chunk = _InFlightChunk( + chunk_idx=progress.chunks_admitted, + step_index=0, + in_pipeline=True, + entered_rank0_at=self._global_micro_step, + ) + progress.in_flight.append(new_chunk) + progress.chunks_admitted += 1 + return + + def _build_assignment(self) -> list[RankTask | None]: + assignment: list[RankTask | None] = [None] * self.pp_size + for progress in self._chunk_progress.values(): + for chunk in progress.in_flight: + if not chunk.in_pipeline: + continue + r = self._global_micro_step - chunk.entered_rank0_at + if 0 <= r < self.pp_size: + assert assignment[r] is None, ( + f"two chunks would be assigned to rank {r} at micro-step " + f"{self._global_micro_step}: existing={assignment[r]}, " + f"new req={progress.sched_req_id} chunk_idx={chunk.chunk_idx}" + ) + assignment[r] = RankTask( + sched_req_id=progress.sched_req_id, + chunk_idx=chunk.chunk_idx, + step_index=chunk.step_index, + ) + return assignment + + # ── Output processing ────────────────────────────────────────────────── + + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput) -> set[str]: + if not self._chunk_progress or sched_output.per_rank_assignment is None: + return set() + + # Read per-chunk fields from RunnerOutput (set by execute_micro_step) + # to advance chunk state — same pattern as StepScheduler reading + # step_index / finished. + terminal: dict[str, DiffusionRequestStatus] = {} + + if output.chunk_idx is not None: + progress = self._chunk_progress.get(output.req_id) + if progress is not None: + chunk = self._find_chunk(progress, output.chunk_idx) + if chunk is not None: + chunk.step_index = output.step_index + chunk.in_pipeline = False + if output.chunk_completed: + progress.in_flight = [ + c for c in progress.in_flight if c.chunk_idx != output.chunk_idx + ] + progress.chunks_completed += 1 + + if output.finished: + terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + + return self._finalize_update_from_output(sched_output, terminal) + + @staticmethod + def _find_chunk(progress: _ChunkProgress, chunk_idx: int) -> _InFlightChunk | None: + for chunk in progress.in_flight: + if chunk.chunk_idx == chunk_idx: + return chunk + return None \ No newline at end of file From 778bc663d6cc0b36836787f1ef88085ddeb38f61 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 20 Apr 2026 10:17:09 +0200 Subject: [PATCH 10/53] Add StreamBatchScheduler tests Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../diffusion/test_stream_batch_scheduler.py | 432 ++++++++++++++++++ 1 file changed, 432 insertions(+) create mode 100644 tests/diffusion/test_stream_batch_scheduler.py diff --git a/tests/diffusion/test_stream_batch_scheduler.py b/tests/diffusion/test_stream_batch_scheduler.py new file mode 100644 index 00000000000..c5a94440390 --- /dev/null +++ b/tests/diffusion/test_stream_batch_scheduler.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for StreamBatchScheduler (temporal PP chunk scheduling).""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.stream_batch_scheduler import StreamBatchScheduler +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_config(pp_size: int = 2) -> SimpleNamespace: + """Minimal OmniDiffusionConfig stub with the fields StreamBatchScheduler reads.""" + return SimpleNamespace(parallel_config=DiffusionParallelConfig(pipeline_parallel_size=pp_size)) + + +def _make_request( + req_id: str, + num_chunks: int = 1, + num_inference_steps: int = 4, +) -> OmniDiffusionRequest: + return OmniDiffusionRequest( + prompts=[f"prompt_{req_id}"], + sampling_params=OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + num_chunks=num_chunks, + ), + request_ids=[req_id], + ) + + +def _make_runner_output( + req_id: str = "", + step_index: int | None = None, + chunk_idx: int | None = None, + chunk_completed: bool = False, + finished: bool = False, +) -> SimpleNamespace: + """Simulate a RunnerOutput from rank N-1.""" + return SimpleNamespace( + req_id=req_id, + step_index=step_index, + chunk_idx=chunk_idx, + chunk_completed=chunk_completed, + finished=finished, + result=None, + ) + + +def _simulate_last_rank_output(sched_output, pp_size: int, num_steps: int) -> SimpleNamespace: + """Build the RunnerOutput that rank N-1 would produce for a given schedule output. + + Simulates: if rank N-1 had a task, its chunk's step_index advances by 1. + chunk_completed is True if the new step_index reaches num_steps. + """ + assignment = sched_output.per_rank_assignment + if assignment is None: + return _make_runner_output() + task = assignment[pp_size - 1] + if task is None: + req_id = sched_output.scheduled_req_ids[0] if sched_output.scheduled_req_ids else "" + return _make_runner_output(req_id=req_id) + new_step = task.step_index + 1 + return _make_runner_output( + req_id=task.sched_req_id, + step_index=new_step, + chunk_idx=task.chunk_idx, + chunk_completed=(new_step >= num_steps), + ) + + +def _run_until_finished( + scheduler: StreamBatchScheduler, + pp_size: int, + num_steps: int, + num_chunks: int, + max_iters: int = 200, +) -> list[tuple[list, SimpleNamespace]]: + """Drive the scheduler loop, returning (assignment, runner_output) per micro-step. + + Simulates what the runner would set in ``RunnerOutput.finished``: True once + the total number of chunks produced equals ``num_chunks`` for that request. + """ + trace: list[tuple[list, SimpleNamespace]] = [] + completed_per_req: dict[str, int] = {} + for _ in range(max_iters): + sched_output = scheduler.schedule() + output = _simulate_last_rank_output(sched_output, pp_size, num_steps) + if output.chunk_completed: + completed_per_req[output.req_id] = completed_per_req.get(output.req_id, 0) + 1 + if completed_per_req[output.req_id] >= num_chunks: + output.finished = True + finished = scheduler.update_from_output(sched_output, output) + assignment = sched_output.per_rank_assignment or [None] * pp_size + trace.append((assignment, output)) + if finished: + break + return trace + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestAddRequestValidation: + def test_rejects_zero_chunks(self): + sched = StreamBatchScheduler() + sched.initialize(_make_config()) + with pytest.raises(ValueError, match="num_chunks"): + sched.add_request(_make_request("r1", num_chunks=0)) + + def test_rejects_negative_steps(self): + sched = StreamBatchScheduler() + sched.initialize(_make_config()) + with pytest.raises(ValueError, match="num_inference_steps"): + sched.add_request(_make_request("r1", num_inference_steps=-1)) + + def test_accepts_valid_request(self): + sched = StreamBatchScheduler() + sched.initialize(_make_config()) + req_id = sched.add_request(_make_request("r1", num_chunks=3, num_inference_steps=10)) + assert req_id == "r1" + + +# --------------------------------------------------------------------------- +# Single chunk, single rank (PP=1) +# --------------------------------------------------------------------------- + + +class TestSingleChunkSingleRank: + """PP=1, K=1 — degenerate case: one chunk, one rank, behaves like step scheduler.""" + + def test_completes_in_m_steps(self): + pp_size, num_chunks, num_steps = 1, 1, 4 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + assert len(trace) == num_steps + + def test_assignment_is_always_rank_0(self): + pp_size, num_chunks, num_steps = 1, 1, 3 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + for assignment, _ in trace: + assert assignment[0] is not None + assert assignment[0].sched_req_id == "r1" + assert assignment[0].chunk_idx == 0 + + def test_step_index_advances(self): + pp_size, num_chunks, num_steps = 1, 1, 3 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + step_indices = [a[0].step_index for a, _ in trace] + assert step_indices == [0, 1, 2] + + +# --------------------------------------------------------------------------- +# Multi-chunk, single rank (PP=1) +# --------------------------------------------------------------------------- + + +class TestMultiChunkSingleRank: + """PP=1, K>1 — chunks are processed sequentially on one rank.""" + + def test_completes_in_k_times_m_steps(self): + pp_size, num_chunks, num_steps = 1, 3, 2 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + assert len(trace) == num_chunks * num_steps + + def test_chunks_processed_in_order(self): + pp_size, num_chunks, num_steps = 1, 3, 2 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + chunk_indices = [a[0].chunk_idx for a, _ in trace] + # Each chunk runs M steps before the next chunk starts. + assert chunk_indices == [0, 0, 1, 1, 2, 2] + + +# --------------------------------------------------------------------------- +# Pipeline warmup / assignment (PP > 1) +# --------------------------------------------------------------------------- + + +class TestPipelineAssignment: + """Verify per-rank assignment table for N=3, K=4, M=2.""" + + def _setup(self): + pp_size, num_chunks, num_steps = 3, 4, 2 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + return sched, pp_size, num_steps, num_chunks + + def _extract_assignment(self, assignment): + """Convert assignment list to tuples of (chunk_idx, step_index) or None.""" + return [ + (t.chunk_idx, t.step_index) if t is not None else None + for t in assignment + ] + + def test_warmup_idles_trailing_ranks(self): + sched, pp_size, num_steps, num_chunks = self._setup() + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + + # Micro-step 0: only rank 0 active. + a0 = self._extract_assignment(trace[0][0]) + assert a0[0] is not None + assert a0[1] is None + assert a0[2] is None + + def test_warmup_fills_pipeline(self): + sched, pp_size, num_steps, num_chunks = self._setup() + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + + # Micro-step 0: rank 0 = chunk 0 + assert self._extract_assignment(trace[0][0]) == [(0, 0), None, None] + # Micro-step 1: rank 0 = chunk 1, rank 1 = chunk 0 + assert self._extract_assignment(trace[1][0]) == [(1, 0), (0, 0), None] + # Micro-step 2: all ranks busy + a2 = self._extract_assignment(trace[2][0]) + assert all(x is not None for x in a2) + + def test_chunk_propagates_through_ranks(self): + sched, pp_size, num_steps, num_chunks = self._setup() + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + + # Chunk 0 should appear at rank 0, then rank 1, then rank 2. + chunk0_ranks = [] + for assignment, _ in trace: + for r, task in enumerate(assignment): + if task is not None and task.chunk_idx == 0: + chunk0_ranks.append(r) + break + # First 3 entries: ranks 0, 1, 2 (warmup propagation of chunk 0's first step). + assert chunk0_ranks[:3] == [0, 1, 2] + + def test_request_completes(self): + sched, pp_size, num_steps, num_chunks = self._setup() + trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) + + # The last micro-step's output should have finished=True (set by _simulate). + # But finished is set by the runner, not the scheduler. In our simulation, + # we don't set finished=True. Instead, check that the scheduler reported + # the request in finished_req_ids (the loop exited). + assert len(trace) > 0 # loop exited → request finished + + +# --------------------------------------------------------------------------- +# Chunk re-admission +# --------------------------------------------------------------------------- + + +class TestChunkReAdmission: + """Verify that a chunk returning from rank N-1 is re-admitted to rank 0.""" + + def test_re_admission_priority_over_new_chunk(self): + """With K=2, M=2, N=2: after chunk 0 exits rank 1 at ms1, + it should re-enter rank 0 at ms2 (priority over admitting chunk 1).""" + pp_size, num_chunks, num_steps = 2, 2, 2 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + # ms0: [chunk0, idle] + out0 = sched.schedule() + assert out0.per_rank_assignment[0].chunk_idx == 0 + assert out0.per_rank_assignment[1] is None + sched.update_from_output(out0, _simulate_last_rank_output(out0, pp_size, num_steps)) + + # ms1: [chunk1, chunk0] — chunk 0 reaches rank 1 and completes step 0. + out1 = sched.schedule() + assert out1.per_rank_assignment[0].chunk_idx == 1 + assert out1.per_rank_assignment[1].chunk_idx == 0 + sched.update_from_output(out1, _simulate_last_rank_output(out1, pp_size, num_steps)) + + # ms2: chunk 0 should be re-admitted (step 1), NOT chunk 1 continuing. + out2 = sched.schedule() + assert out2.per_rank_assignment[0].chunk_idx == 0 + assert out2.per_rank_assignment[0].step_index == 1 + + def test_chunk_not_readmitted_after_completion(self): + """A chunk that finished all denoising steps should NOT be re-admitted.""" + pp_size, num_chunks, num_steps = 1, 1, 2 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + # Run step 0. + out0 = sched.schedule() + runner0 = _make_runner_output(req_id="r1", step_index=1, chunk_idx=0) + sched.update_from_output(out0, runner0) + + # Run step 1 (final). + out1 = sched.schedule() + runner1 = _make_runner_output( + req_id="r1", step_index=2, chunk_idx=0, chunk_completed=True, finished=True, + ) + finished = sched.update_from_output(out1, runner1) + assert "r1" in finished + + # No more requests. + assert not sched.has_requests() + + +# --------------------------------------------------------------------------- +# Completion ordering +# --------------------------------------------------------------------------- + + +class TestCompletionOrdering: + """Verify chunks complete in admission order (FIFO).""" + + def test_chunks_complete_in_fifo_order(self): + pp_size, num_chunks, num_steps = 2, 3, 1 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + completed_chunks: list[int] = [] + for _ in range(20): + out = sched.schedule() + runner = _simulate_last_rank_output(out, pp_size, num_steps) + if runner.chunk_completed: + completed_chunks.append(runner.chunk_idx) + finished = sched.update_from_output(out, runner) + if finished: + break + + assert completed_chunks == [0, 1, 2] + + +# --------------------------------------------------------------------------- +# Request finished signal +# --------------------------------------------------------------------------- + + +class TestRequestFinished: + def test_finished_after_all_chunks(self): + pp_size, num_chunks, num_steps = 1, 3, 1 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + finished_set = set() + for _ in range(20): + out = sched.schedule() + runner = _simulate_last_rank_output(out, pp_size, num_steps) + # On the last chunk, set finished=True (simulating the runner's merge). + if runner.chunk_completed: + progress = sched._chunk_progress.get("r1") + if progress and progress.chunks_completed + 1 >= num_chunks: + runner.finished = True + finished_set = sched.update_from_output(out, runner) + if finished_set: + break + + assert "r1" in finished_set + assert not sched.has_requests() + + def test_not_finished_before_all_chunks(self): + pp_size, num_chunks, num_steps = 1, 3, 1 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) + + # Process only 2 of 3 chunks. + for i in range(2): + out = sched.schedule() + runner = _make_runner_output( + req_id="r1", step_index=1, chunk_idx=i, chunk_completed=True, + ) + finished = sched.update_from_output(out, runner) + assert not finished + + assert sched.has_requests() + + +# --------------------------------------------------------------------------- +# Sequential requests +# --------------------------------------------------------------------------- + + +class TestSequentialRequests: + """Second request is processed after the first finishes.""" + + def test_second_request_starts_after_first(self): + pp_size, num_steps = 1, 1 + sched = StreamBatchScheduler() + sched.initialize(_make_config(pp_size)) + sched.add_request(_make_request("r1", num_chunks=1, num_inference_steps=num_steps)) + sched.add_request(_make_request("r2", num_chunks=1, num_inference_steps=num_steps)) + + # Process r1. + out1 = sched.schedule() + assert out1.per_rank_assignment[0].sched_req_id == "r1" + sched.update_from_output( + out1, + _make_runner_output(req_id="r1", step_index=1, chunk_idx=0, chunk_completed=True, finished=True), + ) + + # r1 finished. Next schedule should pick r2. + out2 = sched.schedule() + assert out2.per_rank_assignment[0].sched_req_id == "r2" \ No newline at end of file From 95a2bd153b2feb59258bc34a6b55fb05c64e4305 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 20 Apr 2026 10:17:31 +0200 Subject: [PATCH 11/53] Add micro-step execution pipeline tests Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../test_diffusion_micro_step_pipeline.py | 486 ++++++++++++++++++ 1 file changed, 486 insertions(+) create mode 100644 tests/diffusion/test_diffusion_micro_step_pipeline.py diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py new file mode 100644 index 00000000000..df87bdc10a4 --- /dev/null +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -0,0 +1,486 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for micro-step (temporal PP) execution across runner / worker / executor / engine.""" + +from __future__ import annotations + +import copy +import queue +import threading +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest +import torch +from pytest_mock import MockerFixture + +import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module +from vllm_omni.diffusion.data import DiffusionOutput, DiffusionParallelConfig +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine +from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched import StreamBatchScheduler +from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, + DiffusionSchedulerOutput, + NewRequestData, + RankTask, +) +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker +from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Helpers & fixtures +# --------------------------------------------------------------------------- + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + del args, kwargs + yield + + +class _FakeScheduler: + """Minimal scheduler deepcopyable and tracks step_index.""" + + def __init__(self): + self._step_index = 0 + + def step(self, noise_pred, t, latents, return_dict=False): + del t, return_dict + self._step_index += 1 + return (latents + noise_pred,) + + +class _MicroStepPipeline: + """Minimal pipeline stub supporting micro-step execution.""" + + supports_step_execution = True + + def __init__(self): + self.prepare_calls = 0 + self.denoise_calls = 0 + self.scheduler_calls = 0 + self.decode_calls = 0 + self.sync_calls = 0 + self.scheduler = _FakeScheduler() + + def prepare_encode(self, state, **kwargs): + del kwargs + self.prepare_calls += 1 + n = state.sampling.num_inference_steps + state.timesteps = [torch.tensor(float(n - i)) for i in range(n)] + state.latents = torch.zeros((1, 1, 2, 2, 2)) # [B, C, T, H, W] video-like + state.step_index = 0 + state.scheduler = copy.deepcopy(self.scheduler) + return state + + def denoise_step(self, state, **kwargs): + del kwargs + self.denoise_calls += 1 + return torch.ones_like(state.latents) + + def step_scheduler(self, state, noise_pred, **kwargs): + del noise_pred, kwargs + self.scheduler_calls += 1 + state.step_index += 1 + + def post_decode(self, state, **kwargs): + del kwargs + self.decode_calls += 1 + # Produce a per-chunk video tensor uniquely tagged by decode call count + # so we can verify concatenation order downstream. + return DiffusionOutput(output=torch.full((1, 1, 2, 2, 2), float(self.decode_calls))) + + def sync_pp_send(self): + self.sync_calls += 1 + + +def _make_pp_group(rank: int, world_size: int) -> SimpleNamespace: + """Mock PP group for the runner's get_pp_group() call.""" + return SimpleNamespace( + rank_in_group=rank, + is_first_rank=(rank == 0), + is_last_rank=(rank == world_size - 1), + ) + + +def _make_runner(pp_size: int = 1, pp_rank: int = 0) -> DiffusionModelRunner: + runner = object.__new__(DiffusionModelRunner) + runner.vllm_config = object() + runner.od_config = SimpleNamespace( + cache_backend=None, + parallel_config=SimpleNamespace(use_hsdp=False, pipeline_parallel_size=pp_size), + ) + runner.device = torch.device("cpu") + runner.pipeline = _MicroStepPipeline() + runner.cache_backend = None + runner.offload_backend = None + runner.state_cache = {} + runner.kv_transfer_manager = SimpleNamespace() + runner._pp_group = _make_pp_group(pp_rank, pp_size) + return runner + + +def _install_pp_group_stub(monkeypatch, rank: int, world_size: int) -> None: + """Replace get_pp_group in the runner module with a constant stub.""" + monkeypatch.setattr( + model_runner_module, + "get_pp_group", + lambda: _make_pp_group(rank, world_size), + ) + + +def _make_micro_step_request( + num_chunks: int = 1, num_inference_steps: int = 2 +) -> OmniDiffusionRequest: + return OmniDiffusionRequest( + prompts=["a prompt"], + sampling_params=OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + num_chunks=num_chunks, + seed=42, + ), + request_ids=["req-1"], + ) + + +def _make_micro_step_scheduler_output( + task: RankTask | None, + pp_size: int, + req: OmniDiffusionRequest | None = None, + step_id: int = 0, + finished_req_ids: set[str] | None = None, +) -> DiffusionSchedulerOutput: + """Scheduler output with a single-rank assignment (the rest idle).""" + assignment: list[RankTask | None] = [None] * pp_size + if task is not None: + # For the runner we're simulating, the task is at the rank the runner reports. + # Tests set up pp_rank separately via monkeypatch; task is placed at rank 0 by default. + assignment[0] = task + + new_reqs = [] + cached = CachedRequestData.make_empty() + if req is not None: + new_reqs = [NewRequestData(sched_req_id="req-1", req=req)] + else: + cached = CachedRequestData(sched_req_ids=["req-1"]) + + return DiffusionSchedulerOutput( + step_id=step_id, + scheduled_new_reqs=new_reqs, + scheduled_cached_reqs=cached, + finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), + num_running_reqs=1, + num_waiting_reqs=0, + per_rank_assignment=assignment, + ) + + +def _make_engine(scheduler, execute_fn=None, stream_batch: bool = True) -> DiffusionEngine: + engine = object.__new__(DiffusionEngine) + engine.od_config = SimpleNamespace(model_class_name="Wan22Pipeline") + engine.pre_process_func = None + engine.post_process_func = None + engine.scheduler = scheduler + engine.execute_fn = execute_fn + engine.stream_batch = stream_batch + engine.step_execution = True + engine._rpc_lock = threading.RLock() + engine.abort_queue = queue.Queue() + return engine + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + + +class TestMicroStepRunner: + """DiffusionModelRunner.execute_micro_step""" + + def test_single_chunk_completes_and_returns_merged(self, monkeypatch): + runner = _make_runner(pp_size=1, pp_rank=0) + _install_pp_group_stub(monkeypatch, rank=0, world_size=1) + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + req = _make_micro_step_request(num_chunks=1, num_inference_steps=2) + + # Step 0. + out0 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_step_scheduler_output( + RankTask(sched_req_id="req-1", chunk_idx=0, step_index=0), pp_size=1, req=req, + ), + ) + assert out0.chunk_idx == 0 + assert out0.step_index == 1 + assert out0.chunk_completed is False + assert out0.finished is False + assert out0.result is None + + # Step 1 (completes the chunk — with single rank, single chunk, this finishes the request). + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_step_scheduler_output( + RankTask(sched_req_id="req-1", chunk_idx=0, step_index=1), pp_size=1, + ), + ) + assert out1.chunk_idx == 0 + assert out1.step_index == 2 + assert out1.chunk_completed is True + assert out1.finished is True + assert out1.result is not None + assert runner.pipeline.decode_calls == 1 + # State cache should be cleared once the request completes. + assert "req-1" not in runner.state_cache + + def test_multi_chunk_produces_concatenated_result(self, monkeypatch): + runner = _make_runner(pp_size=1, pp_rank=0) + _install_pp_group_stub(monkeypatch, rank=0, world_size=1) + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + req = _make_micro_step_request(num_chunks=2, num_inference_steps=1) + + # Chunk 0, step 0 (completes chunk 0). + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_step_scheduler_output( + RankTask(sched_req_id="req-1", chunk_idx=0, step_index=0), pp_size=1, req=req, + ), + ) + # Chunk 1, step 0 (completes chunk 1 → merges → finishes request). + final = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_step_scheduler_output( + RankTask(sched_req_id="req-1", chunk_idx=1, step_index=0), pp_size=1, + ), + ) + assert final.finished is True + assert final.result is not None + # Two chunks concatenated along time dim (dim 2): [1, 1, 4, 2, 2]. + assert final.result.output.shape == (1, 1, 4, 2, 2) + # First chunk's frames tagged 1.0, second chunk's tagged 2.0. + assert torch.all(final.result.output[:, :, :2] == 1.0) + assert torch.all(final.result.output[:, :, 2:] == 2.0) + assert runner.pipeline.decode_calls == 2 + + def test_idle_rank_returns_early_and_syncs(self, monkeypatch): + runner = _make_runner(pp_size=2, pp_rank=1) + _install_pp_group_stub(monkeypatch, rank=1, world_size=2) + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + req = _make_micro_step_request(num_chunks=1, num_inference_steps=2) + + # Assignment puts task at rank 0 only; rank 1 (this runner) is idle. + sched_output = _make_micro_step_scheduler_output( + RankTask(sched_req_id="req-1", chunk_idx=0, step_index=0), pp_size=2, req=req, + ) + out = DiffusionModelRunner.execute_micro_step(runner, sched_output) + + assert out.chunk_idx is None + assert out.step_index is None + assert out.finished is False + # Idle path still drains pending sends. + assert runner.pipeline.sync_calls == 1 + # prepare_encode must run on idle ranks so shared state is ready when + # the rank later receives a chunk. + assert runner.pipeline.prepare_calls == 1 + + def test_rejects_missing_per_rank_assignment(self, monkeypatch): + runner = _make_runner(pp_size=1, pp_rank=0) + _install_pp_group_stub(monkeypatch, rank=0, world_size=1) + req = _make_micro_step_request() + + sched_output = DiffusionSchedulerOutput( + step_id=0, + scheduled_new_reqs=[NewRequestData(sched_req_id="req-1", req=req)], + scheduled_cached_reqs=CachedRequestData.make_empty(), + finished_req_ids=set(), + num_running_reqs=1, + num_waiting_reqs=0, + per_rank_assignment=None, + ) + + with pytest.raises(ValueError, match="per_rank_assignment"): + DiffusionModelRunner.execute_micro_step(runner, sched_output) + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + + +class TestMicroStepWorker: + """DiffusionWorker.execute_micro_step""" + + def test_delegates_to_model_runner(self): + worker = object.__new__(DiffusionWorker) + expected = RunnerOutput(req_id="req-1", chunk_idx=0, step_index=1) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace( + req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None)) + ) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace( + execute_micro_step=lambda arg: expected if arg is scheduler_output else None + ) + + output = DiffusionWorker.execute_micro_step(worker, scheduler_output) + assert output is expected + + def test_rejects_lora_requests(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace( + req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=object())) + ) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace(execute_micro_step=lambda arg: RunnerOutput(req_id="req-1")) + + with pytest.raises(ValueError, match="does not support LoRA"): + DiffusionWorker.execute_micro_step(worker, scheduler_output) + + +# --------------------------------------------------------------------------- +# Executor +# --------------------------------------------------------------------------- + + +class TestMicroStepExecutor: + """MultiprocDiffusionExecutor.execute_micro_step""" + + def test_passes_through_runner_output_and_uses_last_pp_rank(self, mocker: MockerFixture): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._ensure_open = lambda: None + executor.od_config = SimpleNamespace( + parallel_config=DiffusionParallelConfig(pipeline_parallel_size=4), + ) + expected = RunnerOutput(req_id="req-1", chunk_idx=0, step_index=1, chunk_completed=False) + rpc_mock = mocker.Mock(return_value=expected) + executor.collective_rpc = rpc_mock + + sched_output = DiffusionSchedulerOutput( + step_id=0, + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData(sched_req_ids=["req-1"]), + finished_req_ids=set(), + num_running_reqs=1, + num_waiting_reqs=0, + per_rank_assignment=[None, None, None, RankTask("req-1", 0, 0)], + ) + + result = MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) + + assert result is expected + # Reply is collected from the last PP rank (index pp_size - 1). + rpc_mock.assert_called_once() + kwargs = rpc_mock.call_args.kwargs + assert kwargs["unique_reply_rank"] == 3 # pp_size=4 → last rank=3 + assert kwargs["exec_all_ranks"] is True + + +# --------------------------------------------------------------------------- +# Engine (full loop: scheduler + executor + engine) +# --------------------------------------------------------------------------- + + +class TestMicroStepEngine: + """Full stream-batch flow through DiffusionEngine.add_req_and_wait_for_response.""" + + def _make_scheduler(self, pp_size: int = 1) -> StreamBatchScheduler: + scheduler = StreamBatchScheduler() + scheduler.initialize( + SimpleNamespace(parallel_config=DiffusionParallelConfig(pipeline_parallel_size=pp_size)) + ) + return scheduler + + def _make_execute_fn(self, num_chunks: int, num_steps: int): + """Simulate the executor: advance each micro-step's last-rank chunk.""" + completed = {"n": 0} + + def execute_fn(sched_output): + assignment = sched_output.per_rank_assignment + if assignment is None: + return RunnerOutput(req_id="") + task = assignment[-1] # last rank's slot + if task is None: + req_id = sched_output.scheduled_req_ids[0] if sched_output.scheduled_req_ids else "" + return RunnerOutput(req_id=req_id) + + new_step = task.step_index + 1 + chunk_completed = new_step >= num_steps + finished = False + result = None + if chunk_completed: + completed["n"] += 1 + if completed["n"] >= num_chunks: + finished = True + result = DiffusionOutput(output=torch.tensor([float(completed["n"])])) + return RunnerOutput( + req_id=task.sched_req_id, + step_index=new_step, + chunk_idx=task.chunk_idx, + chunk_completed=chunk_completed, + finished=finished, + result=result, + ) + + return execute_fn + + def test_single_chunk_completes(self): + scheduler = self._make_scheduler(pp_size=1) + engine = _make_engine(scheduler, execute_fn=self._make_execute_fn(num_chunks=1, num_steps=2)) + request = _make_micro_step_request(num_chunks=1, num_inference_steps=2) + + output = engine.add_req_and_wait_for_response(request) + + assert output.error is None + assert output.aborted is False + assert torch.equal(output.output, torch.tensor([1.0])) + + def test_multi_chunk_completes(self): + scheduler = self._make_scheduler(pp_size=1) + engine = _make_engine(scheduler, execute_fn=self._make_execute_fn(num_chunks=3, num_steps=2)) + request = _make_micro_step_request(num_chunks=3, num_inference_steps=2) + + output = engine.add_req_and_wait_for_response(request) + + assert output.error is None + assert torch.equal(output.output, torch.tensor([3.0])) # completed 3 chunks + + def test_execute_fn_exception_returns_error(self): + scheduler = self._make_scheduler(pp_size=1) + + def failing(_): + raise RuntimeError("gpu on fire") + + engine = _make_engine(scheduler, execute_fn=failing) + output = engine.add_req_and_wait_for_response(_make_micro_step_request()) + + assert output.output is None + assert "gpu on fire" in output.error + + def test_pipeline_fills_with_pp_gt_1(self): + """With PP>1, scheduler drives warmup/steady/cooldown; engine sees final merged output.""" + pp_size = 3 + num_chunks = 4 + num_steps = 2 + scheduler = self._make_scheduler(pp_size=pp_size) + engine = _make_engine( + scheduler, execute_fn=self._make_execute_fn(num_chunks=num_chunks, num_steps=num_steps) + ) + request = _make_micro_step_request(num_chunks=num_chunks, num_inference_steps=num_steps) + + output = engine.add_req_and_wait_for_response(request) + + assert output.error is None + assert torch.equal(output.output, torch.tensor([float(num_chunks)])) \ No newline at end of file From e18732da0d9456b4e065a2a5ec994f543b1c472b Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 21 Apr 2026 09:01:29 +0200 Subject: [PATCH 12/53] Add stream_batch arg Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/engine/async_omni_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 420f808dc2f..b6828c7add9 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1913,6 +1913,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "model_class_name": kwargs.get("model_class_name", None), "additional_config": kwargs.get("additional_config", None), "step_execution": kwargs.get("step_execution", False), + "stream_batch": kwargs.get("stream_batch", False), "vae_use_slicing": kwargs.get("vae_use_slicing", False), "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, From 90a9bfe8606c9739320595b774d0d2909da070df Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 21 Apr 2026 09:02:16 +0200 Subject: [PATCH 13/53] Set reply rank to 0 Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- tests/diffusion/test_diffusion_micro_step_pipeline.py | 6 +++--- vllm_omni/diffusion/executor/multiproc_executor.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index df87bdc10a4..f4b7af3f064 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -358,7 +358,7 @@ def test_rejects_lora_requests(self): class TestMicroStepExecutor: """MultiprocDiffusionExecutor.execute_micro_step""" - def test_passes_through_runner_output_and_uses_last_pp_rank(self, mocker: MockerFixture): + def test_passes_through_runner_output_and_uses_first_pp_rank(self, mocker: MockerFixture): executor = object.__new__(MultiprocDiffusionExecutor) executor._ensure_open = lambda: None executor.od_config = SimpleNamespace( @@ -381,10 +381,10 @@ def test_passes_through_runner_output_and_uses_last_pp_rank(self, mocker: Mocker result = MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) assert result is expected - # Reply is collected from the last PP rank (index pp_size - 1). + # Reply is collected from the first PP rank (index 0). rpc_mock.assert_called_once() kwargs = rpc_mock.call_args.kwargs - assert kwargs["unique_reply_rank"] == 3 # pp_size=4 → last rank=3 + assert kwargs["unique_reply_rank"] == 0 assert kwargs["exec_all_ranks"] is True diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index 21b6fe5f9ce..94ec2e0a18d 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -357,11 +357,11 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn from vllm_omni.diffusion.worker.utils import RunnerOutput self._ensure_open() - last_pp_rank = self.od_config.parallel_config.pipeline_parallel_size - 1 + result = self.collective_rpc( "execute_micro_step", args=(scheduler_output,), - unique_reply_rank=last_pp_rank, + unique_reply_rank=0, exec_all_ranks=True, ) if not isinstance(result, RunnerOutput): From 53742660651044db1a27032f3b98da15aeae8929 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:43:16 +0200 Subject: [PATCH 14/53] Fix blocking send/recv resulting in deadlock by registering metadata instead of sync send/recv Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 93 +++++++++++++++++++ .../distributed/pipeline_parallel.py | 57 ++++++++++-- vllm_omni/diffusion/sched/interface.py | 6 +- .../diffusion/sched/stream_batch_scheduler.py | 9 +- .../worker/diffusion_model_runner.py | 53 ++++------- vllm_omni/diffusion/worker/utils.py | 2 +- 6 files changed, 165 insertions(+), 55 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 5c5b092ef21..beb6b487315 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -5,6 +5,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import pickle from collections import namedtuple +from dataclasses import dataclass from typing import Any import torch @@ -23,6 +24,14 @@ env_info = envs.PACKAGES_CHECKER.get_packages_info() +@dataclass +class _RegisteredRecvChannel: + """Pre-allocated recv buffer for a registered P2P channel.""" + + tensor_dict: dict[str, Any] + ordered_tensors: list[tuple[str, torch.Tensor]] + + def _split_tensor_dict( tensor_dict: dict[str, torch.Tensor | Any], prefix: str = "" ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: @@ -123,6 +132,9 @@ def __init__( self.device = current_omni_platform.get_torch_device(local_rank) + # Registered recv buffers for zero-metadata P2P channels. + self._recv_channels: dict[str, _RegisteredRecvChannel] = {} + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -501,6 +513,87 @@ def irecv_tensor_dict( return tensor_dict, handles, [] + def register_recv_channel( + self, + name: str, + tensor_specs: list[tuple[str, torch.Size, torch.dtype, str]], + ) -> None: + """Pre-allocate recv buffers for a channel. Local only, no comm. + + Each spec is ``(flattened_key, shape, dtype, device_type)``; + ``device_type`` is ``"cpu"`` or any type accepted by ``torch.device`` + (the coordinator's device is used for non-cpu types). + """ + tensor_dict: dict[str, Any] = {} + ordered: list[tuple[str, torch.Tensor]] = [] + for key, shape, dtype, device_type in tensor_specs: + device = torch.device("cpu") if device_type == "cpu" else self.device + tensor = torch.empty(shape, dtype=dtype, device=device) + _update_nested_dict(tensor_dict, key, tensor) + ordered.append((key, tensor)) + self._recv_channels[name] = _RegisteredRecvChannel(tensor_dict=tensor_dict, ordered_tensors=ordered) + + def isend_registered( + self, + name: str, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + ) -> list[torch.distributed.Work]: + """NCCL-only isend — no metadata handshake. ``name`` is a label for + symmetry with ``irecv_registered``; peer must have registered the + matching recv channel with the same name and specs. + + Caller must ``.wait()`` each returned handle before reusing the source tensors. + """ + if not torch.distributed.is_initialized() or self.world_size == 1: + return [] + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + _, tensor_list = _split_tensor_dict(tensor_dict) + + handles: list[torch.distributed.Work] = [] + for tensor in tensor_list: + if tensor.numel() == 0: + continue + group = self.cpu_group if tensor.is_cpu else self.device_group + handle = torch.distributed.isend(tensor, dst=self.ranks[dst], group=group) + if tensor.is_cuda: + tensor.record_stream(torch.cuda.current_stream(tensor.device)) + handles.append(handle) + return handles + + def irecv_registered( + self, + name: str, + src: int | None = None, + ) -> tuple[dict[str, torch.Tensor | Any], list[torch.distributed.Work], list]: + """NCCL-only irecv into the pre-allocated buffer for a registered channel. + + Returns ``(tensor_dict, handles, postproc)`` compatible with + ``AsyncLatents`` / ``AsyncIntermediateTensors``. Same buffer dict is + returned on every call — caller resolves the wrapper before the next + ``irecv_registered`` on the same channel overwrites it (``sync_pp_send`` + invariant on the recv side). + """ + channel = self._recv_channels[name] + if not torch.distributed.is_initialized() or self.world_size == 1: + return channel.tensor_dict, [], [] + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + handles: list[torch.distributed.Work] = [] + for _, tensor in channel.ordered_tensors: + if tensor.numel() == 0: + continue + group = self.cpu_group if tensor.is_cpu else self.device_group + handles.append(torch.distributed.irecv(tensor, src=self.ranks[src], group=group)) + return channel.tensor_dict, handles, [] + def send_tensor_dict( self, tensor_dict: dict[str, torch.Tensor | Any], diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index b39d8052e48..45c888ed2c4 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -176,6 +176,29 @@ def _sync_pp_send(self) -> None: handle.wait() self._pp_send_work = [] + def _pp_it_channel_specs(self, state: Any) -> list[tuple[str, torch.Size, torch.dtype, str]]: + """IT tensor specs for the ``pp_its`` registered recv channel. + + Default: one ``hidden_states`` tensor shaped like ``state.latents``. + Override for models whose IT layout differs. + """ + latents = state.latents + return [("hidden_states", latents.shape, latents.dtype, latents.device.type)] + + def register_pp_channels(self, state: Any) -> None: + """Pre-allocate recv buffers for the ``pp_its`` and ``pp_latents`` channels.""" + pp_group = get_pp_group() + if pp_group.world_size == 1: + return + if not pp_group.is_first_rank: + pp_group.register_recv_channel("pp_its", self._pp_it_channel_specs(state)) + if pp_group.is_first_rank: + latents = state.latents + pp_group.register_recv_channel( + "pp_latents", + [("latents", latents.shape, latents.dtype, latents.device.type)], + ) + def predict_noise_maybe_with_cfg( self, do_true_cfg: bool, @@ -217,18 +240,32 @@ def predict_noise_maybe_with_cfg( # Sequential CFG (or no CFG): this PP pipeline handles all branches. all_kwargs = [positive_kwargs] + ([negative_kwargs] if do_true_cfg else []) + registered_comms = getattr(self, "_registered_pp_comms", False) + # Non-first ranks receive intermediate tensors asynchronously n = len(all_kwargs) + if registered_comms and n > 1: + raise NotImplementedError( + "registered_comms currently supports a single branch (n=1). " + "Sequential CFG (n=2) requires per-branch channels." + ) + its: list[AsyncIntermediateTensors | None] = [None] * n if not pp_group.is_first_rank: for i in range(n): - its[i] = AsyncIntermediateTensors(*pp_group.irecv_tensor_dict()) + if registered_comms: + its[i] = AsyncIntermediateTensors(*pp_group.irecv_registered("pp_its")) + else: + its[i] = AsyncIntermediateTensors(*pp_group.irecv_tensor_dict()) if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. for kwargs, it in zip(all_kwargs, its): result = self.predict_noise(**kwargs, intermediate_tensors=it) - self._pp_send_work.extend(pp_group.isend_tensor_dict(result.tensors)) + if registered_comms: + self._pp_send_work.extend(pp_group.isend_registered("pp_its", result.tensors)) + else: + self._pp_send_work.extend(pp_group.isend_tensor_dict(result.tensors)) return None # Last rank: run full forward @@ -280,12 +317,18 @@ def scheduler_step_maybe_with_cfg( noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator ) + registered_comms = getattr(self, "_registered_pp_comms", False) + pp_group = get_pp_group() if pp_group.is_last_rank: - latents = super().scheduler_step_maybe_with_cfg( - noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator - ) - self._pp_send_work = pp_group.isend_tensor_dict({"latents": latents}, dst=0) + latents = super().scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg, per_request_scheduler) + if registered_comms: + self._pp_send_work = pp_group.isend_registered("pp_latents", {"latents": latents}, dst=0) + else: + self._pp_send_work = pp_group.isend_tensor_dict({"latents": latents}, dst=0) elif pp_group.is_first_rank: - latents = AsyncLatents(*pp_group.irecv_tensor_dict(src=pp_group.world_size - 1)) + if registered_comms: + latents = AsyncLatents(*pp_group.irecv_registered("pp_latents", src=pp_group.world_size - 1)) + else: + latents = AsyncLatents(*pp_group.irecv_tensor_dict(src=pp_group.world_size - 1)) return latents diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index c94cd657ae8..f8c81f77640 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -108,14 +108,10 @@ def make_empty(cls) -> CachedRequestData: @dataclass class RankTask: - """One ``(request, chunk, step_index)`` triple assigned to a single PP rank for one micro-step. - - Used by ``StreamBatchScheduler`` to tell each rank which work to perform in the current micro-step. - """ + """Used by ``StreamBatchScheduler`` to tell each rank which work to perform in the current micro-step.""" sched_req_id: str chunk_idx: int - step_index: int @dataclass diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index 3893a095178..ea2d67de31a 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -34,7 +34,6 @@ class _InFlightChunk: """One chunk of an active request, tracked through the temporal pipeline.""" chunk_idx: int - step_index: int = 0 in_pipeline: bool = True # currently flowing through ranks (True between admission and exit) entered_rank0_at: int = -1 # micro-step at which the chunk last entered rank 0 @@ -151,7 +150,7 @@ def _advance_chunk_pipeline(self) -> None: # 1. Try to re-admit a returning chunk (FIFO oldest-first across requests). for progress in self._chunk_progress.values(): for chunk in progress.in_flight: - if not chunk.in_pipeline and chunk.step_index < progress.num_steps: + if not chunk.in_pipeline: chunk.in_pipeline = True chunk.entered_rank0_at = self._global_micro_step return # rank 0 is now taken @@ -161,7 +160,6 @@ def _advance_chunk_pipeline(self) -> None: if progress.chunks_admitted < progress.num_chunks: new_chunk = _InFlightChunk( chunk_idx=progress.chunks_admitted, - step_index=0, in_pipeline=True, entered_rank0_at=self._global_micro_step, ) @@ -185,7 +183,6 @@ def _build_assignment(self) -> list[RankTask | None]: assignment[r] = RankTask( sched_req_id=progress.sched_req_id, chunk_idx=chunk.chunk_idx, - step_index=chunk.step_index, ) return assignment @@ -195,9 +192,6 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run if not self._chunk_progress or sched_output.per_rank_assignment is None: return set() - # Read per-chunk fields from RunnerOutput (set by execute_micro_step) - # to advance chunk state — same pattern as StepScheduler reading - # step_index / finished. terminal: dict[str, DiffusionRequestStatus] = {} if output.chunk_idx is not None: @@ -205,7 +199,6 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run if progress is not None: chunk = self._find_chunk(progress, output.chunk_idx) if chunk is not None: - chunk.step_index = output.step_index chunk.in_pipeline = False if output.chunk_completed: progress.in_flight = [ diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 6a717c42d5a..94e983e6785 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -497,6 +497,8 @@ def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: Concatenates video tensors ``[B, C, T, H, W]`` along the temporal dimension (dim 2). + + NOTE: This is a temporary solution until streaming output is supported. """ if len(chunks) == 1: return chunks[0] @@ -517,14 +519,7 @@ def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ return chunk, True def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: - """Execute one temporal-PP micro-step. - - Each rank reads its own slot from ``scheduler_output.per_rank_assignment`` - and performs at most one local compute + send + recv via the existing - ``denoise_step`` / ``step_scheduler`` pipeline methods (with chunk-scoped - ``state.latents`` and ``state.step_index`` swapped in via - ``state.use_chunk``). Only the last PP rank emits ``chunk_events``. - """ + """Execute one temporal-PP micro-step.""" assert self.pipeline is not None, "Model not loaded. Call load_model() first." if not self.supports_step_mode(): raise ValueError("Current pipeline does not support step execution.") @@ -541,36 +536,33 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn with grad_context: state, is_new_request = self._update_states(scheduler_output) - # prepare_encode must run on ALL ranks when the request is first - # seen — even if this rank is idle this micro-step — so that shared - # state (prompt embeds, timesteps, scheduler) is available when the - # rank later receives its first chunk assignment. if is_new_request: if state.sampling.generator is None and state.sampling.seed is not None: gen_device = state.sampling.generator_device or ( "cpu" if self.device.type == "cpu" else self.device ) state.sampling.generator = torch.Generator(device=gen_device).manual_seed(state.sampling.seed) - with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): - self.pipeline.prepare_encode(state) - - pp_group = get_pp_group() - pp_rank = pp_group.rank_in_group - task = assignment[pp_rank] - # Idle rank (warmup / cooldown / no active request). Still drain prior sends. - if task is None: - if hasattr(self.pipeline, "sync_pp_send"): - self.pipeline.sync_pp_send() - return RunnerOutput(req_id=state.req_id) with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + if is_new_request: + self.pipeline.prepare_encode(state) + pp_size = get_pp_group().world_size + self.pipeline._registered_pp_comms = pp_size > 1 + if pp_size > 1: + self.pipeline.register_pp_channels(state) + + pp_group = get_pp_group() + pp_rank = pp_group.rank_in_group + task = assignment[pp_rank] + + if task is None: + return RunnerOutput(req_id=state.req_id) chunk, is_new_chunk = self._get_or_create_chunk(state, task.chunk_idx) if is_new_chunk: # First chunk reuses the noise sampled by prepare_encode; - # subsequent chunks draw fresh noise of the same shape from - # the request's generator (deterministic, advances state). + # subsequent chunks draw fresh noise. # Each chunk gets its own scheduler deepcopy so multi-step # ODE solver state doesn't leak between chunks. chunk.latents = ( @@ -580,13 +572,6 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn ) chunk.scheduler = copy.deepcopy(state.scheduler) - # Sanity guard: scheduler's per-rank assignment must match what this - # rank currently believes about the chunk's progress. - assert chunk.step_index == task.step_index, ( - f"Stale chunk state on rank {pp_rank}: chunk_idx={task.chunk_idx} " - f"local step_index={chunk.step_index}, scheduler said {task.step_index}" - ) - with state.use_chunk(chunk): noise_pred = self.pipeline.denoise_step(state) if noise_pred is None and getattr(self.pipeline, "interrupt", False): @@ -603,8 +588,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn chunk_idx=task.chunk_idx, ) - # Only rank N-1 runs post_decode and tracks chunk completion. - if pp_group.is_last_rank and chunk_done: + if chunk_done: output.chunk_completed = True with state.use_chunk(chunk): chunk_output = self.pipeline.post_decode(state) @@ -616,5 +600,6 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn output.finished = True output.result = self._merge_chunk_outputs(completed) self._update_states_after(state, finished=True) + self.pipeline._registered_pp_comms = False return output diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index 0c8ac1ca703..7353dab84c1 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -184,7 +184,7 @@ class RunnerOutput: finished: bool = False result: DiffusionOutput | None = None - # ── Temporal-PP micro-step fields (set by execute_micro_step) ── + # ── Temporal-PP micro-step fields ── chunk_idx: int | None = None chunk_completed: bool = False From 864ccd12fb0c9b745c5c28e17d4a6ed169aaa182 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:54:51 +0200 Subject: [PATCH 15/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/distributed/group_coordinator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index beb6b487315..f8554781d28 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -782,6 +782,8 @@ def __init__( self.device = current_omni_platform.get_torch_device(local_rank) + self._recv_channels: dict[str, _RegisteredRecvChannel] = {} + self.recv_buffer_set: bool = False self.recv_tasks_queue: list[tuple[str, int]] = [] self.receiving_tasks: list[tuple[torch.distributed.Work, str, int]] = [] From 18daf556a427c78c7c0ddd6ec414cf15f13a5979 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:04:19 +0200 Subject: [PATCH 16/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/pipeline_parallel.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 45c888ed2c4..7c0a61e903f 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -186,12 +186,15 @@ def _pp_it_channel_specs(self, state: Any) -> list[tuple[str, torch.Size, torch. return [("hidden_states", latents.shape, latents.dtype, latents.device.type)] def register_pp_channels(self, state: Any) -> None: - """Pre-allocate recv buffers for the ``pp_its`` and ``pp_latents`` channels.""" + """Pre-allocate recv buffers for the ``pp_its_{i}`` and ``pp_latents`` channels.""" pp_group = get_pp_group() if pp_group.world_size == 1: return if not pp_group.is_first_rank: - pp_group.register_recv_channel("pp_its", self._pp_it_channel_specs(state)) + cfg_parallel = get_classifier_free_guidance_world_size() > 1 + n_branches = 2 if (getattr(state, "do_true_cfg", False) and not cfg_parallel) else 1 + for i in range(n_branches): + pp_group.register_recv_channel(f"pp_its_{i}", self._pp_it_channel_specs(state)) if pp_group.is_first_rank: latents = state.latents pp_group.register_recv_channel( @@ -242,28 +245,22 @@ def predict_noise_maybe_with_cfg( registered_comms = getattr(self, "_registered_pp_comms", False) - # Non-first ranks receive intermediate tensors asynchronously + # Non-first ranks receive intermediate tensors asynchronously. n = len(all_kwargs) - if registered_comms and n > 1: - raise NotImplementedError( - "registered_comms currently supports a single branch (n=1). " - "Sequential CFG (n=2) requires per-branch channels." - ) - its: list[AsyncIntermediateTensors | None] = [None] * n if not pp_group.is_first_rank: for i in range(n): if registered_comms: - its[i] = AsyncIntermediateTensors(*pp_group.irecv_registered("pp_its")) + its[i] = AsyncIntermediateTensors(*pp_group.irecv_registered(f"pp_its_{i}")) else: its[i] = AsyncIntermediateTensors(*pp_group.irecv_tensor_dict()) if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. - for kwargs, it in zip(all_kwargs, its): + for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) if registered_comms: - self._pp_send_work.extend(pp_group.isend_registered("pp_its", result.tensors)) + self._pp_send_work.extend(pp_group.isend_registered(f"pp_its_{i}", result.tensors)) else: self._pp_send_work.extend(pp_group.isend_tensor_dict(result.tensors)) return None From ea12e8bf4813142124452f9345dcdb30dbf53f3e Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:03:14 +0200 Subject: [PATCH 17/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../diffusion/sched/stream_batch_scheduler.py | 2 - .../worker/diffusion_model_runner.py | 41 ++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index ea2d67de31a..e3557adb26e 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -45,7 +45,6 @@ class _ChunkProgress: num_chunks: int # total chunks to produce for this request num_steps: int # denoising steps per chunk chunks_admitted: int = 0 - chunks_completed: int = 0 in_flight: list[_InFlightChunk] = field(default_factory=list) @@ -204,7 +203,6 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run progress.in_flight = [ c for c in progress.in_flight if c.chunk_idx != output.chunk_idx ] - progress.chunks_completed += 1 if output.finished: terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 94e983e6785..a971cdb532e 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -556,6 +556,32 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn pp_rank = pp_group.rank_in_group task = assignment[pp_rank] + if pp_group.is_first_rank: + denoised_chunks = state.extra.pop("denoised_chunks", []) + decoded_chunks = state.extra.get("decoded_chunks", []) + + for chunk in denoised_chunks: + with state.use_chunk(chunk): + decoded_chunks.append(self.pipeline.post_decode(state)) + + if len(decoded_chunks) >= state.sampling.num_chunks: + assert len(denoised_chunks) == state.sampling.num_chunks, ( + f"Expected {state.sampling.num_chunks} denoised chunks but got {len(denoised_chunks)}" + ) + + output = RunnerOutput( + req_id=state.req_id, + step_index=state.step_index, + finished=True, + result=self._merge_chunk_outputs(decoded_chunks), + ) + + self._update_states_after(state, finished=True) + self.pipeline._registered_pp_comms = False + + return output + + if task is None: return RunnerOutput(req_id=state.req_id) @@ -588,18 +614,11 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn chunk_idx=task.chunk_idx, ) - if chunk_done: + if chunk_done and pp_group.is_first_rank: + state.extra.setdefault("denoised_chunks", []).append(chunk) + output.chunk_completed = True - with state.use_chunk(chunk): - chunk_output = self.pipeline.post_decode(state) state.extra["chunks"].pop(task.chunk_idx, None) - completed = state.extra.setdefault("completed_chunk_outputs", []) - completed.append(chunk_output) - if len(completed) >= state.sampling.num_chunks: - output.finished = True - output.result = self._merge_chunk_outputs(completed) - self._update_states_after(state, finished=True) - self.pipeline._registered_pp_comms = False - + return output From 78d80e96ce528c7d3472a0922b273bec6b04b644 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:23:52 +0200 Subject: [PATCH 18/53] Implement fully async isend/recv_dicts and migrate PPMixin to use them Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 276 +++++++++++------- .../distributed/pipeline_parallel.py | 56 +--- 2 files changed, 184 insertions(+), 148 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index f8554781d28..ed447b1683f 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -5,7 +5,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import pickle from collections import namedtuple -from dataclasses import dataclass from typing import Any import torch @@ -24,14 +23,6 @@ env_info = envs.PACKAGES_CHECKER.get_packages_info() -@dataclass -class _RegisteredRecvChannel: - """Pre-allocated recv buffer for a registered P2P channel.""" - - tensor_dict: dict[str, Any] - ordered_tensors: list[tuple[str, torch.Tensor]] - - def _split_tensor_dict( tensor_dict: dict[str, torch.Tensor | Any], prefix: str = "" ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: @@ -132,9 +123,6 @@ def __init__( self.device = current_omni_platform.get_torch_device(local_rank) - # Registered recv buffers for zero-metadata P2P channels. - self._recv_channels: dict[str, _RegisteredRecvChannel] = {} - @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -513,87 +501,6 @@ def irecv_tensor_dict( return tensor_dict, handles, [] - def register_recv_channel( - self, - name: str, - tensor_specs: list[tuple[str, torch.Size, torch.dtype, str]], - ) -> None: - """Pre-allocate recv buffers for a channel. Local only, no comm. - - Each spec is ``(flattened_key, shape, dtype, device_type)``; - ``device_type`` is ``"cpu"`` or any type accepted by ``torch.device`` - (the coordinator's device is used for non-cpu types). - """ - tensor_dict: dict[str, Any] = {} - ordered: list[tuple[str, torch.Tensor]] = [] - for key, shape, dtype, device_type in tensor_specs: - device = torch.device("cpu") if device_type == "cpu" else self.device - tensor = torch.empty(shape, dtype=dtype, device=device) - _update_nested_dict(tensor_dict, key, tensor) - ordered.append((key, tensor)) - self._recv_channels[name] = _RegisteredRecvChannel(tensor_dict=tensor_dict, ordered_tensors=ordered) - - def isend_registered( - self, - name: str, - tensor_dict: dict[str, torch.Tensor | Any], - dst: int | None = None, - ) -> list[torch.distributed.Work]: - """NCCL-only isend — no metadata handshake. ``name`` is a label for - symmetry with ``irecv_registered``; peer must have registered the - matching recv channel with the same name and specs. - - Caller must ``.wait()`` each returned handle before reusing the source tensors. - """ - if not torch.distributed.is_initialized() or self.world_size == 1: - return [] - - if dst is None: - dst = self.group_next_rank - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - _, tensor_list = _split_tensor_dict(tensor_dict) - - handles: list[torch.distributed.Work] = [] - for tensor in tensor_list: - if tensor.numel() == 0: - continue - group = self.cpu_group if tensor.is_cpu else self.device_group - handle = torch.distributed.isend(tensor, dst=self.ranks[dst], group=group) - if tensor.is_cuda: - tensor.record_stream(torch.cuda.current_stream(tensor.device)) - handles.append(handle) - return handles - - def irecv_registered( - self, - name: str, - src: int | None = None, - ) -> tuple[dict[str, torch.Tensor | Any], list[torch.distributed.Work], list]: - """NCCL-only irecv into the pre-allocated buffer for a registered channel. - - Returns ``(tensor_dict, handles, postproc)`` compatible with - ``AsyncLatents`` / ``AsyncIntermediateTensors``. Same buffer dict is - returned on every call — caller resolves the wrapper before the next - ``irecv_registered`` on the same channel overwrites it (``sync_pp_send`` - invariant on the recv side). - """ - channel = self._recv_channels[name] - if not torch.distributed.is_initialized() or self.world_size == 1: - return channel.tensor_dict, [], [] - - if src is None: - src = self.group_prev_rank - assert src < self.world_size, f"Invalid src rank ({src})" - - handles: list[torch.distributed.Work] = [] - for _, tensor in channel.ordered_tensors: - if tensor.numel() == 0: - continue - group = self.cpu_group if tensor.is_cpu else self.device_group - handles.append(torch.distributed.irecv(tensor, src=self.ranks[src], group=group)) - return channel.tensor_dict, handles, [] - def send_tensor_dict( self, tensor_dict: dict[str, torch.Tensor | Any], @@ -782,8 +689,6 @@ def __init__( self.device = current_omni_platform.get_torch_device(local_rank) - self._recv_channels: dict[str, _RegisteredRecvChannel] = {} - self.recv_buffer_set: bool = False self.recv_tasks_queue: list[tuple[str, int]] = [] self.receiving_tasks: list[tuple[torch.distributed.Work, str, int]] = [] @@ -794,6 +699,17 @@ def __init__( self.send_shape: dict[str, dict[int, torch.Size]] = {} self.recv_buffer: dict[str, dict[int, torch.Size]] = {} + # Cached dict schema and pre-allocated recv buffers for + # `pipeline_isend_tensor_dict` / `pipeline_irecv_tensor_dict`. + # The pickled schema is exchanged once per (name, segment_idx) over + # NCCL; subsequent calls only post async tensor sends/recvs. + self.dict_schema_cache: dict[str, dict[int, list[tuple[str, Any]]]] = {} + self.dict_recv_buffer: dict[str, dict[int, dict[str, torch.Tensor]]] = {} + self.recv_dict_tasks_queue: list[tuple[str, int]] = [] + self.receiving_dict_tasks: list[ + tuple[dict[str, Any], list[torch.distributed.Work], str, int] + ] = [] + self.skip_tensor_recv_buffer_set: bool = False self.recv_skip_tasks_queue: list[int | tuple[str, int]] = [] self.receiving_skip_tasks: list[tuple[torch.distributed.Work, str, int]] = [] @@ -812,6 +728,11 @@ def reset_buffer(self): self.send_shape = {} self.recv_buffer = {} + self.dict_schema_cache = {} + self.dict_recv_buffer = {} + self.recv_dict_tasks_queue = [] + self.receiving_dict_tasks = [] + self.recv_skip_tasks_queue = [] self.receiving_skip_tasks = [] self.skip_tensor_recv_buffer = {} @@ -903,6 +824,12 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): recv_prev: boolean for whether tensor should be received from previous rank. """ + send_group = ( + self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group + ) + recv_group = ( + self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group + ) ops = [] if recv_prev: @@ -911,7 +838,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.irecv, recv_prev_dim_tensor, self.prev_rank, - self.device_group, + recv_group, ) ops.append(recv_prev_dim_op) @@ -921,7 +848,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.isend, send_next_dim_tensor, self.next_rank, - self.device_group, + send_group, ) ops.append(send_next_dim_op) @@ -944,7 +871,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.irecv, recv_prev_shape_tensor, self.prev_rank, - self.device_group, + recv_group, ) ops.append(recv_prev_shape_op) @@ -954,7 +881,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.isend, send_next_shape_tensor, self.next_rank, - self.device_group, + send_group, ) ops.append(send_next_shape_op) @@ -970,15 +897,86 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): recv_prev_shape = recv_prev_shape_tensor return torch.Size(recv_prev_shape) + def _communicate_dict_schema( + self, send_metadata: list[tuple[str, Any]] | None = None, recv: bool = False + ) -> list[tuple[str, Any]] | None: + send_group = ( + self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group + ) + recv_group = ( + self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group + ) + + # Phase 1: exchange payload sizes. + payload_tensor: torch.Tensor | None = None + recv_size_tensor: torch.Tensor | None = None + ops: list[torch.distributed.P2POp] = [] + if recv: + recv_size_tensor = torch.empty(1, device=self.device, dtype=torch.int64) + ops.append( + torch.distributed.P2POp( + torch.distributed.irecv, recv_size_tensor, self.prev_rank, recv_group + ) + ) + if send_metadata is not None: + payload_bytes = pickle.dumps(send_metadata) + payload_array = bytearray(payload_bytes) + payload_tensor = torch.frombuffer(payload_array, dtype=torch.uint8).to(self.device) + send_size_tensor = torch.tensor( + [payload_tensor.numel()], device=self.device, dtype=torch.int64 + ) + ops.append( + torch.distributed.P2POp( + torch.distributed.isend, send_size_tensor, self.next_rank, send_group + ) + ) + if ops: + for req in torch.distributed.batch_isend_irecv(ops): + req.wait() + current_omni_platform.synchronize() + + # Phase 2: exchange pickled bytes. + recv_payload: torch.Tensor | None = None + ops = [] + if recv: + assert recv_size_tensor is not None + recv_payload = torch.empty(int(recv_size_tensor.item()), device=self.device, dtype=torch.uint8) + ops.append( + torch.distributed.P2POp( + torch.distributed.irecv, recv_payload, self.prev_rank, recv_group + ) + ) + if payload_tensor is not None: + ops.append( + torch.distributed.P2POp( + torch.distributed.isend, payload_tensor, self.next_rank, send_group + ) + ) + if ops: + for req in torch.distributed.batch_isend_irecv(ops): + req.wait() + current_omni_platform.synchronize() + + if recv: + assert recv_payload is not None + return pickle.loads(recv_payload.cpu().numpy().tobytes()) + return None + def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: tensor = tensor.contiguous() self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) self._pipeline_isend(tensor).wait() - def pipeline_isend(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: + def pipeline_isend( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> torch.distributed.Work: tensor = tensor.contiguous() self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) - self._pipeline_isend(tensor) + handle = self._pipeline_isend(tensor) + if tensor.is_cuda: + # Keep allocator from reusing this CUDA buffer before the async send finishes. + tensor.record_stream(torch.cuda.current_stream(tensor.device)) + return handle def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: name = name or "latent" @@ -986,6 +984,82 @@ def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: self._pipeline_irecv(self.recv_buffer[name][idx]).wait() return self.recv_buffer[name][idx] + def pipeline_isend_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + name: str = "dict", + segment_idx: int = -1, + ) -> list[torch.distributed.Work]: + """Async tensor-dict send. Schema (keys, scalars, per-tensor + dtype/shape) is exchanged once per (name, segment_idx) over NCCL; + steady state posts only per-tensor isend ops. Schema must be + stable across calls; mismatching shapes will fail at NCCL recv. + """ + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + + cache = self.dict_schema_cache.setdefault(name, {}) + if segment_idx not in cache: + self._communicate_dict_schema(send_metadata=metadata_list) + cache[segment_idx] = metadata_list + + handles: list[torch.distributed.Work] = [] + for tensor in tensor_list: + if tensor.numel() == 0: + continue + tensor = tensor.contiguous() + handle = self._pipeline_isend(tensor) + if tensor.is_cuda: + tensor.record_stream(torch.cuda.current_stream(tensor.device)) + handles.append(handle) + return handles + + def pipeline_irecv_tensor_dict( + self, + name: str = "dict", + segment_idx: int = -1, + ) -> tuple[dict[str, torch.Tensor | Any], list[torch.distributed.Work], list]: + """Async tensor-dict recv. Returns ``(tensor_dict, handles, [])`` + matching ``GroupCoordinator.irecv_tensor_dict``. First call + discovers schema over NCCL and pre-allocates persistent buffers; + the returned dict aliases those buffers, so the caller must + consume before the next recv on the same (name, segment_idx). + """ + cache = self.dict_schema_cache.setdefault(name, {}) + if segment_idx not in cache: + metadata_list = self._communicate_dict_schema(recv=True) + assert metadata_list is not None + cache[segment_idx] = metadata_list + buffers: dict[str, torch.Tensor] = {} + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + continue + device = self.device if value.device == "cuda" else torch.device(value.device) + buffers[key] = torch.empty(value.size, dtype=value.dtype, device=device) + self.dict_recv_buffer.setdefault(name, {})[segment_idx] = buffers + + metadata_list = cache[segment_idx] + buffers = self.dict_recv_buffer[name][segment_idx] + + tensor_dict: dict[str, Any] = {} + handles: list[torch.distributed.Work] = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + _update_nested_dict( + tensor_dict, + key, + torch.empty(value.size, dtype=value.dtype, device=self.device), + ) + continue + tensor = buffers[key] + handles.append(self._pipeline_irecv(tensor)) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + + return tensor_dict, handles, [] + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): name = name or "latent" self.recv_tasks_queue.append((name, idx)) diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 7c0a61e903f..55cbebd291d 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -176,32 +176,6 @@ def _sync_pp_send(self) -> None: handle.wait() self._pp_send_work = [] - def _pp_it_channel_specs(self, state: Any) -> list[tuple[str, torch.Size, torch.dtype, str]]: - """IT tensor specs for the ``pp_its`` registered recv channel. - - Default: one ``hidden_states`` tensor shaped like ``state.latents``. - Override for models whose IT layout differs. - """ - latents = state.latents - return [("hidden_states", latents.shape, latents.dtype, latents.device.type)] - - def register_pp_channels(self, state: Any) -> None: - """Pre-allocate recv buffers for the ``pp_its_{i}`` and ``pp_latents`` channels.""" - pp_group = get_pp_group() - if pp_group.world_size == 1: - return - if not pp_group.is_first_rank: - cfg_parallel = get_classifier_free_guidance_world_size() > 1 - n_branches = 2 if (getattr(state, "do_true_cfg", False) and not cfg_parallel) else 1 - for i in range(n_branches): - pp_group.register_recv_channel(f"pp_its_{i}", self._pp_it_channel_specs(state)) - if pp_group.is_first_rank: - latents = state.latents - pp_group.register_recv_channel( - "pp_latents", - [("latents", latents.shape, latents.dtype, latents.device.type)], - ) - def predict_noise_maybe_with_cfg( self, do_true_cfg: bool, @@ -243,26 +217,22 @@ def predict_noise_maybe_with_cfg( # Sequential CFG (or no CFG): this PP pipeline handles all branches. all_kwargs = [positive_kwargs] + ([negative_kwargs] if do_true_cfg else []) - registered_comms = getattr(self, "_registered_pp_comms", False) - - # Non-first ranks receive intermediate tensors asynchronously. + # Non-first ranks receive intermediate tensors asynchronously n = len(all_kwargs) its: list[AsyncIntermediateTensors | None] = [None] * n if not pp_group.is_first_rank: for i in range(n): - if registered_comms: - its[i] = AsyncIntermediateTensors(*pp_group.irecv_registered(f"pp_its_{i}")) - else: - its[i] = AsyncIntermediateTensors(*pp_group.irecv_tensor_dict()) + its[i] = AsyncIntermediateTensors( + *pp_group.pipeline_irecv_tensor_dict(name="intermediate", segment_idx=i) + ) if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) - if registered_comms: - self._pp_send_work.extend(pp_group.isend_registered(f"pp_its_{i}", result.tensors)) - else: - self._pp_send_work.extend(pp_group.isend_tensor_dict(result.tensors)) + self._pp_send_work.extend( + pp_group.pipeline_isend_tensor_dict(result.tensors, name="intermediate", segment_idx=i) + ) return None # Last rank: run full forward @@ -314,18 +284,10 @@ def scheduler_step_maybe_with_cfg( noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator ) - registered_comms = getattr(self, "_registered_pp_comms", False) - pp_group = get_pp_group() if pp_group.is_last_rank: latents = super().scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg, per_request_scheduler) - if registered_comms: - self._pp_send_work = pp_group.isend_registered("pp_latents", {"latents": latents}, dst=0) - else: - self._pp_send_work = pp_group.isend_tensor_dict({"latents": latents}, dst=0) + self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name="latents") elif pp_group.is_first_rank: - if registered_comms: - latents = AsyncLatents(*pp_group.irecv_registered("pp_latents", src=pp_group.world_size - 1)) - else: - latents = AsyncLatents(*pp_group.irecv_tensor_dict(src=pp_group.world_size - 1)) + latents = AsyncLatents(*pp_group.pipeline_irecv_tensor_dict(name="latents")) return latents From 986571c17eb23d241779aaaea97d4e27c245d78f Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:24:08 +0200 Subject: [PATCH 19/53] Rmv stale code Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index a971cdb532e..5a53105fc1f 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -547,10 +547,6 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): if is_new_request: self.pipeline.prepare_encode(state) - pp_size = get_pp_group().world_size - self.pipeline._registered_pp_comms = pp_size > 1 - if pp_size > 1: - self.pipeline.register_pp_channels(state) pp_group = get_pp_group() pp_rank = pp_group.rank_in_group @@ -565,8 +561,8 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn decoded_chunks.append(self.pipeline.post_decode(state)) if len(decoded_chunks) >= state.sampling.num_chunks: - assert len(denoised_chunks) == state.sampling.num_chunks, ( - f"Expected {state.sampling.num_chunks} denoised chunks but got {len(denoised_chunks)}" + assert len(decoded_chunks) == state.sampling.num_chunks, ( + f"Expected {state.sampling.num_chunks} denoised chunks but got {len(decoded_chunks)}" ) output = RunnerOutput( @@ -577,7 +573,6 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn ) self._update_states_after(state, finished=True) - self.pipeline._registered_pp_comms = False return output From f502e9623da9791ebfd51a8dbf1d76c5ad0b2a43 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:29:16 +0200 Subject: [PATCH 20/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 5a53105fc1f..72c078d14be 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -554,7 +554,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn if pp_group.is_first_rank: denoised_chunks = state.extra.pop("denoised_chunks", []) - decoded_chunks = state.extra.get("decoded_chunks", []) + decoded_chunks = state.extra.setdefault("decoded_chunks", []) for chunk in denoised_chunks: with state.use_chunk(chunk): From d111ca043cf7adadac5eae2c81a1e02d07079aaa Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 11:18:59 +0200 Subject: [PATCH 21/53] Make sending dict schema unblocking Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 95 +++++++------------ 1 file changed, 33 insertions(+), 62 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index ed447b1683f..a575e8bb857 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -705,6 +705,8 @@ def __init__( # NCCL; subsequent calls only post async tensor sends/recvs. self.dict_schema_cache: dict[str, dict[int, list[tuple[str, Any]]]] = {} self.dict_recv_buffer: dict[str, dict[int, dict[str, torch.Tensor]]] = {} + + self.dict_schema_keepalive: list[torch.Tensor] = [] self.recv_dict_tasks_queue: list[tuple[str, int]] = [] self.receiving_dict_tasks: list[ tuple[dict[str, Any], list[torch.distributed.Work], str, int] @@ -730,6 +732,7 @@ def reset_buffer(self): self.dict_schema_cache = {} self.dict_recv_buffer = {} + self.dict_schema_keepalive = [] self.recv_dict_tasks_queue = [] self.receiving_dict_tasks = [] @@ -897,70 +900,37 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): recv_prev_shape = recv_prev_shape_tensor return torch.Size(recv_prev_shape) - def _communicate_dict_schema( - self, send_metadata: list[tuple[str, Any]] | None = None, recv: bool = False - ) -> list[tuple[str, Any]] | None: + def _isend_dict_schema( + self, send_metadata: list[tuple[str, Any]] + ) -> tuple[list[torch.distributed.Work], list[torch.Tensor]]: + """Non-blocking schema send. Returns (handles, keepalive_tensors). + Caller must keep the tensors alive until the handles complete. + """ send_group = ( self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group ) + payload_bytes = pickle.dumps(send_metadata) + payload_array = bytearray(payload_bytes) + payload_tensor = torch.frombuffer(payload_array, dtype=torch.uint8).to(self.device) + send_size_tensor = torch.tensor( + [payload_tensor.numel()], device=self.device, dtype=torch.int64 + ) + h_size = torch.distributed.isend(send_size_tensor, dst=self.next_rank, group=send_group) + h_bytes = torch.distributed.isend(payload_tensor, dst=self.next_rank, group=send_group) + return [h_size, h_bytes], [send_size_tensor, payload_tensor] + + def _recv_dict_schema(self) -> list[tuple[str, Any]]: + """Blocking schema recv - must wait because the size value is + needed before allocating the payload buffer. + """ recv_group = ( self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group ) - - # Phase 1: exchange payload sizes. - payload_tensor: torch.Tensor | None = None - recv_size_tensor: torch.Tensor | None = None - ops: list[torch.distributed.P2POp] = [] - if recv: - recv_size_tensor = torch.empty(1, device=self.device, dtype=torch.int64) - ops.append( - torch.distributed.P2POp( - torch.distributed.irecv, recv_size_tensor, self.prev_rank, recv_group - ) - ) - if send_metadata is not None: - payload_bytes = pickle.dumps(send_metadata) - payload_array = bytearray(payload_bytes) - payload_tensor = torch.frombuffer(payload_array, dtype=torch.uint8).to(self.device) - send_size_tensor = torch.tensor( - [payload_tensor.numel()], device=self.device, dtype=torch.int64 - ) - ops.append( - torch.distributed.P2POp( - torch.distributed.isend, send_size_tensor, self.next_rank, send_group - ) - ) - if ops: - for req in torch.distributed.batch_isend_irecv(ops): - req.wait() - current_omni_platform.synchronize() - - # Phase 2: exchange pickled bytes. - recv_payload: torch.Tensor | None = None - ops = [] - if recv: - assert recv_size_tensor is not None - recv_payload = torch.empty(int(recv_size_tensor.item()), device=self.device, dtype=torch.uint8) - ops.append( - torch.distributed.P2POp( - torch.distributed.irecv, recv_payload, self.prev_rank, recv_group - ) - ) - if payload_tensor is not None: - ops.append( - torch.distributed.P2POp( - torch.distributed.isend, payload_tensor, self.next_rank, send_group - ) - ) - if ops: - for req in torch.distributed.batch_isend_irecv(ops): - req.wait() - current_omni_platform.synchronize() - - if recv: - assert recv_payload is not None - return pickle.loads(recv_payload.cpu().numpy().tobytes()) - return None + recv_size_tensor = torch.empty(1, device=self.device, dtype=torch.int64) + torch.distributed.recv(recv_size_tensor, src=self.prev_rank, group=recv_group) + recv_payload = torch.empty(int(recv_size_tensor.item()), device=self.device, dtype=torch.uint8) + torch.distributed.recv(recv_payload, src=self.prev_rank, group=recv_group) + return pickle.loads(recv_payload.cpu().numpy().tobytes()) def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: tensor = tensor.contiguous() @@ -998,11 +968,13 @@ def pipeline_isend_tensor_dict( metadata_list, tensor_list = _split_tensor_dict(tensor_dict) cache = self.dict_schema_cache.setdefault(name, {}) + handles: list[torch.distributed.Work] = [] if segment_idx not in cache: - self._communicate_dict_schema(send_metadata=metadata_list) + schema_handles, keepalive = self._isend_dict_schema(metadata_list) + handles.extend(schema_handles) + self.dict_schema_keepalive.extend(keepalive) cache[segment_idx] = metadata_list - handles: list[torch.distributed.Work] = [] for tensor in tensor_list: if tensor.numel() == 0: continue @@ -1026,8 +998,7 @@ def pipeline_irecv_tensor_dict( """ cache = self.dict_schema_cache.setdefault(name, {}) if segment_idx not in cache: - metadata_list = self._communicate_dict_schema(recv=True) - assert metadata_list is not None + metadata_list = self._recv_dict_schema() cache[segment_idx] = metadata_list buffers: dict[str, torch.Tensor] = {} for key, value in metadata_list: From 26e7b8e1cf675e99cfc28fcf8fcfc86e647594fb Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 11:51:57 +0200 Subject: [PATCH 22/53] Add warmup for nccl comms on init Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index a575e8bb857..8867582eed2 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -723,6 +723,45 @@ def __init__( self.skip_device_group = skip_device_group assert self.skip_device_group is not None + self._warmup_nccl_comms() + + def _warmup_nccl_comms(self) -> None: + """Force eager ncclCommInit on every P2P group while all ranks are + synchronized at __init__. Otherwise the first real P2P op would + trigger a collective comm-init that blocks the early-arriving + rank — breaks temporal-PP where one rank deliberately runs ahead. + """ + if self.world_size == 1: + return + + dummy = torch.zeros(1, device=self.device, dtype=torch.uint8) + + if self.world_size == 2: + for group_idx in (0, 1): + group = self.device_groups[group_idx] + if self.rank_in_group == group_idx: + op = torch.distributed.P2POp(torch.distributed.isend, dummy, self.next_rank, group) + else: + op = torch.distributed.P2POp(torch.distributed.irecv, dummy, self.prev_rank, group) + for req in torch.distributed.batch_isend_irecv([op]): + req.wait() + else: + for req in torch.distributed.batch_isend_irecv( + [ + torch.distributed.P2POp(torch.distributed.isend, dummy, self.next_rank, self.device_group), + torch.distributed.P2POp(torch.distributed.irecv, dummy, self.prev_rank, self.device_group), + ] + ): + req.wait() + + for req in torch.distributed.batch_isend_irecv( + [ + torch.distributed.P2POp(torch.distributed.isend, dummy, self.skip_rank, self.skip_device_group), + torch.distributed.P2POp(torch.distributed.irecv, dummy, self.skip_rank, self.skip_device_group), + ] + ): + req.wait() + def reset_buffer(self): self.recv_tasks_queue = [] self.receiving_tasks = [] From 34a80022703f49e5905be800037c933961b7b5ff Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 12:10:19 +0200 Subject: [PATCH 23/53] Use batch_isend_irecv instead of plain isend/recv Plain P2P on size-2 PG triggers lazy sub-comm creation that requires the peer present. Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 8867582eed2..852f394aae3 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -954,9 +954,14 @@ def _isend_dict_schema( send_size_tensor = torch.tensor( [payload_tensor.numel()], device=self.device, dtype=torch.int64 ) - h_size = torch.distributed.isend(send_size_tensor, dst=self.next_rank, group=send_group) - h_bytes = torch.distributed.isend(payload_tensor, dst=self.next_rank, group=send_group) - return [h_size, h_bytes], [send_size_tensor, payload_tensor] + # batch_isend_irecv (not plain isend) — plain P2P on size-2 PG + # triggers lazy sub-comm creation that requires the peer present. + ops = [ + torch.distributed.P2POp(torch.distributed.isend, send_size_tensor, self.next_rank, send_group), + torch.distributed.P2POp(torch.distributed.isend, payload_tensor, self.next_rank, send_group), + ] + handles = list(torch.distributed.batch_isend_irecv(ops)) + return handles, [send_size_tensor, payload_tensor] def _recv_dict_schema(self) -> list[tuple[str, Any]]: """Blocking schema recv - must wait because the size value is @@ -966,9 +971,15 @@ def _recv_dict_schema(self) -> list[tuple[str, Any]]: self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group ) recv_size_tensor = torch.empty(1, device=self.device, dtype=torch.int64) - torch.distributed.recv(recv_size_tensor, src=self.prev_rank, group=recv_group) + for req in torch.distributed.batch_isend_irecv( + [torch.distributed.P2POp(torch.distributed.irecv, recv_size_tensor, self.prev_rank, recv_group)] + ): + req.wait() recv_payload = torch.empty(int(recv_size_tensor.item()), device=self.device, dtype=torch.uint8) - torch.distributed.recv(recv_payload, src=self.prev_rank, group=recv_group) + for req in torch.distributed.batch_isend_irecv( + [torch.distributed.P2POp(torch.distributed.irecv, recv_payload, self.prev_rank, recv_group)] + ): + req.wait() return pickle.loads(recv_payload.cpu().numpy().tobytes()) def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: @@ -1090,18 +1101,16 @@ def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.T return self.recv_buffer[name][idx] def _pipeline_irecv(self, tensor: torch.tensor): - return torch.distributed.irecv( - tensor, - src=self.prev_rank, - group=(self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), - ) + # batch_isend_irecv (not plain irecv) — plain P2P on size-2 PG + # triggers lazy sub-comm creation that requires the peer present. + group = self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group + op = torch.distributed.P2POp(torch.distributed.irecv, tensor, self.prev_rank, group) + return torch.distributed.batch_isend_irecv([op])[0] def _pipeline_isend(self, tensor: torch.tensor): - return torch.distributed.isend( - tensor, - dst=self.next_rank, - group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), - ) + group = self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group + op = torch.distributed.P2POp(torch.distributed.isend, tensor, self.next_rank, group) + return torch.distributed.batch_isend_irecv([op])[0] def set_skip_tensor_recv_buffer( self, From 49f580769137033156cc45f7e0720565cdd08055 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 14:23:44 +0200 Subject: [PATCH 24/53] Handle tensor receive as a task Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 32 +++++++-- .../distributed/pipeline_parallel.py | 70 +++++++++---------- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 852f394aae3..531c7fd0e6e 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -618,6 +618,27 @@ def destroy(self): self.cpu_group = None +class PipelineRecvDictHandle: + + __slots__ = ("_pp_group", "_name", "_segment_idx", "_resolved") + + def __init__(self, pp_group: "PipelineGroupCoordinator", name: str, segment_idx: int): + self._pp_group = pp_group + self._name = name + self._segment_idx = segment_idx + self._resolved: dict[str, Any] | None = None + + def resolve(self) -> dict[str, Any]: + if self._resolved is None: + tensor_dict, handles, _ = self._pp_group.pipeline_irecv_tensor_dict( + name=self._name, segment_idx=self._segment_idx + ) + for h in handles: + h.wait() + self._resolved = tensor_dict + return self._resolved + + class PipelineGroupCoordinator(GroupCoordinator): """ available attributes: @@ -707,10 +728,6 @@ def __init__( self.dict_recv_buffer: dict[str, dict[int, dict[str, torch.Tensor]]] = {} self.dict_schema_keepalive: list[torch.Tensor] = [] - self.recv_dict_tasks_queue: list[tuple[str, int]] = [] - self.receiving_dict_tasks: list[ - tuple[dict[str, Any], list[torch.distributed.Work], str, int] - ] = [] self.skip_tensor_recv_buffer_set: bool = False self.recv_skip_tasks_queue: list[int | tuple[str, int]] = [] @@ -772,8 +789,6 @@ def reset_buffer(self): self.dict_schema_cache = {} self.dict_recv_buffer = {} self.dict_schema_keepalive = [] - self.recv_dict_tasks_queue = [] - self.receiving_dict_tasks = [] self.recv_skip_tasks_queue = [] self.receiving_skip_tasks = [] @@ -1100,6 +1115,11 @@ def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.T assert receiving_task[1] == name and receiving_task[2] == idx, "Received tensor does not match the requested" return self.recv_buffer[name][idx] + def add_pipeline_recv_dict_task( + self, name: str = "dict", segment_idx: int = -1 + ) -> PipelineRecvDictHandle: + return PipelineRecvDictHandle(self, name, segment_idx) + def _pipeline_irecv(self, tensor: torch.tensor): # batch_isend_irecv (not plain irecv) — plain P2P on size-2 PG # triggers lazy sub-comm creation that requires the peer present. diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 55cbebd291d..80875a5ca32 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -5,8 +5,9 @@ from typing import Any import torch -from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors +from vllm.sequence import IntermediateTensors +from vllm_omni.diffusion.distributed.group_coordinator import PipelineRecvDictHandle from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -18,44 +19,23 @@ class AsyncLatents: - """Transparent async wrapper returned by scheduler_step on rank 0. - - Wraps a pending ``irecv_tensor_dict`` and defers ``handle.wait()`` until the - underlying tensor is actually consumed — either via attribute access - (e.g. ``latents.to(dtype)``, ``latents.shape``) or via a torch operation - (e.g. ``mask * latents``). This keeps the first PP rank non-blocking after - posting the receive, matching the async philosophy used everywhere else in - the PP communication layer. - """ + """Lazy-resolve wrapper around a ``PipelineRecvDictHandle`` for latents.""" - __slots__ = ("_tensor_dict", "_handles", "_postproc", "_tensor") + __slots__ = ("_handle", "_key", "_tensor") - def __init__( - self, - tensor_dict: dict[str, torch.Tensor], - handles: list[torch.distributed.Work], - postproc: list, - ): - self._tensor_dict = tensor_dict - self._handles = handles - self._postproc = postproc + def __init__(self, handle: PipelineRecvDictHandle, key: str = "latents"): + self._handle = handle + self._key = key self._tensor: torch.Tensor | None = None def _resolve(self) -> torch.Tensor: - if self._tensor is not None: - return self._tensor - for h in self._handles: - h.wait() - for fn in self._postproc: - fn() - self._tensor = self._tensor_dict["latents"] + if self._tensor is None: + self._tensor = self._handle.resolve()[self._key] return self._tensor - # Attribute access (e.g. .shape, .to(), .dtype) delegates to the resolved tensor. def __getattr__(self, name: str): return getattr(self._resolve(), name) - # Torch function protocol: any torch op involving an AsyncLatents resolves it first. @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} @@ -64,7 +44,7 @@ def _unwrap(x): if isinstance(x, AsyncLatents): return x._resolve() if isinstance(x, (list, tuple)): - return type(x)(_unwrap(item) for item in x) # type(x) return the class of x to preserve its type + return type(x)(_unwrap(item) for item in x) return x args = tuple(_unwrap(a) for a in args) @@ -72,6 +52,27 @@ def _unwrap(x): return func(*args, **kwargs) +class AsyncIntermediateTensors: + """Lazy-resolve wrapper around a ``PipelineRecvDictHandle`` for an IT.""" + + __slots__ = ("_handle", "_resolved") + + def __init__(self, handle: PipelineRecvDictHandle): + self._handle = handle + self._resolved: IntermediateTensors | None = None + + def _resolve(self) -> IntermediateTensors: + if self._resolved is None: + self._resolved = IntermediateTensors(self._handle.resolve()) + return self._resolved + + def __getitem__(self, key: str): + return self._resolve()[key] + + def __getattr__(self, name: str): + return getattr(self._resolve(), name) + + class PipelineParallelMixin: """ Mixin providing Pipeline Parallelism for diffusion pipelines. @@ -217,14 +218,12 @@ def predict_noise_maybe_with_cfg( # Sequential CFG (or no CFG): this PP pipeline handles all branches. all_kwargs = [positive_kwargs] + ([negative_kwargs] if do_true_cfg else []) - # Non-first ranks receive intermediate tensors asynchronously n = len(all_kwargs) its: list[AsyncIntermediateTensors | None] = [None] * n if not pp_group.is_first_rank: for i in range(n): - its[i] = AsyncIntermediateTensors( - *pp_group.pipeline_irecv_tensor_dict(name="intermediate", segment_idx=i) - ) + handle = pp_group.add_pipeline_recv_dict_task(name="intermediate", segment_idx=i) + its[i] = AsyncIntermediateTensors(handle) if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. @@ -289,5 +288,6 @@ def scheduler_step_maybe_with_cfg( latents = super().scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg, per_request_scheduler) self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name="latents") elif pp_group.is_first_rank: - latents = AsyncLatents(*pp_group.pipeline_irecv_tensor_dict(name="latents")) + handle = pp_group.add_pipeline_recv_dict_task(name="latents") + latents = AsyncLatents(handle) return latents From 63bb687e3cd43a6f84cb6550e578c0c40091c923 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 18:19:33 +0200 Subject: [PATCH 25/53] Chunk-aware buffers Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../diffusion/distributed/pipeline_parallel.py | 13 ++++++++----- .../diffusion/models/wan2_2/pipeline_wan2_2.py | 12 +++++++++--- .../diffusion/worker/diffusion_model_runner.py | 10 +++++++--- vllm_omni/diffusion/worker/utils.py | 6 ++++++ 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 80875a5ca32..8a540871c9a 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -185,6 +185,7 @@ def predict_noise_maybe_with_cfg( negative_kwargs: dict[str, Any] | None, cfg_normalize: bool = True, output_slice: int | None = None, + chunk_idx: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, ...] | None: """ Drop-in replacement for predict_noise_maybe_with_cfg that also handles PP. @@ -220,9 +221,10 @@ def predict_noise_maybe_with_cfg( n = len(all_kwargs) its: list[AsyncIntermediateTensors | None] = [None] * n + it_name = f"{chunk_idx}_intermediate" if chunk_idx is not None else "intermediate" if not pp_group.is_first_rank: for i in range(n): - handle = pp_group.add_pipeline_recv_dict_task(name="intermediate", segment_idx=i) + handle = pp_group.add_pipeline_recv_dict_task(name=it_name, segment_idx=i) its[i] = AsyncIntermediateTensors(handle) if not pp_group.is_last_rank: @@ -230,7 +232,7 @@ def predict_noise_maybe_with_cfg( for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) self._pp_send_work.extend( - pp_group.pipeline_isend_tensor_dict(result.tensors, name="intermediate", segment_idx=i) + pp_group.pipeline_isend_tensor_dict(result.tensors, name=it_name, segment_idx=i) ) return None @@ -265,7 +267,7 @@ def scheduler_step_maybe_with_cfg( latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, per_request_scheduler: Any | None = None, - generator: torch.Generator | None = None, + chunk_idx: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, ...] | AsyncLatents: """ Drop-in replacement for scheduler_step_maybe_with_cfg that also handles PP. @@ -284,10 +286,11 @@ def scheduler_step_maybe_with_cfg( ) pp_group = get_pp_group() + latents_name = f"{chunk_idx}_latents" if chunk_idx is not None else "latents" if pp_group.is_last_rank: latents = super().scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg, per_request_scheduler) - self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name="latents") + self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name=latents_name) elif pp_group.is_first_rank: - handle = pp_group.add_pipeline_recv_dict_task(name="latents") + handle = pp_group.add_pipeline_recv_dict_task(name=latents_name) latents = AsyncLatents(handle) return latents diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 95a3426c9dd..8500e82996a 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1353,6 +1353,7 @@ def denoise_step( true_cfg_scale=current_guidance_scale, positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs, + chunk_idx=state.current_chunk_idx, ) def step_scheduler( @@ -1368,7 +1369,12 @@ def step_scheduler( do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None state.latents = self.scheduler_step_maybe_with_pp_and_cfg( - noise_pred, t, state.latents, do_true_cfg, per_request_scheduler=state.scheduler + noise_pred, + t, + state.latents, + do_true_cfg, + per_request_scheduler=state.scheduler, + chunk_idx=state.current_chunk_idx, ) state.step_index += 1 @@ -1381,8 +1387,8 @@ def post_decode( self.sync_pp_send() self._current_timestep = None - if current_omni_platform.is_available(): - current_omni_platform.empty_cache() + # if current_omni_platform.is_available(): + # current_omni_platform.empty_cache() # I2V: blend final latents with condition latent_condition = state.extra.get("latent_condition") diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 72c078d14be..1d242362cfa 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -545,13 +545,17 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): - if is_new_request: - self.pipeline.prepare_encode(state) - pp_group = get_pp_group() pp_rank = pp_group.rank_in_group task = assignment[pp_rank] + if is_new_request: + pp_group.reset_buffer() + if current_omni_platform.is_available(): + current_omni_platform.empty_cache() + + self.pipeline.prepare_encode(state) + if pp_group.is_first_rank: denoised_chunks = state.extra.pop("denoised_chunks", []) decoded_chunks = state.extra.setdefault("decoded_chunks", []) diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index 7353dab84c1..f031784608f 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -60,6 +60,9 @@ class DiffusionRequestState: # ── Per-request scheduler instance (set once by prepare_encode) ── scheduler: Any | None = None + # Active chunk index while inside a ``use_chunk`` context (None otherwise). + current_chunk_idx: int | None = None + # ── CFG config (set once by prepare_encode) ── do_true_cfg: bool = False guidance: torch.Tensor | None = None @@ -125,9 +128,11 @@ def use_chunk(self, chunk: ChunkState) -> Iterator[None]: saved_latents = self.latents saved_step_index = self.step_index saved_scheduler = self.scheduler + saved_chunk_idx = self.current_chunk_idx self.latents = chunk.latents self.step_index = chunk.step_index self.scheduler = chunk.scheduler + self.current_chunk_idx = chunk.idx try: yield finally: @@ -137,6 +142,7 @@ def use_chunk(self, chunk: ChunkState) -> Iterator[None]: self.latents = saved_latents self.step_index = saved_step_index self.scheduler = saved_scheduler + self.current_chunk_idx = saved_chunk_idx @dataclass From e77972d9ff5d8b50df895aaf0480ac10eca0eee7 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 22 Apr 2026 18:19:43 +0200 Subject: [PATCH 26/53] fix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../diffusion/sched/stream_batch_scheduler.py | 48 ++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index e3557adb26e..6cfd41b0766 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -34,8 +34,9 @@ class _InFlightChunk: """One chunk of an active request, tracked through the temporal pipeline.""" chunk_idx: int - in_pipeline: bool = True # currently flowing through ranks (True between admission and exit) - entered_rank0_at: int = -1 # micro-step at which the chunk last entered rank 0 + is_active: bool = True + is_completed: bool = False + entered_rank0_at: int = -1 @dataclass @@ -149,8 +150,8 @@ def _advance_chunk_pipeline(self) -> None: # 1. Try to re-admit a returning chunk (FIFO oldest-first across requests). for progress in self._chunk_progress.values(): for chunk in progress.in_flight: - if not chunk.in_pipeline: - chunk.in_pipeline = True + if not chunk.is_active: + chunk.is_active = True chunk.entered_rank0_at = self._global_micro_step return # rank 0 is now taken @@ -159,7 +160,7 @@ def _advance_chunk_pipeline(self) -> None: if progress.chunks_admitted < progress.num_chunks: new_chunk = _InFlightChunk( chunk_idx=progress.chunks_admitted, - in_pipeline=True, + is_active=True, entered_rank0_at=self._global_micro_step, ) progress.in_flight.append(new_chunk) @@ -170,7 +171,7 @@ def _build_assignment(self) -> list[RankTask | None]: assignment: list[RankTask | None] = [None] * self.pp_size for progress in self._chunk_progress.values(): for chunk in progress.in_flight: - if not chunk.in_pipeline: + if not chunk.is_active: continue r = self._global_micro_step - chunk.entered_rank0_at if 0 <= r < self.pp_size: @@ -193,16 +194,31 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run terminal: dict[str, DiffusionRequestStatus] = {} - if output.chunk_idx is not None: - progress = self._chunk_progress.get(output.req_id) - if progress is not None: - chunk = self._find_chunk(progress, output.chunk_idx) - if chunk is not None: - chunk.in_pipeline = False - if output.chunk_completed: - progress.in_flight = [ - c for c in progress.in_flight if c.chunk_idx != output.chunk_idx - ] + progress = self._chunk_progress.get(output.req_id) + if progress is None: + return set() + + + chunk = self._find_chunk(progress, output.chunk_idx) if output.chunk_idx is not None else None + if chunk is not None: + chunk.is_completed = output.chunk_completed + + last_task = sched_output.per_rank_assignment[-1] + logger.debug( + "update_from_output: Processing output for micro-step %d: chunk=%s, last_chunk=%s, finished=%s", + self._global_micro_step, chunk, last_task, output.finished, + ) + if last_task is not None and last_task.chunk_idx is not None: + last_chunk = self._find_chunk(progress, last_task.chunk_idx) + if last_chunk is not None: + if last_chunk.is_completed: + progress.in_flight = [ + c for c in progress.in_flight if c.chunk_idx != last_chunk.chunk_idx + ] + else: + last_chunk.is_active = False + + if output.finished: terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED From b95fdfdbb4b5c27b9642bc3191cc465f3c843a0c Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Thu, 23 Apr 2026 13:47:54 +0200 Subject: [PATCH 27/53] Improve and fix PP communication. - Separate comms stream - Double buffering - Set rcv buffers for a new req - Revert changes regarding Async structs Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 191 ++++++++++-------- .../distributed/pipeline_parallel.py | 116 +++++++---- .../models/wan2_2/pipeline_wan2_2.py | 32 ++- .../worker/diffusion_model_runner.py | 4 +- vllm_omni/diffusion/worker/utils.py | 6 - 5 files changed, 217 insertions(+), 132 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 531c7fd0e6e..9defa10eb5d 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -5,6 +5,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import pickle from collections import namedtuple +from contextlib import nullcontext from typing import Any import torch @@ -618,27 +619,6 @@ def destroy(self): self.cpu_group = None -class PipelineRecvDictHandle: - - __slots__ = ("_pp_group", "_name", "_segment_idx", "_resolved") - - def __init__(self, pp_group: "PipelineGroupCoordinator", name: str, segment_idx: int): - self._pp_group = pp_group - self._name = name - self._segment_idx = segment_idx - self._resolved: dict[str, Any] | None = None - - def resolve(self) -> dict[str, Any]: - if self._resolved is None: - tensor_dict, handles, _ = self._pp_group.pipeline_irecv_tensor_dict( - name=self._name, segment_idx=self._segment_idx - ) - for h in handles: - h.wait() - self._resolved = tensor_dict - return self._resolved - - class PipelineGroupCoordinator(GroupCoordinator): """ available attributes: @@ -722,10 +702,11 @@ def __init__( # Cached dict schema and pre-allocated recv buffers for # `pipeline_isend_tensor_dict` / `pipeline_irecv_tensor_dict`. - # The pickled schema is exchanged once per (name, segment_idx) over - # NCCL; subsequent calls only post async tensor sends/recvs. - self.dict_schema_cache: dict[str, dict[int, list[tuple[str, Any]]]] = {} - self.dict_recv_buffer: dict[str, dict[int, dict[str, torch.Tensor]]] = {} + # Keyed by (name, segment_idx). Recv buffer leaf is a length-2 list + # for double buffering. Caller picks the slot via buf_idx. + self.dict_schema_cache: dict[tuple[str, int], list[tuple[str, Any]]] = {} + self.dict_recv_buffer: dict[tuple[str, int], list[dict[str, torch.Tensor]]] = {} + self._comms_stream: Any = None # Dedicated comms stream for PP P2P. None on CPU. self.dict_schema_keepalive: list[torch.Tensor] = [] @@ -794,6 +775,32 @@ def reset_buffer(self): self.receiving_skip_tasks = [] self.skip_tensor_recv_buffer = {} + @property + def comms_stream(self): + """Dedicated stream for PP P2P comms.""" + if self._comms_stream is None and self.device.type != "cpu": + mod = getattr(torch, self.device.type, None) + if mod is not None and hasattr(mod, "Stream"): + self._comms_stream = mod.Stream(device=self.device) + return self._comms_stream + + def _comms_stream_ctx(self): + """Context manager that makes ``comms_stream`` the current stream.""" + stream = self.comms_stream + if stream is None: + return nullcontext() + return getattr(torch, self.device.type).stream(stream) + + def _record_compute_event(self): + """Record an event on the default (compute) stream for later + ``comms_stream.wait_event``.""" + if self.comms_stream is None: + return None + mod = getattr(torch, self.device.type) + ev = mod.Event() + ev.record(mod.current_stream(self.device)) + return ev + def set_config(self, dtype: torch.dtype): self.dtype = dtype @@ -1019,80 +1026,111 @@ def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: self._pipeline_irecv(self.recv_buffer[name][idx]).wait() return self.recv_buffer[name][idx] + def set_recv_dict_buffer( + self, + name: str, + segment_idx: int, + template_dict: dict[str, torch.Tensor | Any], + ) -> None: + """Pre-populate schema cache + a double-buffer pair (indices 0/1) for + (name, segment_idx). + """ + metadata_list, _ = _split_tensor_dict(template_dict) + key = (name, segment_idx) + self.dict_schema_cache[key] = metadata_list + buffer_pair: list[dict[str, torch.Tensor]] = [] + for _ in range(2): + buffers: dict[str, torch.Tensor] = {} + for key_, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + continue + device = self.device if value.device == "cuda" else torch.device(value.device) + buffers[key_] = torch.empty(value.size, dtype=value.dtype, device=device) + buffer_pair.append(buffers) + self.dict_recv_buffer[key] = buffer_pair + def pipeline_isend_tensor_dict( self, tensor_dict: dict[str, torch.Tensor | Any], name: str = "dict", segment_idx: int = -1, ) -> list[torch.distributed.Work]: - """Async tensor-dict send. Schema (keys, scalars, per-tensor - dtype/shape) is exchanged once per (name, segment_idx) over NCCL; - steady state posts only per-tensor isend ops. Schema must be - stable across calls; mismatching shapes will fail at NCCL recv. - """ metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - cache = self.dict_schema_cache.setdefault(name, {}) + key = (name, segment_idx) handles: list[torch.distributed.Work] = [] - if segment_idx not in cache: + if key not in self.dict_schema_cache: schema_handles, keepalive = self._isend_dict_schema(metadata_list) handles.extend(schema_handles) self.dict_schema_keepalive.extend(keepalive) - cache[segment_idx] = metadata_list + self.dict_schema_cache[key] = metadata_list - for tensor in tensor_list: - if tensor.numel() == 0: - continue - tensor = tensor.contiguous() - handle = self._pipeline_isend(tensor) - if tensor.is_cuda: - tensor.record_stream(torch.cuda.current_stream(tensor.device)) - handles.append(handle) + compute_done = self._record_compute_event() + comms = self.comms_stream + with self._comms_stream_ctx(): + if comms is not None and compute_done is not None: + comms.wait_event(compute_done) + for tensor in tensor_list: + if tensor.numel() == 0: + continue + tensor = tensor.contiguous() + if tensor.is_cuda and comms is not None: + tensor.record_stream(comms) + handles.append(self._pipeline_isend(tensor)) return handles def pipeline_irecv_tensor_dict( self, name: str = "dict", segment_idx: int = -1, + buf_idx: int = 0, ) -> tuple[dict[str, torch.Tensor | Any], list[torch.distributed.Work], list]: - """Async tensor-dict recv. Returns ``(tensor_dict, handles, [])`` - matching ``GroupCoordinator.irecv_tensor_dict``. First call - discovers schema over NCCL and pre-allocates persistent buffers; - the returned dict aliases those buffers, so the caller must - consume before the next recv on the same (name, segment_idx). + """Async tensor-dict recv into the ``buf_idx`` slot (0 or 1) of the + double-buffer pair for (name, segment_idx). Caller picks the slot + — typically ``micro_step % 2`` — so consecutive recvs alternate and + the previous result stays readable until its consumer is done. + Posts irecvs on ``comms_stream``. """ - cache = self.dict_schema_cache.setdefault(name, {}) - if segment_idx not in cache: + key = (name, segment_idx) + if key not in self.dict_schema_cache: metadata_list = self._recv_dict_schema() - cache[segment_idx] = metadata_list - buffers: dict[str, torch.Tensor] = {} - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - if torch.Size(value.size).numel() == 0: - continue - device = self.device if value.device == "cuda" else torch.device(value.device) - buffers[key] = torch.empty(value.size, dtype=value.dtype, device=device) - self.dict_recv_buffer.setdefault(name, {})[segment_idx] = buffers - - metadata_list = cache[segment_idx] - buffers = self.dict_recv_buffer[name][segment_idx] + self.dict_schema_cache[key] = metadata_list + buffer_pair: list[dict[str, torch.Tensor]] = [] + for _ in range(2): + buffers: dict[str, torch.Tensor] = {} + for k, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + continue + device = self.device if value.device == "cuda" else torch.device(value.device) + buffers[k] = torch.empty(value.size, dtype=value.dtype, device=device) + buffer_pair.append(buffers) + self.dict_recv_buffer[key] = buffer_pair + + metadata_list = self.dict_schema_cache[key] + buffers = self.dict_recv_buffer[key][buf_idx] + comms = self.comms_stream tensor_dict: dict[str, Any] = {} handles: list[torch.distributed.Work] = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - if torch.Size(value.size).numel() == 0: - _update_nested_dict( - tensor_dict, - key, - torch.empty(value.size, dtype=value.dtype, device=self.device), - ) - continue - tensor = buffers[key] - handles.append(self._pipeline_irecv(tensor)) - _update_nested_dict(tensor_dict, key, tensor) - else: - _update_nested_dict(tensor_dict, key, value) + with self._comms_stream_ctx(): + for k, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + _update_nested_dict( + tensor_dict, + k, + torch.empty(value.size, dtype=value.dtype, device=self.device), + ) + continue + tensor = buffers[k] + if tensor.is_cuda and comms is not None: + tensor.record_stream(comms) + handles.append(self._pipeline_irecv(tensor)) + _update_nested_dict(tensor_dict, k, tensor) + else: + _update_nested_dict(tensor_dict, k, value) return tensor_dict, handles, [] @@ -1115,11 +1153,6 @@ def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.T assert receiving_task[1] == name and receiving_task[2] == idx, "Received tensor does not match the requested" return self.recv_buffer[name][idx] - def add_pipeline_recv_dict_task( - self, name: str = "dict", segment_idx: int = -1 - ) -> PipelineRecvDictHandle: - return PipelineRecvDictHandle(self, name, segment_idx) - def _pipeline_irecv(self, tensor: torch.tensor): # batch_isend_irecv (not plain irecv) — plain P2P on size-2 PG # triggers lazy sub-comm creation that requires the peer present. diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 8a540871c9a..6410be68e98 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -5,9 +5,8 @@ from typing import Any import torch -from vllm.sequence import IntermediateTensors +from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors -from vllm_omni.diffusion.distributed.group_coordinator import PipelineRecvDictHandle from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, @@ -19,23 +18,44 @@ class AsyncLatents: - """Lazy-resolve wrapper around a ``PipelineRecvDictHandle`` for latents.""" + """Transparent async wrapper returned by scheduler_step on rank 0. + + Wraps a pending ``irecv_tensor_dict`` and defers ``handle.wait()`` until the + underlying tensor is actually consumed — either via attribute access + (e.g. ``latents.to(dtype)``, ``latents.shape``) or via a torch operation + (e.g. ``mask * latents``). This keeps the first PP rank non-blocking after + posting the receive, matching the async philosophy used everywhere else in + the PP communication layer. + """ - __slots__ = ("_handle", "_key", "_tensor") + __slots__ = ("_tensor_dict", "_handles", "_postproc", "_tensor") - def __init__(self, handle: PipelineRecvDictHandle, key: str = "latents"): - self._handle = handle - self._key = key + def __init__( + self, + tensor_dict: dict[str, torch.Tensor], + handles: list[torch.distributed.Work], + postproc: list, + ): + self._tensor_dict = tensor_dict + self._handles = handles + self._postproc = postproc self._tensor: torch.Tensor | None = None def _resolve(self) -> torch.Tensor: - if self._tensor is None: - self._tensor = self._handle.resolve()[self._key] + if self._tensor is not None: + return self._tensor + for h in self._handles: + h.wait() + for fn in self._postproc: + fn() + self._tensor = self._tensor_dict["latents"] return self._tensor + # Attribute access (e.g. .shape, .to(), .dtype) delegates to the resolved tensor. def __getattr__(self, name: str): return getattr(self._resolve(), name) + # Torch function protocol: any torch op involving an AsyncLatents resolves it first. @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} @@ -44,7 +64,7 @@ def _unwrap(x): if isinstance(x, AsyncLatents): return x._resolve() if isinstance(x, (list, tuple)): - return type(x)(_unwrap(item) for item in x) + return type(x)(_unwrap(item) for item in x) # type(x) return the class of x to preserve its type return x args = tuple(_unwrap(a) for a in args) @@ -52,27 +72,6 @@ def _unwrap(x): return func(*args, **kwargs) -class AsyncIntermediateTensors: - """Lazy-resolve wrapper around a ``PipelineRecvDictHandle`` for an IT.""" - - __slots__ = ("_handle", "_resolved") - - def __init__(self, handle: PipelineRecvDictHandle): - self._handle = handle - self._resolved: IntermediateTensors | None = None - - def _resolve(self) -> IntermediateTensors: - if self._resolved is None: - self._resolved = IntermediateTensors(self._handle.resolve()) - return self._resolved - - def __getitem__(self, key: str): - return self._resolve()[key] - - def __getattr__(self, name: str): - return getattr(self._resolve(), name) - - class PipelineParallelMixin: """ Mixin providing Pipeline Parallelism for diffusion pipelines. @@ -164,6 +163,19 @@ def _pp_send_work(self) -> list[torch.distributed.Work]: def _pp_send_work(self, work: list[torch.distributed.Work]) -> None: self._pp_send_work_list = work + @property + def _preposted_its(self) -> list[AsyncIntermediateTensors] | None: + """Pre-posted IT recvs for the next micro-step (None if not primed).""" + return getattr(self, "_preposted_its_list", None) + + @_preposted_its.setter + def _preposted_its(self, value: list[AsyncIntermediateTensors] | None) -> None: + self._preposted_its_list = value + + def set_pp_recv_dict_buffers(self, state: Any) -> None: + """Override to pre-register PP dict channels for a new request.""" + return None + def _sync_pp_send(self) -> None: """ Wait on all pending non-blocking PP sends. @@ -185,7 +197,7 @@ def predict_noise_maybe_with_cfg( negative_kwargs: dict[str, Any] | None, cfg_normalize: bool = True, output_slice: int | None = None, - chunk_idx: int | None = None, + buf_idx: int = 0, ) -> torch.Tensor | tuple[torch.Tensor, ...] | None: """ Drop-in replacement for predict_noise_maybe_with_cfg that also handles PP. @@ -221,18 +233,26 @@ def predict_noise_maybe_with_cfg( n = len(all_kwargs) its: list[AsyncIntermediateTensors | None] = [None] * n - it_name = f"{chunk_idx}_intermediate" if chunk_idx is not None else "intermediate" if not pp_group.is_first_rank: - for i in range(n): - handle = pp_group.add_pipeline_recv_dict_task(name=it_name, segment_idx=i) - its[i] = AsyncIntermediateTensors(handle) + # Use recvs pre-posted by the previous step's scheduler_step + preposted = self._preposted_its + if preposted is not None and len(preposted) == n: + its = preposted + self._preposted_its = None + else: + for i in range(n): + its[i] = AsyncIntermediateTensors( + *pp_group.pipeline_irecv_tensor_dict( + name="intermediate", segment_idx=i, buf_idx=buf_idx + ) + ) if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) self._pp_send_work.extend( - pp_group.pipeline_isend_tensor_dict(result.tensors, name=it_name, segment_idx=i) + pp_group.pipeline_isend_tensor_dict(result.tensors, name="intermediate", segment_idx=i) ) return None @@ -267,7 +287,8 @@ def scheduler_step_maybe_with_cfg( latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, per_request_scheduler: Any | None = None, - chunk_idx: int | None = None, + buf_idx: int = 0, + is_last_step: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...] | AsyncLatents: """ Drop-in replacement for scheduler_step_maybe_with_cfg that also handles PP. @@ -286,11 +307,22 @@ def scheduler_step_maybe_with_cfg( ) pp_group = get_pp_group() - latents_name = f"{chunk_idx}_latents" if chunk_idx is not None else "latents" if pp_group.is_last_rank: latents = super().scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg, per_request_scheduler) - self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name=latents_name) + self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name="latents") elif pp_group.is_first_rank: - handle = pp_group.add_pipeline_recv_dict_task(name=latents_name) - latents = AsyncLatents(handle) + latents = AsyncLatents(*pp_group.pipeline_irecv_tensor_dict(name="latents", buf_idx=buf_idx)) + + if not pp_group.is_first_rank and not is_last_step: + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + n = 1 if cfg_parallel_ready else (2 if do_true_cfg else 1) + next_buf_idx = (buf_idx + 1) % 2 + self._preposted_its = [ + AsyncIntermediateTensors( + *pp_group.pipeline_irecv_tensor_dict( + name="intermediate", segment_idx=i, buf_idx=next_buf_idx + ) + ) + for i in range(n) + ] return latents diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 8500e82996a..c13e66ce593 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1311,6 +1311,33 @@ def prepare_encode( return state + def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: + from vllm_omni.diffusion.distributed.parallel_state import get_pp_group + + pp_group = get_pp_group() + if pp_group.world_size == 1: + return + + latents_template = {"latents": state.latents} + + # Intermediate tensor: [batch, seq_len, inner_dim] after patch_embed + flatten. + batch = state.latents.shape[0] + num_frames = state.latents.shape[2] + height = state.latents.shape[3] + width = state.latents.shape[4] + p_t, p_h, p_w = self.transformer_config.patch_size + seq_len = (num_frames // p_t) * (height // p_h) * (width // p_w) + inner_dim = self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim + dtype = (self.transformer or self.transformer_2).dtype + it_template = { + "hidden_states": torch.empty(batch, seq_len, inner_dim, dtype=dtype, device=self.device) + } + + cfg_branches = 2 if state.do_true_cfg else 1 + pp_group.set_recv_dict_buffer("latents", -1, latents_template) + for seg in range(cfg_branches): + pp_group.set_recv_dict_buffer("intermediate", seg, it_template) + def denoise_step( self, state: DiffusionRequestState, @@ -1353,7 +1380,7 @@ def denoise_step( true_cfg_scale=current_guidance_scale, positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs, - chunk_idx=state.current_chunk_idx, + buf_idx=state.step_index % 2, ) def step_scheduler( @@ -1374,7 +1401,8 @@ def step_scheduler( state.latents, do_true_cfg, per_request_scheduler=state.scheduler, - chunk_idx=state.current_chunk_idx, + buf_idx=state.step_index % 2, + is_last_step=state.step_index == state.total_steps - 1, ) state.step_index += 1 diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 1d242362cfa..d8f21da9eb9 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -551,9 +551,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn if is_new_request: pp_group.reset_buffer() - if current_omni_platform.is_available(): - current_omni_platform.empty_cache() - + self.pipeline.set_pp_recv_dict_buffers(state) self.pipeline.prepare_encode(state) if pp_group.is_first_rank: diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index f031784608f..7353dab84c1 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -60,9 +60,6 @@ class DiffusionRequestState: # ── Per-request scheduler instance (set once by prepare_encode) ── scheduler: Any | None = None - # Active chunk index while inside a ``use_chunk`` context (None otherwise). - current_chunk_idx: int | None = None - # ── CFG config (set once by prepare_encode) ── do_true_cfg: bool = False guidance: torch.Tensor | None = None @@ -128,11 +125,9 @@ def use_chunk(self, chunk: ChunkState) -> Iterator[None]: saved_latents = self.latents saved_step_index = self.step_index saved_scheduler = self.scheduler - saved_chunk_idx = self.current_chunk_idx self.latents = chunk.latents self.step_index = chunk.step_index self.scheduler = chunk.scheduler - self.current_chunk_idx = chunk.idx try: yield finally: @@ -142,7 +137,6 @@ def use_chunk(self, chunk: ChunkState) -> Iterator[None]: self.latents = saved_latents self.step_index = saved_step_index self.scheduler = saved_scheduler - self.current_chunk_idx = saved_chunk_idx @dataclass From c9b622a70469ca355a515f0be3995e6e3a685bde Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:01:22 +0200 Subject: [PATCH 28/53] bugfixes Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/pipeline_parallel.py | 12 ++--- .../models/wan2_2/pipeline_wan2_2.py | 25 ++++++++- .../worker/diffusion_model_runner.py | 52 ++++++++++++------- 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 6410be68e98..a39a591ef48 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -172,10 +172,6 @@ def _preposted_its(self) -> list[AsyncIntermediateTensors] | None: def _preposted_its(self, value: list[AsyncIntermediateTensors] | None) -> None: self._preposted_its_list = value - def set_pp_recv_dict_buffers(self, state: Any) -> None: - """Override to pre-register PP dict channels for a new request.""" - return None - def _sync_pp_send(self) -> None: """ Wait on all pending non-blocking PP sends. @@ -312,7 +308,10 @@ def scheduler_step_maybe_with_cfg( self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name="latents") elif pp_group.is_first_rank: latents = AsyncLatents(*pp_group.pipeline_irecv_tensor_dict(name="latents", buf_idx=buf_idx)) - + return latents + + def prefetch_its_maybe_with_pp_and_cfg(self, do_true_cfg: bool, buf_idx: int, is_last_step: bool) -> None: + pp_group = get_pp_group() if not pp_group.is_first_rank and not is_last_step: cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 n = 1 if cfg_parallel_ready else (2 if do_true_cfg else 1) @@ -324,5 +323,4 @@ def scheduler_step_maybe_with_cfg( ) ) for i in range(n) - ] - return latents + ] \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index c13e66ce593..a1e9800f337 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -436,6 +436,8 @@ def __init__( enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler ) + self.is_buffer_setup = False + def _create_transformer(self, config: dict) -> WanTransformer3DModel: """Create a transformer from a config dict. Respects od_config.quantization_config.""" quant_config = getattr(self.od_config, "quantization_config", None) @@ -1338,6 +1340,8 @@ def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: for seg in range(cfg_branches): pp_group.set_recv_dict_buffer("intermediate", seg, it_template) + self.is_buffer_setup = True + def denoise_step( self, state: DiffusionRequestState, @@ -1383,6 +1387,21 @@ def denoise_step( buf_idx=state.step_index % 2, ) + def prefetch_its(self, state: DiffusionRequestState) -> None: + """Prefetch intermediate tensors for the next step.""" + t = state.current_timestep + boundary_timestep = state.extra.get("boundary_timestep") + _, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) + do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None + buf_idx = state.step_index % 2 + is_last_step = state.step_index == state.total_steps - 1 + + self.prefetch_its_maybe_with_pp_and_cfg( + do_true_cfg=do_true_cfg, + buf_idx=buf_idx, + is_last_step=is_last_step, + ) + def step_scheduler( self, state: DiffusionRequestState, @@ -1394,6 +1413,8 @@ def step_scheduler( boundary_timestep = state.extra.get("boundary_timestep") _, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None + buf_idx = state.step_index % 2 + is_last_step = state.step_index == state.total_steps - 1 state.latents = self.scheduler_step_maybe_with_pp_and_cfg( noise_pred, @@ -1401,8 +1422,8 @@ def step_scheduler( state.latents, do_true_cfg, per_request_scheduler=state.scheduler, - buf_idx=state.step_index % 2, - is_last_step=state.step_index == state.total_steps - 1, + buf_idx=buf_idx, + is_last_step=is_last_step, ) state.step_index += 1 diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index d8f21da9eb9..78a3ebde5bc 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -513,7 +513,7 @@ def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ chunks: dict[int, ChunkState] = state.extra.setdefault("chunks", {}) chunk = chunks.get(chunk_idx) if chunk is not None: - return chunk, False + return chunk, chunk.step_index == 0 chunk = ChunkState(idx=chunk_idx) chunks[chunk_idx] = chunk return chunk, True @@ -548,39 +548,41 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn pp_group = get_pp_group() pp_rank = pp_group.rank_in_group task = assignment[pp_rank] + prev_task = assignment[pp_group.prev_rank] if is_new_request: pp_group.reset_buffer() - self.pipeline.set_pp_recv_dict_buffers(state) + self.pipeline.is_buffer_setup = False self.pipeline.prepare_encode(state) - if pp_group.is_first_rank: - denoised_chunks = state.extra.pop("denoised_chunks", []) + if pp_group.is_first_rank: # TODO: race condition + denoised_chunks = state.extra.get("denoised_chunks", []) decoded_chunks = state.extra.setdefault("decoded_chunks", []) + new_denoised_chunks = [] - for chunk in denoised_chunks: - with state.use_chunk(chunk): - decoded_chunks.append(self.pipeline.post_decode(state)) - - if len(decoded_chunks) >= state.sampling.num_chunks: - assert len(decoded_chunks) == state.sampling.num_chunks, ( - f"Expected {state.sampling.num_chunks} denoised chunks but got {len(decoded_chunks)}" - ) + for chunk, steps_left in denoised_chunks: + steps_left -= 1 + if steps_left == 0: + with state.use_chunk(chunk): + decoded_chunks.append(self.pipeline.post_decode(state)) + else: + new_denoised_chunks.append((chunk, steps_left)) + state.extra["denoised_chunks"] = new_denoised_chunks + if len(decoded_chunks) == state.sampling.num_chunks: output = RunnerOutput( req_id=state.req_id, step_index=state.step_index, finished=True, result=self._merge_chunk_outputs(decoded_chunks), ) - - self._update_states_after(state, finished=True) - + self._update_states_after(state, finished=True) # TODO: call properly on all ranks return output - + if task is None: return RunnerOutput(req_id=state.req_id) + chunk, is_new_chunk = self._get_or_create_chunk(state, task.chunk_idx) if is_new_chunk: @@ -596,6 +598,8 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn chunk.scheduler = copy.deepcopy(state.scheduler) with state.use_chunk(chunk): + if not self.pipeline.is_buffer_setup: + self.pipeline.set_pp_recv_dict_buffers(state) noise_pred = self.pipeline.denoise_step(state) if noise_pred is None and getattr(self.pipeline, "interrupt", False): return RunnerOutput( @@ -605,17 +609,25 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn self.pipeline.step_scheduler(state, noise_pred) chunk_done = state.denoise_completed + + # prefetch the chunk of the next micro-step + prev_chunk, _ = self._get_or_create_chunk(state, prev_task.chunk_idx) if prev_task is not None else (None, None) + if prev_chunk is not None: + with state.use_chunk(prev_chunk): + self.pipeline.prefetch_its(state) + + output = RunnerOutput( req_id=task.sched_req_id, step_index=chunk.step_index, chunk_idx=task.chunk_idx, ) - if chunk_done and pp_group.is_first_rank: - state.extra.setdefault("denoised_chunks", []).append(chunk) - + if chunk_done: output.chunk_completed = True state.extra["chunks"].pop(task.chunk_idx, None) - + if pp_group.is_first_rank: + steps_left = pp_group.world_size + state.extra.setdefault("denoised_chunks", []).append((chunk, steps_left)) return output From 826b50568a4d475570630342c341f819348f7864 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:57:13 +0200 Subject: [PATCH 29/53] Remove stale tests Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../test_diffusion_micro_step_pipeline.py | 486 ------------------ .../diffusion/test_stream_batch_scheduler.py | 432 ---------------- 2 files changed, 918 deletions(-) diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index f4b7af3f064..e69de29bb2d 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -1,486 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for micro-step (temporal PP) execution across runner / worker / executor / engine.""" - -from __future__ import annotations - -import copy -import queue -import threading -from contextlib import contextmanager -from types import SimpleNamespace - -import pytest -import torch -from pytest_mock import MockerFixture - -import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module -from vllm_omni.diffusion.data import DiffusionOutput, DiffusionParallelConfig -from vllm_omni.diffusion.diffusion_engine import DiffusionEngine -from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor -from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.sched import StreamBatchScheduler -from vllm_omni.diffusion.sched.interface import ( - CachedRequestData, - DiffusionSchedulerOutput, - NewRequestData, - RankTask, -) -from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner -from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker -from vllm_omni.diffusion.worker.utils import RunnerOutput -from vllm_omni.inputs.data import OmniDiffusionSamplingParams - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] - - -# --------------------------------------------------------------------------- -# Helpers & fixtures -# --------------------------------------------------------------------------- - - -@contextmanager -def _noop_forward_context(*args, **kwargs): - del args, kwargs - yield - - -class _FakeScheduler: - """Minimal scheduler deepcopyable and tracks step_index.""" - - def __init__(self): - self._step_index = 0 - - def step(self, noise_pred, t, latents, return_dict=False): - del t, return_dict - self._step_index += 1 - return (latents + noise_pred,) - - -class _MicroStepPipeline: - """Minimal pipeline stub supporting micro-step execution.""" - - supports_step_execution = True - - def __init__(self): - self.prepare_calls = 0 - self.denoise_calls = 0 - self.scheduler_calls = 0 - self.decode_calls = 0 - self.sync_calls = 0 - self.scheduler = _FakeScheduler() - - def prepare_encode(self, state, **kwargs): - del kwargs - self.prepare_calls += 1 - n = state.sampling.num_inference_steps - state.timesteps = [torch.tensor(float(n - i)) for i in range(n)] - state.latents = torch.zeros((1, 1, 2, 2, 2)) # [B, C, T, H, W] video-like - state.step_index = 0 - state.scheduler = copy.deepcopy(self.scheduler) - return state - - def denoise_step(self, state, **kwargs): - del kwargs - self.denoise_calls += 1 - return torch.ones_like(state.latents) - - def step_scheduler(self, state, noise_pred, **kwargs): - del noise_pred, kwargs - self.scheduler_calls += 1 - state.step_index += 1 - - def post_decode(self, state, **kwargs): - del kwargs - self.decode_calls += 1 - # Produce a per-chunk video tensor uniquely tagged by decode call count - # so we can verify concatenation order downstream. - return DiffusionOutput(output=torch.full((1, 1, 2, 2, 2), float(self.decode_calls))) - - def sync_pp_send(self): - self.sync_calls += 1 - - -def _make_pp_group(rank: int, world_size: int) -> SimpleNamespace: - """Mock PP group for the runner's get_pp_group() call.""" - return SimpleNamespace( - rank_in_group=rank, - is_first_rank=(rank == 0), - is_last_rank=(rank == world_size - 1), - ) - - -def _make_runner(pp_size: int = 1, pp_rank: int = 0) -> DiffusionModelRunner: - runner = object.__new__(DiffusionModelRunner) - runner.vllm_config = object() - runner.od_config = SimpleNamespace( - cache_backend=None, - parallel_config=SimpleNamespace(use_hsdp=False, pipeline_parallel_size=pp_size), - ) - runner.device = torch.device("cpu") - runner.pipeline = _MicroStepPipeline() - runner.cache_backend = None - runner.offload_backend = None - runner.state_cache = {} - runner.kv_transfer_manager = SimpleNamespace() - runner._pp_group = _make_pp_group(pp_rank, pp_size) - return runner - - -def _install_pp_group_stub(monkeypatch, rank: int, world_size: int) -> None: - """Replace get_pp_group in the runner module with a constant stub.""" - monkeypatch.setattr( - model_runner_module, - "get_pp_group", - lambda: _make_pp_group(rank, world_size), - ) - - -def _make_micro_step_request( - num_chunks: int = 1, num_inference_steps: int = 2 -) -> OmniDiffusionRequest: - return OmniDiffusionRequest( - prompts=["a prompt"], - sampling_params=OmniDiffusionSamplingParams( - num_inference_steps=num_inference_steps, - num_chunks=num_chunks, - seed=42, - ), - request_ids=["req-1"], - ) - - -def _make_micro_step_scheduler_output( - task: RankTask | None, - pp_size: int, - req: OmniDiffusionRequest | None = None, - step_id: int = 0, - finished_req_ids: set[str] | None = None, -) -> DiffusionSchedulerOutput: - """Scheduler output with a single-rank assignment (the rest idle).""" - assignment: list[RankTask | None] = [None] * pp_size - if task is not None: - # For the runner we're simulating, the task is at the rank the runner reports. - # Tests set up pp_rank separately via monkeypatch; task is placed at rank 0 by default. - assignment[0] = task - - new_reqs = [] - cached = CachedRequestData.make_empty() - if req is not None: - new_reqs = [NewRequestData(sched_req_id="req-1", req=req)] - else: - cached = CachedRequestData(sched_req_ids=["req-1"]) - - return DiffusionSchedulerOutput( - step_id=step_id, - scheduled_new_reqs=new_reqs, - scheduled_cached_reqs=cached, - finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), - num_running_reqs=1, - num_waiting_reqs=0, - per_rank_assignment=assignment, - ) - - -def _make_engine(scheduler, execute_fn=None, stream_batch: bool = True) -> DiffusionEngine: - engine = object.__new__(DiffusionEngine) - engine.od_config = SimpleNamespace(model_class_name="Wan22Pipeline") - engine.pre_process_func = None - engine.post_process_func = None - engine.scheduler = scheduler - engine.execute_fn = execute_fn - engine.stream_batch = stream_batch - engine.step_execution = True - engine._rpc_lock = threading.RLock() - engine.abort_queue = queue.Queue() - return engine - - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - - -class TestMicroStepRunner: - """DiffusionModelRunner.execute_micro_step""" - - def test_single_chunk_completes_and_returns_merged(self, monkeypatch): - runner = _make_runner(pp_size=1, pp_rank=0) - _install_pp_group_stub(monkeypatch, rank=0, world_size=1) - monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) - req = _make_micro_step_request(num_chunks=1, num_inference_steps=2) - - # Step 0. - out0 = DiffusionModelRunner.execute_micro_step( - runner, - _make_micro_step_scheduler_output( - RankTask(sched_req_id="req-1", chunk_idx=0, step_index=0), pp_size=1, req=req, - ), - ) - assert out0.chunk_idx == 0 - assert out0.step_index == 1 - assert out0.chunk_completed is False - assert out0.finished is False - assert out0.result is None - - # Step 1 (completes the chunk — with single rank, single chunk, this finishes the request). - out1 = DiffusionModelRunner.execute_micro_step( - runner, - _make_micro_step_scheduler_output( - RankTask(sched_req_id="req-1", chunk_idx=0, step_index=1), pp_size=1, - ), - ) - assert out1.chunk_idx == 0 - assert out1.step_index == 2 - assert out1.chunk_completed is True - assert out1.finished is True - assert out1.result is not None - assert runner.pipeline.decode_calls == 1 - # State cache should be cleared once the request completes. - assert "req-1" not in runner.state_cache - - def test_multi_chunk_produces_concatenated_result(self, monkeypatch): - runner = _make_runner(pp_size=1, pp_rank=0) - _install_pp_group_stub(monkeypatch, rank=0, world_size=1) - monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) - req = _make_micro_step_request(num_chunks=2, num_inference_steps=1) - - # Chunk 0, step 0 (completes chunk 0). - DiffusionModelRunner.execute_micro_step( - runner, - _make_micro_step_scheduler_output( - RankTask(sched_req_id="req-1", chunk_idx=0, step_index=0), pp_size=1, req=req, - ), - ) - # Chunk 1, step 0 (completes chunk 1 → merges → finishes request). - final = DiffusionModelRunner.execute_micro_step( - runner, - _make_micro_step_scheduler_output( - RankTask(sched_req_id="req-1", chunk_idx=1, step_index=0), pp_size=1, - ), - ) - assert final.finished is True - assert final.result is not None - # Two chunks concatenated along time dim (dim 2): [1, 1, 4, 2, 2]. - assert final.result.output.shape == (1, 1, 4, 2, 2) - # First chunk's frames tagged 1.0, second chunk's tagged 2.0. - assert torch.all(final.result.output[:, :, :2] == 1.0) - assert torch.all(final.result.output[:, :, 2:] == 2.0) - assert runner.pipeline.decode_calls == 2 - - def test_idle_rank_returns_early_and_syncs(self, monkeypatch): - runner = _make_runner(pp_size=2, pp_rank=1) - _install_pp_group_stub(monkeypatch, rank=1, world_size=2) - monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) - req = _make_micro_step_request(num_chunks=1, num_inference_steps=2) - - # Assignment puts task at rank 0 only; rank 1 (this runner) is idle. - sched_output = _make_micro_step_scheduler_output( - RankTask(sched_req_id="req-1", chunk_idx=0, step_index=0), pp_size=2, req=req, - ) - out = DiffusionModelRunner.execute_micro_step(runner, sched_output) - - assert out.chunk_idx is None - assert out.step_index is None - assert out.finished is False - # Idle path still drains pending sends. - assert runner.pipeline.sync_calls == 1 - # prepare_encode must run on idle ranks so shared state is ready when - # the rank later receives a chunk. - assert runner.pipeline.prepare_calls == 1 - - def test_rejects_missing_per_rank_assignment(self, monkeypatch): - runner = _make_runner(pp_size=1, pp_rank=0) - _install_pp_group_stub(monkeypatch, rank=0, world_size=1) - req = _make_micro_step_request() - - sched_output = DiffusionSchedulerOutput( - step_id=0, - scheduled_new_reqs=[NewRequestData(sched_req_id="req-1", req=req)], - scheduled_cached_reqs=CachedRequestData.make_empty(), - finished_req_ids=set(), - num_running_reqs=1, - num_waiting_reqs=0, - per_rank_assignment=None, - ) - - with pytest.raises(ValueError, match="per_rank_assignment"): - DiffusionModelRunner.execute_micro_step(runner, sched_output) - - -# --------------------------------------------------------------------------- -# Worker -# --------------------------------------------------------------------------- - - -class TestMicroStepWorker: - """DiffusionWorker.execute_micro_step""" - - def test_delegates_to_model_runner(self): - worker = object.__new__(DiffusionWorker) - expected = RunnerOutput(req_id="req-1", chunk_idx=0, step_index=1) - scheduler_output = SimpleNamespace( - scheduled_new_reqs=[ - SimpleNamespace( - req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None)) - ) - ] - ) - worker.lora_manager = None - worker.model_runner = SimpleNamespace( - execute_micro_step=lambda arg: expected if arg is scheduler_output else None - ) - - output = DiffusionWorker.execute_micro_step(worker, scheduler_output) - assert output is expected - - def test_rejects_lora_requests(self): - worker = object.__new__(DiffusionWorker) - scheduler_output = SimpleNamespace( - scheduled_new_reqs=[ - SimpleNamespace( - req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=object())) - ) - ] - ) - worker.lora_manager = None - worker.model_runner = SimpleNamespace(execute_micro_step=lambda arg: RunnerOutput(req_id="req-1")) - - with pytest.raises(ValueError, match="does not support LoRA"): - DiffusionWorker.execute_micro_step(worker, scheduler_output) - - -# --------------------------------------------------------------------------- -# Executor -# --------------------------------------------------------------------------- - - -class TestMicroStepExecutor: - """MultiprocDiffusionExecutor.execute_micro_step""" - - def test_passes_through_runner_output_and_uses_first_pp_rank(self, mocker: MockerFixture): - executor = object.__new__(MultiprocDiffusionExecutor) - executor._ensure_open = lambda: None - executor.od_config = SimpleNamespace( - parallel_config=DiffusionParallelConfig(pipeline_parallel_size=4), - ) - expected = RunnerOutput(req_id="req-1", chunk_idx=0, step_index=1, chunk_completed=False) - rpc_mock = mocker.Mock(return_value=expected) - executor.collective_rpc = rpc_mock - - sched_output = DiffusionSchedulerOutput( - step_id=0, - scheduled_new_reqs=[], - scheduled_cached_reqs=CachedRequestData(sched_req_ids=["req-1"]), - finished_req_ids=set(), - num_running_reqs=1, - num_waiting_reqs=0, - per_rank_assignment=[None, None, None, RankTask("req-1", 0, 0)], - ) - - result = MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) - - assert result is expected - # Reply is collected from the first PP rank (index 0). - rpc_mock.assert_called_once() - kwargs = rpc_mock.call_args.kwargs - assert kwargs["unique_reply_rank"] == 0 - assert kwargs["exec_all_ranks"] is True - - -# --------------------------------------------------------------------------- -# Engine (full loop: scheduler + executor + engine) -# --------------------------------------------------------------------------- - - -class TestMicroStepEngine: - """Full stream-batch flow through DiffusionEngine.add_req_and_wait_for_response.""" - - def _make_scheduler(self, pp_size: int = 1) -> StreamBatchScheduler: - scheduler = StreamBatchScheduler() - scheduler.initialize( - SimpleNamespace(parallel_config=DiffusionParallelConfig(pipeline_parallel_size=pp_size)) - ) - return scheduler - - def _make_execute_fn(self, num_chunks: int, num_steps: int): - """Simulate the executor: advance each micro-step's last-rank chunk.""" - completed = {"n": 0} - - def execute_fn(sched_output): - assignment = sched_output.per_rank_assignment - if assignment is None: - return RunnerOutput(req_id="") - task = assignment[-1] # last rank's slot - if task is None: - req_id = sched_output.scheduled_req_ids[0] if sched_output.scheduled_req_ids else "" - return RunnerOutput(req_id=req_id) - - new_step = task.step_index + 1 - chunk_completed = new_step >= num_steps - finished = False - result = None - if chunk_completed: - completed["n"] += 1 - if completed["n"] >= num_chunks: - finished = True - result = DiffusionOutput(output=torch.tensor([float(completed["n"])])) - return RunnerOutput( - req_id=task.sched_req_id, - step_index=new_step, - chunk_idx=task.chunk_idx, - chunk_completed=chunk_completed, - finished=finished, - result=result, - ) - - return execute_fn - - def test_single_chunk_completes(self): - scheduler = self._make_scheduler(pp_size=1) - engine = _make_engine(scheduler, execute_fn=self._make_execute_fn(num_chunks=1, num_steps=2)) - request = _make_micro_step_request(num_chunks=1, num_inference_steps=2) - - output = engine.add_req_and_wait_for_response(request) - - assert output.error is None - assert output.aborted is False - assert torch.equal(output.output, torch.tensor([1.0])) - - def test_multi_chunk_completes(self): - scheduler = self._make_scheduler(pp_size=1) - engine = _make_engine(scheduler, execute_fn=self._make_execute_fn(num_chunks=3, num_steps=2)) - request = _make_micro_step_request(num_chunks=3, num_inference_steps=2) - - output = engine.add_req_and_wait_for_response(request) - - assert output.error is None - assert torch.equal(output.output, torch.tensor([3.0])) # completed 3 chunks - - def test_execute_fn_exception_returns_error(self): - scheduler = self._make_scheduler(pp_size=1) - - def failing(_): - raise RuntimeError("gpu on fire") - - engine = _make_engine(scheduler, execute_fn=failing) - output = engine.add_req_and_wait_for_response(_make_micro_step_request()) - - assert output.output is None - assert "gpu on fire" in output.error - - def test_pipeline_fills_with_pp_gt_1(self): - """With PP>1, scheduler drives warmup/steady/cooldown; engine sees final merged output.""" - pp_size = 3 - num_chunks = 4 - num_steps = 2 - scheduler = self._make_scheduler(pp_size=pp_size) - engine = _make_engine( - scheduler, execute_fn=self._make_execute_fn(num_chunks=num_chunks, num_steps=num_steps) - ) - request = _make_micro_step_request(num_chunks=num_chunks, num_inference_steps=num_steps) - - output = engine.add_req_and_wait_for_response(request) - - assert output.error is None - assert torch.equal(output.output, torch.tensor([float(num_chunks)])) \ No newline at end of file diff --git a/tests/diffusion/test_stream_batch_scheduler.py b/tests/diffusion/test_stream_batch_scheduler.py index c5a94440390..e69de29bb2d 100644 --- a/tests/diffusion/test_stream_batch_scheduler.py +++ b/tests/diffusion/test_stream_batch_scheduler.py @@ -1,432 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -"""Unit tests for StreamBatchScheduler (temporal PP chunk scheduling).""" - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -from vllm_omni.diffusion.data import DiffusionParallelConfig -from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.sched.stream_batch_scheduler import StreamBatchScheduler -from vllm_omni.inputs.data import OmniDiffusionSamplingParams - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_config(pp_size: int = 2) -> SimpleNamespace: - """Minimal OmniDiffusionConfig stub with the fields StreamBatchScheduler reads.""" - return SimpleNamespace(parallel_config=DiffusionParallelConfig(pipeline_parallel_size=pp_size)) - - -def _make_request( - req_id: str, - num_chunks: int = 1, - num_inference_steps: int = 4, -) -> OmniDiffusionRequest: - return OmniDiffusionRequest( - prompts=[f"prompt_{req_id}"], - sampling_params=OmniDiffusionSamplingParams( - num_inference_steps=num_inference_steps, - num_chunks=num_chunks, - ), - request_ids=[req_id], - ) - - -def _make_runner_output( - req_id: str = "", - step_index: int | None = None, - chunk_idx: int | None = None, - chunk_completed: bool = False, - finished: bool = False, -) -> SimpleNamespace: - """Simulate a RunnerOutput from rank N-1.""" - return SimpleNamespace( - req_id=req_id, - step_index=step_index, - chunk_idx=chunk_idx, - chunk_completed=chunk_completed, - finished=finished, - result=None, - ) - - -def _simulate_last_rank_output(sched_output, pp_size: int, num_steps: int) -> SimpleNamespace: - """Build the RunnerOutput that rank N-1 would produce for a given schedule output. - - Simulates: if rank N-1 had a task, its chunk's step_index advances by 1. - chunk_completed is True if the new step_index reaches num_steps. - """ - assignment = sched_output.per_rank_assignment - if assignment is None: - return _make_runner_output() - task = assignment[pp_size - 1] - if task is None: - req_id = sched_output.scheduled_req_ids[0] if sched_output.scheduled_req_ids else "" - return _make_runner_output(req_id=req_id) - new_step = task.step_index + 1 - return _make_runner_output( - req_id=task.sched_req_id, - step_index=new_step, - chunk_idx=task.chunk_idx, - chunk_completed=(new_step >= num_steps), - ) - - -def _run_until_finished( - scheduler: StreamBatchScheduler, - pp_size: int, - num_steps: int, - num_chunks: int, - max_iters: int = 200, -) -> list[tuple[list, SimpleNamespace]]: - """Drive the scheduler loop, returning (assignment, runner_output) per micro-step. - - Simulates what the runner would set in ``RunnerOutput.finished``: True once - the total number of chunks produced equals ``num_chunks`` for that request. - """ - trace: list[tuple[list, SimpleNamespace]] = [] - completed_per_req: dict[str, int] = {} - for _ in range(max_iters): - sched_output = scheduler.schedule() - output = _simulate_last_rank_output(sched_output, pp_size, num_steps) - if output.chunk_completed: - completed_per_req[output.req_id] = completed_per_req.get(output.req_id, 0) + 1 - if completed_per_req[output.req_id] >= num_chunks: - output.finished = True - finished = scheduler.update_from_output(sched_output, output) - assignment = sched_output.per_rank_assignment or [None] * pp_size - trace.append((assignment, output)) - if finished: - break - return trace - - -# --------------------------------------------------------------------------- -# Validation -# --------------------------------------------------------------------------- - - -class TestAddRequestValidation: - def test_rejects_zero_chunks(self): - sched = StreamBatchScheduler() - sched.initialize(_make_config()) - with pytest.raises(ValueError, match="num_chunks"): - sched.add_request(_make_request("r1", num_chunks=0)) - - def test_rejects_negative_steps(self): - sched = StreamBatchScheduler() - sched.initialize(_make_config()) - with pytest.raises(ValueError, match="num_inference_steps"): - sched.add_request(_make_request("r1", num_inference_steps=-1)) - - def test_accepts_valid_request(self): - sched = StreamBatchScheduler() - sched.initialize(_make_config()) - req_id = sched.add_request(_make_request("r1", num_chunks=3, num_inference_steps=10)) - assert req_id == "r1" - - -# --------------------------------------------------------------------------- -# Single chunk, single rank (PP=1) -# --------------------------------------------------------------------------- - - -class TestSingleChunkSingleRank: - """PP=1, K=1 — degenerate case: one chunk, one rank, behaves like step scheduler.""" - - def test_completes_in_m_steps(self): - pp_size, num_chunks, num_steps = 1, 1, 4 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - assert len(trace) == num_steps - - def test_assignment_is_always_rank_0(self): - pp_size, num_chunks, num_steps = 1, 1, 3 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - for assignment, _ in trace: - assert assignment[0] is not None - assert assignment[0].sched_req_id == "r1" - assert assignment[0].chunk_idx == 0 - - def test_step_index_advances(self): - pp_size, num_chunks, num_steps = 1, 1, 3 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - step_indices = [a[0].step_index for a, _ in trace] - assert step_indices == [0, 1, 2] - - -# --------------------------------------------------------------------------- -# Multi-chunk, single rank (PP=1) -# --------------------------------------------------------------------------- - - -class TestMultiChunkSingleRank: - """PP=1, K>1 — chunks are processed sequentially on one rank.""" - - def test_completes_in_k_times_m_steps(self): - pp_size, num_chunks, num_steps = 1, 3, 2 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - assert len(trace) == num_chunks * num_steps - - def test_chunks_processed_in_order(self): - pp_size, num_chunks, num_steps = 1, 3, 2 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - chunk_indices = [a[0].chunk_idx for a, _ in trace] - # Each chunk runs M steps before the next chunk starts. - assert chunk_indices == [0, 0, 1, 1, 2, 2] - - -# --------------------------------------------------------------------------- -# Pipeline warmup / assignment (PP > 1) -# --------------------------------------------------------------------------- - - -class TestPipelineAssignment: - """Verify per-rank assignment table for N=3, K=4, M=2.""" - - def _setup(self): - pp_size, num_chunks, num_steps = 3, 4, 2 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - return sched, pp_size, num_steps, num_chunks - - def _extract_assignment(self, assignment): - """Convert assignment list to tuples of (chunk_idx, step_index) or None.""" - return [ - (t.chunk_idx, t.step_index) if t is not None else None - for t in assignment - ] - - def test_warmup_idles_trailing_ranks(self): - sched, pp_size, num_steps, num_chunks = self._setup() - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - - # Micro-step 0: only rank 0 active. - a0 = self._extract_assignment(trace[0][0]) - assert a0[0] is not None - assert a0[1] is None - assert a0[2] is None - - def test_warmup_fills_pipeline(self): - sched, pp_size, num_steps, num_chunks = self._setup() - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - - # Micro-step 0: rank 0 = chunk 0 - assert self._extract_assignment(trace[0][0]) == [(0, 0), None, None] - # Micro-step 1: rank 0 = chunk 1, rank 1 = chunk 0 - assert self._extract_assignment(trace[1][0]) == [(1, 0), (0, 0), None] - # Micro-step 2: all ranks busy - a2 = self._extract_assignment(trace[2][0]) - assert all(x is not None for x in a2) - - def test_chunk_propagates_through_ranks(self): - sched, pp_size, num_steps, num_chunks = self._setup() - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - - # Chunk 0 should appear at rank 0, then rank 1, then rank 2. - chunk0_ranks = [] - for assignment, _ in trace: - for r, task in enumerate(assignment): - if task is not None and task.chunk_idx == 0: - chunk0_ranks.append(r) - break - # First 3 entries: ranks 0, 1, 2 (warmup propagation of chunk 0's first step). - assert chunk0_ranks[:3] == [0, 1, 2] - - def test_request_completes(self): - sched, pp_size, num_steps, num_chunks = self._setup() - trace = _run_until_finished(sched, pp_size, num_steps, num_chunks) - - # The last micro-step's output should have finished=True (set by _simulate). - # But finished is set by the runner, not the scheduler. In our simulation, - # we don't set finished=True. Instead, check that the scheduler reported - # the request in finished_req_ids (the loop exited). - assert len(trace) > 0 # loop exited → request finished - - -# --------------------------------------------------------------------------- -# Chunk re-admission -# --------------------------------------------------------------------------- - - -class TestChunkReAdmission: - """Verify that a chunk returning from rank N-1 is re-admitted to rank 0.""" - - def test_re_admission_priority_over_new_chunk(self): - """With K=2, M=2, N=2: after chunk 0 exits rank 1 at ms1, - it should re-enter rank 0 at ms2 (priority over admitting chunk 1).""" - pp_size, num_chunks, num_steps = 2, 2, 2 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - # ms0: [chunk0, idle] - out0 = sched.schedule() - assert out0.per_rank_assignment[0].chunk_idx == 0 - assert out0.per_rank_assignment[1] is None - sched.update_from_output(out0, _simulate_last_rank_output(out0, pp_size, num_steps)) - - # ms1: [chunk1, chunk0] — chunk 0 reaches rank 1 and completes step 0. - out1 = sched.schedule() - assert out1.per_rank_assignment[0].chunk_idx == 1 - assert out1.per_rank_assignment[1].chunk_idx == 0 - sched.update_from_output(out1, _simulate_last_rank_output(out1, pp_size, num_steps)) - - # ms2: chunk 0 should be re-admitted (step 1), NOT chunk 1 continuing. - out2 = sched.schedule() - assert out2.per_rank_assignment[0].chunk_idx == 0 - assert out2.per_rank_assignment[0].step_index == 1 - - def test_chunk_not_readmitted_after_completion(self): - """A chunk that finished all denoising steps should NOT be re-admitted.""" - pp_size, num_chunks, num_steps = 1, 1, 2 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - # Run step 0. - out0 = sched.schedule() - runner0 = _make_runner_output(req_id="r1", step_index=1, chunk_idx=0) - sched.update_from_output(out0, runner0) - - # Run step 1 (final). - out1 = sched.schedule() - runner1 = _make_runner_output( - req_id="r1", step_index=2, chunk_idx=0, chunk_completed=True, finished=True, - ) - finished = sched.update_from_output(out1, runner1) - assert "r1" in finished - - # No more requests. - assert not sched.has_requests() - - -# --------------------------------------------------------------------------- -# Completion ordering -# --------------------------------------------------------------------------- - - -class TestCompletionOrdering: - """Verify chunks complete in admission order (FIFO).""" - - def test_chunks_complete_in_fifo_order(self): - pp_size, num_chunks, num_steps = 2, 3, 1 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - completed_chunks: list[int] = [] - for _ in range(20): - out = sched.schedule() - runner = _simulate_last_rank_output(out, pp_size, num_steps) - if runner.chunk_completed: - completed_chunks.append(runner.chunk_idx) - finished = sched.update_from_output(out, runner) - if finished: - break - - assert completed_chunks == [0, 1, 2] - - -# --------------------------------------------------------------------------- -# Request finished signal -# --------------------------------------------------------------------------- - - -class TestRequestFinished: - def test_finished_after_all_chunks(self): - pp_size, num_chunks, num_steps = 1, 3, 1 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - finished_set = set() - for _ in range(20): - out = sched.schedule() - runner = _simulate_last_rank_output(out, pp_size, num_steps) - # On the last chunk, set finished=True (simulating the runner's merge). - if runner.chunk_completed: - progress = sched._chunk_progress.get("r1") - if progress and progress.chunks_completed + 1 >= num_chunks: - runner.finished = True - finished_set = sched.update_from_output(out, runner) - if finished_set: - break - - assert "r1" in finished_set - assert not sched.has_requests() - - def test_not_finished_before_all_chunks(self): - pp_size, num_chunks, num_steps = 1, 3, 1 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=num_chunks, num_inference_steps=num_steps)) - - # Process only 2 of 3 chunks. - for i in range(2): - out = sched.schedule() - runner = _make_runner_output( - req_id="r1", step_index=1, chunk_idx=i, chunk_completed=True, - ) - finished = sched.update_from_output(out, runner) - assert not finished - - assert sched.has_requests() - - -# --------------------------------------------------------------------------- -# Sequential requests -# --------------------------------------------------------------------------- - - -class TestSequentialRequests: - """Second request is processed after the first finishes.""" - - def test_second_request_starts_after_first(self): - pp_size, num_steps = 1, 1 - sched = StreamBatchScheduler() - sched.initialize(_make_config(pp_size)) - sched.add_request(_make_request("r1", num_chunks=1, num_inference_steps=num_steps)) - sched.add_request(_make_request("r2", num_chunks=1, num_inference_steps=num_steps)) - - # Process r1. - out1 = sched.schedule() - assert out1.per_rank_assignment[0].sched_req_id == "r1" - sched.update_from_output( - out1, - _make_runner_output(req_id="r1", step_index=1, chunk_idx=0, chunk_completed=True, finished=True), - ) - - # r1 finished. Next schedule should pick r2. - out2 = sched.schedule() - assert out2.per_rank_assignment[0].sched_req_id == "r2" \ No newline at end of file From 081818e725c1eebe2b49764befeb9356951b5b1f Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 27 Apr 2026 11:51:41 +0200 Subject: [PATCH 30/53] Add unit tests for StreamBatchScheduler Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- tests/diffusion/test_diffusion_scheduler.py | 367 ++++++++++++++++++ .../diffusion/test_stream_batch_scheduler.py | 0 .../diffusion/sched/stream_batch_scheduler.py | 12 +- 3 files changed, 375 insertions(+), 4 deletions(-) delete mode 100644 tests/diffusion/test_stream_batch_scheduler.py diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index f0bc1da35ac..7ba040d1da3 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -19,6 +19,7 @@ Scheduler, SchedulerInterface, StepScheduler, + StreamBatchScheduler, ) from vllm_omni.diffusion.sched.interface import CachedRequestData, NewRequestData from vllm_omni.diffusion.worker.utils import RunnerOutput @@ -809,3 +810,369 @@ def test_rejects_invalid_initial_step_state(self, sampling_params: OmniDiffusion with pytest.raises(ValueError): self.scheduler.add_request(request) + + +def _make_stream_request( + req_id: str, + *, + num_inference_steps: int = 2, + num_chunks: int = 1, +) -> OmniDiffusionRequest: + return OmniDiffusionRequest( + prompts=[f"prompt_{req_id}"], + sampling_params=OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + num_chunks=num_chunks, + ), + request_ids=[req_id], + ) + + +def _make_stream_output( + req_id: str, + *, + chunk_idx: int = 0, + chunk_completed: bool = False, + finished: bool = False, + error: str | None = None, +): + return SimpleNamespace( + req_id=req_id, + step_index=None, + finished=finished, + chunk_idx=chunk_idx, + chunk_completed=chunk_completed, + result=DiffusionOutput(output=None, error=error) if error is not None else None, + ) + + +def _make_od_config(pp_size: int) -> SimpleNamespace: + return SimpleNamespace(parallel_config=SimpleNamespace(pipeline_parallel_size=pp_size)) + + +def _ranks(sched_output) -> list[tuple[str, int] | None]: + """Compact view of per_rank_assignment for assertions.""" + if sched_output.per_rank_assignment is None: + return [] + return [ + (t.sched_req_id, t.chunk_idx) if t is not None else None + for t in sched_output.per_rank_assignment + ] + + +class TestStreamBatchScheduler: + def _make_scheduler(self, pp_size: int = 2) -> StreamBatchScheduler: + sched = StreamBatchScheduler() + sched.initialize(_make_od_config(pp_size)) + return sched + + def test_add_request_rejects_invalid_num_chunks(self) -> None: + scheduler = self._make_scheduler() + request = _make_stream_request("bad-chunks", num_chunks=0) + with pytest.raises(ValueError): + scheduler.add_request(request) + + def test_add_request_rejects_invalid_num_inference_steps(self) -> None: + scheduler = self._make_scheduler() + request = _make_stream_request("bad-steps", num_inference_steps=0) + with pytest.raises(ValueError): + scheduler.add_request(request) + + + def test_pp1_single_chunk_single_step(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) + + out0 = scheduler.schedule() + assert _new_ids(out0) == [req_id] + assert _ranks(out0) == [(req_id, 0)] + assert scheduler.update_from_output( + out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) + ) == set() + + # nothing to admit; runner decodes chunk 0 and returns finished. + out1 = scheduler.schedule() + assert _ranks(out1) == [None] + finished = scheduler.update_from_output( + out1, _make_stream_output(req_id, finished=True) + ) + assert finished == {req_id} + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert scheduler.has_requests() is False + + def test_pp1_single_chunk_multi_step_re_admits_same_chunk(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("multi", num_inference_steps=3, num_chunks=1)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, 0)] + assert scheduler.update_from_output( + out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False) + ) == set() + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, 0)] + assert scheduler.update_from_output( + out1, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False) + ) == set() + + out2 = scheduler.schedule() + assert _ranks(out2) == [(req_id, 0)] + assert scheduler.update_from_output( + out2, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) + ) == set() + + # nothing to admit; runner decodes and returns finished. + out3 = scheduler.schedule() + assert _ranks(out3) == [None] + finished = scheduler.update_from_output( + out3, _make_stream_output(req_id, finished=True) + ) + assert finished == {req_id} + + def test_pp1_multi_chunk_admits_in_order(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("multi", num_inference_steps=1, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, 0)] + assert scheduler.update_from_output( + out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) + ) == set() + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, 1)] + assert scheduler.update_from_output( + out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=True) + ) == set() + + out2 = scheduler.schedule() + assert _ranks(out2) == [None] + finished = scheduler.update_from_output( + out2, _make_stream_output(req_id, finished=True) + ) + assert finished == {req_id} + + def test_pp2_pipelined_chunks_advance_through_ranks(self) -> None: + scheduler = self._make_scheduler(pp_size=2) + req_id = scheduler.add_request(_make_stream_request("pp2", num_inference_steps=1, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, 0), None] + assert scheduler.update_from_output( + out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) + ) == set() + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, 1), (req_id, 0)] + assert scheduler.update_from_output( + out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=True) + ) == set() + + out2 = scheduler.schedule() + assert _ranks(out2) == [None, (req_id, 1)] + assert scheduler.update_from_output( + out2, _make_stream_output(req_id, chunk_idx=None, finished=False) + ) == set() + + out3 = scheduler.schedule() + assert _ranks(out3) == [None, None] + finished = scheduler.update_from_output( + out3, _make_stream_output(req_id, finished=True) + ) + assert finished == {req_id} + + def test_pp3_three_chunks_two_steps_each(self) -> None: + scheduler = self._make_scheduler(pp_size=3) + req_id = scheduler.add_request(_make_stream_request("pp3", num_inference_steps=2, num_chunks=3)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, 0), None, None] + assert scheduler.update_from_output( + out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False) + ) == set() + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, 1), (req_id, 0), None] + assert scheduler.update_from_output( + out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=False) + ) == set() + + out2 = scheduler.schedule() + assert _ranks(out2) == [(req_id, 2), (req_id, 1), (req_id, 0)] + assert scheduler.update_from_output( + out2, _make_stream_output(req_id, chunk_idx=2, chunk_completed=False) + ) == set() + + out3 = scheduler.schedule() + assert _ranks(out3) == [(req_id, 0), (req_id, 2), (req_id, 1)] + assert scheduler.update_from_output( + out3, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) + ) == set() + + out4 = scheduler.schedule() + assert _ranks(out4) == [(req_id, 1), (req_id, 0), (req_id, 2)] + assert scheduler.update_from_output( + out4, _make_stream_output(req_id, chunk_idx=1, chunk_completed=True) + ) == set() + + out5 = scheduler.schedule() + assert _ranks(out5) == [(req_id, 2), (req_id, 1), (req_id, 0)] + assert scheduler.update_from_output( + out5, _make_stream_output(req_id, chunk_idx=2, chunk_completed=True) + ) == set() + + out6 = scheduler.schedule() + assert _ranks(out6) == [None, (req_id, 2), (req_id, 1)] + assert scheduler.update_from_output( + out6, _make_stream_output(req_id, chunk_idx=None) + ) == set() + + out7 = scheduler.schedule() + assert _ranks(out7) == [None, None, (req_id, 2)] + assert scheduler.update_from_output( + out7, _make_stream_output(req_id, chunk_idx=None) + ) == set() + + # runner decodes and reports finished. + out8 = scheduler.schedule() + assert _ranks(out8) == [None, None, None] + finished = scheduler.update_from_output( + out8, _make_stream_output(req_id, finished=True) + ) + + assert finished == {req_id} + assert scheduler.has_requests() is False + + def test_re_admission_takes_priority_over_new_chunk(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("prio", num_inference_steps=2, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, 0)] + scheduler.update_from_output(out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, 0)] + + def test_chunk_progress_cleared_after_request_finishes(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("cleanup", num_inference_steps=1, num_chunks=1)) + + out0 = scheduler.schedule() + scheduler.update_from_output( + out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) + ) + # runner decodes and reports finished. + out1 = scheduler.schedule() + scheduler.update_from_output( + out1, _make_stream_output(req_id, finished=True) + ) + + scheduler.pop_request_state(req_id) + + assert req_id not in scheduler._chunk_progress + assert scheduler.has_requests() is False + + def test_schedule_with_no_requests_emits_no_assignment(self) -> None: + scheduler = self._make_scheduler(pp_size=2) + out = scheduler.schedule() + assert out.per_rank_assignment is None + assert out.scheduled_req_ids == [] + + def test_fifo_two_requests(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_a = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) + req_b = scheduler.add_request(_make_stream_request("b", num_inference_steps=1, num_chunks=1)) + + out0 = scheduler.schedule() + assert _new_ids(out0) == [req_a] + assert _ranks(out0) == [(req_a, 0)] + scheduler.update_from_output(out0, _make_stream_output(req_a, chunk_idx=0, chunk_completed=True)) + + # B still waiting until A finishes. + out1 = scheduler.schedule() + assert _new_ids(out1) == [] + scheduler.update_from_output(out1, _make_stream_output(req_a, chunk_idx=None, finished=True)) + + out2 = scheduler.schedule() + assert _new_ids(out2) == [req_b] + assert _ranks(out2) == [(req_b, 0)] + + def test_has_requests_state_transition(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + assert scheduler.has_requests() is False + + req_id = scheduler.add_request(_make_stream_request("has", num_inference_steps=1, num_chunks=1)) + assert scheduler.has_requests() is True + + out0 = scheduler.schedule() + assert scheduler.has_requests() is True + scheduler.update_from_output(out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True)) + + out1 = scheduler.schedule() + finished = scheduler.update_from_output( + out1, _make_stream_output(req_id, chunk_idx=None, finished=True) + ) + assert finished == {req_id} + assert scheduler.has_requests() is False + + def test_abort_waiting_and_running_requests(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_a = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) + req_b = scheduler.add_request(_make_stream_request("b", num_inference_steps=1, num_chunks=1)) + + scheduler.finish_requests(req_b, DiffusionRequestStatus.FINISHED_ABORTED) + assert scheduler.get_request_state(req_b).status == DiffusionRequestStatus.FINISHED_ABORTED + + out = scheduler.schedule() + assert _new_ids(out) == [req_a] + + scheduler.finish_requests(req_a, DiffusionRequestStatus.FINISHED_ABORTED) + assert scheduler.get_request_state(req_a).status == DiffusionRequestStatus.FINISHED_ABORTED + assert scheduler.has_requests() is False + + def test_error_output_marks_finished_error(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("err", num_inference_steps=2, num_chunks=1)) + + out = scheduler.schedule() + finished = scheduler.update_from_output( + out, _make_stream_output(req_id, chunk_idx=0, error="worker failed") + ) + + assert finished == {req_id} + state = scheduler.get_request_state(req_id) + assert state.status == DiffusionRequestStatus.FINISHED_ERROR + assert state.error == "worker failed" + assert scheduler.has_requests() is False + + def test_preempt_request_preserves_chunk_progress(self) -> None: + scheduler = self._make_scheduler(pp_size=2) + req_id = scheduler.add_request(_make_stream_request("preempt", num_inference_steps=2, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, 0), None] + scheduler.update_from_output(out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, 1), (req_id, 0)] + scheduler.update_from_output(out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=False)) + + before = scheduler._chunk_progress[req_id] + assert before.chunks_admitted == 2 + in_flight_before = {c.chunk_idx: (c.is_active, c.is_completed) for c in before.in_flight} + assert in_flight_before == {0: (False, False), 1: (True, False)} + + assert scheduler.preempt_request(req_id) is True + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.PREEMPTED + + after = scheduler._chunk_progress[req_id] + assert after.chunks_admitted == 2 + in_flight_after = {c.chunk_idx: (c.is_active, c.is_completed) for c in after.in_flight} + assert in_flight_after == in_flight_before + + out2 = scheduler.schedule() + assert _new_ids(out2) == [] # not a fresh promotion + assert _ranks(out2) == [(req_id, 0), (req_id, 1)] + assert scheduler._chunk_progress[req_id].chunks_admitted == 2 diff --git a/tests/diffusion/test_stream_batch_scheduler.py b/tests/diffusion/test_stream_batch_scheduler.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index 6cfd41b0766..e9fbc11ac22 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -193,11 +193,17 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run return set() terminal: dict[str, DiffusionRequestStatus] = {} + terminal_errors: dict[str, str | None] = {} progress = self._chunk_progress.get(output.req_id) if progress is None: return set() + output_error = output.result.error if output.result is not None else None + if output_error is not None: + terminal[output.req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[output.req_id] = output_error + return self._finalize_update_from_output(sched_output, terminal, terminal_errors) chunk = self._find_chunk(progress, output.chunk_idx) if output.chunk_idx is not None else None if chunk is not None: @@ -216,14 +222,12 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run c for c in progress.in_flight if c.chunk_idx != last_chunk.chunk_idx ] else: - last_chunk.is_active = False - - + last_chunk.is_active = False if output.finished: terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED - return self._finalize_update_from_output(sched_output, terminal) + return self._finalize_update_from_output(sched_output, terminal, terminal_errors) @staticmethod def _find_chunk(progress: _ChunkProgress, chunk_idx: int) -> _InFlightChunk | None: From a87a9ae7f49a8e1ec4b04dd270b5b486ce1d634d Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 27 Apr 2026 12:45:41 +0200 Subject: [PATCH 31/53] add SupportsMicroStepExecution Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/models/interface.py | 28 +++++++++++++++++++ .../models/wan2_2/pipeline_wan2_2.py | 3 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index f0ded9cdb0a..5e6aa56e095 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -87,3 +87,31 @@ def supports_step_execution(pipeline: object) -> bool: """Return whether `pipeline` implements :class:`SupportsStepExecution`.""" return isinstance(pipeline, SupportsStepExecution) + + +@runtime_checkable +class SupportsMicroStepExecution(SupportsStepExecution, Protocol): + """Temporal-PP micro-step execution protocol. + + Extends :class:`SupportsStepExecution` with the per-micro-step hooks + used by ``DiffusionModelRunner.execute_micro_step``: + + - ``set_pp_recv_dict_buffers`` pre-registers PPGC dict channels for + this request to skip the blocking first-call schema exchange. + - ``prefetch_its`` pre-posts the next-step IT recv on the comms stream + so it overlaps with the current micro-step's compute. + """ + + supports_micro_step_execution: ClassVar[bool] = True + + def set_pp_recv_dict_buffers(self, state: DiffusionRequestState, **kwargs: Any) -> None: + """Pre-register PP dict recv buffers and schema cache for this request.""" + + def prefetch_its(self, state: DiffusionRequestState, **kwargs: Any) -> None: + """Pre-post the next-step IT recv (no-op if not in temporal PP).""" + + +def supports_micro_step_execution(pipeline: object) -> bool: + """Return whether `pipeline` implements :class:`SupportsMicroStepExecution`.""" + + return isinstance(pipeline, SupportsMicroStepExecution) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index a1e9800f337..ee555c24cca 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -307,6 +307,7 @@ class Wan22Pipeline( nn.Module, PipelineParallelMixin, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin ): supports_step_execution: ClassVar[bool] = True + supports_micro_step_execution: ClassVar[bool] = True def __init__( self, @@ -1036,7 +1037,7 @@ def check_inputs( if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") - # ── Step-execution protocol (SupportsStepExecution) ── + # ── Step-execution protocol (SupportsStepExecution) + Micro-step execution (SupportsMicroStepExecution) ── def _extract_prompts( self, From 814a1a55333163b576301324b4b016fccd623600 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 27 Apr 2026 13:40:23 +0200 Subject: [PATCH 32/53] add unit tests for micro-step execution pipeline Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../test_diffusion_micro_step_pipeline.py | 444 ++++++++++++++++++ 1 file changed, 444 insertions(+) diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index e69de29bb2d..bc9521fd4e6 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -0,0 +1,444 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for micro-step level diffusion execution across runner / worker / executor / engine.""" + +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest +import torch +from pytest_mock import MockerFixture + +import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor +from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, + DiffusionSchedulerOutput, + NewRequestData, + RankTask, +) +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker +from vllm_omni.diffusion.worker.utils import RunnerOutput + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + del args, kwargs + yield + + +class _FakePPGroup: + def __init__(self, rank_in_group: int = 0, world_size: int = 1): + self.rank_in_group = rank_in_group + self.world_size = world_size + self.is_first_rank = rank_in_group == 0 + self.is_last_rank = rank_in_group == world_size - 1 + self.prev_rank = (rank_in_group - 1) % world_size + self.next_rank = (rank_in_group + 1) % world_size + self.reset_calls = 0 + + def reset_buffer(self) -> None: + self.reset_calls += 1 + + +class _MicroStepPipeline: + supports_step_execution = True + supports_micro_step_execution = True + + def __init__(self, num_steps: int = 1): + self.num_steps = num_steps + self.prepare_calls = 0 + self.set_buffer_calls = 0 + self.denoise_calls = 0 + self.scheduler_calls = 0 + self.decode_calls = 0 + self.prefetch_calls = 0 + self.is_buffer_setup = False + + def prepare_encode(self, state, **kwargs): + del kwargs + self.prepare_calls += 1 + state.timesteps = [torch.tensor(float(i)) for i in range(self.num_steps)] + state.latents = torch.zeros((1,)) + state.scheduler = None + return state + + def set_pp_recv_dict_buffers(self, state, **kwargs): + del state, kwargs + self.set_buffer_calls += 1 + self.is_buffer_setup = True + + def denoise_step(self, state, **kwargs): + del state, kwargs + self.denoise_calls += 1 + return torch.tensor([1.0]) + + def step_scheduler(self, state, noise_pred, **kwargs): + del noise_pred, kwargs + self.scheduler_calls += 1 + state.step_index += 1 + + def post_decode(self, state, **kwargs): + del kwargs + self.decode_calls += 1 + return DiffusionOutput(output=torch.tensor([state.step_index], dtype=torch.float32)) + + def prefetch_its(self, state, **kwargs): + del state, kwargs + self.prefetch_calls += 1 + + +class _InterruptingMicroStepPipeline(_MicroStepPipeline): + interrupt = True + + def denoise_step(self, state, **kwargs): + del state, kwargs + self.denoise_calls += 1 + return None + + def step_scheduler(self, state, noise_pred, **kwargs): + del state, noise_pred, kwargs + raise AssertionError("step_scheduler should not run after interrupt") + + def post_decode(self, state, **kwargs): + del state, kwargs + raise AssertionError("post_decode should not run after interrupt") + + +def _make_micro_request( + req_id: str = "req-1", + *, + num_inference_steps: int = 1, + num_chunks: int = 1, +): + return SimpleNamespace( + prompts=["a prompt"], + request_ids=[req_id], + sampling_params=SimpleNamespace( + generator=None, + seed=None, + generator_device=None, + num_inference_steps=num_inference_steps, + num_chunks=num_chunks, + lora_request=None, + ), + ) + + +def _make_runner(pp_size: int = 1, num_steps: int = 1): + runner = object.__new__(DiffusionModelRunner) + runner.vllm_config = object() + runner.od_config = SimpleNamespace( + cache_backend=None, + parallel_config=SimpleNamespace(use_hsdp=False), + ) + runner.device = torch.device("cpu") + runner.pipeline = _MicroStepPipeline(num_steps=num_steps) + runner.cache_backend = None + runner.offload_backend = None + runner.state_cache = {} + runner.kv_transfer_manager = SimpleNamespace() + runner._fake_pp_group = _FakePPGroup(world_size=pp_size) + return runner + + +def _make_micro_scheduler_output( + *, + req=None, + sched_req_id: str = "req-1", + step_id: int = 0, + assignment=None, + is_new: bool = True, + finished_req_ids=None, +): + if assignment is None: + assignment = [RankTask(sched_req_id=sched_req_id, chunk_idx=0)] + if is_new and req is not None: + new_reqs = [NewRequestData(sched_req_id=sched_req_id, req=req)] + cached_reqs = CachedRequestData.make_empty() + else: + new_reqs = [] + cached_reqs = CachedRequestData(sched_req_ids=[sched_req_id]) + return DiffusionSchedulerOutput( + step_id=step_id, + scheduled_new_reqs=new_reqs, + scheduled_cached_reqs=cached_reqs, + finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), + num_running_reqs=1, + num_waiting_reqs=0, + per_rank_assignment=assignment, + ) + + +def _patch_runtime(monkeypatch, runner) -> None: + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + monkeypatch.setattr(model_runner_module, "get_pp_group", lambda: runner._fake_pp_group) + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + + +class TestRunner: + """DiffusionModelRunner.execute_micro_step (PP=1).""" + + def test_completes_single_chunk_request(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + # μ-step 0: admit chunk 0, denoise it, mark completed. + out0 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(req=req, step_id=0), + ) + assert out0.req_id == "req-1" + assert out0.chunk_idx == 0 + assert out0.chunk_completed is True + assert out0.finished is False + assert "req-1" in runner.state_cache + + # μ-step 1: runner decodes chunk 0 and returns finished. + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), + ) + assert out1.finished is True + assert out1.result is not None + assert torch.equal(out1.result.output, torch.tensor([1.0])) + assert "req-1" not in runner.state_cache + + assert runner.pipeline.prepare_calls == 1 + assert runner.pipeline.set_buffer_calls == 1 + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 1 + assert runner.pipeline.decode_calls == 1 + + def test_completes_multi_chunk_request(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) + + # μ-step 0: process chunk 0. + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(req=req, step_id=0), + ) + # μ-step 1: decode chunk 0, process chunk 1. + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", + step_id=1, + assignment=[RankTask(sched_req_id="req-1", chunk_idx=1)], + is_new=False, + ), + ) + assert out1.chunk_idx == 1 + assert out1.chunk_completed is True + assert out1.finished is False + + # μ-step 2: decode chunk 1; finished after both chunks decoded. + out2 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(sched_req_id="req-1", step_id=2, assignment=[None], is_new=False), + ) + assert out2.finished is True + assert out2.result is not None + assert "req-1" not in runner.state_cache + + assert runner.pipeline.prepare_calls == 1 + assert runner.pipeline.denoise_calls == 2 + assert runner.pipeline.decode_calls == 2 + + def test_idle_rank_returns_no_op(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=2) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=2, num_chunks=1) + + # μ-step 0: process chunk 0 step 1/2 (not yet completed). + out0 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(req=req, step_id=0), + ) + assert out0.chunk_completed is False + + # μ-step 1 with assignment=[None] → no chunk processed, no decode. + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), + ) + assert out1.req_id == "req-1" + assert out1.chunk_idx is None + assert out1.chunk_completed is False + assert out1.finished is False + assert runner.pipeline.denoise_calls == 1 + + def test_interrupt_marks_chunk_as_aborted(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + runner.pipeline = _InterruptingMicroStepPipeline(num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(req=req, step_id=0), + ) + assert out.req_id == "req-1" + assert out.result is not None + assert out.result.error == "micro-step denoise interrupted" + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 0 + assert runner.pipeline.decode_calls == 0 + + def test_rejects_missing_per_rank_assignment(self): + runner = _make_runner(pp_size=1) + req = _make_micro_request() + sched_output = _make_micro_scheduler_output(req=req) + sched_output.per_rank_assignment = None + + with pytest.raises(ValueError, match="per_rank_assignment"): + DiffusionModelRunner.execute_micro_step(runner, sched_output) + + def test_rejects_cache_backend(self): + runner = _make_runner(pp_size=1) + runner.od_config = SimpleNamespace( + cache_backend="teacache", + parallel_config=SimpleNamespace(use_hsdp=False), + ) + req = _make_micro_request() + + with pytest.raises(ValueError, match="cache_backend"): + DiffusionModelRunner.execute_micro_step(runner, _make_micro_scheduler_output(req=req)) + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + + +class TestWorker: + """DiffusionWorker.execute_micro_step""" + + def test_delegates_to_model_runner(self): + worker = object.__new__(DiffusionWorker) + expected = RunnerOutput(req_id="req-1", chunk_idx=0, chunk_completed=False) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None))) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace( + execute_micro_step=lambda arg: expected if arg is scheduler_output else None + ) + worker._get_profiler = lambda: None + + output = DiffusionWorker.execute_micro_step(worker, scheduler_output) + assert output is expected + + def test_clears_active_lora(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None))) + ] + ) + calls: list = [] + + class _FakeLoRAManager: + def set_active_adapter(self, adapter): + calls.append(adapter) + + worker.lora_manager = _FakeLoRAManager() + worker.model_runner = SimpleNamespace(execute_micro_step=lambda _: RunnerOutput(req_id="req-1")) + worker._get_profiler = lambda: None + + DiffusionWorker.execute_micro_step(worker, scheduler_output) + assert calls == [None] + + def test_rejects_lora_requests(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=object()))) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace(execute_micro_step=lambda _: RunnerOutput(req_id="req-1")) + worker._get_profiler = lambda: None + + with pytest.raises(ValueError, match="does not support LoRA"): + DiffusionWorker.execute_micro_step(worker, scheduler_output) + + +# --------------------------------------------------------------------------- +# Executor +# --------------------------------------------------------------------------- + + +class TestSupportedPipelines: + """Micro-step protocol membership checks.""" + + def test_stub_pipeline_satisfies_protocol(self): + from vllm_omni.diffusion.models.interface import ( + SupportsMicroStepExecution, + SupportsStepExecution, + supports_micro_step_execution, + supports_step_execution, + ) + + pipeline = _MicroStepPipeline() + assert isinstance(pipeline, SupportsMicroStepExecution) is True + assert supports_micro_step_execution(pipeline) is True + # Micro-step protocol extends step protocol. + assert isinstance(pipeline, SupportsStepExecution) is True + assert supports_step_execution(pipeline) is True + + def test_wan22_supports_micro_step_execution(self): + from vllm_omni.diffusion.models.interface import ( + SupportsMicroStepExecution, + supports_micro_step_execution, + ) + from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline + + # Avoid loading weights; protocol membership is a class-contract check. + pipeline = object.__new__(Wan22Pipeline) + + assert pipeline.supports_step_execution is True + assert pipeline.supports_micro_step_execution is True + assert supports_micro_step_execution(pipeline) is True + assert isinstance(pipeline, SupportsMicroStepExecution) is True + + +class TestExecutor: + """MultiprocDiffusionExecutor.execute_micro_step collects rank-0's reply.""" + + def test_passes_through_runner_output(self, mocker: MockerFixture): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._ensure_open = lambda: None + expected = RunnerOutput(req_id="req-1", chunk_idx=0, chunk_completed=True) + rpc = mocker.Mock(return_value=expected) + executor.collective_rpc = rpc + + sched_output = _make_micro_scheduler_output(req=_make_micro_request()) + output = MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) + + assert output is expected + _, kwargs = rpc.call_args + assert kwargs.get("unique_reply_rank") == 0 + assert kwargs.get("exec_all_ranks") is True + + def test_rejects_unexpected_reply_type(self, mocker: MockerFixture): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._ensure_open = lambda: None + executor.collective_rpc = mocker.Mock(return_value="not a runner output") + + sched_output = _make_micro_scheduler_output(req=_make_micro_request()) + with pytest.raises(RuntimeError, match="Unexpected response type"): + MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) \ No newline at end of file From 8bc0ed593aa45532ee2f9a1a58acbe5321cdbf7a Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 27 Apr 2026 13:40:34 +0200 Subject: [PATCH 33/53] edit Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/models/wan2_2/__init__.py | 1 - .../worker/diffusion_model_runner.py | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py index 4059c4fb568..bb2ec9b9d68 100644 --- a/vllm_omni/diffusion/models/wan2_2/__init__.py +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -4,7 +4,6 @@ from .patch_diffusers import patch_wan_rms_norm from .pipeline_wan2_2 import ( Wan22Pipeline, - WanT2VDMD2Pipeline, create_transformer_from_config, get_wan22_post_process_func, get_wan22_pre_process_func, diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 78a3ebde5bc..14db0f40f11 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -555,7 +555,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn self.pipeline.is_buffer_setup = False self.pipeline.prepare_encode(state) - if pp_group.is_first_rank: # TODO: race condition + if pp_group.is_first_rank: denoised_chunks = state.extra.get("denoised_chunks", []) decoded_chunks = state.extra.setdefault("decoded_chunks", []) new_denoised_chunks = [] @@ -576,7 +576,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn finished=True, result=self._merge_chunk_outputs(decoded_chunks), ) - self._update_states_after(state, finished=True) # TODO: call properly on all ranks + self._update_states_after(state, finished=True) return output @@ -600,14 +600,17 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn with state.use_chunk(chunk): if not self.pipeline.is_buffer_setup: self.pipeline.set_pp_recv_dict_buffers(state) + noise_pred = self.pipeline.denoise_step(state) if noise_pred is None and getattr(self.pipeline, "interrupt", False): + self._update_states_after(state, finished=True) return RunnerOutput( req_id=task.sched_req_id, result=DiffusionOutput(error="micro-step denoise interrupted"), ) + self.pipeline.step_scheduler(state, noise_pred) - chunk_done = state.denoise_completed + chunk_completed = state.denoise_completed # prefetch the chunk of the next micro-step @@ -617,17 +620,17 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn self.pipeline.prefetch_its(state) - output = RunnerOutput( + if chunk_completed: + state.extra["chunks"].pop(task.chunk_idx, None) + + if pp_group.is_first_rank: + steps_left = pp_group.world_size + state.extra.setdefault("denoised_chunks", []).append((chunk, steps_left)) + + return RunnerOutput( req_id=task.sched_req_id, step_index=chunk.step_index, chunk_idx=task.chunk_idx, + chunk_completed=chunk_completed, ) - if chunk_done: - output.chunk_completed = True - state.extra["chunks"].pop(task.chunk_idx, None) - if pp_group.is_first_rank: - steps_left = pp_group.world_size - state.extra.setdefault("denoised_chunks", []).append((chunk, steps_left)) - - return output From 5aee2a2bd6f6319b82fe41eea8baead4f3565911 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 6 May 2026 10:45:54 +0200 Subject: [PATCH 34/53] add slo-adaptive scheduling Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/sched/interface.py | 7 +- .../diffusion/sched/stream_batch_scheduler.py | 274 ++++++++++++------ .../worker/diffusion_model_runner.py | 152 +++++----- vllm_omni/diffusion/worker/utils.py | 2 + vllm_omni/inputs/data.py | 5 + 5 files changed, 271 insertions(+), 169 deletions(-) diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index f8c81f77640..9efa419c40b 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -108,7 +108,7 @@ def make_empty(cls) -> CachedRequestData: @dataclass class RankTask: - """Used by ``StreamBatchScheduler`` to tell each rank which work to perform in the current micro-step.""" + """One unit of work for a rank in a stream-batch micro-step.""" sched_req_id: str chunk_idx: int @@ -124,9 +124,8 @@ class DiffusionSchedulerOutput: finished_req_ids: set[str] num_running_reqs: int num_waiting_reqs: int - # Temporal-PP per-rank assignment table. Index = PP rank id. ``None`` entries - # mark idle ranks (warmup / cooldown). - per_rank_assignment: list[RankTask | None] | None = None + # Per-rank task list. Index = PP rank id. + per_rank_assignment: list[list[RankTask]] | None = None @cached_property def scheduled_req_ids(self) -> list[str]: diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index e9fbc11ac22..cd2e5778dc4 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -1,10 +1,10 @@ """Temporal-pipeline-parallel scheduler for streaming chunked diffusion. -Each ``schedule()`` call corresponds to -one micro-step. At any micro-step, each PP rank processes a different -``(chunk, step_index)`` pair drawn from the active requests' in-flight -chunks. Chunks are admitted to rank 0 in order, propagate through ranks under -NCCL FIFO ordering, and exit at rank N-1 in the same order. +Each ``schedule()`` call corresponds to one micro-step. At any micro-step, +each PP rank processes the chunks at the denoising step ``r = current - +entered_rank0_at`` from the active requests' in-flight chunks. Chunks are +admitted to rank 0 in order, propagate through ranks under NCCL FIFO +ordering, and exit at rank N-1 in the same order. """ from __future__ import annotations @@ -33,45 +33,131 @@ class _InFlightChunk: """One chunk of an active request, tracked through the temporal pipeline.""" - chunk_idx: int - is_active: bool = True - is_completed: bool = False - entered_rank0_at: int = -1 + chunk_idx: int + is_active: bool = True + is_completed: bool = False + entered_rank0_at: int = -1 @dataclass class _ChunkProgress: """Per-request chunk-level scheduling state.""" sched_req_id: str - num_chunks: int # total chunks to produce for this request - num_steps: int # denoising steps per chunk + num_chunks: int + num_steps: int chunks_admitted: int = 0 in_flight: list[_InFlightChunk] = field(default_factory=list) +@dataclass +class _SLOReqState: + """Per-request SLO state, owned by ``_SLOController``.""" + + slo_fps: float + max_batch: int + ema_alpha: float + chunk_frames: int + batch_size: int = 1 + latency_ema_ns: float | None = None + slack_streak: int = 0 + violation_streak: int = 0 + + +class _SLOController: + """AIMD controller for the stream batch size, tracked per request. + + Driven by per-micro-step wall-clock latency observations on rank 0. + + Maintains an EMA of micro-step latency; halves B on sustained budget violations and + increments B by 1 on sustained slack. + """ + + SLACK_THRESHOLD_RATIO = 0.25 + SLACK_STREAK_TARGET = 4 + VIOLATION_STREAK_TARGET = 2 + + def __init__(self) -> None: + self._reqs: dict[str, _SLOReqState] = {} + + def register( + self, + sched_req_id: str, + slo_fps: float | None, + max_batch: int, + ema_alpha: float, + chunk_frames: int, + ) -> None: + if slo_fps is None or slo_fps <= 0: + return + self._reqs[sched_req_id] = _SLOReqState( + slo_fps=float(slo_fps), + max_batch=max(1, max_batch), + ema_alpha=ema_alpha, + chunk_frames=max(1, chunk_frames), + ) + + def unregister(self, sched_req_id: str) -> None: + self._reqs.pop(sched_req_id, None) + + def batch_size(self, sched_req_id: str) -> int: + st = self._reqs.get(sched_req_id) + return st.batch_size if st is not None else 1 + + def observe(self, sched_req_id: str, latency_ns: int | None) -> None: + st = self._reqs.get(sched_req_id) + if st is None or latency_ns is None or latency_ns <= 0: + return + + if st.latency_ema_ns is None: + st.latency_ema_ns = float(latency_ns) + else: + a = st.ema_alpha + st.latency_ema_ns = a * float(latency_ns) + (1.0 - a) * st.latency_ema_ns + + budget = (st.batch_size * st.chunk_frames / st.slo_fps) * 1e9 + ema = st.latency_ema_ns + + if ema > budget: + st.violation_streak += 1 + st.slack_streak = 0 + if st.violation_streak >= self.VIOLATION_STREAK_TARGET: + new_b = max(1, st.batch_size // 2) + if new_b != st.batch_size: + logger.info(f"SLO[{sched_req_id}]: halving batch_size {st.batch_size} -> {new_b} (ema={ema/1e6:.2}ms budget={budget/1e6:.2}ms)") + st.batch_size = new_b + st.violation_streak = 0 + return + + st.violation_streak = 0 + headroom_ratio = (budget - ema) / budget + if headroom_ratio >= self.SLACK_THRESHOLD_RATIO: + st.slack_streak += 1 + if st.slack_streak >= self.SLACK_STREAK_TARGET and st.batch_size < st.max_batch: + st.batch_size += 1 + logger.info(f"SLO[{sched_req_id}]: increasing batch_size -> {st.batch_size} (ema={ema/1e6:.2}ms budget={budget/1e6:.2}ms)") + st.slack_streak = 0 + else: + st.slack_streak = 0 + + class StreamBatchScheduler(_BaseScheduler): """Temporal-PP scheduler driving chunked-streaming diffusion requests. Per micro-step: - 1. Promote waiting requests up to ``max_num_running_reqs`` (handled by the base class). - 2. Re-admit at most one returning chunk to rank 0 (FIFO across all active requests). - 3. If rank 0 is still free and admission budget remains, admit a new chunk. - 4. Build the per-rank assignment table from in-pipeline chunks' positions. - - A chunk that entered rank 0 at micro-step ``m₀`` is at rank - ``r = current_micro_step - m₀`` while ``0 ≤ r < pp_size``. After ``r == - pp_size - 1``, rank N-1's ``step_scheduler`` runs the ODE; the chunk's - latents are sent back to rank 0; the chunk leaves the pipeline and may be - re-admitted on the next micro-step (until it has run all - ``num_steps`` denoising steps). + 1. Promote waiting requests (handled by the base class). + 2. Admit returning chunks to rank 0 first (FIFO across active requests), + then admit fresh chunks, until ``batch_size`` chunks have entered + rank 0 this micro-step or no admittable chunks remain. + 3. Build the per-rank assignment table from in-pipeline chunks' + positions ``r = current_micro_step - entered_rank0_at``. """ def __init__(self) -> None: super().__init__() - self.pp_size: int = 1 # set in initialize() - self.B: int = 1 # intra-rank batch + self.pp_size: int = 1 self._global_micro_step: int = 0 self._chunk_progress: dict[str, _ChunkProgress] = {} + self._slo: _SLOController = _SLOController() # ── Lifecycle ────────────────────────────────────────────────────────── @@ -82,9 +168,11 @@ def initialize(self, od_config: OmniDiffusionConfig) -> None: def _reset_scheduler_state(self) -> None: self._global_micro_step = 0 self._chunk_progress.clear() + self._slo = _SLOController() def _pop_extra_request_state(self, sched_req_id: str) -> None: self._chunk_progress.pop(sched_req_id, None) + self._slo.unregister(sched_req_id) # ── Request admission ────────────────────────────────────────────────── @@ -102,88 +190,83 @@ def add_request(self, request: OmniDiffusionRequest) -> str: # ── Scheduling ───────────────────────────────────────────────────────── def schedule(self) -> DiffusionSchedulerOutput: - # Base class promotes waiting → running and fills scheduled_new_reqs / step_id. base_output = super().schedule() - # Initialize chunk-progress state for any newly promoted requests. for new_req in base_output.scheduled_new_reqs: self._init_chunk_progress(new_req.sched_req_id, new_req.req) - # Re-admit a returning chunk; otherwise admit a new chunk if rank 0 is free. self._advance_chunk_pipeline() - # Build the per-rank assignment from current in-pipeline chunks. if self._chunk_progress: base_output.per_rank_assignment = self._build_assignment() - # else: no active request → executor sees per_rank_assignment=None and idles. self._global_micro_step += 1 return base_output def _init_chunk_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: - num_chunks = req.sampling_params.num_chunks - num_steps = req.sampling_params.num_inference_steps - assert num_chunks is not None and num_steps is not None # validated in add_request() + sampling = req.sampling_params + num_chunks = sampling.num_chunks + num_steps = sampling.num_inference_steps + assert num_chunks is not None and num_steps is not None self._chunk_progress[sched_req_id] = _ChunkProgress( sched_req_id=sched_req_id, num_chunks=num_chunks, num_steps=num_steps, ) - logger.debug( - "StreamBatchScheduler initialized chunk progress for %s " - "(num_chunks=%d, num_steps=%d, pp_size=%d)", - sched_req_id, num_chunks, num_steps, self.pp_size, + + chunk_frames = max(1, sampling.num_frames) + self._slo.register( + sched_req_id=sched_req_id, + slo_fps=sampling.slo_fps, + max_batch=sampling.slo_max_batch, + ema_alpha=sampling.slo_ema_alpha, + chunk_frames=chunk_frames, ) + + logger.debug(f"""StreamBatchScheduler initialized chunk progress for {sched_req_id} + (num_chunks={num_chunks}, num_steps={num_steps}, chunk_frames={chunk_frames}, slo_fps={sampling.slo_fps}, pp_size={self.pp_size})""") def _advance_chunk_pipeline(self) -> None: - """Admit at most one chunk to rank 0 this micro-step. - - Re-admission of a returning chunk takes priority over admitting a new - chunk so that FIFO order is preserved (an admitted chunk's latents - always re-enter rank 0 before any later-admitted chunk's first entry). - Admission order across requests follows ``_chunk_progress`` insertion - order, which matches the order the base scheduler promoted them. - """ + """Admit returning + new chunks to rank 0 this micro-step.""" if not self._chunk_progress: return - # 1. Try to re-admit a returning chunk (FIFO oldest-first across requests). + m = self._global_micro_step + + # 1. Re-admit every returning chunk. for progress in self._chunk_progress.values(): for chunk in progress.in_flight: if not chunk.is_active: chunk.is_active = True - chunk.entered_rank0_at = self._global_micro_step - return # rank 0 is now taken + chunk.entered_rank0_at = m - # 2. Otherwise admit a new chunk from the first request with budget. + # 2. Admit up to ``B_req`` fresh chunks per request. for progress in self._chunk_progress.values(): - if progress.chunks_admitted < progress.num_chunks: - new_chunk = _InFlightChunk( + budget = self._slo.batch_size(progress.sched_req_id) + admitted = 0 + while ( + progress.chunks_admitted < progress.num_chunks + and admitted < budget + ): + progress.in_flight.append(_InFlightChunk( chunk_idx=progress.chunks_admitted, - is_active=True, - entered_rank0_at=self._global_micro_step, - ) - progress.in_flight.append(new_chunk) + entered_rank0_at=m, + )) progress.chunks_admitted += 1 - return + admitted += 1 - def _build_assignment(self) -> list[RankTask | None]: - assignment: list[RankTask | None] = [None] * self.pp_size + def _build_assignment(self) -> list[list[RankTask]]: + assignment: list[list[RankTask]] = [[] for _ in range(self.pp_size)] for progress in self._chunk_progress.values(): for chunk in progress.in_flight: if not chunk.is_active: continue r = self._global_micro_step - chunk.entered_rank0_at if 0 <= r < self.pp_size: - assert assignment[r] is None, ( - f"two chunks would be assigned to rank {r} at micro-step " - f"{self._global_micro_step}: existing={assignment[r]}, " - f"new req={progress.sched_req_id} chunk_idx={chunk.chunk_idx}" - ) - assignment[r] = RankTask( + assignment[r].append(RankTask( sched_req_id=progress.sched_req_id, chunk_idx=chunk.chunk_idx, - ) + )) return assignment # ── Output processing ────────────────────────────────────────────────── @@ -192,40 +275,43 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Run if not self._chunk_progress or sched_output.per_rank_assignment is None: return set() - terminal: dict[str, DiffusionRequestStatus] = {} - terminal_errors: dict[str, str | None] = {} - - progress = self._chunk_progress.get(output.req_id) - if progress is None: - return set() + per_task = [output] + list(output.extra_task_outputs or []) - output_error = output.result.error if output.result is not None else None - if output_error is not None: - terminal[output.req_id] = DiffusionRequestStatus.FINISHED_ERROR - terminal_errors[output.req_id] = output_error - return self._finalize_update_from_output(sched_output, terminal, terminal_errors) + if output.micro_step_wall_ns is not None: + self._slo.observe(output.req_id, output.micro_step_wall_ns) - chunk = self._find_chunk(progress, output.chunk_idx) if output.chunk_idx is not None else None - if chunk is not None: - chunk.is_completed = output.chunk_completed + terminal: dict[str, DiffusionRequestStatus] = {} + terminal_errors: dict[str, str | None] = {} - last_task = sched_output.per_rank_assignment[-1] - logger.debug( - "update_from_output: Processing output for micro-step %d: chunk=%s, last_chunk=%s, finished=%s", - self._global_micro_step, chunk, last_task, output.finished, - ) - if last_task is not None and last_task.chunk_idx is not None: + for task_out in per_task: + progress = self._chunk_progress.get(task_out.req_id) + if progress is None: + continue + err = task_out.result.error if task_out.result is not None else None + if err is not None: + terminal[task_out.req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[task_out.req_id] = err + continue + chunk = self._find_chunk(progress, task_out.chunk_idx) if task_out.chunk_idx is not None else None + if chunk is not None: + chunk.is_completed = task_out.chunk_completed + if task_out.finished: + terminal[task_out.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + + # Roll last-rank chunks off the pipeline / mark them inactive. + for last_task in (sched_output.per_rank_assignment[-1] if sched_output.per_rank_assignment else []): + progress = self._chunk_progress.get(last_task.sched_req_id) + if progress is None: + continue last_chunk = self._find_chunk(progress, last_task.chunk_idx) - if last_chunk is not None: - if last_chunk.is_completed: - progress.in_flight = [ - c for c in progress.in_flight if c.chunk_idx != last_chunk.chunk_idx - ] - else: - last_chunk.is_active = False - - if output.finished: - terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + if last_chunk is None: + continue + if last_chunk.is_completed: + progress.in_flight = [ + c for c in progress.in_flight if c.chunk_idx != last_chunk.chunk_idx + ] + else: + last_chunk.is_active = False return self._finalize_update_from_output(sched_output, terminal, terminal_errors) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 14db0f40f11..6a3386e9689 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -543,94 +543,104 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn ) state.sampling.generator = torch.Generator(device=gen_device).manual_seed(state.sampling.seed) - with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): pp_group = get_pp_group() pp_rank = pp_group.rank_in_group - task = assignment[pp_rank] - prev_task = assignment[pp_group.prev_rank] + tasks = assignment[pp_rank] + prev_tasks = assignment[pp_group.prev_rank] if is_new_request: pp_group.reset_buffer() - self.pipeline.is_buffer_setup = False + self.pipeline.is_buffer_setup = False self.pipeline.prepare_encode(state) - if pp_group.is_first_rank: - denoised_chunks = state.extra.get("denoised_chunks", []) - decoded_chunks = state.extra.setdefault("decoded_chunks", []) - new_denoised_chunks = [] - - for chunk, steps_left in denoised_chunks: - steps_left -= 1 - if steps_left == 0: - with state.use_chunk(chunk): - decoded_chunks.append(self.pipeline.post_decode(state)) - else: - new_denoised_chunks.append((chunk, steps_left)) - - state.extra["denoised_chunks"] = new_denoised_chunks - if len(decoded_chunks) == state.sampling.num_chunks: - output = RunnerOutput( - req_id=state.req_id, - step_index=state.step_index, - finished=True, - result=self._merge_chunk_outputs(decoded_chunks), - ) - self._update_states_after(state, finished=True) - return output + t_start_ns = time.perf_counter_ns() if pp_group.is_first_rank else None + + if pp_group.is_first_rank: + finished_output = self._rank0_decode_due_chunks(state) + if finished_output is not None: + return finished_output - - if task is None: + if not tasks: return RunnerOutput(req_id=state.req_id) - - - chunk, is_new_chunk = self._get_or_create_chunk(state, task.chunk_idx) - if is_new_chunk: - # First chunk reuses the noise sampled by prepare_encode; - # subsequent chunks draw fresh noise. - # Each chunk gets its own scheduler deepcopy so multi-step - # ODE solver state doesn't leak between chunks. - chunk.latents = ( - state.latents - if task.chunk_idx == 0 - else torch.randn_like(state.latents, generator=state.sampling.generator) - ) - chunk.scheduler = copy.deepcopy(state.scheduler) - with state.use_chunk(chunk): - if not self.pipeline.is_buffer_setup: - self.pipeline.set_pp_recv_dict_buffers(state) - - noise_pred = self.pipeline.denoise_step(state) - if noise_pred is None and getattr(self.pipeline, "interrupt", False): - self._update_states_after(state, finished=True) - return RunnerOutput( - req_id=task.sched_req_id, - result=DiffusionOutput(error="micro-step denoise interrupted"), + task_outputs: list[RunnerOutput] = [] + for task in tasks: + chunk, is_new_chunk = self._get_or_create_chunk(state, task.chunk_idx) + if is_new_chunk: + chunk.latents = ( + state.latents + if task.chunk_idx == 0 + else torch.randn_like(state.latents, generator=state.sampling.generator) ) - - self.pipeline.step_scheduler(state, noise_pred) - chunk_completed = state.denoise_completed + chunk.scheduler = copy.deepcopy(state.scheduler) + + with state.use_chunk(chunk): + if not self.pipeline.is_buffer_setup: + self.pipeline.set_pp_recv_dict_buffers(state) + + noise_pred = self.pipeline.denoise_step(state) + if noise_pred is None and getattr(self.pipeline, "interrupt", False): + self._update_states_after(state, finished=True) + return RunnerOutput( + req_id=task.sched_req_id, + result=DiffusionOutput(error="micro-step denoise interrupted"), + ) + + self.pipeline.step_scheduler(state, noise_pred) + chunk_completed = state.denoise_completed + if chunk_completed: + state.extra["chunks"].pop(task.chunk_idx, None) + if pp_group.is_first_rank: + state.extra.setdefault("denoised_chunks", []).append((chunk, pp_group.world_size)) - # prefetch the chunk of the next micro-step - prev_chunk, _ = self._get_or_create_chunk(state, prev_task.chunk_idx) if prev_task is not None else (None, None) - if prev_chunk is not None: + task_outputs.append(RunnerOutput( + req_id=task.sched_req_id, + step_index=chunk.step_index, + chunk_idx=task.chunk_idx, + chunk_completed=chunk_completed, + )) + + for prev_task in prev_tasks: + prev_chunk, _ = self._get_or_create_chunk(state, prev_task.chunk_idx) with state.use_chunk(prev_chunk): self.pipeline.prefetch_its(state) + primary = task_outputs[0] + if len(task_outputs) > 1: + primary.extra_task_outputs = task_outputs[1:] + if t_start_ns is not None: + primary.micro_step_wall_ns = time.perf_counter_ns() - t_start_ns + + return primary - if chunk_completed: - state.extra["chunks"].pop(task.chunk_idx, None) - - if pp_group.is_first_rank: - steps_left = pp_group.world_size - state.extra.setdefault("denoised_chunks", []).append((chunk, steps_left)) + def _rank0_decode_due_chunks(self, state: DiffusionRequestState) -> RunnerOutput | None: + """Decode any chunks whose pipeline-drain delay has elapsed. - return RunnerOutput( - req_id=task.sched_req_id, - step_index=chunk.step_index, - chunk_idx=task.chunk_idx, - chunk_completed=chunk_completed, - ) + Returns a finished RunnerOutput when all of the request's chunks have + been decoded, otherwise ``None``. + """ + denoised_chunks = state.extra.get("denoised_chunks", []) + decoded_chunks = state.extra.setdefault("decoded_chunks", []) + remaining: list[tuple[ChunkState, int]] = [] + for chunk, steps_left in denoised_chunks: + steps_left -= 1 + if steps_left == 0: + with state.use_chunk(chunk): + decoded_chunks.append(self.pipeline.post_decode(state)) + else: + remaining.append((chunk, steps_left)) + state.extra["denoised_chunks"] = remaining + + if len(decoded_chunks) == state.sampling.num_chunks: + output = RunnerOutput( + req_id=state.req_id, + step_index=state.step_index, + finished=True, + result=self._merge_chunk_outputs(decoded_chunks), + ) + self._update_states_after(state, finished=True) + return output + return None diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index 7353dab84c1..0c7c3d1cdc7 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -187,6 +187,8 @@ class RunnerOutput: # ── Temporal-PP micro-step fields ── chunk_idx: int | None = None chunk_completed: bool = False + micro_step_wall_ns: int | None = None + extra_task_outputs: list["RunnerOutput"] | None = None # for B>1 def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: return self if self.req_id == sched_req_id else None diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 6cdf08fa3c4..96c65bb2e86 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -224,6 +224,11 @@ class OmniDiffusionSamplingParams: # (temporal PP) to know how many chunks to admit through the pipeline. num_chunks: int = 1 + # SLO-adaptive stream batching. ``slo_fps=None`` keeps B fixed at 1. + slo_fps: float | None = None + slo_max_batch: int = 8 + slo_ema_alpha: float = 0.3 + # Original dimensions (before VAE scaling) height: int | None = None width: int | None = None From 26962f71536d23cff062b1bc14f269b98effefc9 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 6 May 2026 11:43:17 +0200 Subject: [PATCH 35/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 6a3386e9689..527af929f05 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -547,7 +547,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn pp_group = get_pp_group() pp_rank = pp_group.rank_in_group tasks = assignment[pp_rank] - prev_tasks = assignment[pp_group.prev_rank] + prev_task = assignment[pp_group.prev_rank] if pp_group.world_size > 1 else None if is_new_request: pp_group.reset_buffer() @@ -612,7 +612,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn primary.extra_task_outputs = task_outputs[1:] if t_start_ns is not None: primary.micro_step_wall_ns = time.perf_counter_ns() - t_start_ns - + return primary def _rank0_decode_due_chunks(self, state: DiffusionRequestState) -> RunnerOutput | None: From 824975ee3bbe97e68287fea98fca340b5a30c156 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 6 May 2026 11:48:54 +0200 Subject: [PATCH 36/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/worker/diffusion_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 527af929f05..a741d6199b9 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -547,7 +547,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn pp_group = get_pp_group() pp_rank = pp_group.rank_in_group tasks = assignment[pp_rank] - prev_task = assignment[pp_group.prev_rank] if pp_group.world_size > 1 else None + prev_tasks = assignment[pp_group.prev_rank] if pp_group.world_size > 1 else None if is_new_request: pp_group.reset_buffer() From 31ed1d592b3c5034594b13cff6d109ee90200f67 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Mon, 11 May 2026 18:49:40 +0200 Subject: [PATCH 37/53] fix and improve stream batching Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../test_diffusion_micro_step_pipeline.py | 88 +++++-- .../distributed/group_coordinator.py | 18 +- .../distributed/pipeline_parallel.py | 104 +++++--- .../diffusion/executor/multiproc_executor.py | 4 - .../models/wan2_2/pipeline_wan2_2.py | 81 +++++-- vllm_omni/diffusion/sched/interface.py | 24 +- .../diffusion/sched/stream_batch_scheduler.py | 210 +++++++++-------- .../worker/diffusion_model_runner.py | 223 +++++++++++------- vllm_omni/diffusion/worker/utils.py | 9 +- 9 files changed, 474 insertions(+), 287 deletions(-) diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index bc9521fd4e6..77fdf778b87 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -64,7 +64,9 @@ def prepare_encode(self, state, **kwargs): self.prepare_calls += 1 state.timesteps = [torch.tensor(float(i)) for i in range(self.num_steps)] state.latents = torch.zeros((1,)) - state.scheduler = None + # Provide a scheduler stub so the runner can build the per-chunk + # batched_ts from ``chunk.scheduler.timesteps[chunk.step_index]``. + state.scheduler = SimpleNamespace(timesteps=list(state.timesteps)) return state def set_pp_recv_dict_buffers(self, state, **kwargs): @@ -124,6 +126,10 @@ def _make_micro_request( generator_device=None, num_inference_steps=num_inference_steps, num_chunks=num_chunks, + num_frames=1, + slo_fps=None, + slo_max_batch=8, + slo_ema_alpha=0.3, lora_request=None, ), ) @@ -156,7 +162,7 @@ def _make_micro_scheduler_output( finished_req_ids=None, ): if assignment is None: - assignment = [RankTask(sched_req_id=sched_req_id, chunk_idx=0)] + assignment = [RankTask(sched_req_id=sched_req_id, chunk_indices=[0])] if is_new and req is not None: new_reqs = [NewRequestData(sched_req_id=sched_req_id, req=req)] cached_reqs = CachedRequestData.make_empty() @@ -192,18 +198,15 @@ def test_completes_single_chunk_request(self, monkeypatch): _patch_runtime(monkeypatch, runner) req = _make_micro_request(num_inference_steps=1, num_chunks=1) - # μ-step 0: admit chunk 0, denoise it, mark completed. out0 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output(req=req, step_id=0), ) assert out0.req_id == "req-1" - assert out0.chunk_idx == 0 - assert out0.chunk_completed is True + assert out0.chunk_completion_map == {0: True} assert out0.finished is False assert "req-1" in runner.state_cache - # μ-step 1: runner decodes chunk 0 and returns finished. out1 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), @@ -224,26 +227,22 @@ def test_completes_multi_chunk_request(self, monkeypatch): _patch_runtime(monkeypatch, runner) req = _make_micro_request(num_inference_steps=1, num_chunks=2) - # μ-step 0: process chunk 0. DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output(req=req, step_id=0), ) - # μ-step 1: decode chunk 0, process chunk 1. out1 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, - assignment=[RankTask(sched_req_id="req-1", chunk_idx=1)], + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[1])], is_new=False, ), ) - assert out1.chunk_idx == 1 - assert out1.chunk_completed is True + assert out1.chunk_completion_map == {1: True} assert out1.finished is False - # μ-step 2: decode chunk 1; finished after both chunks decoded. out2 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output(sched_req_id="req-1", step_id=2, assignment=[None], is_new=False), @@ -261,21 +260,18 @@ def test_idle_rank_returns_no_op(self, monkeypatch): _patch_runtime(monkeypatch, runner) req = _make_micro_request(num_inference_steps=2, num_chunks=1) - # μ-step 0: process chunk 0 step 1/2 (not yet completed). out0 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output(req=req, step_id=0), ) - assert out0.chunk_completed is False + assert out0.chunk_completion_map == {0: False} - # μ-step 1 with assignment=[None] → no chunk processed, no decode. out1 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), ) assert out1.req_id == "req-1" - assert out1.chunk_idx is None - assert out1.chunk_completed is False + assert out1.chunk_completion_map is None assert out1.finished is False assert runner.pipeline.denoise_calls == 1 @@ -316,6 +312,60 @@ def test_rejects_cache_backend(self): with pytest.raises(ValueError, match="cache_backend"): DiffusionModelRunner.execute_micro_step(runner, _make_micro_scheduler_output(req=req)) + def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(req=req, step_id=0), + ) + assert out.micro_step_wall_ns is not None + assert out.micro_step_wall_ns >= 0 + + def test_batch_two_runs_two_tasks_in_one_micro_step(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) + + # μ-step 0: rank 0 admits chunks 0 and 1 in a single chunk-batch task (B=2). + assignment = [RankTask(sched_req_id="req-1", chunk_indices=[0, 1])] + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(req=req, step_id=0, assignment=assignment), + ) + + # Phase 4 fuses the two chunks into ONE batched denoise_step call; + # Phase 5 runs scheduler.step per chunk (multistep history is per-chunk). + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 2 + assert out.req_id == "req-1" + assert out.chunk_completion_map == {0: True, 1: True} + assert out.micro_step_wall_ns is not None + + def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) + + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, + step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], + ), + ) + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), + ) + + assert out.finished is True + assert runner.pipeline.decode_calls == 2 + assert "req-1" not in runner.state_cache + # --------------------------------------------------------------------------- # Worker @@ -327,7 +377,7 @@ class TestWorker: def test_delegates_to_model_runner(self): worker = object.__new__(DiffusionWorker) - expected = RunnerOutput(req_id="req-1", chunk_idx=0, chunk_completed=False) + expected = RunnerOutput(req_id="req-1", chunk_completion_map={0: False}) scheduler_output = SimpleNamespace( scheduled_new_reqs=[ SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None))) @@ -422,7 +472,7 @@ class TestExecutor: def test_passes_through_runner_output(self, mocker: MockerFixture): executor = object.__new__(MultiprocDiffusionExecutor) executor._ensure_open = lambda: None - expected = RunnerOutput(req_id="req-1", chunk_idx=0, chunk_completed=True) + expected = RunnerOutput(req_id="req-1", chunk_completion_map={0: True}) rpc = mocker.Mock(return_value=expected) executor.collective_rpc = rpc diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 9defa10eb5d..7bc6f40cf54 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -1031,12 +1031,13 @@ def set_recv_dict_buffer( name: str, segment_idx: int, template_dict: dict[str, torch.Tensor | Any], + batch_size: int = 1, ) -> None: """Pre-populate schema cache + a double-buffer pair (indices 0/1) for - (name, segment_idx). + ``(name, segment_idx, batch_size)``. """ metadata_list, _ = _split_tensor_dict(template_dict) - key = (name, segment_idx) + key = (name, segment_idx, batch_size) self.dict_schema_cache[key] = metadata_list buffer_pair: list[dict[str, torch.Tensor]] = [] for _ in range(2): @@ -1055,10 +1056,12 @@ def pipeline_isend_tensor_dict( tensor_dict: dict[str, torch.Tensor | Any], name: str = "dict", segment_idx: int = -1, + batch_size: int = 1, ) -> list[torch.distributed.Work]: + """Non-blocking dict send keyed by ``(name, segment_idx, batch_size)``.""" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - key = (name, segment_idx) + key = (name, segment_idx, batch_size) handles: list[torch.distributed.Work] = [] if key not in self.dict_schema_cache: schema_handles, keepalive = self._isend_dict_schema(metadata_list) @@ -1085,14 +1088,15 @@ def pipeline_irecv_tensor_dict( name: str = "dict", segment_idx: int = -1, buf_idx: int = 0, + batch_size: int = 1, ) -> tuple[dict[str, torch.Tensor | Any], list[torch.distributed.Work], list]: """Async tensor-dict recv into the ``buf_idx`` slot (0 or 1) of the - double-buffer pair for (name, segment_idx). Caller picks the slot - — typically ``micro_step % 2`` — so consecutive recvs alternate and - the previous result stays readable until its consumer is done. + double-buffer pair for ``(name, segment_idx, batch_size)``. Caller picks + the slot — typically ``micro_step % 2`` — so consecutive recvs alternate + and the previous result stays readable until its consumer is done. Posts irecvs on ``comms_stream``. """ - key = (name, segment_idx) + key = (name, segment_idx, batch_size) if key not in self.dict_schema_cache: metadata_list = self._recv_dict_schema() self.dict_schema_cache[key] = metadata_list diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index a39a591ef48..502e4d6e5dd 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -163,15 +163,6 @@ def _pp_send_work(self) -> list[torch.distributed.Work]: def _pp_send_work(self, work: list[torch.distributed.Work]) -> None: self._pp_send_work_list = work - @property - def _preposted_its(self) -> list[AsyncIntermediateTensors] | None: - """Pre-posted IT recvs for the next micro-step (None if not primed).""" - return getattr(self, "_preposted_its_list", None) - - @_preposted_its.setter - def _preposted_its(self, value: list[AsyncIntermediateTensors] | None) -> None: - self._preposted_its_list = value - def _sync_pp_send(self) -> None: """ Wait on all pending non-blocking PP sends. @@ -194,6 +185,8 @@ def predict_noise_maybe_with_cfg( cfg_normalize: bool = True, output_slice: int | None = None, buf_idx: int = 0, + batch_size: int = 1, + preposted_its: list[AsyncIntermediateTensors] | None = None, ) -> torch.Tensor | tuple[torch.Tensor, ...] | None: """ Drop-in replacement for predict_noise_maybe_with_cfg that also handles PP. @@ -230,16 +223,15 @@ def predict_noise_maybe_with_cfg( n = len(all_kwargs) its: list[AsyncIntermediateTensors | None] = [None] * n if not pp_group.is_first_rank: - # Use recvs pre-posted by the previous step's scheduler_step - preposted = self._preposted_its - if preposted is not None and len(preposted) == n: - its = preposted - self._preposted_its = None + # Use recvs pre-posted by the previous step's scheduler_step. + # Caller owns the lifecycle in state.extra; we just consume here. + if preposted_its is not None and len(preposted_its) == n: + its = list(preposted_its) else: for i in range(n): its[i] = AsyncIntermediateTensors( *pp_group.pipeline_irecv_tensor_dict( - name="intermediate", segment_idx=i, buf_idx=buf_idx + name="intermediate", segment_idx=i, buf_idx=buf_idx, batch_size=batch_size, ) ) @@ -248,7 +240,9 @@ def predict_noise_maybe_with_cfg( for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) self._pp_send_work.extend( - pp_group.pipeline_isend_tensor_dict(result.tensors, name="intermediate", segment_idx=i) + pp_group.pipeline_isend_tensor_dict( + result.tensors, name="intermediate", segment_idx=i, batch_size=batch_size, + ) ) return None @@ -282,9 +276,10 @@ def scheduler_step_maybe_with_cfg( t: torch.Tensor | tuple[torch.Tensor, ...], latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, - per_request_scheduler: Any | None = None, + per_request_scheduler: Any | list[Any] | None = None, buf_idx: int = 0, is_last_step: bool = False, + batch_size: int = 1, ) -> torch.Tensor | tuple[torch.Tensor, ...] | AsyncLatents: """ Drop-in replacement for scheduler_step_maybe_with_cfg that also handles PP. @@ -292,35 +287,66 @@ def scheduler_step_maybe_with_cfg( Only the last rank runs the scheduler (it already has noise_pred); the result is sent to rank 0 which needs it for the next forward pass. - Returns a ``AsyncLatents`` on rank 0 that transparently defers - ``handle.wait()`` until the tensor is actually consumed (via attribute - access or a torch operation), keeping the rank non-blocking after the - ``irecv`` is posted. + Returns ``AsyncLatents`` on rank 0 that defers wait() until the tensor + is actually consumed (via attribute access or a torch op), keeping the + rank non-blocking after the irecv. """ if get_pipeline_parallel_world_size() == 1: - return super().scheduler_step_maybe_with_cfg( - noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator - ) + return self._scheduler_step_local(noise_pred, t, latents, do_true_cfg, per_request_scheduler) pp_group = get_pp_group() if pp_group.is_last_rank: - latents = super().scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg, per_request_scheduler) - self._pp_send_work = pp_group.pipeline_isend_tensor_dict({"latents": latents}, name="latents") + latents = self._scheduler_step_local(noise_pred, t, latents, do_true_cfg, per_request_scheduler) + self._pp_send_work = pp_group.pipeline_isend_tensor_dict( + {"latents": latents}, name="latents", batch_size=batch_size, + ) elif pp_group.is_first_rank: - latents = AsyncLatents(*pp_group.pipeline_irecv_tensor_dict(name="latents", buf_idx=buf_idx)) + latents = AsyncLatents( + *pp_group.pipeline_irecv_tensor_dict(name="latents", buf_idx=buf_idx, batch_size=batch_size) + ) return latents + + def _scheduler_step_local( + self, + noise_pred: torch.Tensor, + t: torch.Tensor, + latents: torch.Tensor, + do_true_cfg: bool, + per_request_scheduler: Any | list[Any] | None, + ) -> torch.Tensor: + """Run scheduler.step on this rank — single call or per-chunk loop.""" + if not isinstance(per_request_scheduler, list): + return super().scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg, per_request_scheduler, + ) + new_rows: list[torch.Tensor] = [] + for i, sched in enumerate(per_request_scheduler): + t_i = t[i] if t.ndim > 0 else t + new_rows.append( + super().scheduler_step_maybe_with_cfg( + noise_pred[i:i + 1], t_i, latents[i:i + 1], do_true_cfg, sched, + ) + ) + return torch.cat(new_rows, dim=0) - def prefetch_its_maybe_with_pp_and_cfg(self, do_true_cfg: bool, buf_idx: int, is_last_step: bool) -> None: + def prefetch_its_maybe_with_pp_and_cfg( + self, + do_true_cfg: bool, + buf_idx: int, + is_last_step: bool, + batch_size: int = 1, + ) -> list[AsyncIntermediateTensors] | None: pp_group = get_pp_group() - if not pp_group.is_first_rank and not is_last_step: - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - n = 1 if cfg_parallel_ready else (2 if do_true_cfg else 1) - next_buf_idx = (buf_idx + 1) % 2 - self._preposted_its = [ - AsyncIntermediateTensors( - *pp_group.pipeline_irecv_tensor_dict( - name="intermediate", segment_idx=i, buf_idx=next_buf_idx - ) + if pp_group.is_first_rank or is_last_step: + return None + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + n = 1 if cfg_parallel_ready else (2 if do_true_cfg else 1) + next_buf_idx = (buf_idx + 1) % 2 + return [ + AsyncIntermediateTensors( + *pp_group.pipeline_irecv_tensor_dict( + name="intermediate", segment_idx=i, buf_idx=next_buf_idx, batch_size=batch_size, ) - for i in range(n) - ] \ No newline at end of file + ) + for i in range(n) + ] \ No newline at end of file diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index 94ec2e0a18d..d5583e36e01 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -347,10 +347,6 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: """Forward a temporal-PP micro-step to worker ``execute_micro_step`` RPC. - The reply is collected from the last PP rank, which owns rank N-1's ODE - results and any chunk-finished decodes (carried in - ``RunnerOutput.chunk_events``). Other ranks' replies are discarded. - Assumes worker rank == PP rank (true for PP-only layouts; revisit when introducing TP/DP combinations). """ diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index ee555c24cca..297b7a4a95c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1321,39 +1321,55 @@ def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: if pp_group.world_size == 1: return - latents_template = {"latents": state.latents} + # Pre-populate buffer pairs for every B in 1..slo_max_batch + slo_fps = getattr(state.sampling, "slo_fps", None) + slo_max_batch = getattr(state.sampling, "slo_max_batch", 1) + b_max = max(1, slo_max_batch if slo_fps else 1) - # Intermediate tensor: [batch, seq_len, inner_dim] after patch_embed + flatten. - batch = state.latents.shape[0] - num_frames = state.latents.shape[2] - height = state.latents.shape[3] - width = state.latents.shape[4] + _, channels, num_frames, height, width = state.latents.shape p_t, p_h, p_w = self.transformer_config.patch_size seq_len = (num_frames // p_t) * (height // p_h) * (width // p_w) inner_dim = self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim - dtype = (self.transformer or self.transformer_2).dtype - it_template = { - "hidden_states": torch.empty(batch, seq_len, inner_dim, dtype=dtype, device=self.device) - } + it_dtype = (self.transformer or self.transformer_2).dtype + latents_dtype = state.latents.dtype + device = state.latents.device cfg_branches = 2 if state.do_true_cfg else 1 - pp_group.set_recv_dict_buffer("latents", -1, latents_template) - for seg in range(cfg_branches): - pp_group.set_recv_dict_buffer("intermediate", seg, it_template) + for B in range(1, b_max + 1): + latents_template = { + "latents": torch.empty(B, channels, num_frames, height, width, dtype=latents_dtype, device=device) + } + it_template = { + "hidden_states": torch.empty(B, seq_len, inner_dim, dtype=it_dtype, device=self.device) + } + pp_group.set_recv_dict_buffer("latents", -1, latents_template, batch_size=B) + for seg in range(cfg_branches): + pp_group.set_recv_dict_buffer("intermediate", seg, it_template, batch_size=B) self.is_buffer_setup = True def denoise_step( self, state: DiffusionRequestState, + batch_size: int = 1, **kwargs: Any, ) -> torch.Tensor | None: - """Run one denoising iteration.""" - t = state.current_timestep + """Run one denoising iteration. + + When ``state.batched_timesteps`` is set (stream-batch path with + multiple chunks at different ``step_index`` fused along dim 0), + it overrides ``current_timestep`` and ``_prepare_latent_input`` + forwards the per-row timesteps directly to the transformer. + Model selection uses the per-row max so a batch straddling the + high/low noise boundary picks the high-noise transformer (correct + for single-stage Wan2.1; Wan2.2 boundary-straddling is a TODO). + """ + t = state.batched_timesteps if state.batched_timesteps is not None else state.current_timestep self._current_timestep = t boundary_timestep = state.extra.get("boundary_timestep") - current_model, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) + t_select = t.max() if t.ndim > 0 else t + current_model, current_guidance_scale = self._select_model_for_timestep(t_select, boundary_timestep) latent_model_input, timestep = self._prepare_latent_input(state, t, current_model.dtype) @@ -1379,6 +1395,7 @@ def denoise_step( if do_true_cfg else None ) + preposted_its = state.extra.pop("preposted_its", None) return self.predict_noise_maybe_with_pp_and_cfg( do_true_cfg=do_true_cfg, @@ -1386,9 +1403,11 @@ def denoise_step( positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs, buf_idx=state.step_index % 2, + batch_size=batch_size, + preposted_its=preposted_its, ) - def prefetch_its(self, state: DiffusionRequestState) -> None: + def prefetch_its(self, state: DiffusionRequestState, batch_size: int = 1) -> None: """Prefetch intermediate tensors for the next step.""" t = state.current_timestep boundary_timestep = state.extra.get("boundary_timestep") @@ -1397,34 +1416,48 @@ def prefetch_its(self, state: DiffusionRequestState) -> None: buf_idx = state.step_index % 2 is_last_step = state.step_index == state.total_steps - 1 - self.prefetch_its_maybe_with_pp_and_cfg( + preposted = self.prefetch_its_maybe_with_pp_and_cfg( do_true_cfg=do_true_cfg, buf_idx=buf_idx, is_last_step=is_last_step, + batch_size=batch_size, ) + if preposted is not None: + state.extra["preposted_its"] = preposted def step_scheduler( self, state: DiffusionRequestState, noise_pred: torch.Tensor, + *, + per_request_scheduler: Any | list[Any] | None = None, + batch_size: int = 1, **kwargs: Any, ) -> None: - """Run one scheduler step: update latents and advance step_index.""" - t = state.current_timestep + """Run one scheduler step: update latents and advance step_index. + + ``per_request_scheduler`` may be a single scheduler (B=1) or a list of + per-chunk schedulers (B>1, last rank loops one ``step()`` per row). + ``batch_size`` keys the PP recv buffer pool on first rank. + """ + t = state.batched_timesteps if state.batched_timesteps is not None else state.current_timestep boundary_timestep = state.extra.get("boundary_timestep") - _, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) + t_select = t.max() if t.ndim > 0 else t + _, current_guidance_scale = self._select_model_for_timestep(t_select, boundary_timestep) do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None buf_idx = state.step_index % 2 - is_last_step = state.step_index == state.total_steps - 1 + + if per_request_scheduler is None: + per_request_scheduler = state.scheduler state.latents = self.scheduler_step_maybe_with_pp_and_cfg( noise_pred, t, state.latents, do_true_cfg, - per_request_scheduler=state.scheduler, + per_request_scheduler=per_request_scheduler, buf_idx=buf_idx, - is_last_step=is_last_step, + batch_size=batch_size, ) state.step_index += 1 diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index 9efa419c40b..c28ec5eb688 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -111,7 +111,23 @@ class RankTask: """One unit of work for a rank in a stream-batch micro-step.""" sched_req_id: str - chunk_idx: int + chunk_indices: list[int] + + +@dataclass +class Rank0Layout: + """How rank 0 should slice the [B_prev, ...] tensor it receives from last rank. + + - head [0:n_finished] are chunks completing denoising (to decode) + - next [n_finished : n_finished+n_circulating] are re-admitted chunks + - rank 0 appends n_new fresh randn rows at the tail before forwarding. + """ + + n_finished: int + n_circulating: int + n_new: int + finished_idxs: list[int] + new_idxs: list[int] @dataclass @@ -124,8 +140,10 @@ class DiffusionSchedulerOutput: finished_req_ids: set[str] num_running_reqs: int num_waiting_reqs: int - # Per-rank task list. Index = PP rank id. - per_rank_assignment: list[list[RankTask]] | None = None + + # stream-batch scheduling fields + per_rank_assignment: list[RankTask | None] | None = None + rank0_layouts: dict[str, Rank0Layout] | None = None @cached_property def scheduled_req_ids(self) -> list[str]: diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index cd2e5778dc4..b257d28e73e 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -1,14 +1,15 @@ """Temporal-pipeline-parallel scheduler for streaming chunked diffusion. -Each ``schedule()`` call corresponds to one micro-step. At any micro-step, -each PP rank processes the chunks at the denoising step ``r = current - -entered_rank0_at`` from the active requests' in-flight chunks. Chunks are -admitted to rank 0 in order, propagate through ranks under NCCL FIFO -ordering, and exit at rank N-1 in the same order. +Each ``schedule()`` call corresponds to one micro-step. The pipeline is modeled +as ``pp_size`` per-rank chunk queues plus a transient ``returning`` queue. +At each schedule(), chunks at rank N-1 drain (finished -> finished_head, +otherwise -> returning), queues shift one rank, and rank 0 receives the +returning chunks plus B fresh admits. """ from __future__ import annotations +from collections import deque from dataclasses import dataclass, field from typing import TYPE_CHECKING @@ -20,6 +21,7 @@ from vllm_omni.diffusion.sched.interface import ( DiffusionRequestStatus, DiffusionSchedulerOutput, + Rank0Layout, RankTask, ) @@ -31,22 +33,24 @@ @dataclass class _InFlightChunk: - """One chunk of an active request, tracked through the temporal pipeline.""" + """One chunk of an active request currently in the pipeline.""" chunk_idx: int - is_active: bool = True - is_completed: bool = False - entered_rank0_at: int = -1 + steps_done: int = 0 @dataclass class _ChunkProgress: """Per-request chunk-level scheduling state.""" + sched_req_id: str num_chunks: int num_steps: int + pp_size: int chunks_admitted: int = 0 - in_flight: list[_InFlightChunk] = field(default_factory=list) + # chunks_at[r] = chunks that will be processed by rank r at the current step. + chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) + returning: deque[_InFlightChunk] = field(default_factory=deque) @dataclass @@ -64,7 +68,7 @@ class _SLOReqState: class _SLOController: - """AIMD controller for the stream batch size, tracked per request. + """AIMD controller for the stream admission rate B, tracked per request. Driven by per-micro-step wall-clock latency observations on rank 0. @@ -123,7 +127,10 @@ def observe(self, sched_req_id: str, latency_ns: int | None) -> None: if st.violation_streak >= self.VIOLATION_STREAK_TARGET: new_b = max(1, st.batch_size // 2) if new_b != st.batch_size: - logger.info(f"SLO[{sched_req_id}]: halving batch_size {st.batch_size} -> {new_b} (ema={ema/1e6:.2}ms budget={budget/1e6:.2}ms)") + logger.info( + f"SLO[{sched_req_id}]: halving B {st.batch_size} -> {new_b} " + f"(ema={ema/1e6:.2f}ms budget={budget/1e6:.2f}ms)" + ) st.batch_size = new_b st.violation_streak = 0 return @@ -134,7 +141,10 @@ def observe(self, sched_req_id: str, latency_ns: int | None) -> None: st.slack_streak += 1 if st.slack_streak >= self.SLACK_STREAK_TARGET and st.batch_size < st.max_batch: st.batch_size += 1 - logger.info(f"SLO[{sched_req_id}]: increasing batch_size -> {st.batch_size} (ema={ema/1e6:.2}ms budget={budget/1e6:.2}ms)") + logger.info( + f"SLO[{sched_req_id}]: B -> {st.batch_size} " + f"(ema={ema/1e6:.2f}ms budget={budget/1e6:.2f}ms)" + ) st.slack_streak = 0 else: st.slack_streak = 0 @@ -145,17 +155,16 @@ class StreamBatchScheduler(_BaseScheduler): Per micro-step: 1. Promote waiting requests (handled by the base class). - 2. Admit returning chunks to rank 0 first (FIFO across active requests), - then admit fresh chunks, until ``batch_size`` chunks have entered - rank 0 this micro-step or no admittable chunks remain. - 3. Build the per-rank assignment table from in-pipeline chunks' - positions ``r = current_micro_step - entered_rank0_at``. + 2. Drain rank N-1 of last step: finished chunks -> finished_head (decode + layout for rank 0), others -> returning queue. + 3. Shift per-rank queues by one (rank r <- rank r-1). + 4. Rank 0 = returning + B fresh admits (unconditional re-admit). + 5. Emit per-rank assignment and the per-request Rank0Layout. """ def __init__(self) -> None: super().__init__() self.pp_size: int = 1 - self._global_micro_step: int = 0 self._chunk_progress: dict[str, _ChunkProgress] = {} self._slo: _SLOController = _SLOController() @@ -166,7 +175,6 @@ def initialize(self, od_config: OmniDiffusionConfig) -> None: self.pp_size = od_config.parallel_config.pipeline_parallel_size def _reset_scheduler_state(self) -> None: - self._global_micro_step = 0 self._chunk_progress.clear() self._slo = _SLOController() @@ -195,12 +203,14 @@ def schedule(self) -> DiffusionSchedulerOutput: for new_req in base_output.scheduled_new_reqs: self._init_chunk_progress(new_req.sched_req_id, new_req.req) - self._advance_chunk_pipeline() + rank0_layouts: dict[str, Rank0Layout] = {} + for progress in self._chunk_progress.values(): + rank0_layouts[progress.sched_req_id] = self._advance_chunk_pipeline_for(progress) if self._chunk_progress: base_output.per_rank_assignment = self._build_assignment() + base_output.rank0_layouts = rank0_layouts - self._global_micro_step += 1 return base_output def _init_chunk_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: @@ -212,6 +222,8 @@ def _init_chunk_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> sched_req_id=sched_req_id, num_chunks=num_chunks, num_steps=num_steps, + pp_size=self.pp_size, + chunks_at=[deque() for _ in range(self.pp_size)], ) chunk_frames = max(1, sampling.num_frames) @@ -222,102 +234,98 @@ def _init_chunk_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> ema_alpha=sampling.slo_ema_alpha, chunk_frames=chunk_frames, ) - - logger.debug(f"""StreamBatchScheduler initialized chunk progress for {sched_req_id} - (num_chunks={num_chunks}, num_steps={num_steps}, chunk_frames={chunk_frames}, slo_fps={sampling.slo_fps}, pp_size={self.pp_size})""") - def _advance_chunk_pipeline(self) -> None: - """Admit returning + new chunks to rank 0 this micro-step.""" - if not self._chunk_progress: - return - - m = self._global_micro_step + logger.debug( + "StreamBatchScheduler initialized chunk progress for %s " + "(num_chunks=%d, num_steps=%d, chunk_frames=%d, slo_fps=%s, pp_size=%d)", + sched_req_id, num_chunks, num_steps, chunk_frames, sampling.slo_fps, self.pp_size, + ) - # 1. Re-admit every returning chunk. - for progress in self._chunk_progress.values(): - for chunk in progress.in_flight: - if not chunk.is_active: - chunk.is_active = True - chunk.entered_rank0_at = m + def _advance_chunk_pipeline_for(self, progress: _ChunkProgress) -> Rank0Layout: + """Advance the per-rank queues by one micro-step and return rank 0's layout.""" + + pp = progress.pp_size + + # 1. Drain last rank from previous step + finished_idxs: list[int] = [] + n_finished = 0 + n_circulating = 0 + last = progress.chunks_at[pp - 1] + while last: + chunk = last.popleft() + chunk.steps_done += 1 + if chunk.steps_done >= progress.num_steps: + finished_idxs.append(chunk.chunk_idx) + n_finished += 1 + else: + progress.returning.append(chunk) + n_circulating += 1 + + # 2. Shift: rank r receives what rank r-1 had + for r in range(pp - 1, 0, -1): + progress.chunks_at[r] = progress.chunks_at[r - 1] + progress.chunks_at[0] = deque() + + # 3. Rank 0 = returning + B fresh admits + while progress.returning: + progress.chunks_at[0].append(progress.returning.popleft()) + + new_idxs: list[int] = [] + budget = self._slo.batch_size(progress.sched_req_id) + admitted = 0 + while admitted < budget and progress.chunks_admitted < progress.num_chunks: + idx = progress.chunks_admitted + progress.chunks_at[0].append(_InFlightChunk(chunk_idx=idx)) + progress.chunks_admitted += 1 + new_idxs.append(idx) + admitted += 1 + + return Rank0Layout( + n_finished=n_finished, + n_circulating=n_circulating, + n_new=len(new_idxs), + finished_idxs=finished_idxs, + new_idxs=new_idxs, + ) - # 2. Admit up to ``B_req`` fresh chunks per request. + def _build_assignment(self) -> list[RankTask | None]: + assignment: list[RankTask | None] = [None] * self.pp_size for progress in self._chunk_progress.values(): - budget = self._slo.batch_size(progress.sched_req_id) - admitted = 0 - while ( - progress.chunks_admitted < progress.num_chunks - and admitted < budget - ): - progress.in_flight.append(_InFlightChunk( - chunk_idx=progress.chunks_admitted, - entered_rank0_at=m, - )) - progress.chunks_admitted += 1 - admitted += 1 - - def _build_assignment(self) -> list[list[RankTask]]: - assignment: list[list[RankTask]] = [[] for _ in range(self.pp_size)] - for progress in self._chunk_progress.values(): - for chunk in progress.in_flight: - if not chunk.is_active: + for r in range(self.pp_size): + queue = progress.chunks_at[r] + if not queue: continue - r = self._global_micro_step - chunk.entered_rank0_at - if 0 <= r < self.pp_size: - assignment[r].append(RankTask( + indices = [c.chunk_idx for c in queue] + if assignment[r] is None: + assignment[r] = RankTask( sched_req_id=progress.sched_req_id, - chunk_idx=chunk.chunk_idx, - )) + chunk_indices=indices, + ) + else: + assignment[r].chunk_indices.extend(indices) return assignment # ── Output processing ────────────────────────────────────────────────── - def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput) -> set[str]: - if not self._chunk_progress or sched_output.per_rank_assignment is None: + def update_from_output( + self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput + ) -> set[str]: + if not self._chunk_progress: return set() - per_task = [output] + list(output.extra_task_outputs or []) - if output.micro_step_wall_ns is not None: self._slo.observe(output.req_id, output.micro_step_wall_ns) terminal: dict[str, DiffusionRequestStatus] = {} terminal_errors: dict[str, str | None] = {} - for task_out in per_task: - progress = self._chunk_progress.get(task_out.req_id) - if progress is None: - continue - err = task_out.result.error if task_out.result is not None else None + progress = self._chunk_progress.get(output.req_id) + if progress is not None: + err = output.result.error if output.result is not None else None if err is not None: - terminal[task_out.req_id] = DiffusionRequestStatus.FINISHED_ERROR - terminal_errors[task_out.req_id] = err - continue - chunk = self._find_chunk(progress, task_out.chunk_idx) if task_out.chunk_idx is not None else None - if chunk is not None: - chunk.is_completed = task_out.chunk_completed - if task_out.finished: - terminal[task_out.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED - - # Roll last-rank chunks off the pipeline / mark them inactive. - for last_task in (sched_output.per_rank_assignment[-1] if sched_output.per_rank_assignment else []): - progress = self._chunk_progress.get(last_task.sched_req_id) - if progress is None: - continue - last_chunk = self._find_chunk(progress, last_task.chunk_idx) - if last_chunk is None: - continue - if last_chunk.is_completed: - progress.in_flight = [ - c for c in progress.in_flight if c.chunk_idx != last_chunk.chunk_idx - ] - else: - last_chunk.is_active = False - - return self._finalize_update_from_output(sched_output, terminal, terminal_errors) + terminal[output.req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[output.req_id] = err + elif output.finished: + terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED - @staticmethod - def _find_chunk(progress: _ChunkProgress, chunk_idx: int) -> _InFlightChunk | None: - for chunk in progress.in_flight: - if chunk.chunk_idx == chunk_idx: - return chunk - return None \ No newline at end of file + return self._finalize_update_from_output(sched_output, terminal, terminal_errors) \ No newline at end of file diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index a741d6199b9..c5dfa75b508 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -493,17 +493,18 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BatchR @staticmethod def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: - """Merge K completed chunk outputs into a single ``DiffusionOutput``. + """Merge decoded chunk outputs into a single video tensor. - Concatenates video tensors ``[B, C, T, H, W]`` along the temporal - dimension (dim 2). + Each entry's ``.output`` is ``[B_i, C, T, H, W]`` (one batched VAE + decode in the stream-batch path). Concat along the batch dim then + unroll into the temporal axis → ``[1, C, total_chunks * T, H, W]``. NOTE: This is a temporary solution until streaming output is supported. """ - if len(chunks) == 1: - return chunks[0] try: - merged = torch.cat([c.output for c in chunks], dim=2) + cat0 = torch.cat([c.output for c in chunks], dim=0) + B, C, T, H, W = cat0.shape + merged = cat0.permute(1, 0, 2, 3, 4).reshape(1, C, B * T, H, W) except Exception as e: return DiffusionOutput(error=f"Failed to merge {len(chunks)} chunk outputs: {e}") return DiffusionOutput(output=merged) @@ -545,102 +546,154 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): pp_group = get_pp_group() - pp_rank = pp_group.rank_in_group - tasks = assignment[pp_rank] - prev_tasks = assignment[pp_group.prev_rank] if pp_group.world_size > 1 else None + task = assignment[pp_group.rank_in_group] + chunk_idxs = list(task.chunk_indices) if task else [] if is_new_request: pp_group.reset_buffer() self.pipeline.is_buffer_setup = False self.pipeline.prepare_encode(state) + state.extra["initial_latent_template"] = state.latents + + state.batched_timesteps = None t_start_ns = time.perf_counter_ns() if pp_group.is_first_rank else None + result: DiffusionOutput | None = None + finished = False if pp_group.is_first_rank: - finished_output = self._rank0_decode_due_chunks(state) - if finished_output is not None: - return finished_output - - if not tasks: - return RunnerOutput(req_id=state.req_id) - - task_outputs: list[RunnerOutput] = [] - for task in tasks: - chunk, is_new_chunk = self._get_or_create_chunk(state, task.chunk_idx) - if is_new_chunk: - chunk.latents = ( - state.latents - if task.chunk_idx == 0 - else torch.randn_like(state.latents, generator=state.sampling.generator) - ) - chunk.scheduler = copy.deepcopy(state.scheduler) - - with state.use_chunk(chunk): - if not self.pipeline.is_buffer_setup: - self.pipeline.set_pp_recv_dict_buffers(state) - - noise_pred = self.pipeline.denoise_step(state) - if noise_pred is None and getattr(self.pipeline, "interrupt", False): - self._update_states_after(state, finished=True) - return RunnerOutput( - req_id=task.sched_req_id, - result=DiffusionOutput(error="micro-step denoise interrupted"), - ) + result, finished = self._rank0_assemble_input(state, scheduler_output) + + if not chunk_idxs: + return RunnerOutput( + req_id=state.req_id, + finished=finished, + result=result, + micro_step_wall_ns=( + time.perf_counter_ns() - t_start_ns if t_start_ns is not None else None + ), + ) - self.pipeline.step_scheduler(state, noise_pred) - chunk_completed = state.denoise_completed + template = state.extra["initial_latent_template"] + chunks: list[ChunkState] = [ + self._get_or_create_chunk(state, idx)[0] for idx in chunk_idxs + ] + + if pp_group.is_last_rank: + # Per-chunk schedulers regardless of pp_size — stateful step(). + for c in chunks: + if c.scheduler is None: + c.scheduler = copy.deepcopy(state.scheduler) + + if pp_group.is_first_rank: + # pp_size==1: state.latents was assembled by + # _rank0_assemble_input; take per-chunk views to avoid + # double randn for new admits on the shared generator. + for i, c in enumerate(chunks): + c.latents = state.latents[i:i + 1] + else: + # Multi-rank last rank: maintain per-chunk latents in + # lockstep with rank 0 via shared seed + matching + # randn_like call order. + for c in chunks: + if c.latents is None: + c.latents = ( + template + if c.idx == 0 + else torch.randn_like(template, generator=state.sampling.generator) + ) + state.latents = torch.cat([c.latents for c in chunks], dim=0) + elif not pp_group.is_first_rank: + # Middle rank: ITs carry the forward; state.latents shape only + # for buffer setup. Broadcast view, no extra memory. + state.latents = template.expand(len(chunks), *template.shape[1:]) + + if not self.pipeline.is_buffer_setup: + self.pipeline.set_pp_recv_dict_buffers(state) + + # Per-row timesteps + state.batched_timesteps = torch.stack( + [state.scheduler.timesteps[c.step_index] for c in chunks] + ) - if chunk_completed: - state.extra["chunks"].pop(task.chunk_idx, None) - if pp_group.is_first_rank: - state.extra.setdefault("denoised_chunks", []).append((chunk, pp_group.world_size)) + B = len(chunks) + noise_pred = self.pipeline.denoise_step(state, batch_size=B) - task_outputs.append(RunnerOutput( - req_id=task.sched_req_id, - step_index=chunk.step_index, - chunk_idx=task.chunk_idx, - chunk_completed=chunk_completed, - )) + if noise_pred is None and getattr(self.pipeline, "interrupt", False): + self._update_states_after(state, finished=True) + return RunnerOutput( + req_id=state.req_id, + finished=True, + result=DiffusionOutput(error="micro-step denoise interrupted"), + ) - for prev_task in prev_tasks: - prev_chunk, _ = self._get_or_create_chunk(state, prev_task.chunk_idx) - with state.use_chunk(prev_chunk): - self.pipeline.prefetch_its(state) + self.pipeline.prefetch_its(state, batch_size=B) - primary = task_outputs[0] - if len(task_outputs) > 1: - primary.extra_task_outputs = task_outputs[1:] - if t_start_ns is not None: - primary.micro_step_wall_ns = time.perf_counter_ns() - t_start_ns + schedulers = [c.scheduler for c in chunks] if pp_group.is_last_rank else None + self.pipeline.step_scheduler( + state, noise_pred, per_request_scheduler=schedulers, batch_size=B, + ) - return primary + if pp_group.is_last_rank: + for i, c in enumerate(chunks): + c.latents = state.latents[i:i + 1] - def _rank0_decode_due_chunks(self, state: DiffusionRequestState) -> RunnerOutput | None: - """Decode any chunks whose pipeline-drain delay has elapsed. + for c in chunks: + c.step_index += 1 - Returns a finished RunnerOutput when all of the request's chunks have - been decoded, otherwise ``None``. - """ - denoised_chunks = state.extra.get("denoised_chunks", []) - decoded_chunks = state.extra.setdefault("decoded_chunks", []) - remaining: list[tuple[ChunkState, int]] = [] - for chunk, steps_left in denoised_chunks: - steps_left -= 1 - if steps_left == 0: - with state.use_chunk(chunk): - decoded_chunks.append(self.pipeline.post_decode(state)) - else: - remaining.append((chunk, steps_left)) - state.extra["denoised_chunks"] = remaining - - if len(decoded_chunks) == state.sampling.num_chunks: - output = RunnerOutput( - req_id=state.req_id, - step_index=state.step_index, - finished=True, - result=self._merge_chunk_outputs(decoded_chunks), + return RunnerOutput( + req_id=state.req_id, + finished=finished, + result=result, + micro_step_wall_ns=( + time.perf_counter_ns() - t_start_ns if t_start_ns is not None else None + ), + ) + + def _rank0_assemble_input( + self, state: DiffusionRequestState, scheduler_output: DiffusionSchedulerOutput, + ) -> tuple[DiffusionOutput | None, bool]: + """Build rank 0's batched forward input.""" + + layouts = scheduler_output.rank0_layouts + layout = layouts.get(state.req_id) if layouts else None + if layout is None: + return None, False + + prev_latents = state.latents + pieces: list[torch.Tensor] = [] + + if layout.n_finished > 0 and prev_latents is not None: + saved = state.latents + state.latents = prev_latents[: layout.n_finished] + decoded = self.pipeline.post_decode(state) + state.latents = saved + state.extra.setdefault("decoded_chunks", []).append(decoded) + state.extra["chunks_decoded"] = ( + state.extra.get("chunks_decoded", 0) + layout.n_finished ) + for idx in layout.finished_idxs: + state.extra.get("chunks", {}).pop(idx, None) + + if layout.n_circulating > 0 and prev_latents is not None: + pieces.append( + prev_latents[layout.n_finished : layout.n_finished + layout.n_circulating] + ) + + if layout.n_new > 0: + template = state.extra["initial_latent_template"] # [1, C, T, H, W] + for idx in layout.new_idxs: + row = ( + template + if idx == 0 + else torch.randn_like(template, generator=state.sampling.generator) + ) + pieces.append(row) + + state.latents = torch.cat(pieces, dim=0) if pieces else None + + if state.extra.get("chunks_decoded", 0) >= state.sampling.num_chunks: self._update_states_after(state, finished=True) - return output - return None + return self._merge_chunk_outputs(state.extra["decoded_chunks"]), True + return None, False diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index 0c7c3d1cdc7..f914e92878a 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -57,6 +57,8 @@ class DiffusionRequestState: timesteps: torch.Tensor | list[torch.Tensor] | None = None step_index: int = 0 + batched_timesteps: torch.Tensor | None = None + # ── Per-request scheduler instance (set once by prepare_encode) ── scheduler: Any | None = None @@ -173,8 +175,8 @@ class RunnerOutput: Each scheduler reads the fields it needs: - ``StepScheduler`` reads ``step_index`` / ``finished``. - - ``StreamBatchScheduler`` reads ``chunk_idx`` / ``step_index`` / - ``chunk_completed`` / ``finished``. + - ``StreamBatchScheduler`` reads ``finished`` / ``result`` / + ``micro_step_wall_ns``. Fields not relevant to an execution path are left as ``None`` / ``False``. """ @@ -185,10 +187,7 @@ class RunnerOutput: result: DiffusionOutput | None = None # ── Temporal-PP micro-step fields ── - chunk_idx: int | None = None - chunk_completed: bool = False micro_step_wall_ns: int | None = None - extra_task_outputs: list["RunnerOutput"] | None = None # for B>1 def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: return self if self.req_id == sched_req_id else None From 95b46e2bbe13e31dd2dc0add59623f48d494b06c Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 12 May 2026 10:35:02 +0200 Subject: [PATCH 38/53] modify tests Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../test_diffusion_micro_step_pipeline.py | 169 +++++++++--- tests/diffusion/test_diffusion_scheduler.py | 256 ++++++++---------- 2 files changed, 246 insertions(+), 179 deletions(-) diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index 77fdf778b87..83a42c08e81 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -16,6 +16,7 @@ CachedRequestData, DiffusionSchedulerOutput, NewRequestData, + Rank0Layout, RankTask, ) from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner @@ -63,9 +64,7 @@ def prepare_encode(self, state, **kwargs): del kwargs self.prepare_calls += 1 state.timesteps = [torch.tensor(float(i)) for i in range(self.num_steps)] - state.latents = torch.zeros((1,)) - # Provide a scheduler stub so the runner can build the per-chunk - # batched_ts from ``chunk.scheduler.timesteps[chunk.step_index]``. + state.latents = torch.zeros((1, 1, 1, 1, 1)) state.scheduler = SimpleNamespace(timesteps=list(state.timesteps)) return state @@ -87,7 +86,10 @@ def step_scheduler(self, state, noise_pred, **kwargs): def post_decode(self, state, **kwargs): del kwargs self.decode_calls += 1 - return DiffusionOutput(output=torch.tensor([state.step_index], dtype=torch.float32)) + # One batched decode covers all rows on this rank; output keeps the + # per-row layout so _merge_chunk_outputs can stitch the temporal axis. + b = state.latents.shape[0] if state.latents.ndim > 0 else 1 + return DiffusionOutput(output=torch.ones(b, 1, 1, 1, 1, dtype=torch.float32)) def prefetch_its(self, state, **kwargs): del state, kwargs @@ -152,6 +154,23 @@ def _make_runner(pp_size: int = 1, num_steps: int = 1): return runner +def _make_layout( + *, + n_finished: int = 0, + n_circulating: int = 0, + n_new: int = 0, + finished_idxs: list[int] | None = None, + new_idxs: list[int] | None = None, +) -> Rank0Layout: + return Rank0Layout( + n_finished=n_finished, + n_circulating=n_circulating, + n_new=n_new, + finished_idxs=finished_idxs or [], + new_idxs=new_idxs or [], + ) + + def _make_micro_scheduler_output( *, req=None, @@ -160,9 +179,11 @@ def _make_micro_scheduler_output( assignment=None, is_new: bool = True, finished_req_ids=None, + rank0_layout: Rank0Layout | None = None, ): if assignment is None: assignment = [RankTask(sched_req_id=sched_req_id, chunk_indices=[0])] + rank0_layouts = {sched_req_id: rank0_layout} if rank0_layout is not None else None if is_new and req is not None: new_reqs = [NewRequestData(sched_req_id=sched_req_id, req=req)] cached_reqs = CachedRequestData.make_empty() @@ -177,6 +198,7 @@ def _make_micro_scheduler_output( num_running_reqs=1, num_waiting_reqs=0, per_rank_assignment=assignment, + rank0_layouts=rank0_layouts, ) @@ -200,24 +222,30 @@ def test_completes_single_chunk_request(self, monkeypatch): out0 = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(req=req, step_id=0), + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + ), ) assert out0.req_id == "req-1" - assert out0.chunk_completion_map == {0: True} assert out0.finished is False assert "req-1" in runner.state_cache out1 = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + assignment=[None], is_new=False, + rank0_layout=_make_layout(n_finished=1, finished_idxs=[0]), + ), ) assert out1.finished is True assert out1.result is not None - assert torch.equal(out1.result.output, torch.tensor([1.0])) + assert out1.result.output is not None assert "req-1" not in runner.state_cache assert runner.pipeline.prepare_calls == 1 - assert runner.pipeline.set_buffer_calls == 1 assert runner.pipeline.denoise_calls == 1 assert runner.pipeline.scheduler_calls == 1 assert runner.pipeline.decode_calls == 1 @@ -229,23 +257,30 @@ def test_completes_multi_chunk_request(self, monkeypatch): DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(req=req, step_id=0), + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + ), ) out1 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( - sched_req_id="req-1", - step_id=1, + sched_req_id="req-1", step_id=1, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[1])], is_new=False, + rank0_layout=_make_layout(n_finished=1, n_new=1, finished_idxs=[0], new_idxs=[1]), ), ) - assert out1.chunk_completion_map == {1: True} assert out1.finished is False out2 = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(sched_req_id="req-1", step_id=2, assignment=[None], is_new=False), + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=2, + assignment=[None], is_new=False, + rank0_layout=_make_layout(n_finished=1, finished_idxs=[1]), + ), ) assert out2.finished is True assert out2.result is not None @@ -255,27 +290,73 @@ def test_completes_multi_chunk_request(self, monkeypatch): assert runner.pipeline.denoise_calls == 2 assert runner.pipeline.decode_calls == 2 - def test_idle_rank_returns_no_op(self, monkeypatch): + def test_re_admits_circulating_chunk(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=2) _patch_runtime(monkeypatch, runner) req = _make_micro_request(num_inference_steps=2, num_chunks=1) out0 = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(req=req, step_id=0), + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + ), ) - assert out0.chunk_completion_map == {0: False} + assert out0.finished is False out1 = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + is_new=False, + rank0_layout=_make_layout(n_circulating=1), + ), ) - assert out1.req_id == "req-1" - assert out1.chunk_completion_map is None assert out1.finished is False - assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.denoise_calls == 2 + + out2 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=2, + assignment=[None], is_new=False, + rank0_layout=_make_layout(n_finished=1, finished_idxs=[0]), + ), + ) + assert out2.finished is True + assert runner.pipeline.decode_calls == 1 + + def test_empty_layout_is_a_no_op(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) - def test_interrupt_marks_chunk_as_aborted(self, monkeypatch): + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + ), + ) + denoise_calls_before = runner.pipeline.denoise_calls + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + assignment=[None], is_new=False, + rank0_layout=_make_layout(), + ), + ) + assert out.req_id == "req-1" + assert out.finished is False + assert runner.pipeline.denoise_calls == denoise_calls_before + assert runner.pipeline.decode_calls == 0 + + def test_interrupt_marks_request_as_aborted(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) runner.pipeline = _InterruptingMicroStepPipeline(num_steps=1) _patch_runtime(monkeypatch, runner) @@ -283,7 +364,11 @@ def test_interrupt_marks_chunk_as_aborted(self, monkeypatch): out = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(req=req, step_id=0), + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + ), ) assert out.req_id == "req-1" assert out.result is not None @@ -319,29 +404,33 @@ def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): out = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(req=req, step_id=0), + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + ), ) assert out.micro_step_wall_ns is not None assert out.micro_step_wall_ns >= 0 - def test_batch_two_runs_two_tasks_in_one_micro_step(self, monkeypatch): + def test_batch_two_runs_one_fused_forward(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) req = _make_micro_request(num_inference_steps=1, num_chunks=2) - # μ-step 0: rank 0 admits chunks 0 and 1 in a single chunk-batch task (B=2). - assignment = [RankTask(sched_req_id="req-1", chunk_indices=[0, 1])] out = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(req=req, step_id=0, assignment=assignment), + _make_micro_scheduler_output( + req=req, step_id=0, + assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], + rank0_layout=_make_layout(n_new=2, new_idxs=[0, 1]), + ), ) - # Phase 4 fuses the two chunks into ONE batched denoise_step call; - # Phase 5 runs scheduler.step per chunk (multistep history is per-chunk). assert runner.pipeline.denoise_calls == 1 - assert runner.pipeline.scheduler_calls == 2 + assert runner.pipeline.scheduler_calls == 1 assert out.req_id == "req-1" - assert out.chunk_completion_map == {0: True, 1: True} + assert out.finished is False assert out.micro_step_wall_ns is not None def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): @@ -352,18 +441,22 @@ def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( - req=req, - step_id=0, + req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], + rank0_layout=_make_layout(n_new=2, new_idxs=[0, 1]), ), ) out = DiffusionModelRunner.execute_micro_step( runner, - _make_micro_scheduler_output(sched_req_id="req-1", step_id=1, assignment=[None], is_new=False), + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + assignment=[None], is_new=False, + rank0_layout=_make_layout(n_finished=2, finished_idxs=[0, 1]), + ), ) assert out.finished is True - assert runner.pipeline.decode_calls == 2 + assert runner.pipeline.decode_calls == 1 assert "req-1" not in runner.state_cache @@ -377,7 +470,7 @@ class TestWorker: def test_delegates_to_model_runner(self): worker = object.__new__(DiffusionWorker) - expected = RunnerOutput(req_id="req-1", chunk_completion_map={0: False}) + expected = RunnerOutput(req_id="req-1") scheduler_output = SimpleNamespace( scheduled_new_reqs=[ SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None))) @@ -472,7 +565,7 @@ class TestExecutor: def test_passes_through_runner_output(self, mocker: MockerFixture): executor = object.__new__(MultiprocDiffusionExecutor) executor._ensure_open = lambda: None - expected = RunnerOutput(req_id="req-1", chunk_completion_map={0: True}) + expected = RunnerOutput(req_id="req-1", finished=True) rpc = mocker.Mock(return_value=expected) executor.collective_rpc = rpc diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index 7ba040d1da3..f754dd456f3 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -831,18 +831,16 @@ def _make_stream_request( def _make_stream_output( req_id: str, *, - chunk_idx: int = 0, - chunk_completed: bool = False, finished: bool = False, error: str | None = None, + micro_step_wall_ns: int | None = None, ): return SimpleNamespace( req_id=req_id, step_index=None, finished=finished, - chunk_idx=chunk_idx, - chunk_completed=chunk_completed, result=DiffusionOutput(output=None, error=error) if error is not None else None, + micro_step_wall_ns=micro_step_wall_ns, ) @@ -850,16 +848,24 @@ def _make_od_config(pp_size: int) -> SimpleNamespace: return SimpleNamespace(parallel_config=SimpleNamespace(pipeline_parallel_size=pp_size)) -def _ranks(sched_output) -> list[tuple[str, int] | None]: - """Compact view of per_rank_assignment for assertions.""" +def _ranks(sched_output) -> list[tuple[str, list[int]] | None]: if sched_output.per_rank_assignment is None: return [] return [ - (t.sched_req_id, t.chunk_idx) if t is not None else None + (t.sched_req_id, list(t.chunk_indices)) if t is not None else None for t in sched_output.per_rank_assignment ] +def _layout(sched_output, sched_req_id: str) -> tuple[int, int, int] | None: + if sched_output.rank0_layouts is None: + return None + layout = sched_output.rank0_layouts.get(sched_req_id) + if layout is None: + return None + return (layout.n_finished, layout.n_circulating, layout.n_new) + + class TestStreamBatchScheduler: def _make_scheduler(self, pp_size: int = 2) -> StreamBatchScheduler: sched = StreamBatchScheduler() @@ -868,34 +874,28 @@ def _make_scheduler(self, pp_size: int = 2) -> StreamBatchScheduler: def test_add_request_rejects_invalid_num_chunks(self) -> None: scheduler = self._make_scheduler() - request = _make_stream_request("bad-chunks", num_chunks=0) with pytest.raises(ValueError): - scheduler.add_request(request) + scheduler.add_request(_make_stream_request("bad-chunks", num_chunks=0)) def test_add_request_rejects_invalid_num_inference_steps(self) -> None: scheduler = self._make_scheduler() - request = _make_stream_request("bad-steps", num_inference_steps=0) with pytest.raises(ValueError): - scheduler.add_request(request) + scheduler.add_request(_make_stream_request("bad-steps", num_inference_steps=0)) - def test_pp1_single_chunk_single_step(self) -> None: scheduler = self._make_scheduler(pp_size=1) req_id = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) out0 = scheduler.schedule() assert _new_ids(out0) == [req_id] - assert _ranks(out0) == [(req_id, 0)] - assert scheduler.update_from_output( - out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) - ) == set() + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + assert scheduler.update_from_output(out0, _make_stream_output(req_id)) == set() - # nothing to admit; runner decodes chunk 0 and returns finished. out1 = scheduler.schedule() assert _ranks(out1) == [None] - finished = scheduler.update_from_output( - out1, _make_stream_output(req_id, finished=True) - ) + assert _layout(out1, req_id) == (1, 0, 0) + finished = scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) assert finished == {req_id} assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED assert scheduler.has_requests() is False @@ -905,172 +905,142 @@ def test_pp1_single_chunk_multi_step_re_admits_same_chunk(self) -> None: req_id = scheduler.add_request(_make_stream_request("multi", num_inference_steps=3, num_chunks=1)) out0 = scheduler.schedule() - assert _ranks(out0) == [(req_id, 0)] - assert scheduler.update_from_output( - out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False) - ) == set() + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - assert _ranks(out1) == [(req_id, 0)] - assert scheduler.update_from_output( - out1, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False) - ) == set() + assert _ranks(out1) == [(req_id, [0])] + assert _layout(out1, req_id) == (0, 1, 0) + scheduler.update_from_output(out1, _make_stream_output(req_id)) out2 = scheduler.schedule() - assert _ranks(out2) == [(req_id, 0)] - assert scheduler.update_from_output( - out2, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) - ) == set() + assert _ranks(out2) == [(req_id, [0])] + assert _layout(out2, req_id) == (0, 1, 0) + scheduler.update_from_output(out2, _make_stream_output(req_id)) - # nothing to admit; runner decodes and returns finished. out3 = scheduler.schedule() assert _ranks(out3) == [None] - finished = scheduler.update_from_output( - out3, _make_stream_output(req_id, finished=True) - ) - assert finished == {req_id} + assert _layout(out3, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out3, _make_stream_output(req_id, finished=True)) == {req_id} def test_pp1_multi_chunk_admits_in_order(self) -> None: scheduler = self._make_scheduler(pp_size=1) req_id = scheduler.add_request(_make_stream_request("multi", num_inference_steps=1, num_chunks=2)) out0 = scheduler.schedule() - assert _ranks(out0) == [(req_id, 0)] - assert scheduler.update_from_output( - out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) - ) == set() + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - assert _ranks(out1) == [(req_id, 1)] - assert scheduler.update_from_output( - out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=True) - ) == set() + assert _ranks(out1) == [(req_id, [1])] + assert _layout(out1, req_id) == (1, 0, 1) + scheduler.update_from_output(out1, _make_stream_output(req_id)) out2 = scheduler.schedule() assert _ranks(out2) == [None] - finished = scheduler.update_from_output( - out2, _make_stream_output(req_id, finished=True) - ) - assert finished == {req_id} + assert _layout(out2, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out2, _make_stream_output(req_id, finished=True)) == {req_id} def test_pp2_pipelined_chunks_advance_through_ranks(self) -> None: scheduler = self._make_scheduler(pp_size=2) req_id = scheduler.add_request(_make_stream_request("pp2", num_inference_steps=1, num_chunks=2)) out0 = scheduler.schedule() - assert _ranks(out0) == [(req_id, 0), None] - assert scheduler.update_from_output( - out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) - ) == set() + assert _ranks(out0) == [(req_id, [0]), None] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - assert _ranks(out1) == [(req_id, 1), (req_id, 0)] - assert scheduler.update_from_output( - out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=True) - ) == set() + assert _ranks(out1) == [(req_id, [1]), (req_id, [0])] + assert _layout(out1, req_id) == (0, 0, 1) + scheduler.update_from_output(out1, _make_stream_output(req_id)) out2 = scheduler.schedule() - assert _ranks(out2) == [None, (req_id, 1)] - assert scheduler.update_from_output( - out2, _make_stream_output(req_id, chunk_idx=None, finished=False) - ) == set() + assert _ranks(out2) == [None, (req_id, [1])] + assert _layout(out2, req_id) == (1, 0, 0) + scheduler.update_from_output(out2, _make_stream_output(req_id)) out3 = scheduler.schedule() assert _ranks(out3) == [None, None] - finished = scheduler.update_from_output( - out3, _make_stream_output(req_id, finished=True) - ) - assert finished == {req_id} + assert _layout(out3, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out3, _make_stream_output(req_id, finished=True)) == {req_id} def test_pp3_three_chunks_two_steps_each(self) -> None: scheduler = self._make_scheduler(pp_size=3) req_id = scheduler.add_request(_make_stream_request("pp3", num_inference_steps=2, num_chunks=3)) out0 = scheduler.schedule() - assert _ranks(out0) == [(req_id, 0), None, None] - assert scheduler.update_from_output( - out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False) - ) == set() + assert _ranks(out0) == [(req_id, [0]), None, None] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - assert _ranks(out1) == [(req_id, 1), (req_id, 0), None] - assert scheduler.update_from_output( - out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=False) - ) == set() + assert _ranks(out1) == [(req_id, [1]), (req_id, [0]), None] + assert _layout(out1, req_id) == (0, 0, 1) + scheduler.update_from_output(out1, _make_stream_output(req_id)) out2 = scheduler.schedule() - assert _ranks(out2) == [(req_id, 2), (req_id, 1), (req_id, 0)] - assert scheduler.update_from_output( - out2, _make_stream_output(req_id, chunk_idx=2, chunk_completed=False) - ) == set() + assert _ranks(out2) == [(req_id, [2]), (req_id, [1]), (req_id, [0])] + assert _layout(out2, req_id) == (0, 0, 1) + scheduler.update_from_output(out2, _make_stream_output(req_id)) out3 = scheduler.schedule() - assert _ranks(out3) == [(req_id, 0), (req_id, 2), (req_id, 1)] - assert scheduler.update_from_output( - out3, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) - ) == set() + assert _ranks(out3) == [(req_id, [0]), (req_id, [2]), (req_id, [1])] + assert _layout(out3, req_id) == (0, 1, 0) + scheduler.update_from_output(out3, _make_stream_output(req_id)) out4 = scheduler.schedule() - assert _ranks(out4) == [(req_id, 1), (req_id, 0), (req_id, 2)] - assert scheduler.update_from_output( - out4, _make_stream_output(req_id, chunk_idx=1, chunk_completed=True) - ) == set() + assert _ranks(out4) == [(req_id, [1]), (req_id, [0]), (req_id, [2])] + assert _layout(out4, req_id) == (0, 1, 0) + scheduler.update_from_output(out4, _make_stream_output(req_id)) out5 = scheduler.schedule() - assert _ranks(out5) == [(req_id, 2), (req_id, 1), (req_id, 0)] - assert scheduler.update_from_output( - out5, _make_stream_output(req_id, chunk_idx=2, chunk_completed=True) - ) == set() + assert _ranks(out5) == [(req_id, [2]), (req_id, [1]), (req_id, [0])] + assert _layout(out5, req_id) == (0, 1, 0) + scheduler.update_from_output(out5, _make_stream_output(req_id)) out6 = scheduler.schedule() - assert _ranks(out6) == [None, (req_id, 2), (req_id, 1)] - assert scheduler.update_from_output( - out6, _make_stream_output(req_id, chunk_idx=None) - ) == set() + assert _ranks(out6) == [None, (req_id, [2]), (req_id, [1])] + assert _layout(out6, req_id) == (1, 0, 0) + scheduler.update_from_output(out6, _make_stream_output(req_id)) out7 = scheduler.schedule() - assert _ranks(out7) == [None, None, (req_id, 2)] - assert scheduler.update_from_output( - out7, _make_stream_output(req_id, chunk_idx=None) - ) == set() + assert _ranks(out7) == [None, None, (req_id, [2])] + assert _layout(out7, req_id) == (1, 0, 0) + scheduler.update_from_output(out7, _make_stream_output(req_id)) - # runner decodes and reports finished. out8 = scheduler.schedule() assert _ranks(out8) == [None, None, None] - finished = scheduler.update_from_output( - out8, _make_stream_output(req_id, finished=True) - ) - - assert finished == {req_id} + assert _layout(out8, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out8, _make_stream_output(req_id, finished=True)) == {req_id} assert scheduler.has_requests() is False - def test_re_admission_takes_priority_over_new_chunk(self) -> None: + def test_returning_chunk_leads_fresh_admits_in_fifo(self) -> None: + # Re-admit prepends; new admits append. Order: [returning..., new...]. scheduler = self._make_scheduler(pp_size=1) req_id = scheduler.add_request(_make_stream_request("prio", num_inference_steps=2, num_chunks=2)) out0 = scheduler.schedule() - assert _ranks(out0) == [(req_id, 0)] - scheduler.update_from_output(out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False)) + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - assert _ranks(out1) == [(req_id, 0)] + assert _ranks(out1) == [(req_id, [0, 1])] + assert _layout(out1, req_id) == (0, 1, 1) def test_chunk_progress_cleared_after_request_finishes(self) -> None: scheduler = self._make_scheduler(pp_size=1) req_id = scheduler.add_request(_make_stream_request("cleanup", num_inference_steps=1, num_chunks=1)) out0 = scheduler.schedule() - scheduler.update_from_output( - out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True) - ) - # runner decodes and reports finished. + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - scheduler.update_from_output( - out1, _make_stream_output(req_id, finished=True) - ) - - scheduler.pop_request_state(req_id) + scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) + scheduler.pop_request_state(req_id) assert req_id not in scheduler._chunk_progress assert scheduler.has_requests() is False @@ -1078,6 +1048,7 @@ def test_schedule_with_no_requests_emits_no_assignment(self) -> None: scheduler = self._make_scheduler(pp_size=2) out = scheduler.schedule() assert out.per_rank_assignment is None + assert out.rank0_layouts is None assert out.scheduled_req_ids == [] def test_fifo_two_requests(self) -> None: @@ -1087,17 +1058,16 @@ def test_fifo_two_requests(self) -> None: out0 = scheduler.schedule() assert _new_ids(out0) == [req_a] - assert _ranks(out0) == [(req_a, 0)] - scheduler.update_from_output(out0, _make_stream_output(req_a, chunk_idx=0, chunk_completed=True)) + assert _ranks(out0) == [(req_a, [0])] + scheduler.update_from_output(out0, _make_stream_output(req_a)) - # B still waiting until A finishes. out1 = scheduler.schedule() assert _new_ids(out1) == [] - scheduler.update_from_output(out1, _make_stream_output(req_a, chunk_idx=None, finished=True)) + scheduler.update_from_output(out1, _make_stream_output(req_a, finished=True)) out2 = scheduler.schedule() assert _new_ids(out2) == [req_b] - assert _ranks(out2) == [(req_b, 0)] + assert _ranks(out2) == [(req_b, [0])] def test_has_requests_state_transition(self) -> None: scheduler = self._make_scheduler(pp_size=1) @@ -1108,13 +1078,10 @@ def test_has_requests_state_transition(self) -> None: out0 = scheduler.schedule() assert scheduler.has_requests() is True - scheduler.update_from_output(out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=True)) + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - finished = scheduler.update_from_output( - out1, _make_stream_output(req_id, chunk_idx=None, finished=True) - ) - assert finished == {req_id} + assert scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) == {req_id} assert scheduler.has_requests() is False def test_abort_waiting_and_running_requests(self) -> None: @@ -1137,9 +1104,7 @@ def test_error_output_marks_finished_error(self) -> None: req_id = scheduler.add_request(_make_stream_request("err", num_inference_steps=2, num_chunks=1)) out = scheduler.schedule() - finished = scheduler.update_from_output( - out, _make_stream_output(req_id, chunk_idx=0, error="worker failed") - ) + finished = scheduler.update_from_output(out, _make_stream_output(req_id, error="worker failed")) assert finished == {req_id} state = scheduler.get_request_state(req_id) @@ -1152,27 +1117,36 @@ def test_preempt_request_preserves_chunk_progress(self) -> None: req_id = scheduler.add_request(_make_stream_request("preempt", num_inference_steps=2, num_chunks=2)) out0 = scheduler.schedule() - assert _ranks(out0) == [(req_id, 0), None] - scheduler.update_from_output(out0, _make_stream_output(req_id, chunk_idx=0, chunk_completed=False)) + assert _ranks(out0) == [(req_id, [0]), None] + scheduler.update_from_output(out0, _make_stream_output(req_id)) out1 = scheduler.schedule() - assert _ranks(out1) == [(req_id, 1), (req_id, 0)] - scheduler.update_from_output(out1, _make_stream_output(req_id, chunk_idx=1, chunk_completed=False)) + assert _ranks(out1) == [(req_id, [1]), (req_id, [0])] + scheduler.update_from_output(out1, _make_stream_output(req_id)) before = scheduler._chunk_progress[req_id] assert before.chunks_admitted == 2 - in_flight_before = {c.chunk_idx: (c.is_active, c.is_completed) for c in before.in_flight} - assert in_flight_before == {0: (False, False), 1: (True, False)} + snapshot = [[c.chunk_idx for c in q] for q in before.chunks_at] + assert snapshot == [[1], [0]] assert scheduler.preempt_request(req_id) is True assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.PREEMPTED after = scheduler._chunk_progress[req_id] assert after.chunks_admitted == 2 - in_flight_after = {c.chunk_idx: (c.is_active, c.is_completed) for c in after.in_flight} - assert in_flight_after == in_flight_before + assert [[c.chunk_idx for c in q] for q in after.chunks_at] == snapshot - out2 = scheduler.schedule() - assert _new_ids(out2) == [] # not a fresh promotion - assert _ranks(out2) == [(req_id, 0), (req_id, 1)] - assert scheduler._chunk_progress[req_id].chunks_admitted == 2 + def test_b_admission(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("b", num_inference_steps=2, num_chunks=4)) + scheduler._slo.register(req_id, slo_fps=30.0, max_batch=4, ema_alpha=0.3, chunk_frames=1) + scheduler._slo._reqs[req_id].batch_size = 2 + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0, 1])] + assert _layout(out0, req_id) == (0, 0, 2) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [0, 1, 2, 3])] + assert _layout(out1, req_id) == (0, 2, 2) From 735eb5437d4882ba6999428d2cfd8e7b5189d406 Mon Sep 17 00:00:00 2001 From: Miguel Vieira Pereira Date: Tue, 19 May 2026 14:40:17 +0000 Subject: [PATCH 39/53] Implement video continuation Signed-off-by: Miguel Vieira Pereira --- .../lingbot_world_fast/openai_client.py | 113 +++++++---- vllm_omni/diffusion/data.py | 3 + .../pipeline_lingbot_world_fast.py | 191 +++++++++++++----- .../state_lingbot_world_fast.py | 69 +++++-- vllm_omni/entrypoints/cli/serve.py | 2 +- vllm_omni/entrypoints/openai/api_server.py | 4 +- .../openai/realtime/world/camera_serving.py | 1 - 7 files changed, 273 insertions(+), 110 deletions(-) diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py index e7cc22c0a9c..01678db4ff3 100644 --- a/examples/online_serving/lingbot_world_fast/openai_client.py +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -46,7 +46,7 @@ def _unpack(data): return msgpack_numpy.unpackb(data) -def _load_image(path: str) -> np.ndarray: +def _load_image(path: str | None) -> np.ndarray | None: image = PIL.Image.open(path).convert("RGB") return np.asarray(image) @@ -61,45 +61,73 @@ def _load_camera(camera_dir: str) -> dict: def generate_video(args: Namespace) -> np.ndarray: """Send a single inference request and return the generated frames.""" image = _load_image(args.image) - camera = _load_camera(args.camera_path) + full_camera = _load_camera(args.camera_path) + + extra_body = { + "height": args.height, + "width": args.width, + "num_frames": args.num_frames, + "fps": args.fps, + "session_id": args.session_id, + } + + video = [] + starting_frame = 0 + + for i in range(args.num_calls): + camera = { + "poses": full_camera["poses"][starting_frame : starting_frame + args.num_frames], + "intrinsics": full_camera["intrinsics"][starting_frame : starting_frame + args.num_frames ], + } + + if i == 0: + extra_body["num_frames"] = (args.num_frames // 4) * 4 + 1 + else: + extra_body["num_frames"] = (args.num_frames // 4) * 4 + + obs: dict = {"prompt": args.prompt, "camera": camera, "extra_body": extra_body} + if i == 0: + obs["image"] = image - extra_body = {"height": args.height, "width": args.width, "num_frames": args.num_frames, "fps": args.fps} - - obs: dict = {"prompt": args.prompt, "image": image, "camera": camera, "extra_body": extra_body} - - if args.session_id is not None: obs["session_id"] = args.session_id - endpoint = f"{args.server.rstrip('/')}/v1/realtime/world/camera" - print(f"Connecting to {endpoint} ...") - - with ws_sync.connect(endpoint, max_size=None, ping_interval=None, ping_timeout=None) as ws: - # 1. Server sends CameraServerConfig on connect. - server_config = _unpack(ws.recv()) - - # 2. Send obs. - print( - f"Sending obs (image={image.shape}, " - f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." - ) - ws.send(_pack(obs)) - - # 3. Receive generated frames. - chunks: list[np.ndarray] = [] - total = None - while total is None or len(chunks) < total: - msg = _unpack(ws.recv()) - if isinstance(msg, dict) and msg.get("type") == "error": - raise RuntimeError(f"Server error: {msg.get('message')}") - if not isinstance(msg, dict) or msg.get("type") != "frame": - continue # ignore anything unexpected - total = msg["total"] - chunks.append(msg["video"]) - print(f" received chunk {msg['index'] + 1}/{total}") - - video = np.concatenate(chunks, axis=0) - - return video + endpoint = f"{args.server.rstrip('/')}/v1/realtime/world/camera" + print(f"Connecting to {endpoint} ...") + + with ws_sync.connect(endpoint, max_size=None, ping_interval=None, ping_timeout=None) as ws: + # 1. Server sends CameraServerConfig on connect. + server_config = _unpack(ws.recv()) + + # 2. Send obs. + print( + f"Sending obs ({'image=' + str(image.shape) if obs.get('image', None) is not None else 'None'}, " + f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." + ) + ws.send(_pack(obs)) + + # 3. Receive generated frames. + chunks: list[np.ndarray] = [] + total = None + while total is None or len(chunks) < total: + msg = _unpack(ws.recv()) + if isinstance(msg, dict) and msg.get("type") == "error": + raise RuntimeError(f"Server error: {msg.get('message')}") + if not isinstance(msg, dict) or msg.get("type") != "frame": + continue # ignore anything unexpected + total = msg["total"] + chunks.append(msg["video"]) + print(f" received chunk {msg['index'] + 1}/{total}") + + clip = np.concatenate(chunks, axis=0) + # The first chunk of frames returned was used to condition the video continuation but they are not useful + if i != 0: + clip = clip[3:] + for frame in clip: + video.append(frame) + + starting_frame += args.num_frames + + return video def main(): @@ -133,7 +161,8 @@ def main(): parser.add_argument("--width", type=int, default=832) parser.add_argument("--height", type=int, default=480) parser.add_argument("--fps", type=int, default=16) - parser.add_argument("--num-frames", type=int, default=81) + parser.add_argument("--num-frames", type=int, default=24) + parser.add_argument("--num-calls", type=int, default=2) args = parser.parse_args() frames = generate_video(args) @@ -142,7 +171,13 @@ def main(): output_path.parent.mkdir(parents=True, exist_ok=True) print(frames.__class__) - print(frames[0].__class__) + print(len(frames)) + print(frames[0].shape) + + for i, frame in enumerate(frames): + tmp_frame = (frame * 255).astype(np.uint8) + im = PIL.Image.fromarray(tmp_frame) + im.save(f"{args.output[:-4]}_{i}.png") export_to_video(frames, str(output_path), fps=args.fps) print(f"Saved generated video to {output_path}") diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index c45ee63e7f1..33648a7126b 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -867,6 +867,9 @@ def enrich_config(self) -> None: self.model_class_name = "WanS2VPipeline" self.tf_model_config = TransformerConfig() self.update_multimodal_support() + elif self.model_class_name == "LingbotWorldFastPipeline": + self.tf_config_dict = get_hf_file_to_dict("config.json", self.model) + self.tf_model_config = TransformerConfig.from_dict(self.tf_config_dict) elif architectures and len(architectures) == 1: self.model_class_name = architectures[0] else: diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py index 4a31f6a23ed..2500cdbc540 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -169,14 +169,29 @@ def forward( prompt = req.prompts[0].get("prompt") multi_modal_data = req.prompts[0].get("multi_modal_data", {}) - # Always reset: Lingbot Fast does not support video continuation - self.state.reset() + session_id = str(req.sampling_params.extra_args.get("session_id") or None) + extension = True + + if self.state.session_id is None or self.state.session_id != session_id: + self.state.reset() + self.state.session_id = session_id + extension = False + else: + extension = True camera = multi_modal_data.get("camera", None) if camera is None: self.od_config.model raise ValueError("A path to camera positions must be passed to this model through action_path.") + if extension: + assert multi_modal_data.get("image") is None, ( + "image must not be provided on extension calls; it is only used on the first call of a session" + ) + assert self.model.config.local_attn_size == -1, ( + "video extension requires the model to be configured with local_attn_size == -1" + ) + batch_size = 1 num_frames = req.sampling_params.num_frames # In order to generate something num_frames must be at least 5 since it expects 4*n + 1 as input @@ -184,41 +199,62 @@ def forward( num_frames = max(25, num_frames) c2ws = camera.get("poses") + chunk_size = CONFIG["chunk_size"] + max_area = CONFIG["max_area"] - len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 - num_frames = ((num_frames - 1) // 4) * 4 + 1 - num_frames = min(num_frames, len_c2ws) + # Fresh: 4N+1 pixel frames → N+1 latents, the first slot is the anchor. + # Extension: 4N pixel frames → N regular latents, no anchor. + if extension: + len_c2ws = (len(c2ws) // 4) * 4 + num_frames = (num_frames // 4) * 4 + num_frames = min(num_frames, len_c2ws) + new_lat_f = num_frames // 4 + else: + len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 + num_frames = ((num_frames - 1) // 4) * 4 + 1 + num_frames = min(num_frames, len_c2ws) + new_lat_f = (num_frames - 1) // 4 + 1 c2ws = c2ws[:num_frames] - # preprocess - img = multi_modal_data.get("image") - img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) - - max_area = CONFIG["max_area"] - chunk_size = CONFIG["chunk_size"] - - h, w = img.shape[1:] - aspect_ratio = h / w - lat_h = round(np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) - lat_w = round(np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) - h = lat_h * self.vae_stride[1] - w = lat_w * self.vae_stride[2] - lat_f = (num_frames - 1) // self.vae_stride[0] + 1 - lat_f = int(lat_f - (lat_f % chunk_size)) - lat_f = max(lat_f, 1) - F = (lat_f - 1) * 4 + 1 + # 1. Derive spatial shape: from the input image on fresh start, from cache on extension. + if not extension: + img = multi_modal_data.get("image") + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1] + ) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2] + ) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + else: + img = None + h, w, lat_h, lat_w = self.state.h, self.state.w, self.state.lat_h, self.state.lat_w + + new_lat_f = int(new_lat_f - (new_lat_f % chunk_size)) + new_lat_f = max(new_lat_f, 1) max_seq_len = chunk_size * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - noise = torch.randn(16, lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) - - msk = torch.ones(1, F, lat_h, lat_w, device=self.device) - msk[:, 1:] = 0 - msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) - msk = msk.transpose(1, 2)[0] + noise = torch.randn(16, new_lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + + # Fresh: msk[0] = 1 (anchor) and the rest = 0, replicated into 4 channels grouped + # by latent frame to give shape [4, new_lat_f, lat_h, lat_w]. + # Extension: no anchor, all zeros, already in the [4, new_lat_f, ...] layout. + if not extension: + F = (new_lat_f - 1) * 4 + 1 + msk = torch.zeros(1, F, lat_h, lat_w, device=self.device) + msk[:, 0] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + else: + msk = torch.zeros(4, new_lat_f, lat_h, lat_w, device=self.device) # 2. Prepare timesteps self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) @@ -236,9 +272,9 @@ def forward( ) Ks = Ks[0] + # One target pose per output latent — must match the f= in the rearrange below. len_c2ws = len(c2ws) - len_c2ws_ = int((len_c2ws - 1) // 4) + 1 - len_c2ws_ = int(len_c2ws_ - (len_c2ws_ % chunk_size)) + len_c2ws_ = new_lat_f c2ws_infer = interpolate_camera_poses( src_indices=np.linspace(0, len_c2ws - 1, len_c2ws), src_rot_mat=c2ws[:, :3, :3], @@ -259,21 +295,30 @@ def forward( c2=int(w // lat_w), ) c2ws_plucker_emb = c2ws_plucker_emb[None, ...] # [b, f*h*w, c] - c2ws_plucker_emb = rearrange(c2ws_plucker_emb, "b (f h w) c -> b c f h w", f=lat_f, h=lat_h, w=lat_w).to( + c2ws_plucker_emb = rearrange(c2ws_plucker_emb, "b (f h w) c -> b c f h w", f=new_lat_f, h=lat_h, w=lat_w).to( self.target_dtype ) - y = self.vae.encode( - [ - torch.concat( - [ - torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), - torch.zeros(3, F - 1, h, w), - ], - dim=1, - ).to(self.device) - ] - )[0] + # Fresh: pixels = [anchor_image, zeros...] of shape [3, 4N+1, h, w]. + # VAE produces N+1 latents; latent[0] is the anchor encoding. + # Extension: pixels = zeros [3, 4N+1, h, w]. VAE produces N+1 latents, + # of which latent[0] is the special "1-frame init" encoding + # (biased differently than the regular 4-frame-group latents). + # Slice it off so the N conditioning slots are all regular — + # this drops a CONDITIONING slot, not an output latent. + if not extension: + F = (new_lat_f - 1) * 4 + 1 + pixels = torch.concat( + [ + torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, F - 1, h, w), + ], + dim=1, + ).to(self.device) + y = self.vae.encode([pixels])[0] + else: + pixels = torch.zeros(3, 4 * new_lat_f + 1, h, w, device=self.device) + y = self.vae.encode([pixels])[0][:, 1:] y = torch.concat([msk, y]) @contextmanager @@ -282,17 +327,34 @@ def noop_no_sync(): no_sync_model = getattr(self.model, "no_sync", noop_no_sync) - # Initialize KV cache to all zeros + # Initialize (fresh) or grow (extension) the KV cache. Cross-attn cache is + # left untouched on extension so text-context k/v computed on the first call + # are reused via crossattn_cache[i]["is_init"] == True. model_args = self.model.config transformer_dtype = self.target_dtype frame_seqlen = int(noise.shape[-2] * noise.shape[-1] // 4) - kv_size = frame_seqlen * lat_f + extra_kv_size = frame_seqlen * new_lat_f head_dim = model_args.dim // model_args.num_heads local_num_heads = model_args.num_heads // self.sp_size - self.state.create_kv_caches( - batch_size, transformer_dtype, self.device, kv_size, model_args.num_layers, local_num_heads, head_dim - ) + if not extension: + self.state.create_kv_caches( + batch_size, + transformer_dtype, + self.device, + extra_kv_size, + model_args.num_layers, + local_num_heads, + head_dim, + ) + else: + self.state.extend_kv_caches(extra_kv_size) + + # Total cache size after this call, used both as the per-query attention + # window and as the absolute-token offset base for the chunk loop. + prev_lat_f = self.state.current_lat_f + total_kv_size = frame_seqlen * (prev_lat_f + new_lat_f) + start_token_offset = prev_lat_f * frame_seqlen # evaluation mode with ( @@ -325,8 +387,8 @@ def noop_no_sync(): "local_end_index": self.state.local_end_index, "global_end_index": self.state.global_end_index, "crossattn_cache": self.state.get_crossattn_caches(), - "current_start": chunk_id * chunk_size * frame_seqlen, - "max_attention_size": kv_size, + "current_start": start_token_offset + chunk_id * chunk_size * frame_seqlen, + "max_attention_size": total_kv_size, } for timestep_idx in range(len(timesteps)): @@ -363,11 +425,38 @@ def noop_no_sync(): pred_latent_chunks = torch.cat(pred_latent_chunks, dim=1) if self.device.index == 0: - videos = self.vae.decode([pred_latent_chunks]) + # Wan VAE decode() calls clear_cache() internally, so the very + # first latent always runs the i==0 path (no temporal upsample, + # single-frame output) and leaves feat_map polluted with that + # bias. The decoder's stacked temporal-causal layers also need + # ~2 latents of streaming context before deeper feat_map slots + # match a true mid-stream decode. On extension, prepend the + # prior chunk's last 2 latents so warmup_0 absorbs the i==0 + # bias and warmup_1 fully primes the cache. Then discard the + # 4*K - 3 leading pixels (re-decodes of already-shown frames). + if extension and self.state.last_decoded_latent is not None: + warmup = self.state.last_decoded_latent.to(pred_latent_chunks.device, pred_latent_chunks.dtype) + k = warmup.shape[1] + drop = 4 * k - 3 + to_decode = torch.cat([warmup, pred_latent_chunks], dim=1) + videos = self.vae.decode([to_decode]) + videos = [v[:, drop:] for v in videos] + else: + videos = self.vae.decode([pred_latent_chunks]) + + self.state.last_decoded_latent = pred_latent_chunks[:, -2:].detach().clone() if dist.is_initialized(): dist.barrier() + if not extension: + self.state.h = h + self.state.w = w + self.state.lat_h = lat_h + self.state.lat_w = lat_w + self.state.frame_seqlen = frame_seqlen + self.state.advance(new_lat_f) + return DiffusionOutput(output=videos[0]) def load_weights(self, weights): diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py index 4b62b190ed5..e2ae2174e6a 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py @@ -8,7 +8,6 @@ import logging from enum import IntEnum -import numpy as np import torch logger = logging.getLogger(__name__) @@ -43,22 +42,29 @@ def reset(self) -> None: self.local_end_index: list[torch.Tensor] | None = None self.global_end_index: list[torch.Tensor] | None = None - def should_reset(self, text_tokens: torch.Tensor | None, num_video_frames: int, local_attn_size: int) -> bool: - """Determine if state should be reset before this forward().""" - # NOTE: after accumulate_frames, num_video_frames is the accumulated T - # (1 for first call, 4 for subsequent). Only reset on true single-frame - # which happens when the stitched_buffer was cleared externally. - if num_video_frames == 1 and self.call_count > 1: - logger.info("single frame input after first call, resetting") - return True - - if local_attn_size != -1 and self.current_start_frame >= local_attn_size: - logger.info( - "current_start_frame %d >= local_attn_size %d, resetting", self.current_start_frame, local_attn_size - ) - return True - - return False + self.is_initialized: bool = False + self.current_lat_f: int = 0 + self.session_id: str | None = None + + self.batch_size: int | None = None + self.num_layers: int | None = None + self.num_heads: int | None = None + self.head_dim: int | None = None + + # Shape constants captured on the first call of a session and reused + # on extension calls, where multi_modal_data["image"] is absent. + self.h: int | None = None + self.w: int | None = None + self.lat_h: int | None = None + self.lat_w: int | None = None + self.frame_seqlen: int | None = None + + # Last few latents emitted by the diffusion loop on the previous call. + # Prepended to pred_latent_chunks on extension so the Wan VAE decoder's + # stacked temporal feat_maps are fully warmed before the first NEW + # latent is decoded. The decoder's temporal receptive field spans + # ~2 latents, so we cache the last 2. + self.last_decoded_latent: torch.Tensor | None = None # ------------------------------------------------------------------ # KV cache management @@ -74,6 +80,11 @@ def create_kv_caches( num_heads: int, head_dim: int, ) -> None: + self.batch_size = batch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + """Initialize empty KV caches and cross-attention caches.""" self.kv_cache = [ torch.zeros(2, batch_size, kv_size, num_heads, head_dim, dtype=dtype, device=device) @@ -85,6 +96,27 @@ def create_kv_caches( self.crossattn_cache = [{"is_init": False, "k": None, "v": None} for _ in range(num_layers)] + self.is_initialized = True + + def extend_kv_caches(self, extra_kv_size: int): + assert self.is_initialized, "Cannot extend uninitialized kv cache" + + dtype = self.kv_cache[0].dtype + device = self.kv_cache[0].device + + self.kv_cache = [ + torch.cat( + [ + self.kv_cache[i], + torch.zeros( + 2, self.batch_size, extra_kv_size, self.num_heads, self.head_dim, dtype=dtype, device=device + ), + ], + dim=2, + ) + for i in range(self.num_layers) + ] + def update_kv_cache( self, layer_index: int, @@ -105,3 +137,6 @@ def get_crossattn_caches(self, is_negative: bool = False) -> list[dict[str, bool """Get cross-attention caches for the specified branch.""" assert self.crossattn_cache is not None, "Cross-attn caches not initialized" return self.crossattn_cache + + def advance(self, delta: int): + self.current_lat_f += delta diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 3e8e6ccc04f..c49c984b345 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -552,7 +552,7 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Change max size of a websocket payload that is accepted by the server", ) omni_config_group.add_argument( - "ws", + "--ws", default="auto", help="Set the websocket Protocol type", ) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 16b3367a16b..1d2ea233f24 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -115,9 +115,9 @@ VideoListResponse, VideoResponse, ) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate -from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler @@ -1419,6 +1419,7 @@ async def realtime_websocket(websocket: WebSocket): connection = RealtimeConnection(websocket, serving) await connection.handle_connection() + @router.websocket("/v1/realtime/world/camera") async def realtime_world_camera_openpi(websocket: WebSocket): from vllm_omni.entrypoints.openai.realtime.world.camera_connection import WorldCameraRealtimeConnection @@ -1433,6 +1434,7 @@ async def realtime_world_camera_openpi(websocket: WebSocket): connection = WorldCameraRealtimeConnection(websocket, serving) await connection.handle_connection() + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py index 9544d5cd094..6c7be761762 100644 --- a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py @@ -109,7 +109,6 @@ def reset(self, obs: dict) -> None: Engine-side Lingbot state is reset on the next inference request via `extra_args["reset"]`, not by an immediate websocket-side RPC. """ - self._call_count = 0 self._current_session_id = None async def infer(self, obs: dict) -> np.ndarray: From db6bb72b917271970cced7425ac2fef236551ada Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 19 May 2026 17:50:53 +0200 Subject: [PATCH 40/53] support v2v Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/diffusion_engine.py | 39 ++- .../distributed/pipeline_parallel.py | 3 +- .../models/wan2_2/pipeline_wan2_2.py | 57 +++- vllm_omni/diffusion/sched/base_scheduler.py | 30 +- vllm_omni/diffusion/sched/interface.py | 5 +- .../diffusion/sched/stream_batch_scheduler.py | 287 +++++++++--------- .../worker/diffusion_model_runner.py | 147 ++++----- .../diffusion/worker/diffusion_worker.py | 4 +- vllm_omni/diffusion/worker/utils.py | 34 +-- vllm_omni/inputs/data.py | 7 +- 10 files changed, 344 insertions(+), 269 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 21a2ef16c2b..9e2ece9cac9 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -144,8 +144,6 @@ def __init__( self.executor = executor_class(od_config) self.step_execution = bool(getattr(od_config, "step_execution", False)) self.stream_batch = bool(getattr(od_config, "stream_batch", False)) - if self.stream_batch and not self.step_execution: - raise ValueError("stream_batch=True requires step_execution=True.") if scheduler is not None: self.scheduler: SchedulerInterface = scheduler @@ -710,22 +708,33 @@ def _dummy_run(self): dummy_audio = np.random.randn(audio_sr * 2).astype(np.float32) prompt.setdefault("multi_modal_data", {})["audio"] = dummy_audio + sampling_kwargs: dict[str, Any] = { + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + # Keep warmup path minimal and robust across text encoders. + # Some models may fail when warmup implicitly triggers + # classifier-free guidance with an empty negative prompt. + "guidance_scale": 0.0, + "num_outputs_per_prompt": 1, + # Disable CFG for warmup to avoid triggering CFG parallel + # validation when cfg_parallel_size > 1. + "extra_args": {"cfg_text_scale": 1.0, "cfg_img_scale": 1.0}, + } + + if self.stream_batch: + # Stream-batch requires chunk_frames/num_frames and a source video + chunk_frames = 8 + sampling_kwargs["chunk_frames"] = chunk_frames + sampling_kwargs["num_frames"] = chunk_frames + prompt.setdefault("multi_modal_data", {})["video"] = [ + torch.zeros(3, height, width, dtype=torch.float32) for _ in range(chunk_frames) + ] + req = OmniDiffusionRequest( prompts=[prompt], request_ids=["dummy_req_id"], - sampling_params=OmniDiffusionSamplingParams( - height=height, - width=width, - num_inference_steps=num_inference_steps, - # Keep warmup path minimal and robust across text encoders. - # Some models may fail when warmup implicitly triggers - # classifier-free guidance with an empty negative prompt. - guidance_scale=0.0, - num_outputs_per_prompt=1, - # Disable CFG for warmup to avoid triggering CFG parallel - # validation when cfg_parallel_size > 1. - extra_args={"cfg_text_scale": 1.0, "cfg_img_scale": 1.0}, - ), + sampling_params=OmniDiffusionSamplingParams(**sampling_kwargs), ) logger.info("dummy run to warm up the model") request = self.pre_process_func(req) if self.pre_process_func is not None else req diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 502e4d6e5dd..2bc82c4c83e 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -278,7 +278,6 @@ def scheduler_step_maybe_with_cfg( do_true_cfg: bool, per_request_scheduler: Any | list[Any] | None = None, buf_idx: int = 0, - is_last_step: bool = False, batch_size: int = 1, ) -> torch.Tensor | tuple[torch.Tensor, ...] | AsyncLatents: """ @@ -329,7 +328,7 @@ def _scheduler_step_local( ) return torch.cat(new_rows, dim=0) - def prefetch_its_maybe_with_pp_and_cfg( + def prefetch_its_maybe_with_cfg( self, do_true_cfg: bool, buf_idx: int, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 297b7a4a95c..6edf524c5c3 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -400,6 +400,10 @@ def __init__( model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) + z_dim = self.vae.config.z_dim + self._latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, z_dim, 1, 1, 1) + self._latents_inv_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, z_dim, 1, 1, 1) + # Initialize transformers with correct config (weights loaded via load_weights) if load_transformer: transformer_config = load_transformer_config(model, "transformer", local_files_only) @@ -1416,7 +1420,7 @@ def prefetch_its(self, state: DiffusionRequestState, batch_size: int = 1) -> Non buf_idx = state.step_index % 2 is_last_step = state.step_index == state.total_steps - 1 - preposted = self.prefetch_its_maybe_with_pp_and_cfg( + preposted = self.prefetch_its_maybe_with_cfg( do_true_cfg=do_true_cfg, buf_idx=buf_idx, is_last_step=is_last_step, @@ -1495,3 +1499,54 @@ def post_decode( output=output, stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, ) + + def encode_chunk_inputs( + self, + state: DiffusionRequestState, + new_idxs: list[int], + ) -> list[torch.Tensor]: + """Streaming V2V initial latents (StreamDiffusionV2-style). + + For each newly admitted chunk: VAE-encode the source frames at + ``chunk_idx * chunk_frames : (idx+1) * chunk_frames``, normalize, and + linearly blend with Gaussian noise via ``extra_args["noise_scale"]`` + (default 0.8). The transformer then runs the full schedule starting at + ``step_index=0`` on this partially-noised latent. + """ + noise_scale = float((state.sampling.extra_args or {}).get("noise_scale", 0.8)) + chunk_frames = state.sampling.chunk_frames + prompt = state.prompts[0] if state.prompts else None + video = None + if prompt is not None and not isinstance(prompt, str): + video = (prompt.get("multi_modal_data") or {}).get("video") + if video is None: + raise ValueError( + "encode_chunk_inputs requires V2V source frames in prompts[0]['multi_modal_data']['video']" + ) + + latents_mean = self._latents_mean.to(self.device) + latents_inv_std = self._latents_inv_std.to(self.device) + + out: list[torch.Tensor] = [] + for idx in new_idxs: + start = idx * chunk_frames + end = start + chunk_frames + if isinstance(video, list): + frames = torch.stack(list(video[start:end]), dim=0) + else: + frames = video[start:end] + + if frames.shape[0] < chunk_frames: + pad = torch.zeros( + (chunk_frames - frames.shape[0], *frames.shape[1:]), + dtype=frames.dtype, device=frames.device, + ) + frames = torch.cat([frames, pad], dim=0) + + # [n, C, H, W] -> [1, C, n, H, W] + control = frames.permute(1, 0, 2, 3).unsqueeze(0).to(device=self.device, dtype=self.vae.dtype) + clean = retrieve_latents(self.vae.encode(control), sample_mode="argmax") + clean = ((clean.float() - latents_mean.to(clean.dtype)) * latents_inv_std.to(clean.dtype)).to(self.vae.dtype) + noise = torch.randn_like(clean) + out.append(noise * noise_scale + clean * (1.0 - noise_scale)) + return out diff --git a/vllm_omni/diffusion/sched/base_scheduler.py b/vllm_omni/diffusion/sched/base_scheduler.py index f21c03e983b..9cda554e0bd 100644 --- a/vllm_omni/diffusion/sched/base_scheduler.py +++ b/vllm_omni/diffusion/sched/base_scheduler.py @@ -45,6 +45,7 @@ def __init__(self) -> None: self._waiting: deque[str] = deque() self._running: list[str] = [] self._running_sampling_params_key: SamplingParamsKey | None = None + self._blocked: set[str] = set() self._finished_req_ids: set[str] = set() self.max_num_running_reqs: int = 1 @@ -56,6 +57,7 @@ def initialize(self, od_config: OmniDiffusionConfig) -> None: self._waiting.clear() self._running.clear() self._running_sampling_params_key = None + self._blocked.clear() self._finished_req_ids.clear() max_num_seqs = getattr(od_config, "max_num_seqs", 1) try: @@ -123,7 +125,27 @@ def schedule(self) -> DiffusionSchedulerOutput: return scheduler_output def has_requests(self) -> bool: - return bool(self._waiting or self._running) + return bool(self._waiting or self._running or self._blocked) + + def block_request(self, sched_req_id: str) -> bool: + """Move a RUNNING request to BLOCKED. In-flight work continues.""" + + if sched_req_id not in self._running: + return False + self._running.remove(sched_req_id) + self._blocked.add(sched_req_id) + self._request_states[sched_req_id].status = DiffusionRequestStatus.BLOCKED + return True + + def unblock_request(self, sched_req_id: str) -> bool: + """Move a BLOCKED request to WAITING.""" + + if sched_req_id not in self._blocked: + return False + self._blocked.discard(sched_req_id) + self._waiting.append(sched_req_id) + self._request_states[sched_req_id].status = DiffusionRequestStatus.WAITING + return True def get_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: return self._request_states.get(sched_req_id) @@ -162,6 +184,7 @@ def close(self) -> None: self._waiting.clear() self._running.clear() self._running_sampling_params_key = None + self._blocked.clear() self._finished_req_ids.clear() self._reset_scheduler_state() @@ -176,6 +199,7 @@ def _finish_requests( finished_req_ids: set[str] = set() running_to_remove: set[str] = set() waiting_to_remove: set[str] = set() + blocked_to_remove: set[str] = set() for sched_req_id, status in statuses.items(): assert DiffusionRequestStatus.is_finished(status) @@ -188,6 +212,8 @@ def _finish_requests( running_to_remove.add(sched_req_id) if sched_req_id in self._waiting: waiting_to_remove.add(sched_req_id) + if sched_req_id in self._blocked: + blocked_to_remove.add(sched_req_id) if running_to_remove: self._running = [sched_req_id for sched_req_id in self._running if sched_req_id not in running_to_remove] @@ -197,6 +223,8 @@ def _finish_requests( self._waiting = deque( sched_req_id for sched_req_id in self._waiting if sched_req_id not in waiting_to_remove ) + if blocked_to_remove: + self._blocked -= blocked_to_remove for sched_req_id in finished_req_ids: state = self._request_states[sched_req_id] diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index c28ec5eb688..0b1910567f6 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -27,6 +27,7 @@ class DiffusionRequestStatus(enum.IntEnum): WAITING = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() + BLOCKED = enum.auto() # if any status is after or equal to FINISHED_COMPLETED, it is considered finished FINISHED_COMPLETED = enum.auto() @@ -123,9 +124,7 @@ class Rank0Layout: - rank 0 appends n_new fresh randn rows at the tail before forwarding. """ - n_finished: int n_circulating: int - n_new: int finished_idxs: list[int] new_idxs: list[int] @@ -142,7 +141,7 @@ class DiffusionSchedulerOutput: num_waiting_reqs: int # stream-batch scheduling fields - per_rank_assignment: list[RankTask | None] | None = None + assignment: list[RankTask | None] | None = None rank0_layouts: dict[str, Rank0Layout] | None = None @cached_property diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index b257d28e73e..c9b61ca9f2f 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -2,9 +2,10 @@ Each ``schedule()`` call corresponds to one micro-step. The pipeline is modeled as ``pp_size`` per-rank chunk queues plus a transient ``returning`` queue. -At each schedule(), chunks at rank N-1 drain (finished -> finished_head, -otherwise -> returning), queues shift one rank, and rank 0 receives the -returning chunks plus B fresh admits. +At each schedule(), chunks at rank N-1 drain (finished -> Rank0Layout finished +slice, otherwise -> returning), queues shift one rank, and rank 0 receives the +returning chunks plus B fresh admits drawn from the source video frames in +``prompts[0]["multi_modal_data"]["video"]``. """ from __future__ import annotations @@ -13,6 +14,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +import torch from vllm.logger import init_logger from vllm_omni.diffusion.data import OmniDiffusionConfig @@ -31,54 +33,63 @@ logger = init_logger(__name__) +def _video_frame_count(request: OmniDiffusionRequest) -> int: + """Number of frames currently in ``prompts[0]["multi_modal_data"]["video"]``.""" + if not request.prompts: + return 0 + prompt = request.prompts[0] + if isinstance(prompt, str): + return 0 + multi_modal = prompt.get("multi_modal_data") or {} + video = multi_modal.get("video") + if video is None: + return 0 + if isinstance(video, torch.Tensor): + return int(video.shape[0]) + if isinstance(video, list): + return len(video) + raise TypeError( + f"multi_modal_data['video'] must be a Tensor or list of Tensors; got {type(video).__name__}." + ) + + @dataclass class _InFlightChunk: - """One chunk of an active request currently in the pipeline.""" - chunk_idx: int steps_done: int = 0 @dataclass -class _ChunkProgress: - """Per-request chunk-level scheduling state.""" - +class _Progress: sched_req_id: str - num_chunks: int - num_steps: int pp_size: int - chunks_admitted: int = 0 - # chunks_at[r] = chunks that will be processed by rank r at the current step. - chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) - returning: deque[_InFlightChunk] = field(default_factory=deque) + chunk_frames: int + num_frames: int + num_steps: int + + frames_committed: int = 0 + next_chunk_idx: int = 0 + batch_size: int = 0 + + chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) # chunks that will be processed by rank r at the current micro-step + + @property + def output_chunks_target(self) -> int: + return self.num_frames // self.chunk_frames @dataclass class _SLOReqState: - """Per-request SLO state, owned by ``_SLOController``.""" - slo_fps: float max_batch: int - ema_alpha: float chunk_frames: int batch_size: int = 1 - latency_ema_ns: float | None = None - slack_streak: int = 0 - violation_streak: int = 0 class _SLOController: - """AIMD controller for the stream admission rate B, tracked per request. - - Driven by per-micro-step wall-clock latency observations on rank 0. + """Per-step B_target adjustment for per-request chunk admission.""" - Maintains an EMA of micro-step latency; halves B on sustained budget violations and - increments B by 1 on sustained slack. - """ - - SLACK_THRESHOLD_RATIO = 0.25 - SLACK_STREAK_TARGET = 4 - VIOLATION_STREAK_TARGET = 2 + SLACK_HEADROOM = 0.2 def __init__(self) -> None: self._reqs: dict[str, _SLOReqState] = {} @@ -88,7 +99,6 @@ def register( sched_req_id: str, slo_fps: float | None, max_batch: int, - ema_alpha: float, chunk_frames: int, ) -> None: if slo_fps is None or slo_fps <= 0: @@ -96,58 +106,35 @@ def register( self._reqs[sched_req_id] = _SLOReqState( slo_fps=float(slo_fps), max_batch=max(1, max_batch), - ema_alpha=ema_alpha, chunk_frames=max(1, chunk_frames), ) - def unregister(self, sched_req_id: str) -> None: - self._reqs.pop(sched_req_id, None) - - def batch_size(self, sched_req_id: str) -> int: + def get_target(self, sched_req_id: int) -> int: st = self._reqs.get(sched_req_id) return st.batch_size if st is not None else 1 - def observe(self, sched_req_id: str, latency_ns: int | None) -> None: + def observe(self, sched_req_id: str, latency_ns: int | None, b_current: int | None) -> None: st = self._reqs.get(sched_req_id) - if st is None or latency_ns is None or latency_ns <= 0: + if st is None or latency_ns is None or latency_ns <= 0 or b_current is None or b_current <= 0: return - if st.latency_ema_ns is None: - st.latency_ema_ns = float(latency_ns) + budget = (b_current * st.chunk_frames / st.slo_fps) * 1e9 + if latency_ns > budget: + new_b = max(1, st.batch_size - 1) + elif latency_ns < budget * (1.0 - self.SLACK_HEADROOM) and st.batch_size < st.max_batch: + new_b = st.batch_size + 1 else: - a = st.ema_alpha - st.latency_ema_ns = a * float(latency_ns) + (1.0 - a) * st.latency_ema_ns - - budget = (st.batch_size * st.chunk_frames / st.slo_fps) * 1e9 - ema = st.latency_ema_ns - - if ema > budget: - st.violation_streak += 1 - st.slack_streak = 0 - if st.violation_streak >= self.VIOLATION_STREAK_TARGET: - new_b = max(1, st.batch_size // 2) - if new_b != st.batch_size: - logger.info( - f"SLO[{sched_req_id}]: halving B {st.batch_size} -> {new_b} " - f"(ema={ema/1e6:.2f}ms budget={budget/1e6:.2f}ms)" - ) - st.batch_size = new_b - st.violation_streak = 0 return - st.violation_streak = 0 - headroom_ratio = (budget - ema) / budget - if headroom_ratio >= self.SLACK_THRESHOLD_RATIO: - st.slack_streak += 1 - if st.slack_streak >= self.SLACK_STREAK_TARGET and st.batch_size < st.max_batch: - st.batch_size += 1 - logger.info( - f"SLO[{sched_req_id}]: B -> {st.batch_size} " - f"(ema={ema/1e6:.2f}ms budget={budget/1e6:.2f}ms)" - ) - st.slack_streak = 0 - else: - st.slack_streak = 0 + if new_b != st.batch_size: + logger.info( + "SLO[%s]: B_target %d -> %d (latency=%.2fms budget=%.2fms)", + sched_req_id, st.batch_size, new_b, latency_ns / 1e6, budget / 1e6, + ) + st.batch_size = new_b + + def unregister(self, sched_req_id: str) -> None: + self._reqs.pop(sched_req_id, None) class StreamBatchScheduler(_BaseScheduler): @@ -155,17 +142,19 @@ class StreamBatchScheduler(_BaseScheduler): Per micro-step: 1. Promote waiting requests (handled by the base class). - 2. Drain rank N-1 of last step: finished chunks -> finished_head (decode - layout for rank 0), others -> returning queue. + 2. Drain rank N-1: finished chunks -> finished slice in + Rank0Layout, otherwise -> returning queue. 3. Shift per-rank queues by one (rank r <- rank r-1). - 4. Rank 0 = returning + B fresh admits (unconditional re-admit). - 5. Emit per-rank assignment and the per-request Rank0Layout. + 4. Rank 0 = returning + B fresh admits, where + `B = min(B_target, queue_chunks_available, output_chunks_remaining)`. + 5. Emit per-rank assignment and the per-request Rank0Layout. Flip req state + RUNNING -> BLOCKED when admission is starved on input. """ def __init__(self) -> None: super().__init__() self.pp_size: int = 1 - self._chunk_progress: dict[str, _ChunkProgress] = {} + self._progress: dict[str, _Progress] = {} self._slo: _SLOController = _SLOController() # ── Lifecycle ────────────────────────────────────────────────────────── @@ -173,25 +162,30 @@ def __init__(self) -> None: def initialize(self, od_config: OmniDiffusionConfig) -> None: super().initialize(od_config) self.pp_size = od_config.parallel_config.pipeline_parallel_size + # TODO: support multiple requests + self.max_num_running_reqs = 1 def _reset_scheduler_state(self) -> None: - self._chunk_progress.clear() + self._progress.clear() self._slo = _SLOController() def _pop_extra_request_state(self, sched_req_id: str) -> None: - self._chunk_progress.pop(sched_req_id, None) + self._progress.pop(sched_req_id, None) self._slo.unregister(sched_req_id) # ── Request admission ────────────────────────────────────────────────── def add_request(self, request: OmniDiffusionRequest) -> str: - num_chunks = request.sampling_params.num_chunks - num_steps = request.sampling_params.num_inference_steps - if num_chunks is None or num_chunks <= 0: - raise ValueError(f"num_chunks must be a positive int, got {num_chunks!r}") - if num_steps is None or num_steps <= 0: + sampling = request.sampling_params + if sampling.chunk_frames is None or sampling.chunk_frames <= 0: raise ValueError( - f"num_inference_steps must be a positive int, got {num_steps!r}" + f"chunk_frames must be a positive int when stream_batch=True, got {sampling.chunk_frames}" + ) + if sampling.num_frames is None or sampling.num_frames <= 0: + raise ValueError(f"num_frames must be a positive int, got {sampling.num_frames}") + if sampling.num_inference_steps is None or sampling.num_inference_steps <= 0: + raise ValueError( + f"num_inference_steps must be a positive int, got {sampling.num_inference_steps}" ) return super().add_request(request) @@ -201,131 +195,150 @@ def schedule(self) -> DiffusionSchedulerOutput: base_output = super().schedule() for new_req in base_output.scheduled_new_reqs: - self._init_chunk_progress(new_req.sched_req_id, new_req.req) + self._init_progress(new_req.sched_req_id, new_req.req) rank0_layouts: dict[str, Rank0Layout] = {} - for progress in self._chunk_progress.values(): - rank0_layouts[progress.sched_req_id] = self._advance_chunk_pipeline_for(progress) + for progress in self._progress.values(): + rank0_layouts[progress.sched_req_id] = self._advance_chunk_pipeline(progress) - if self._chunk_progress: - base_output.per_rank_assignment = self._build_assignment() + if self._progress: + base_output.assignment = self._build_assignment() base_output.rank0_layouts = rank0_layouts return base_output - def _init_chunk_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: + def _init_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: sampling = req.sampling_params - num_chunks = sampling.num_chunks + chunk_frames = sampling.chunk_frames + num_frames = sampling.num_frames num_steps = sampling.num_inference_steps - assert num_chunks is not None and num_steps is not None - self._chunk_progress[sched_req_id] = _ChunkProgress( + + self._progress[sched_req_id] = _Progress( sched_req_id=sched_req_id, - num_chunks=num_chunks, + chunk_frames=chunk_frames, + num_frames=num_frames, num_steps=num_steps, pp_size=self.pp_size, chunks_at=[deque() for _ in range(self.pp_size)], ) - chunk_frames = max(1, sampling.num_frames) self._slo.register( sched_req_id=sched_req_id, slo_fps=sampling.slo_fps, max_batch=sampling.slo_max_batch, - ema_alpha=sampling.slo_ema_alpha, chunk_frames=chunk_frames, ) logger.debug( - "StreamBatchScheduler initialized chunk progress for %s " - "(num_chunks=%d, num_steps=%d, chunk_frames=%d, slo_fps=%s, pp_size=%d)", - sched_req_id, num_chunks, num_steps, chunk_frames, sampling.slo_fps, self.pp_size, + "StreamBatchScheduler initialized progress for %s " + "(chunk_frames=%d, num_frames=%d, num_steps=%d, slo_fps=%s, pp_size=%d)", + sched_req_id, chunk_frames, num_frames, num_steps, sampling.slo_fps, self.pp_size, ) - def _advance_chunk_pipeline_for(self, progress: _ChunkProgress) -> Rank0Layout: + def _advance_chunk_pipeline(self, progress: _Progress) -> Rank0Layout: """Advance the per-rank queues by one micro-step and return rank 0's layout.""" pp = progress.pp_size # 1. Drain last rank from previous step finished_idxs: list[int] = [] - n_finished = 0 - n_circulating = 0 + circulating = [] last = progress.chunks_at[pp - 1] while last: chunk = last.popleft() chunk.steps_done += 1 if chunk.steps_done >= progress.num_steps: finished_idxs.append(chunk.chunk_idx) - n_finished += 1 else: - progress.returning.append(chunk) - n_circulating += 1 + circulating.append(chunk) # 2. Shift: rank r receives what rank r-1 had for r in range(pp - 1, 0, -1): progress.chunks_at[r] = progress.chunks_at[r - 1] progress.chunks_at[0] = deque() - # 3. Rank 0 = returning + B fresh admits - while progress.returning: - progress.chunks_at[0].append(progress.returning.popleft()) + # 3. Rank 0 = circulating + B fresh admits + for chunk in circulating: + progress.chunks_at[0].append(chunk) + + state = self.get_request_state(progress.sched_req_id) + available_frames = _video_frame_count(state.req) if state is not None else 0 + queue_chunks = max(0, (available_frames - progress.frames_committed) // progress.chunk_frames) + output_chunks_remaining = progress.output_chunks_target - progress.next_chunk_idx + b_target = self._slo.get_target(progress.sched_req_id) + batch_size = min(b_target, queue_chunks, output_chunks_remaining) new_idxs: list[int] = [] - budget = self._slo.batch_size(progress.sched_req_id) - admitted = 0 - while admitted < budget and progress.chunks_admitted < progress.num_chunks: - idx = progress.chunks_admitted - progress.chunks_at[0].append(_InFlightChunk(chunk_idx=idx)) - progress.chunks_admitted += 1 - new_idxs.append(idx) - admitted += 1 + for _ in range(batch_size): + chunk_idx = progress.next_chunk_idx + progress.next_chunk_idx += 1 + progress.frames_committed += progress.chunk_frames + progress.chunks_at[0].append(_InFlightChunk(chunk_idx=chunk_idx)) + new_idxs.append(chunk_idx) + progress.batch_size = batch_size + + # 4. Flip RUNNING -> BLOCKED if input-starved and we still owe output. + if ( + batch_size == 0 + and output_chunks_remaining > 0 + and queue_chunks == 0 + and progress.sched_req_id in self._running + ): + self.block_request(progress.sched_req_id) + logger.debug( + "StreamBatchScheduler: %s BLOCKED on input " + "(committed_frames=%d, target_frames=%d, available_frames=%d)", + progress.sched_req_id, progress.frames_committed, progress.num_frames, available_frames, + ) return Rank0Layout( - n_finished=n_finished, - n_circulating=n_circulating, - n_new=len(new_idxs), + n_circulating=len(circulating), finished_idxs=finished_idxs, new_idxs=new_idxs, ) def _build_assignment(self) -> list[RankTask | None]: - assignment: list[RankTask | None] = [None] * self.pp_size - for progress in self._chunk_progress.values(): + assert len(self._progress) <= 1 #TODO: support multiple requests + for progress in self._progress.values(): + assignment: list[RankTask | None] = [None] * self.pp_size for r in range(self.pp_size): queue = progress.chunks_at[r] if not queue: continue - indices = [c.chunk_idx for c in queue] - if assignment[r] is None: - assignment[r] = RankTask( - sched_req_id=progress.sched_req_id, - chunk_indices=indices, - ) - else: - assignment[r].chunk_indices.extend(indices) - return assignment + assignment[r] = RankTask( + sched_req_id=progress.sched_req_id, + chunk_indices=[c.chunk_idx for c in queue], + ) + return assignment # ── Output processing ────────────────────────────────────────────────── def update_from_output( self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput ) -> set[str]: - if not self._chunk_progress: + sched_req_ids = sched_output.scheduled_req_ids + if not sched_req_ids: return set() + + assert len(sched_req_ids) == 1, "Multiple scheduled requests not supported" + + sched_req_id = output.req_id + + assert sched_req_id == sched_req_ids[0] - if output.micro_step_wall_ns is not None: - self._slo.observe(output.req_id, output.micro_step_wall_ns) + progress = self._progress.get(sched_req_id) + if progress is not None and output.micro_step_wall_ns is not None: + self._slo.observe(sched_req_id, output.micro_step_wall_ns, progress.batch_size) terminal: dict[str, DiffusionRequestStatus] = {} terminal_errors: dict[str, str | None] = {} - progress = self._chunk_progress.get(output.req_id) if progress is not None: err = output.result.error if output.result is not None else None if err is not None: - terminal[output.req_id] = DiffusionRequestStatus.FINISHED_ERROR - terminal_errors[output.req_id] = err + terminal[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[sched_req_id] = err elif output.finished: - terminal[output.req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + terminal[sched_req_id] = DiffusionRequestStatus.FINISHED_COMPLETED return self._finalize_update_from_output(sched_output, terminal, terminal_errors) \ No newline at end of file diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index c5dfa75b508..256800add06 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -28,7 +28,7 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import supports_step_execution +from vllm_omni.diffusion.models.interface import supports_step_execution, supports_micro_step_execution from vllm_omni.diffusion.offloader import get_offload_backend from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -155,6 +155,13 @@ def get_memory_context(): "prepare_encode(), denoise_step(), step_scheduler(), and post_decode(); " f"{self.od_config.model_class_name} does not support that contract." ) + if getattr(self.od_config, "stream_batch", False) and not self.supports_micro_step_mode(): + raise ValueError( + "stream_batch=True requires a pipeline implementing the micro-step " + "execution protocol (prepare_encode, set_pp_recv_dict_buffers, " + "denoise_step, prefetch_its, step_scheduler, post_decode, encode_chunk_inputs); " + f"{self.od_config.model_class_name} does not support that contract." + ) # Apply CPU offloading self.offload_backend = get_offload_backend(self.od_config, device=self.device) @@ -491,51 +498,35 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BatchR # Temporal-PP micro-step execution # ------------------------------------------------------------------ - @staticmethod - def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: - """Merge decoded chunk outputs into a single video tensor. + def supports_micro_step_mode(self) -> bool: + """Return whether current pipeline supports micro-step execution.""" - Each entry's ``.output`` is ``[B_i, C, T, H, W]`` (one batched VAE - decode in the stream-batch path). Concat along the batch dim then - unroll into the temporal axis → ``[1, C, total_chunks * T, H, W]``. + return self.pipeline is not None and supports_micro_step_execution(self.pipeline) - NOTE: This is a temporary solution until streaming output is supported. - """ - try: - cat0 = torch.cat([c.output for c in chunks], dim=0) - B, C, T, H, W = cat0.shape - merged = cat0.permute(1, 0, 2, 3, 4).reshape(1, C, B * T, H, W) - except Exception as e: - return DiffusionOutput(error=f"Failed to merge {len(chunks)} chunk outputs: {e}") - return DiffusionOutput(output=merged) - - @staticmethod - def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ChunkState, bool]: - chunks: dict[int, ChunkState] = state.extra.setdefault("chunks", {}) - chunk = chunks.get(chunk_idx) - if chunk is not None: - return chunk, chunk.step_index == 0 - chunk = ChunkState(idx=chunk_idx) - chunks[chunk_idx] = chunk - return chunk, True - - def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + def execute_micro_step(self, sched_output: DiffusionSchedulerOutput) -> RunnerOutput: """Execute one temporal-PP micro-step.""" + assert self.pipeline is not None, "Model not loaded. Call load_model() first." - if not self.supports_step_mode(): - raise ValueError("Current pipeline does not support step execution.") + if not self.supports_micro_step_mode(): + raise ValueError("Current pipeline does not support micro-step execution.") if self.od_config.cache_backend not in (None, "none"): - raise ValueError("Stream-batch mode does not support cache_backend yet.") + raise ValueError("Micro-step mode does not support cache_backend yet.") - assignment = scheduler_output.per_rank_assignment + assignment = sched_output.assignment if assignment is None: - raise ValueError("execute_micro_step requires per_rank_assignment in scheduler_output.") + raise ValueError("execute_micro_step requires assignment in sched_output.") use_hsdp = self.od_config.parallel_config.use_hsdp grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() with grad_context: - state, is_new_request = self._update_states(scheduler_output) + states, new_request_ids = self._update_states(sched_output) + if len(states) != 1: + raise ValueError( + f"Micro-step mode supports exactly one running request, got {len(states)}." + ) + state = states[0] + is_new_request = state.req_id in new_request_ids if is_new_request: if state.sampling.generator is None and state.sampling.seed is not None: @@ -562,7 +553,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn finished = False if pp_group.is_first_rank: - result, finished = self._rank0_assemble_input(state, scheduler_output) + result, finished = self._rank0_assemble_input(state, sched_output) if not chunk_idxs: return RunnerOutput( @@ -580,21 +571,10 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn ] if pp_group.is_last_rank: - # Per-chunk schedulers regardless of pp_size — stateful step(). - for c in chunks: - if c.scheduler is None: - c.scheduler = copy.deepcopy(state.scheduler) - - if pp_group.is_first_rank: - # pp_size==1: state.latents was assembled by - # _rank0_assemble_input; take per-chunk views to avoid - # double randn for new admits on the shared generator. + if pp_group.world_size == 1: for i, c in enumerate(chunks): c.latents = state.latents[i:i + 1] else: - # Multi-rank last rank: maintain per-chunk latents in - # lockstep with rank 0 via shared seed + matching - # randn_like call order. for c in chunks: if c.latents is None: c.latents = ( @@ -604,8 +584,6 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn ) state.latents = torch.cat([c.latents for c in chunks], dim=0) elif not pp_group.is_first_rank: - # Middle rank: ITs carry the forward; state.latents shape only - # for buffer setup. Broadcast view, no extra memory. state.latents = template.expand(len(chunks), *template.shape[1:]) if not self.pipeline.is_buffer_setup: @@ -616,22 +594,22 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn [state.scheduler.timesteps[c.step_index] for c in chunks] ) - B = len(chunks) - noise_pred = self.pipeline.denoise_step(state, batch_size=B) + batch_size = len(chunks) + noise_pred = self.pipeline.denoise_step(state, batch_size=batch_size) if noise_pred is None and getattr(self.pipeline, "interrupt", False): - self._update_states_after(state, finished=True) + self.state_cache.pop(state.req_id, None) return RunnerOutput( req_id=state.req_id, finished=True, result=DiffusionOutput(error="micro-step denoise interrupted"), ) - self.pipeline.prefetch_its(state, batch_size=B) + self.pipeline.prefetch_its(state, batch_size=batch_size) - schedulers = [c.scheduler for c in chunks] if pp_group.is_last_rank else None + schedulers = [c.scheduler for c in chunks] self.pipeline.step_scheduler( - state, noise_pred, per_request_scheduler=schedulers, batch_size=B, + state, noise_pred, per_request_scheduler=schedulers, batch_size=batch_size, ) if pp_group.is_last_rank: @@ -650,6 +628,17 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn ), ) + @staticmethod + def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ChunkState, bool]: + chunks: dict[int, ChunkState] = state.extra.setdefault("chunks", {}) + chunk = chunks.get(chunk_idx) + if chunk is not None: + return chunk, chunk.step_index == 0 + chunk = ChunkState(idx=chunk_idx) + chunk.scheduler = copy.deepcopy(state.scheduler) + chunks[chunk_idx] = chunk + return chunk, True + def _rank0_assemble_input( self, state: DiffusionRequestState, scheduler_output: DiffusionSchedulerOutput, ) -> tuple[DiffusionOutput | None, bool]: @@ -663,37 +652,53 @@ def _rank0_assemble_input( prev_latents = state.latents pieces: list[torch.Tensor] = [] - if layout.n_finished > 0 and prev_latents is not None: + n_finished = len(layout.finished_idxs) + if n_finished > 0 and prev_latents is not None: saved = state.latents - state.latents = prev_latents[: layout.n_finished] + state.latents = prev_latents[: n_finished] decoded = self.pipeline.post_decode(state) state.latents = saved state.extra.setdefault("decoded_chunks", []).append(decoded) - state.extra["chunks_decoded"] = ( - state.extra.get("chunks_decoded", 0) + layout.n_finished + state.extra["num_chunks_decoded"] = ( + state.extra.get("num_chunks_decoded", 0) + n_finished ) for idx in layout.finished_idxs: state.extra.get("chunks", {}).pop(idx, None) if layout.n_circulating > 0 and prev_latents is not None: pieces.append( - prev_latents[layout.n_finished : layout.n_finished + layout.n_circulating] + prev_latents[n_finished : n_finished + layout.n_circulating] ) - if layout.n_new > 0: - template = state.extra["initial_latent_template"] # [1, C, T, H, W] - for idx in layout.new_idxs: - row = ( - template - if idx == 0 - else torch.randn_like(template, generator=state.sampling.generator) - ) - pieces.append(row) + if layout.new_idxs: + encoded = self.pipeline.encode_chunk_inputs(state, layout.new_idxs) + for latent, idx in zip(encoded, layout.new_idxs): + chunk, _ = self._get_or_create_chunk(state, idx) + chunk.latents = latent + pieces.append(latent) state.latents = torch.cat(pieces, dim=0) if pieces else None - if state.extra.get("chunks_decoded", 0) >= state.sampling.num_chunks: - self._update_states_after(state, finished=True) + output_chunks_target = state.sampling.num_frames // state.sampling.chunk_frames + if state.extra.get("num_chunks_decoded", 0) >= output_chunks_target: + self.state_cache.pop(state.req_id, None) return self._merge_chunk_outputs(state.extra["decoded_chunks"]), True return None, False + @staticmethod + def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: + """Merge decoded chunk outputs into a single video tensor. + + Each entry's ``.output`` is ``[B_i, C, T, H, W]``. + Concat along the batch dim then unroll into the temporal axis ``[1, C, total_chunks * T, H, W]``. + + NOTE: This is a temporary solution until streaming output is supported. + """ + + try: + cat0 = torch.cat([c.output for c in chunks], dim=0) + B, C, T, H, W = cat0.shape + merged = cat0.permute(1, 0, 2, 3, 4).reshape(1, C, B * T, H, W) + except Exception as e: + return DiffusionOutput(error=f"Failed to merge {len(chunks)} chunk outputs: {e}") + return DiffusionOutput(output=merged) \ No newline at end of file diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index da80798376e..7cfcd9331eb 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -382,7 +382,7 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRu profiler.step() return output - def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: """Execute one temporal-PP micro-step by delegating to the model runner.""" assert self.model_runner is not None, "Model runner not initialized" if self.lora_manager is not None: @@ -966,7 +966,7 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRu """Execute one diffusion step.""" return self.worker.execute_stepwise(scheduler_output) - def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: """Execute one temporal-PP micro-step.""" return self.worker.execute_micro_step(scheduler_output) diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index f914e92878a..b2e9d241c16 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -5,7 +5,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -import contextlib from collections.abc import Iterator from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -113,42 +112,13 @@ def new_request(self) -> bool: # A real "new request" signal should eventually come from scheduler/runner state transitions. return self.step_index == 0 or self.timesteps is None - @contextlib.contextmanager - def use_chunk(self, chunk: ChunkState) -> Iterator[None]: - """Temporarily alias per-chunk fields on ``self`` to a ``ChunkState``'s view. - - Swapped fields: ``latents``, ``step_index``, ``scheduler``. - - Lets ``prepare_encode`` / ``denoise_step`` / ``step_scheduler`` operate - per-chunk without any pipeline-side changes. Updates made inside the - context are written back to the chunk on exit; the request-level fields - are restored. - """ - saved_latents = self.latents - saved_step_index = self.step_index - saved_scheduler = self.scheduler - self.latents = chunk.latents - self.step_index = chunk.step_index - self.scheduler = chunk.scheduler - try: - yield - finally: - chunk.latents = self.latents - chunk.step_index = self.step_index - chunk.scheduler = self.scheduler - self.latents = saved_latents - self.step_index = saved_step_index - self.scheduler = saved_scheduler - @dataclass class ChunkState: """Per-chunk state for one in-flight chunk of a streaming request. Lives inside ``DiffusionRequestState.extra["chunks"]`` (keyed by - ``chunk_idx``). The runner swaps a chunk into the request state via - ``state.use_chunk(chunk)`` for the duration of one micro-step's - ``denoise_step + step_scheduler`` calls. + ``chunk_idx``). Each chunk owns its own ``scheduler`` instance (deepcopied from the pipeline's scheduler by ``prepare_encode``) because multi-step ODE solvers @@ -169,7 +139,7 @@ def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: @dataclass -class RunnerOutput: +class RunnerOutput(BaseRunnerOutput): """Output of a single execution step for a request. Each scheduler reads the fields it needs: diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 96c65bb2e86..3d7a4e689d2 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -220,14 +220,11 @@ class OmniDiffusionSamplingParams: width_latents: list[int] | int | None = None num_frames: int = 1 # Default for image models num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus - # Number of output chunks the request produces. Read by ``StreamBatchScheduler`` - # (temporal PP) to know how many chunks to admit through the pipeline. - num_chunks: int = 1 + chunk_frames: int = 8 # Used when stream_batch=True - # SLO-adaptive stream batching. ``slo_fps=None`` keeps B fixed at 1. + # SLO-adaptive stream batching. ``slo_fps=None`` keeps B_target fixed at 1. slo_fps: float | None = None slo_max_batch: int = 8 - slo_ema_alpha: float = 0.3 # Original dimensions (before VAE scaling) height: int | None = None From dde68499418e2b2f99fd5b9905971423ffb8488c Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 19 May 2026 17:51:04 +0200 Subject: [PATCH 41/53] update tests Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../test_diffusion_micro_step_pipeline.py | 74 ++++++++++--------- tests/diffusion/test_diffusion_scheduler.py | 35 +++++---- 2 files changed, 62 insertions(+), 47 deletions(-) diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index 83a42c08e81..718c02b7f49 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -58,6 +58,7 @@ def __init__(self, num_steps: int = 1): self.scheduler_calls = 0 self.decode_calls = 0 self.prefetch_calls = 0 + self.encode_calls = 0 self.is_buffer_setup = False def prepare_encode(self, state, **kwargs): @@ -68,6 +69,11 @@ def prepare_encode(self, state, **kwargs): state.scheduler = SimpleNamespace(timesteps=list(state.timesteps)) return state + def encode_chunk_inputs(self, state, new_idxs): + del state + self.encode_calls += 1 + return [torch.zeros((1, 1, 1, 1, 1)) for _ in new_idxs] + def set_pp_recv_dict_buffers(self, state, **kwargs): del state, kwargs self.set_buffer_calls += 1 @@ -117,7 +123,8 @@ def _make_micro_request( req_id: str = "req-1", *, num_inference_steps: int = 1, - num_chunks: int = 1, + num_frames: int = 1, + chunk_frames: int = 1, ): return SimpleNamespace( prompts=["a prompt"], @@ -127,11 +134,10 @@ def _make_micro_request( seed=None, generator_device=None, num_inference_steps=num_inference_steps, - num_chunks=num_chunks, - num_frames=1, + chunk_frames=chunk_frames, + num_frames=num_frames, slo_fps=None, slo_max_batch=8, - slo_ema_alpha=0.3, lora_request=None, ), ) @@ -156,16 +162,12 @@ def _make_runner(pp_size: int = 1, num_steps: int = 1): def _make_layout( *, - n_finished: int = 0, n_circulating: int = 0, - n_new: int = 0, finished_idxs: list[int] | None = None, new_idxs: list[int] | None = None, ) -> Rank0Layout: return Rank0Layout( - n_finished=n_finished, n_circulating=n_circulating, - n_new=n_new, finished_idxs=finished_idxs or [], new_idxs=new_idxs or [], ) @@ -197,7 +199,7 @@ def _make_micro_scheduler_output( finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), num_running_reqs=1, num_waiting_reqs=0, - per_rank_assignment=assignment, + assignment=assignment, rank0_layouts=rank0_layouts, ) @@ -218,14 +220,14 @@ class TestRunner: def test_completes_single_chunk_request(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=1) + req = _make_micro_request(num_inference_steps=1, num_frames=1) out0 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + rank0_layout=_make_layout(new_idxs=[0]), ), ) assert out0.req_id == "req-1" @@ -237,7 +239,7 @@ def test_completes_single_chunk_request(self, monkeypatch): _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, assignment=[None], is_new=False, - rank0_layout=_make_layout(n_finished=1, finished_idxs=[0]), + rank0_layout=_make_layout(finished_idxs=[0]), ), ) assert out1.finished is True @@ -253,14 +255,14 @@ def test_completes_single_chunk_request(self, monkeypatch): def test_completes_multi_chunk_request(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=2) + req = _make_micro_request(num_inference_steps=1, num_frames=2) DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + rank0_layout=_make_layout(new_idxs=[0]), ), ) out1 = DiffusionModelRunner.execute_micro_step( @@ -269,7 +271,7 @@ def test_completes_multi_chunk_request(self, monkeypatch): sched_req_id="req-1", step_id=1, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[1])], is_new=False, - rank0_layout=_make_layout(n_finished=1, n_new=1, finished_idxs=[0], new_idxs=[1]), + rank0_layout=_make_layout(finished_idxs=[0], new_idxs=[1]), ), ) assert out1.finished is False @@ -279,7 +281,7 @@ def test_completes_multi_chunk_request(self, monkeypatch): _make_micro_scheduler_output( sched_req_id="req-1", step_id=2, assignment=[None], is_new=False, - rank0_layout=_make_layout(n_finished=1, finished_idxs=[1]), + rank0_layout=_make_layout(finished_idxs=[1]), ), ) assert out2.finished is True @@ -293,14 +295,14 @@ def test_completes_multi_chunk_request(self, monkeypatch): def test_re_admits_circulating_chunk(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=2) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=2, num_chunks=1) + req = _make_micro_request(num_inference_steps=2, num_frames=1) out0 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + rank0_layout=_make_layout(new_idxs=[0]), ), ) assert out0.finished is False @@ -322,7 +324,7 @@ def test_re_admits_circulating_chunk(self, monkeypatch): _make_micro_scheduler_output( sched_req_id="req-1", step_id=2, assignment=[None], is_new=False, - rank0_layout=_make_layout(n_finished=1, finished_idxs=[0]), + rank0_layout=_make_layout(finished_idxs=[0]), ), ) assert out2.finished is True @@ -331,14 +333,14 @@ def test_re_admits_circulating_chunk(self, monkeypatch): def test_empty_layout_is_a_no_op(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=1) + req = _make_micro_request(num_inference_steps=1, num_frames=1) DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + rank0_layout=_make_layout(new_idxs=[0]), ), ) denoise_calls_before = runner.pipeline.denoise_calls @@ -360,14 +362,14 @@ def test_interrupt_marks_request_as_aborted(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) runner.pipeline = _InterruptingMicroStepPipeline(num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=1) + req = _make_micro_request(num_inference_steps=1, num_frames=1) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + rank0_layout=_make_layout(new_idxs=[0]), ), ) assert out.req_id == "req-1" @@ -377,13 +379,13 @@ def test_interrupt_marks_request_as_aborted(self, monkeypatch): assert runner.pipeline.scheduler_calls == 0 assert runner.pipeline.decode_calls == 0 - def test_rejects_missing_per_rank_assignment(self): + def test_rejects_missing_assignment(self): runner = _make_runner(pp_size=1) req = _make_micro_request() sched_output = _make_micro_scheduler_output(req=req) - sched_output.per_rank_assignment = None + sched_output.assignment = None - with pytest.raises(ValueError, match="per_rank_assignment"): + with pytest.raises(ValueError, match="assignment"): DiffusionModelRunner.execute_micro_step(runner, sched_output) def test_rejects_cache_backend(self): @@ -400,14 +402,14 @@ def test_rejects_cache_backend(self): def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=1) + req = _make_micro_request(num_inference_steps=1, num_frames=1) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(n_new=1, new_idxs=[0]), + rank0_layout=_make_layout(new_idxs=[0]), ), ) assert out.micro_step_wall_ns is not None @@ -416,14 +418,14 @@ def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): def test_batch_two_runs_one_fused_forward(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=2) + req = _make_micro_request(num_inference_steps=1, num_frames=2) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], - rank0_layout=_make_layout(n_new=2, new_idxs=[0, 1]), + rank0_layout=_make_layout(new_idxs=[0, 1]), ), ) @@ -436,14 +438,14 @@ def test_batch_two_runs_one_fused_forward(self, monkeypatch): def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_chunks=2) + req = _make_micro_request(num_inference_steps=1, num_frames=2) DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], - rank0_layout=_make_layout(n_new=2, new_idxs=[0, 1]), + rank0_layout=_make_layout(new_idxs=[0, 1]), ), ) out = DiffusionModelRunner.execute_micro_step( @@ -451,7 +453,7 @@ def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, assignment=[None], is_new=False, - rank0_layout=_make_layout(n_finished=2, finished_idxs=[0, 1]), + rank0_layout=_make_layout(finished_idxs=[0, 1]), ), ) @@ -548,7 +550,11 @@ def test_wan22_supports_micro_step_execution(self): SupportsMicroStepExecution, supports_micro_step_execution, ) - from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline + + try: + from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline + except (RuntimeError, ImportError) as exc: + pytest.skip(f"Wan22Pipeline import not available on this platform: {exc}") # Avoid loading weights; protocol membership is a class-contract check. pipeline = object.__new__(Wan22Pipeline) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index f754dd456f3..f4aaa27b757 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -817,12 +817,20 @@ def _make_stream_request( *, num_inference_steps: int = 2, num_chunks: int = 1, + chunk_frames: int = 1, ) -> OmniDiffusionRequest: + """``num_chunks`` is a test-helper shorthand for ``num_frames / chunk_frames``.""" + num_frames = num_chunks * chunk_frames + video = [torch.zeros(3, 8, 8) for _ in range(num_frames)] return OmniDiffusionRequest( - prompts=[f"prompt_{req_id}"], + prompts=[{ + "prompt": f"prompt_{req_id}", + "multi_modal_data": {"video": video}, + }], sampling_params=OmniDiffusionSamplingParams( num_inference_steps=num_inference_steps, - num_chunks=num_chunks, + chunk_frames=chunk_frames, + num_frames=num_frames, ), request_ids=[req_id], ) @@ -849,11 +857,11 @@ def _make_od_config(pp_size: int) -> SimpleNamespace: def _ranks(sched_output) -> list[tuple[str, list[int]] | None]: - if sched_output.per_rank_assignment is None: + if sched_output.assignment is None: return [] return [ (t.sched_req_id, list(t.chunk_indices)) if t is not None else None - for t in sched_output.per_rank_assignment + for t in sched_output.assignment ] @@ -863,7 +871,7 @@ def _layout(sched_output, sched_req_id: str) -> tuple[int, int, int] | None: layout = sched_output.rank0_layouts.get(sched_req_id) if layout is None: return None - return (layout.n_finished, layout.n_circulating, layout.n_new) + return (len(layout.finished_idxs), layout.n_circulating, len(layout.new_idxs)) class TestStreamBatchScheduler: @@ -1041,13 +1049,13 @@ def test_chunk_progress_cleared_after_request_finishes(self) -> None: scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) scheduler.pop_request_state(req_id) - assert req_id not in scheduler._chunk_progress + assert req_id not in scheduler._progress assert scheduler.has_requests() is False def test_schedule_with_no_requests_emits_no_assignment(self) -> None: scheduler = self._make_scheduler(pp_size=2) out = scheduler.schedule() - assert out.per_rank_assignment is None + assert out.assignment is None assert out.rank0_layouts is None assert out.scheduled_req_ids == [] @@ -1063,7 +1071,8 @@ def test_fifo_two_requests(self) -> None: out1 = scheduler.schedule() assert _new_ids(out1) == [] - scheduler.update_from_output(out1, _make_stream_output(req_a, finished=True)) + finished = scheduler.update_from_output(out1, _make_stream_output(req_a, finished=True)) + scheduler.pop_request_state(req_a) out2 = scheduler.schedule() assert _new_ids(out2) == [req_b] @@ -1124,22 +1133,22 @@ def test_preempt_request_preserves_chunk_progress(self) -> None: assert _ranks(out1) == [(req_id, [1]), (req_id, [0])] scheduler.update_from_output(out1, _make_stream_output(req_id)) - before = scheduler._chunk_progress[req_id] - assert before.chunks_admitted == 2 + before = scheduler._progress[req_id] + assert before.next_chunk_idx == 2 snapshot = [[c.chunk_idx for c in q] for q in before.chunks_at] assert snapshot == [[1], [0]] assert scheduler.preempt_request(req_id) is True assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.PREEMPTED - after = scheduler._chunk_progress[req_id] - assert after.chunks_admitted == 2 + after = scheduler._progress[req_id] + assert after.next_chunk_idx == 2 assert [[c.chunk_idx for c in q] for q in after.chunks_at] == snapshot def test_b_admission(self) -> None: scheduler = self._make_scheduler(pp_size=1) req_id = scheduler.add_request(_make_stream_request("b", num_inference_steps=2, num_chunks=4)) - scheduler._slo.register(req_id, slo_fps=30.0, max_batch=4, ema_alpha=0.3, chunk_frames=1) + scheduler._slo.register(req_id, slo_fps=30.0, max_batch=4, chunk_frames=1) scheduler._slo._reqs[req_id].batch_size = 2 out0 = scheduler.schedule() From 010cfc9463ebb24aac69ed9d06e39f71a559703e Mon Sep 17 00:00:00 2001 From: Miguel Vieira Pereira Date: Tue, 19 May 2026 16:02:28 +0000 Subject: [PATCH 42/53] Remove dependency on external code from Lingbot World repo Signed-off-by: Miguel Vieira Pereira --- .../download_lingbot_world_fast.py | 13 - .../lingbot_world_fast/end2end.py | 1 + .../lingbot_world_fast/openai_client.py | 15 +- .../models/lingbot_world_fast/cam_utils.py | 153 +++++ .../lingbot_world_fast/fm_solvers_unipc.py | 576 +++++++++++++++++ .../pipeline_lingbot_world_fast.py | 20 +- .../diffusion/models/lingbot_world_fast/t5.py | 451 +++++++++++++ .../models/lingbot_world_fast/tokenizers.py | 78 +++ .../models/lingbot_world_fast/vae2_1.py | 610 ++++++++++++++++++ .../models/lingbot_world_fast/wan_fast.py | 2 +- .../models/lingbot_world_fast/wan_model.py | 95 +++ vllm_omni/entrypoints/cli/serve.py | 2 +- 12 files changed, 1978 insertions(+), 38 deletions(-) create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/t5.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py diff --git a/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py index 2cd6b6f0538..6db8ebb5c6b 100644 --- a/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py +++ b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py @@ -1,9 +1,7 @@ import argparse -import fcntl import json import os import site -import subprocess import tempfile import time from pathlib import Path @@ -20,17 +18,6 @@ def download_dependency(): CACHE_DIR.mkdir(parents=True, exist_ok=True) - with open(LOCK_FILE, "w") as f: - fcntl.flock(f, fcntl.LOCK_EX) - if not DEPENDENCY_DIR.exists(): - print(f"Downloading Lingbot World Fast to {DEPENDENCY_DIR} ...") - subprocess.run( - ["git", "clone", "--depth", "1", DEPENDENCY_REPO, "--branch", DEPENDENCY_BRANCH, str(DEPENDENCY_DIR)], - check=True, - ) - print("Download finished.") - fcntl.flock(f, fcntl.LOCK_UN) - # write .pth to site-packages site_packages = Path(site.getsitepackages()[0]) pth_file = site_packages / "vllm_omni_dependency.pth" diff --git a/examples/offline_inference/lingbot_world_fast/end2end.py b/examples/offline_inference/lingbot_world_fast/end2end.py index a5c44b4d58c..565cd99a285 100644 --- a/examples/offline_inference/lingbot_world_fast/end2end.py +++ b/examples/offline_inference/lingbot_world_fast/end2end.py @@ -135,6 +135,7 @@ def main(): generator=generator, num_frames=args.num_frames, frame_rate=args.fps, + extra_args={"session_id": "offline_generation"}, ), ) generation_end = time.perf_counter() diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py index 01678db4ff3..9731308bdb0 100644 --- a/examples/online_serving/lingbot_world_fast/openai_client.py +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -77,7 +77,7 @@ def generate_video(args: Namespace) -> np.ndarray: for i in range(args.num_calls): camera = { "poses": full_camera["poses"][starting_frame : starting_frame + args.num_frames], - "intrinsics": full_camera["intrinsics"][starting_frame : starting_frame + args.num_frames ], + "intrinsics": full_camera["intrinsics"][starting_frame : starting_frame + args.num_frames], } if i == 0: @@ -96,11 +96,11 @@ def generate_video(args: Namespace) -> np.ndarray: with ws_sync.connect(endpoint, max_size=None, ping_interval=None, ping_timeout=None) as ws: # 1. Server sends CameraServerConfig on connect. - server_config = _unpack(ws.recv()) + _unpack(ws.recv()) # 2. Send obs. print( - f"Sending obs ({'image=' + str(image.shape) if obs.get('image', None) is not None else 'None'}, " + f"Sending obs image= ({str(image.shape) if obs.get('image', None) is not None else 'None'}, " f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." ) ws.send(_pack(obs)) @@ -170,15 +170,6 @@ def main(): output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) - print(frames.__class__) - print(len(frames)) - print(frames[0].shape) - - for i, frame in enumerate(frames): - tmp_frame = (frame * 255).astype(np.uint8) - im = PIL.Image.fromarray(tmp_frame) - im.save(f"{args.output[:-4]}_{i}.png") - export_to_video(frames, str(output_path), fps=args.fps) print(f"Saved generated video to {output_path}") diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py b/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py new file mode 100644 index 00000000000..cbc0da6889e --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py @@ -0,0 +1,153 @@ +# Adapted from Lingbot-World/wan/utils/cam_utils.py +import numpy as np +import torch +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation, Slerp + + +def interpolate_camera_poses( + src_indices: np.ndarray, + src_rot_mat: np.ndarray, + src_trans_vec: np.ndarray, + tgt_indices: np.ndarray, +) -> torch.Tensor: + # interpolate translation + interp_func_trans = interp1d( + src_indices, + src_trans_vec, + axis=0, + kind="linear", + bounds_error=False, + fill_value="extrapolate", + ) + interpolated_trans_vec = interp_func_trans(tgt_indices) + + # interpolate rotation + src_quat_vec = Rotation.from_matrix(src_rot_mat) + # ensure there is no sudden change in qw + quats = src_quat_vec.as_quat().copy() # [N, 4] + for i in range(1, len(quats)): + if np.dot(quats[i], quats[i - 1]) < 0: + quats[i] = -quats[i] + src_quat_vec = Rotation.from_quat(quats) + slerp_func_rot = Slerp(src_indices, src_quat_vec) + interpolated_rot_quat = slerp_func_rot(tgt_indices) + interpolated_rot_mat = interpolated_rot_quat.as_matrix() + + poses = np.zeros((len(tgt_indices), 4, 4)) + poses[:, :3, :3] = interpolated_rot_mat + poses[:, :3, 3] = interpolated_trans_vec + poses[:, 3, 3] = 1.0 + return torch.from_numpy(poses).float() + + +def SE3_inverse(t: torch.Tensor) -> torch.Tensor: + Rot = t[:, :3, :3] # [B,3,3] + trans = t[:, :3, 3:] # [B,3,1] + R_inv = Rot.transpose(-1, -2) + t_inv = -torch.bmm(R_inv, trans) + T_inv = torch.eye(4, device=t.device, dtype=t.dtype)[None, :, :].repeat(t.shape[0], 1, 1) + T_inv[:, :3, :3] = R_inv + T_inv[:, :3, 3:] = t_inv + return T_inv + + +def compute_relative_poses( + c2ws_mat: torch.Tensor, + framewise: bool = False, + normalize_trans: bool = True, +) -> torch.Tensor: + ref_w2cs = SE3_inverse(c2ws_mat[0:1]) + relative_poses = torch.matmul(ref_w2cs, c2ws_mat) + # ensure identity matrix for 1st frame + relative_poses[0] = torch.eye(4, device=c2ws_mat.device, dtype=c2ws_mat.dtype) + if framewise: + # compute pose between i and i+1 + relative_poses_framewise = torch.bmm(SE3_inverse(relative_poses[:-1]), relative_poses[1:]) + relative_poses[1:] = relative_poses_framewise + if normalize_trans: + # scale the coordinate inputs to roughly 1 standard deviation to simplify model learning (camctrl2). + translations = relative_poses[:, :3, 3] # [f, 3] + max_norm = torch.norm(translations, dim=-1).max() + # only normalize when moving + if max_norm > 0: + relative_poses[:, :3, 3] = translations / max_norm + return relative_poses + + +@torch.no_grad() +def create_meshgrid( + n_frames: int, height: int, width: int, bias: float = 0.5, device="cuda", dtype=torch.float32 +) -> torch.Tensor: + x_range = torch.arange(width, device=device, dtype=dtype) + y_range = torch.arange(height, device=device, dtype=dtype) + grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") + grid_xy = torch.stack([grid_x, grid_y], dim=-1).view([-1, 2]) + bias # [h*w, 2] + grid_xy = grid_xy[None, ...].repeat(n_frames, 1, 1) # [f, h*w, 2] + return grid_xy + + +def get_plucker_embeddings( + c2ws_mat: torch.Tensor, + k: torch.Tensor, + height: int, + width: int, + only_rays_d: bool = False, +): + n_frames = c2ws_mat.shape[0] + grid_xy = create_meshgrid(n_frames, height, width, device=c2ws_mat.device, dtype=c2ws_mat.dtype) # [f, h*w, 2] + fx, fy, cx, cy = k.chunk(4, dim=-1) # [f, 1] + + i = grid_xy[..., 0] # [f, h*w] + j = grid_xy[..., 1] # [f, h*w] + zs = torch.ones_like(i) # [f, h*w] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + + directions = torch.stack([xs, ys, zs], dim=-1) # [f, h*w, 3] + directions = directions / directions.norm(dim=-1, keepdim=True) # [f, h*w, 3] + + rays_d = directions @ c2ws_mat[:, :3, :3].transpose(-1, -2) # [f, h*w, 3] + if only_rays_d: + plucker_embeddings = rays_d # [f, h*w, 3] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 3]) # [f*h*w, 3] + else: + rays_o = c2ws_mat[:, :3, 3] # [f, 3] + rays_o = rays_o[:, None, :].expand_as(rays_d) # [f, h*w, 3] + plucker_embeddings = torch.cat([rays_o, rays_d], dim=-1) # [f, h*w, 6] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 6]) # [f*h*w, 6] + return plucker_embeddings + + +def get_Ks_transformed( + k: torch.Tensor, + height_org: int, + width_org: int, + height_resize: int, + width_resize: int, + height_final: int, + width_final: int, +): + fx, fy, cx, cy = k.chunk(4, dim=-1) # [f, 1] + + scale_x = width_resize / width_org + scale_y = height_resize / height_org + + fx_resize = fx * scale_x + fy_resize = fy * scale_y + cx_resize = cx * scale_x + cy_resize = cy * scale_y + + crop_offset_x = (width_resize - width_final) / 2 + crop_offset_y = (height_resize - height_final) / 2 + + cx_final = cx_resize - crop_offset_x + cy_final = cy_resize - crop_offset_y + + Ks_transformed = torch.zeros_like(k) + Ks_transformed[:, 0:1] = fx_resize + Ks_transformed[:, 1:2] = fy_resize + Ks_transformed[:, 2:3] = cx_final + Ks_transformed[:, 3:4] = cy_final + + return Ks_transformed diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py b/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py new file mode 100644 index 00000000000..d9a18899da7 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py @@ -0,0 +1,576 @@ +# Adapted from Lingbot-World/wan/utils/fm_solvers_unipc.py +# Originally derived from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Converted to flow matching. + + +import math + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats # noqa: F401 + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # set table values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None = None, + shift: float | None = None, + ): + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> SchedulerOutput | tuple: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + if isinstance(timesteps, list): + timesteps = [timestep.to(original_samples.device) for timestep in timesteps] + else: + timesteps = timesteps.to(original_samples.device) + + if self.begin_index is None: + if not isinstance(timesteps, list): + timesteps = [timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + step_indices = [self.step_index] * timesteps.shape[0] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py index 2500cdbc540..fb697fc88c8 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -13,24 +13,22 @@ from torch import nn from tqdm import tqdm -# Load dependencies from Lingbot World source code -from wan.modules.t5 import T5EncoderModel -from wan.modules.vae2_1 import Wan2_1_VAE -from wan.utils.cam_utils import ( - compute_relative_poses, - get_Ks_transformed, - get_plucker_embeddings, - interpolate_camera_poses, -) -from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.models.interface import SupportCameraPosInput, SupportImageInput from vllm_omni.diffusion.request import OmniDiffusionRequest +from .cam_utils import ( + compute_relative_poses, + get_Ks_transformed, + get_plucker_embeddings, + interpolate_camera_poses, +) +from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .state_lingbot_world_fast import LingbotWorldFastState +from .t5 import T5EncoderModel +from .vae2_1 import Wan2_1_VAE from .wan_fast import WanModelFast logger = logging.getLogger(__name__) diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/t5.py b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py new file mode 100644 index 00000000000..ccdb8160d6d --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py @@ -0,0 +1,451 @@ +# Adapted from Lingbot-World/wan/modules/t5.py +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + def __init__(self, dim, dim_ffn, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) + + def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super().__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = ( + max_exact + + ( + torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) + ).long() + ) + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + def __init__( + self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout + ) + self.decoder = T5Decoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout + ) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5( + name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device="cpu", + **kwargs, +): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.accelerator.current_device_idx(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = ( + umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False) + ) + logging.info(f"loading {checkpoint_path}") + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def __call__(self, texts, device): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py b/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py new file mode 100644 index 00000000000..e939b191c12 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py @@ -0,0 +1,78 @@ +# Adapted from Lingbot-World/wan/modules/tokenizers.py +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ["HuggingfaceTokenizer"] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + + # arguments + _kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py b/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py new file mode 100644 index 00000000000..6017abe3476 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py @@ -0,0 +1,610 @@ +# Adapted from Lingbot-World/wan/modules/vae2_1.py +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = ["Wan2_1_VAE"] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolution. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +class Encoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1) + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temporal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temporal_upsample = temporal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = "upsample3d" if temporal_upsample[i] else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + logging.info(f"loading {pretrained_path}") + model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class Wan2_1_VAE: + def __init__(self, z_dim=16, vae_pth="cache/vae_step_411000.pth", dtype=torch.float, device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py index dc7d7961553..1011c2bdf54 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py @@ -8,11 +8,11 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from einops import rearrange -from wan.modules.model import WanLayerNorm, WanRMSNorm, WanSelfAttention, rope_params, sinusoidal_embedding_1d from vllm_omni.diffusion.attention.layer import Attention from .state_lingbot_world_fast import CacheIndex +from .wan_model import WanLayerNorm, WanRMSNorm, WanSelfAttention, rope_params, sinusoidal_embedding_1d def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py new file mode 100644 index 00000000000..4a21edd4585 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py @@ -0,0 +1,95 @@ +# Adapted from Lingbot-World/wan/modules/model.py +# +# Only the building blocks used by wan_fast.py are kept here: norm layers, +# the self-attention __init__ shape (used as a base class for the local +# WanCrossAttention that overrides forward), and the rope / time-embedding +# helpers. The original `flash_attention`-based forward paths are not used by +# the Fast pipeline, so this file does not depend on wan.modules.attention. + +import torch +import torch.nn as nn + +__all__ = [ + "WanLayerNorm", + "WanRMSNorm", + "WanSelfAttention", + "rope_params", + "sinusoidal_embedding_1d", +] + + +def sinusoidal_embedding_1d(dim, position): + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + """Base class providing the Q/K/V/O linear layers and (optional) QK RMSNorm. + + `wan_fast.py` only consumes this class via inheritance and `super().__init__()` + — the attention forward path is always overridden — so this trimmed copy + intentionally omits the flash-attention-based forward. + """ + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index c49c984b345..241a1c02362 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -548,7 +548,7 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu omni_config_group.add_argument( "--ws-max-size", type=int, - default=1_048_576, # 1MB + default=1_048_576, # 1MB help="Change max size of a websocket payload that is accepted by the server", ) omni_config_group.add_argument( From 47cdf31a53459f8e3efada2396c1da19546e9662 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 20 May 2026 10:39:31 +0200 Subject: [PATCH 43/53] bugfix Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 7 ++++--- vllm_omni/diffusion/sched/stream_batch_scheduler.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 6edf524c5c3..6c75dd15ae2 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1323,6 +1323,7 @@ def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: pp_group = get_pp_group() if pp_group.world_size == 1: + self.is_buffer_setup = True return # Pre-populate buffer pairs for every B in 1..slo_max_batch @@ -1401,7 +1402,7 @@ def denoise_step( ) preposted_its = state.extra.pop("preposted_its", None) - return self.predict_noise_maybe_with_pp_and_cfg( + return self.predict_noise_maybe_with_cfg( do_true_cfg=do_true_cfg, true_cfg_scale=current_guidance_scale, positive_kwargs=positive_kwargs, @@ -1454,7 +1455,7 @@ def step_scheduler( if per_request_scheduler is None: per_request_scheduler = state.scheduler - state.latents = self.scheduler_step_maybe_with_pp_and_cfg( + state.latents = self.scheduler_step_maybe_with_cfg( noise_pred, t, state.latents, @@ -1471,7 +1472,7 @@ def post_decode( **kwargs: Any, ) -> DiffusionOutput: """Decode final latents after denoising completes.""" - self.sync_pp_send() + self._sync_pp_send() self._current_timestep = None # if current_omni_platform.is_available(): diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index c9b61ca9f2f..274fabf9967 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -109,7 +109,7 @@ def register( chunk_frames=max(1, chunk_frames), ) - def get_target(self, sched_req_id: int) -> int: + def get_target(self, sched_req_id: str) -> int: st = self._reqs.get(sched_req_id) return st.batch_size if st is not None else 1 @@ -299,8 +299,8 @@ def _advance_chunk_pipeline(self, progress: _Progress) -> Rank0Layout: def _build_assignment(self) -> list[RankTask | None]: assert len(self._progress) <= 1 #TODO: support multiple requests + assignment: list[RankTask | None] = [None] * self.pp_size for progress in self._progress.values(): - assignment: list[RankTask | None] = [None] * self.pp_size for r in range(self.pp_size): queue = progress.chunks_at[r] if not queue: @@ -309,7 +309,7 @@ def _build_assignment(self) -> list[RankTask | None]: sched_req_id=progress.sched_req_id, chunk_indices=[c.chunk_idx for c in queue], ) - return assignment + return assignment # ── Output processing ────────────────────────────────────────────────── From e9fadb44c6c22411b388c5de2f3ccda5b831be6a Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Wed, 20 May 2026 10:59:17 +0200 Subject: [PATCH 44/53] fixes Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- tests/diffusion/test_diffusion_scheduler.py | 2 +- vllm_omni/diffusion/executor/abstract.py | 2 +- vllm_omni/diffusion/executor/multiproc_executor.py | 6 +++--- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 6 +++++- vllm_omni/diffusion/sched/stream_batch_scheduler.py | 5 +++-- vllm_omni/diffusion/worker/utils.py | 5 ----- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index f4aaa27b757..13ac9969244 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -1071,7 +1071,7 @@ def test_fifo_two_requests(self) -> None: out1 = scheduler.schedule() assert _new_ids(out1) == [] - finished = scheduler.update_from_output(out1, _make_stream_output(req_a, finished=True)) + scheduler.update_from_output(out1, _make_stream_output(req_a, finished=True)) scheduler.pop_request_state(req_a) out2 = scheduler.schedule() diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py index 53715e2f7b1..5d51bbe50a3 100644 --- a/vllm_omni/diffusion/executor/abstract.py +++ b/vllm_omni/diffusion/executor/abstract.py @@ -85,7 +85,7 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner pass @abstractmethod - def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: """Execute one temporal-PP micro-step from a scheduler output.""" pass diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index d5583e36e01..cc2c1ed21e4 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -344,13 +344,13 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner else: raise RuntimeError(f"Unexpected response type for execute_step: {type(result)!r}") - def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: """Forward a temporal-PP micro-step to worker ``execute_micro_step`` RPC. Assumes worker rank == PP rank (true for PP-only layouts; revisit when introducing TP/DP combinations). """ - from vllm_omni.diffusion.worker.utils import RunnerOutput + from vllm_omni.diffusion.worker.utils import BaseRunnerOutput, RunnerOutput self._ensure_open() @@ -360,7 +360,7 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Runn unique_reply_rank=0, exec_all_ranks=True, ) - if not isinstance(result, RunnerOutput): + if not isinstance(result, BaseRunnerOutput): raise RuntimeError( f"Unexpected response type for execute_micro_step: {type(result)!r}" ) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 6c75dd15ae2..ee2545bb1ae 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1545,7 +1545,11 @@ def encode_chunk_inputs( frames = torch.cat([frames, pad], dim=0) # [n, C, H, W] -> [1, C, n, H, W] - control = frames.permute(1, 0, 2, 3).unsqueeze(0).to(device=self.device, dtype=self.vae.dtype) + control = ( + frames.permute(1, 0, 2, 3) + .unsqueeze(0) + .to(device=self.device, dtype=self.vae.dtype) + ) clean = retrieve_latents(self.vae.encode(control), sample_mode="argmax") clean = ((clean.float() - latents_mean.to(clean.dtype)) * latents_inv_std.to(clean.dtype)).to(self.vae.dtype) noise = torch.randn_like(clean) diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index 274fabf9967..fc68b3ef3a8 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -70,8 +70,9 @@ class _Progress: frames_committed: int = 0 next_chunk_idx: int = 0 batch_size: int = 0 - - chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) # chunks that will be processed by rank r at the current micro-step + + # chunks that will be processed by rank r at the current micro-step + chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) @property def output_chunks_target(self) -> int: diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index b2e9d241c16..ccf13d01fef 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -119,11 +119,6 @@ class ChunkState: Lives inside ``DiffusionRequestState.extra["chunks"]`` (keyed by ``chunk_idx``). - - Each chunk owns its own ``scheduler`` instance (deepcopied from the - pipeline's scheduler by ``prepare_encode``) because multi-step ODE solvers - (e.g. ``FlowUniPCMultistepScheduler``) are stateful — they track per-step - ``model_outputs`` that must not leak between chunks. """ idx: int From 31317c3662df438f56858579713f33b716df340b Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 26 May 2026 01:06:58 +0200 Subject: [PATCH 45/53] fix the flow of latents aross ranks Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../distributed/group_coordinator.py | 4 +- .../distributed/pipeline_parallel.py | 44 ++--- .../diffusion/executor/multiproc_executor.py | 18 +- vllm_omni/diffusion/models/interface.py | 9 +- .../scheduling_flow_unipc_multistep.py | 10 ++ .../models/wan2_2/pipeline_wan2_2.py | 149 +++++++++------- .../models/wan2_2/scheduling_wan_euler.py | 7 + vllm_omni/diffusion/sched/interface.py | 33 ++-- .../diffusion/sched/stream_batch_scheduler.py | 62 ++++--- .../worker/diffusion_model_runner.py | 163 ++++++++---------- 10 files changed, 271 insertions(+), 228 deletions(-) diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 7bc6f40cf54..1bde474e6f4 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -1046,7 +1046,7 @@ def set_recv_dict_buffer( if isinstance(value, TensorMetadata): if torch.Size(value.size).numel() == 0: continue - device = self.device if value.device == "cuda" else torch.device(value.device) + device = torch.device(value.device) if value.device == "cpu" else self.device buffers[key_] = torch.empty(value.size, dtype=value.dtype, device=device) buffer_pair.append(buffers) self.dict_recv_buffer[key] = buffer_pair @@ -1107,7 +1107,7 @@ def pipeline_irecv_tensor_dict( if isinstance(value, TensorMetadata): if torch.Size(value.size).numel() == 0: continue - device = self.device if value.device == "cuda" else torch.device(value.device) + device = torch.device(value.device) if value.device == "cpu" else self.device buffers[k] = torch.empty(value.size, dtype=value.dtype, device=device) buffer_pair.append(buffers) self.dict_recv_buffer[key] = buffer_pair diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 2bc82c4c83e..1cfb187b89a 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -55,6 +55,9 @@ def _resolve(self) -> torch.Tensor: def __getattr__(self, name: str): return getattr(self._resolve(), name) + def __getitem__(self, key): + return self._resolve()[key] + # Torch function protocol: any torch op involving an AsyncLatents resolves it first. @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -237,7 +240,7 @@ def predict_noise_maybe_with_cfg( if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. - for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): + for kwargs, it in zip(all_kwargs, its): result = self.predict_noise(**kwargs, intermediate_tensors=it) self._pp_send_work.extend( pp_group.pipeline_isend_tensor_dict( @@ -277,18 +280,13 @@ def scheduler_step_maybe_with_cfg( latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, per_request_scheduler: Any | list[Any] | None = None, - buf_idx: int = 0, batch_size: int = 1, - ) -> torch.Tensor | tuple[torch.Tensor, ...] | AsyncLatents: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Drop-in replacement for scheduler_step_maybe_with_cfg that also handles PP. - Only the last rank runs the scheduler (it already has noise_pred); the result - is sent to rank 0 which needs it for the next forward pass. - - Returns ``AsyncLatents`` on rank 0 that defers wait() until the tensor - is actually consumed (via attribute access or a torch op), keeping the - rank non-blocking after the irecv. + Only the last rank runs the scheduler (it already has noise_pred) and + sends the result back to rank 0. """ if get_pipeline_parallel_world_size() == 1: return self._scheduler_step_local(noise_pred, t, latents, do_true_cfg, per_request_scheduler) @@ -299,10 +297,6 @@ def scheduler_step_maybe_with_cfg( self._pp_send_work = pp_group.pipeline_isend_tensor_dict( {"latents": latents}, name="latents", batch_size=batch_size, ) - elif pp_group.is_first_rank: - latents = AsyncLatents( - *pp_group.pipeline_irecv_tensor_dict(name="latents", buf_idx=buf_idx, batch_size=batch_size) - ) return latents def _scheduler_step_local( @@ -328,23 +322,33 @@ def _scheduler_step_local( ) return torch.cat(new_rows, dim=0) - def prefetch_its_maybe_with_cfg( + def prefetch_tensors_maybe_with_cfg( self, do_true_cfg: bool, buf_idx: int, - is_last_step: bool, batch_size: int = 1, - ) -> list[AsyncIntermediateTensors] | None: - pp_group = get_pp_group() - if pp_group.is_first_rank or is_last_step: + ) -> list[AsyncIntermediateTensors] | AsyncLatents | None: + """Pre-post the next-step recv on this rank's comms stream. + + First rank pre-posts the latents irecv from the last rank. + Non-first ranks pre-post the intermediate-tensor irecv from the previous rank. + """ + if get_pipeline_parallel_world_size() == 1: return None + pp_group = get_pp_group() + if pp_group.is_first_rank: + return AsyncLatents( + *pp_group.pipeline_irecv_tensor_dict( + name="latents", buf_idx=buf_idx, batch_size=batch_size, + ) + ) + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 n = 1 if cfg_parallel_ready else (2 if do_true_cfg else 1) - next_buf_idx = (buf_idx + 1) % 2 return [ AsyncIntermediateTensors( *pp_group.pipeline_irecv_tensor_dict( - name="intermediate", segment_idx=i, buf_idx=next_buf_idx, batch_size=batch_size, + name="intermediate", segment_idx=i, buf_idx=buf_idx, batch_size=batch_size, ) ) for i in range(n) diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index cc2c1ed21e4..043833c87fb 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -360,12 +360,20 @@ def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> Base unique_reply_rank=0, exec_all_ranks=True, ) - if not isinstance(result, BaseRunnerOutput): - raise RuntimeError( - f"Unexpected response type for execute_micro_step: {type(result)!r}" + + if isinstance(result, BaseRunnerOutput): + return result + if isinstance(result, DiffusionOutput): + req_id = scheduler_output.scheduled_req_ids[0] if scheduler_output.scheduled_req_ids else "" + return RunnerOutput( + req_id=req_id, + step_index=None, + finished=True, + result=result, ) - return result - + else: + raise RuntimeError(f"Unexpected response type for execute_step: {type(result)!r}") + def collective_rpc( self, method: str, diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index 5e6aa56e095..a104f7512ce 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -98,8 +98,9 @@ class SupportsMicroStepExecution(SupportsStepExecution, Protocol): - ``set_pp_recv_dict_buffers`` pre-registers PPGC dict channels for this request to skip the blocking first-call schema exchange. - - ``prefetch_its`` pre-posts the next-step IT recv on the comms stream - so it overlaps with the current micro-step's compute. + - ``prefetch_tensors`` pre-posts the next-step recv on the comms stream + so it overlaps with the current micro-step's compute (latents on the + first PP rank, intermediate tensors on the others). """ supports_micro_step_execution: ClassVar[bool] = True @@ -107,8 +108,8 @@ class SupportsMicroStepExecution(SupportsStepExecution, Protocol): def set_pp_recv_dict_buffers(self, state: DiffusionRequestState, **kwargs: Any) -> None: """Pre-register PP dict recv buffers and schema cache for this request.""" - def prefetch_its(self, state: DiffusionRequestState, **kwargs: Any) -> None: - """Pre-post the next-step IT recv (no-op if not in temporal PP).""" + def prefetch_tensors(self, state: DiffusionRequestState, **kwargs: Any) -> None: + """Pre-post the next-step recv.""" def supports_micro_step_execution(pipeline: object) -> bool: diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py index 3efe564bc61..3cc6fe94833 100644 --- a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py +++ b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py @@ -161,6 +161,7 @@ def set_timesteps( sigmas: list[float] | None = None, mu: float | None = None, shift: float | None = None, + sigma_start: float = 1.0, ) -> None: """ Sets the discrete timesteps used for the diffusion chain (run before inference). @@ -176,10 +177,16 @@ def set_timesteps( Parameter for dynamic shifting. shift (`float`, *optional*): Override shift parameter. + sigma_start (`float`, defaults to 1.0): + Scales the post-shift sigmas so step 0 + lands at ``sigma_start`` instead of 1.0. """ if self.config.use_dynamic_shifting and mu is None: raise ValueError("Must pass a value for `mu` when `use_dynamic_shifting` is True") + if not 0.0 < sigma_start <= 1.0: + raise ValueError(f"sigma_start must be in (0, 1], got {sigma_start}") + if sigmas is None: assert num_inference_steps is not None sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] @@ -193,6 +200,9 @@ def set_timesteps( assert isinstance(sigmas, np.ndarray) sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + if sigma_start != 1.0: + sigmas = sigmas * float(sigma_start) + if self.config.final_sigmas_type == "sigma_min": sigma_last = self.sigma_min elif self.config.final_sigmas_type == "zero": diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index ee2545bb1ae..6766d376e9b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -441,8 +441,6 @@ def __init__( enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler ) - self.is_buffer_setup = False - def _create_transformer(self, config: dict) -> WanTransformer3DModel: """Create a transformer from a config dict. Respects od_config.quantization_config.""" quant_config = getattr(self.od_config, "quantization_config", None) @@ -1041,7 +1039,7 @@ def check_inputs( if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") - # ── Step-execution protocol (SupportsStepExecution) + Micro-step execution (SupportsMicroStepExecution) ── + # ── Micro-step execution (SupportsMicroStepExecution) Wan2.1 model (single-stage) ── def _extract_prompts( self, @@ -1191,6 +1189,7 @@ def prepare_encode( height = params["height"] width = params["width"] num_frames = params["num_frames"] + chunk_frames = params["chunk_frames"] device = params["device"] dtype = params["dtype"] generator = params["generator"] @@ -1213,11 +1212,7 @@ def prepare_encode( dtype=dtype, ) - # Scheduler - self.scheduler.set_timesteps(params["num_steps"], device=device) - req_scheduler = copy.deepcopy(self.scheduler) - - # I2V conditioning + # Multi-modal inputs (I2V image, V2V video) multi_modal_data = ( state.prompts[0].get("multi_modal_data", {}) if state.prompts and not isinstance(state.prompts[0], str) @@ -1227,9 +1222,31 @@ def prepare_encode( if isinstance(raw_image, list): raw_image = raw_image[0] + v2v_video = multi_modal_data.get("video", None) if multi_modal_data else None + noise_scale = float((state.sampling.extra_args or {}).get("noise_scale", 0.8)) + sigma_start = noise_scale if (v2v_video is not None and 0.0 < noise_scale < 1.0) else 1.0 + + # Scheduler + self.scheduler.set_timesteps(params["num_steps"], device=device, sigma_start=sigma_start) + req_scheduler = copy.deepcopy(self.scheduler) + latent_condition = None first_frame_mask = None + if v2v_video is not None: + # V2V mode + num_channels_latents = self.transformer_config.in_channels + latents = self.prepare_latents( + batch_size=prompt_embeds.shape[0], + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=chunk_frames, + dtype=torch.float32, + device="meta", # NOTE: stream mode; latents will be prepared per-chunk in `encode_chunk_inputs` + generator=generator, + latents=state.sampling.latents, + ) if self.expand_timesteps and raw_image is not None: # I2V mode from diffusers.video_processor import VideoProcessor @@ -1323,35 +1340,33 @@ def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: pp_group = get_pp_group() if pp_group.world_size == 1: - self.is_buffer_setup = True return - # Pre-populate buffer pairs for every B in 1..slo_max_batch slo_fps = getattr(state.sampling, "slo_fps", None) slo_max_batch = getattr(state.sampling, "slo_max_batch", 1) - b_max = max(1, slo_max_batch if slo_fps else 1) + slo_max_batch = max(1, slo_max_batch if slo_fps else 1) + num_inference_steps = state.sampling.num_inference_steps or 4 - _, channels, num_frames, height, width = state.latents.shape + _, channels, latent_chunk_frames, height, width = state.latents.shape p_t, p_h, p_w = self.transformer_config.patch_size - seq_len = (num_frames // p_t) * (height // p_h) * (width // p_w) inner_dim = self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim it_dtype = (self.transformer or self.transformer_2).dtype latents_dtype = state.latents.dtype - device = state.latents.device + + seq_len = (latent_chunk_frames // p_t) * (height // p_h) * (width // p_w) cfg_branches = 2 if state.do_true_cfg else 1 - for B in range(1, b_max + 1): + + for batch_size in range(1, slo_max_batch * num_inference_steps + 1): latents_template = { - "latents": torch.empty(B, channels, num_frames, height, width, dtype=latents_dtype, device=device) + "latents": torch.empty(batch_size, channels, latent_chunk_frames, height, width, dtype=latents_dtype, device="meta") } it_template = { - "hidden_states": torch.empty(B, seq_len, inner_dim, dtype=it_dtype, device=self.device) + "hidden_states": torch.empty(batch_size, seq_len, inner_dim, dtype=it_dtype, device="meta") } - pp_group.set_recv_dict_buffer("latents", -1, latents_template, batch_size=B) + pp_group.set_recv_dict_buffer("latents", -1, latents_template, batch_size=batch_size) for seg in range(cfg_branches): - pp_group.set_recv_dict_buffer("intermediate", seg, it_template, batch_size=B) - - self.is_buffer_setup = True + pp_group.set_recv_dict_buffer("intermediate", seg, it_template, batch_size=batch_size) def denoise_step( self, @@ -1366,8 +1381,8 @@ def denoise_step( it overrides ``current_timestep`` and ``_prepare_latent_input`` forwards the per-row timesteps directly to the transformer. Model selection uses the per-row max so a batch straddling the - high/low noise boundary picks the high-noise transformer (correct - for single-stage Wan2.1; Wan2.2 boundary-straddling is a TODO). + high/low noise boundary picks the high-noise transformer (NOTE correct + for single-stage Wan2.1). """ t = state.batched_timesteps if state.batched_timesteps is not None else state.current_timestep self._current_timestep = t @@ -1380,10 +1395,19 @@ def denoise_step( do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None + # Stream-batch path fuses B chunks along dim 0 + B = latent_model_input.shape[0] + encoder_pos = state.prompt_embeds + if encoder_pos is not None and encoder_pos.shape[0] == 1 and B > 1: + encoder_pos = encoder_pos.expand(B, *encoder_pos.shape[1:]).contiguous() + encoder_neg = state.negative_prompt_embeds + if encoder_neg is not None and encoder_neg.shape[0] == 1 and B > 1: + encoder_neg = encoder_neg.expand(B, *encoder_neg.shape[1:]).contiguous() + positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep, - "encoder_hidden_states": state.prompt_embeds, + "encoder_hidden_states": encoder_pos, "attention_kwargs": {}, "return_dict": False, "current_model": current_model, @@ -1392,7 +1416,7 @@ def denoise_step( { "hidden_states": latent_model_input, "timestep": timestep, - "encoder_hidden_states": state.negative_prompt_embeds, + "encoder_hidden_states": encoder_neg, "attention_kwargs": {}, "return_dict": False, "current_model": current_model, @@ -1412,22 +1436,20 @@ def denoise_step( preposted_its=preposted_its, ) - def prefetch_its(self, state: DiffusionRequestState, batch_size: int = 1) -> None: - """Prefetch intermediate tensors for the next step.""" - t = state.current_timestep - boundary_timestep = state.extra.get("boundary_timestep") - _, current_guidance_scale = self._select_model_for_timestep(t, boundary_timestep) - do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None + def prefetch_tensors(self, state: DiffusionRequestState, batch_size: int = 1) -> None: + """Prefetch next-step tensors: latents on first rank, ITs on others.""" + do_true_cfg = state.do_true_cfg buf_idx = state.step_index % 2 - is_last_step = state.step_index == state.total_steps - 1 - preposted = self.prefetch_its_maybe_with_cfg( + preposted = self.prefetch_tensors_maybe_with_cfg( do_true_cfg=do_true_cfg, buf_idx=buf_idx, - is_last_step=is_last_step, batch_size=batch_size, ) - if preposted is not None: + + if isinstance(preposted, AsyncLatents): + state.latents = preposted + elif preposted is not None: state.extra["preposted_its"] = preposted def step_scheduler( @@ -1443,14 +1465,9 @@ def step_scheduler( ``per_request_scheduler`` may be a single scheduler (B=1) or a list of per-chunk schedulers (B>1, last rank loops one ``step()`` per row). - ``batch_size`` keys the PP recv buffer pool on first rank. """ t = state.batched_timesteps if state.batched_timesteps is not None else state.current_timestep - boundary_timestep = state.extra.get("boundary_timestep") - t_select = t.max() if t.ndim > 0 else t - _, current_guidance_scale = self._select_model_for_timestep(t_select, boundary_timestep) - do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None - buf_idx = state.step_index % 2 + do_true_cfg = state.do_true_cfg if per_request_scheduler is None: per_request_scheduler = state.scheduler @@ -1461,9 +1478,9 @@ def step_scheduler( state.latents, do_true_cfg, per_request_scheduler=per_request_scheduler, - buf_idx=buf_idx, batch_size=batch_size, ) + state.step_index += 1 def post_decode( @@ -1475,9 +1492,6 @@ def post_decode( self._sync_pp_send() self._current_timestep = None - # if current_omni_platform.is_available(): - # current_omni_platform.empty_cache() - # I2V: blend final latents with condition latent_condition = state.extra.get("latent_condition") first_frame_mask = state.extra.get("first_frame_mask") @@ -1485,15 +1499,13 @@ def post_decode( state.latents = (1 - first_frame_mask) * latent_condition + first_frame_mask * state.latents latents = state.latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + + latents_mean = self._latents_mean.to(latents.device, latents.dtype) + latents_inv_std = self._latents_inv_std.to( latents.device, latents.dtype ) - latents = latents / latents_std + latents_mean + + latents = latents / latents_inv_std + latents_mean output = self.vae.decode(latents, return_dict=False)[0] return DiffusionOutput( @@ -1505,7 +1517,7 @@ def encode_chunk_inputs( self, state: DiffusionRequestState, new_idxs: list[int], - ) -> list[torch.Tensor]: + ) -> torch.Tensor: """Streaming V2V initial latents (StreamDiffusionV2-style). For each newly admitted chunk: VAE-encode the source frames at @@ -1514,6 +1526,7 @@ def encode_chunk_inputs( (default 0.8). The transformer then runs the full schedule starting at ``step_index=0`` on this partially-noised latent. """ + batch_size = len(new_idxs) noise_scale = float((state.sampling.extra_args or {}).get("noise_scale", 0.8)) chunk_frames = state.sampling.chunk_frames prompt = state.prompts[0] if state.prompts else None @@ -1528,7 +1541,7 @@ def encode_chunk_inputs( latents_mean = self._latents_mean.to(self.device) latents_inv_std = self._latents_inv_std.to(self.device) - out: list[torch.Tensor] = [] + controls: list[torch.Tensor] = [] for idx in new_idxs: start = idx * chunk_frames end = start + chunk_frames @@ -1545,13 +1558,29 @@ def encode_chunk_inputs( frames = torch.cat([frames, pad], dim=0) # [n, C, H, W] -> [1, C, n, H, W] - control = ( + controls.append( frames.permute(1, 0, 2, 3) .unsqueeze(0) .to(device=self.device, dtype=self.vae.dtype) ) - clean = retrieve_latents(self.vae.encode(control), sample_mode="argmax") - clean = ((clean.float() - latents_mean.to(clean.dtype)) * latents_inv_std.to(clean.dtype)).to(self.vae.dtype) - noise = torch.randn_like(clean) - out.append(noise * noise_scale + clean * (1.0 - noise_scale)) - return out + control = torch.cat(controls, dim=0) + + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=self.transformer_config.in_channels, + height=state.extra["height"], + width=state.extra["width"], + num_frames=chunk_frames, + dtype=self.vae.dtype, + device=self.device, + generator=state.sampling.generator, + latents=None, + ) + + latent_condition = retrieve_latents(self.vae.encode(control), sample_mode="argmax") + latent_condition = ( + (latent_condition.float() - latents_mean.to(latent_condition.dtype)) + * latents_inv_std.to(latent_condition.dtype) + ).to(self.vae.dtype) + + return latents * noise_scale + latent_condition * (1.0 - noise_scale) diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py index 25444044c2d..0a7651113c7 100644 --- a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py +++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py @@ -96,6 +96,7 @@ def set_timesteps( self, num_inference_steps: int, device: torch.device | str | int | None = None, + sigma_start: float = 1.0, **kwargs, # noqa: ARG002 - kept for scheduler API compatibility ) -> None: timesteps = _get_timesteps( @@ -107,6 +108,12 @@ def set_timesteps( device=device or self.device, ) self.set_shift(self._shift) + # scale shifted sigmas so step 0 lands at sigma_start + if sigma_start != 1.0: + if not 0.0 < sigma_start <= 1.0: + raise ValueError(f"sigma_start must be in (0, 1], got {sigma_start}") + self.sigmas = self.sigmas * float(sigma_start) + self.timesteps = self.sigmas[:-1] * self.num_train_timesteps self._step_index = None self._begin_index = None diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index 0b1910567f6..78f69c5e211 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -108,27 +108,29 @@ def make_empty(cls) -> CachedRequestData: @dataclass -class RankTask: - """One unit of work for a rank in a stream-batch micro-step.""" - - sched_req_id: str - chunk_indices: list[int] - - -@dataclass -class Rank0Layout: - """How rank 0 should slice the [B_prev, ...] tensor it receives from last rank. +class Layout: + """How the previous latent should be sliced. - - head [0:n_finished] are chunks completing denoising (to decode) - - next [n_finished : n_finished+n_circulating] are re-admitted chunks - - rank 0 appends n_new fresh randn rows at the tail before forwarding. + - head [0:len(finished_idxs)] are chunks completing denoising (to decode) + - next [len(finished_idxs) : len(finished_idxs)+len(circulating_idxs)] are + re-admitted chunks + - rank 0 appends len(new_idxs) fresh randn rows at the tail before forwarding. """ - n_circulating: int + circulating_idxs: list[int] finished_idxs: list[int] new_idxs: list[int] +@dataclass +class RankTask: + """One unit of work for a rank in a stream-batch micro-step.""" + + sched_req_id: str + chunk_indices: list[int] + layout: Layout + + @dataclass class DiffusionSchedulerOutput: """Output of a single scheduling cycle.""" @@ -141,8 +143,7 @@ class DiffusionSchedulerOutput: num_waiting_reqs: int # stream-batch scheduling fields - assignment: list[RankTask | None] | None = None - rank0_layouts: dict[str, Rank0Layout] | None = None + assignment: list[RankTask] | None = None @cached_property def scheduled_req_ids(self) -> list[str]: diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index fc68b3ef3a8..6c27ca39065 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -2,7 +2,7 @@ Each ``schedule()`` call corresponds to one micro-step. The pipeline is modeled as ``pp_size`` per-rank chunk queues plus a transient ``returning`` queue. -At each schedule(), chunks at rank N-1 drain (finished -> Rank0Layout finished +At each schedule(), chunks at rank N-1 drain (finished -> Layout finished slice, otherwise -> returning), queues shift one rank, and rank 0 receives the returning chunks plus B fresh admits drawn from the source video frames in ``prompts[0]["multi_modal_data"]["video"]``. @@ -23,7 +23,7 @@ from vllm_omni.diffusion.sched.interface import ( DiffusionRequestStatus, DiffusionSchedulerOutput, - Rank0Layout, + Layout, RankTask, ) @@ -72,7 +72,9 @@ class _Progress: batch_size: int = 0 # chunks that will be processed by rank r at the current micro-step - chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) + chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) + # rank r's layout — constructed at rank 0 and shifted forward each step + layouts_at: list[Layout] = field(default_factory=list) @property def output_chunks_target(self) -> int: @@ -144,12 +146,12 @@ class StreamBatchScheduler(_BaseScheduler): Per micro-step: 1. Promote waiting requests (handled by the base class). 2. Drain rank N-1: finished chunks -> finished slice in - Rank0Layout, otherwise -> returning queue. + Layout, otherwise -> returning queue. 3. Shift per-rank queues by one (rank r <- rank r-1). 4. Rank 0 = returning + B fresh admits, where `B = min(B_target, queue_chunks_available, output_chunks_remaining)`. - 5. Emit per-rank assignment and the per-request Rank0Layout. Flip req state - RUNNING -> BLOCKED when admission is starved on input. + 5. Emit per-rank assignment with Layout attached to every RankTask. Flip + req state RUNNING -> BLOCKED when admission is starved on input. """ def __init__(self) -> None: @@ -198,13 +200,16 @@ def schedule(self) -> DiffusionSchedulerOutput: for new_req in base_output.scheduled_new_reqs: self._init_progress(new_req.sched_req_id, new_req.req) - rank0_layouts: dict[str, Rank0Layout] = {} for progress in self._progress.values(): - rank0_layouts[progress.sched_req_id] = self._advance_chunk_pipeline(progress) + self._advance_chunk_pipeline(progress) if self._progress: base_output.assignment = self._build_assignment() - base_output.rank0_layouts = rank0_layouts + + logger.info( + "StreamBatchScheduler schedule: %d running req(s), assignment=%s", + len(self._running), base_output.assignment, + ) return base_output @@ -221,6 +226,10 @@ def _init_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: num_steps=num_steps, pp_size=self.pp_size, chunks_at=[deque() for _ in range(self.pp_size)], + layouts_at=[ + Layout(circulating_idxs=[], finished_idxs=[], new_idxs=[]) + for _ in range(self.pp_size) + ], ) self._slo.register( @@ -236,14 +245,14 @@ def _init_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: sched_req_id, chunk_frames, num_frames, num_steps, sampling.slo_fps, self.pp_size, ) - def _advance_chunk_pipeline(self, progress: _Progress) -> Rank0Layout: - """Advance the per-rank queues by one micro-step and return rank 0's layout.""" + def _advance_chunk_pipeline(self, progress: _Progress) -> None: + """Advance the per-rank queues and layouts by one micro-step.""" pp = progress.pp_size # 1. Drain last rank from previous step finished_idxs: list[int] = [] - circulating = [] + circulating: list[_InFlightChunk] = [] last = progress.chunks_at[pp - 1] while last: chunk = last.popleft() @@ -253,9 +262,10 @@ def _advance_chunk_pipeline(self, progress: _Progress) -> Rank0Layout: else: circulating.append(chunk) - # 2. Shift: rank r receives what rank r-1 had + # 2. Shift chunks and layouts: rank r receives what rank r-1 had for r in range(pp - 1, 0, -1): progress.chunks_at[r] = progress.chunks_at[r - 1] + progress.layouts_at[r] = progress.layouts_at[r - 1] progress.chunks_at[0] = deque() # 3. Rank 0 = circulating + B fresh admits @@ -278,7 +288,14 @@ def _advance_chunk_pipeline(self, progress: _Progress) -> Rank0Layout: new_idxs.append(chunk_idx) progress.batch_size = batch_size - # 4. Flip RUNNING -> BLOCKED if input-starved and we still owe output. + # 4. Set rank 0's layout for this step. + progress.layouts_at[0] = Layout( + circulating_idxs=[c.chunk_idx for c in circulating], + finished_idxs=finished_idxs, + new_idxs=new_idxs, + ) + + # 5. Flip RUNNING -> BLOCKED if input-starved and we still owe output. if ( batch_size == 0 and output_chunks_remaining > 0 @@ -292,24 +309,17 @@ def _advance_chunk_pipeline(self, progress: _Progress) -> Rank0Layout: progress.sched_req_id, progress.frames_committed, progress.num_frames, available_frames, ) - return Rank0Layout( - n_circulating=len(circulating), - finished_idxs=finished_idxs, - new_idxs=new_idxs, - ) - - def _build_assignment(self) -> list[RankTask | None]: + def _build_assignment(self) -> list[RankTask]: assert len(self._progress) <= 1 #TODO: support multiple requests - assignment: list[RankTask | None] = [None] * self.pp_size + assignment: list[RankTask] = [] for progress in self._progress.values(): for r in range(self.pp_size): queue = progress.chunks_at[r] - if not queue: - continue - assignment[r] = RankTask( + assignment.append(RankTask( sched_req_id=progress.sched_req_id, chunk_indices=[c.chunk_idx for c in queue], - ) + layout=progress.layouts_at[r], + )) return assignment # ── Output processing ────────────────────────────────────────────────── diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 256800add06..42ee51f91fc 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -33,7 +33,7 @@ from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.distributed.parallel_state import get_pp_group -from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput +from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput, Layout from vllm_omni.diffusion.worker.input_batch import InputBatch, scatter_latents from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, DiffusionRequestState, RunnerOutput, ChunkState from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager @@ -159,7 +159,7 @@ def get_memory_context(): raise ValueError( "stream_batch=True requires a pipeline implementing the micro-step " "execution protocol (prepare_encode, set_pp_recv_dict_buffers, " - "denoise_step, prefetch_its, step_scheduler, post_decode, encode_chunk_inputs); " + "denoise_step, prefetch_tensors, step_scheduler, post_decode, encode_chunk_inputs); " f"{self.od_config.model_class_name} does not support that contract." ) @@ -538,87 +538,63 @@ def execute_micro_step(self, sched_output: DiffusionSchedulerOutput) -> RunnerOu with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): pp_group = get_pp_group() task = assignment[pp_group.rank_in_group] - chunk_idxs = list(task.chunk_indices) if task else [] + chunk_idxs = list(task.chunk_indices) + layout = task.layout if is_new_request: pp_group.reset_buffer() - self.pipeline.is_buffer_setup = False self.pipeline.prepare_encode(state) - state.extra["initial_latent_template"] = state.latents - - state.batched_timesteps = None + self.pipeline.set_pp_recv_dict_buffers(state) t_start_ns = time.perf_counter_ns() if pp_group.is_first_rank else None result: DiffusionOutput | None = None finished = False if pp_group.is_first_rank: - result, finished = self._rank0_assemble_input(state, sched_output) - - if not chunk_idxs: - return RunnerOutput( - req_id=state.req_id, - finished=finished, - result=result, - micro_step_wall_ns=( - time.perf_counter_ns() - t_start_ns if t_start_ns is not None else None - ), - ) + result = self._update_decoded_chunks(state, layout) + finished = result is not None - template = state.extra["initial_latent_template"] - chunks: list[ChunkState] = [ - self._get_or_create_chunk(state, idx)[0] for idx in chunk_idxs - ] + if pp_group.is_first_rank or pp_group.is_last_rank: + self._prepare_chunk_latents(state, layout, is_first_rank=pp_group.is_first_rank) - if pp_group.is_last_rank: - if pp_group.world_size == 1: - for i, c in enumerate(chunks): - c.latents = state.latents[i:i + 1] - else: - for c in chunks: - if c.latents is None: - c.latents = ( - template - if c.idx == 0 - else torch.randn_like(template, generator=state.sampling.generator) - ) - state.latents = torch.cat([c.latents for c in chunks], dim=0) - elif not pp_group.is_first_rank: - state.latents = template.expand(len(chunks), *template.shape[1:]) - - if not self.pipeline.is_buffer_setup: - self.pipeline.set_pp_recv_dict_buffers(state) + if chunk_idxs: + chunks: list[ChunkState] = [ + self._get_or_create_chunk(state, idx)[0] for idx in chunk_idxs + ] - # Per-row timesteps - state.batched_timesteps = torch.stack( - [state.scheduler.timesteps[c.step_index] for c in chunks] - ) + # Per-row timesteps + state.batched_timesteps = torch.stack( + [state.timesteps[c.step_index] for c in chunks] + ) - batch_size = len(chunks) - noise_pred = self.pipeline.denoise_step(state, batch_size=batch_size) + batch_size = len(chunks) + noise_pred = self.pipeline.denoise_step(state, batch_size=batch_size) - if noise_pred is None and getattr(self.pipeline, "interrupt", False): - self.state_cache.pop(state.req_id, None) - return RunnerOutput( - req_id=state.req_id, - finished=True, - result=DiffusionOutput(error="micro-step denoise interrupted"), - ) + if noise_pred is None and getattr(self.pipeline, "interrupt", False): + self._update_state_after(state, layout, finished=True) + return RunnerOutput( + req_id=state.req_id, + finished=True, + result=DiffusionOutput(error="micro-step denoise interrupted"), + ) - self.pipeline.prefetch_its(state, batch_size=batch_size) + schedulers = [c.scheduler for c in chunks] + self.pipeline.step_scheduler( + state, noise_pred, per_request_scheduler=schedulers, batch_size=batch_size, + ) - schedulers = [c.scheduler for c in chunks] - self.pipeline.step_scheduler( - state, noise_pred, per_request_scheduler=schedulers, batch_size=batch_size, - ) + for c in chunks: + c.step_index += 1 - if pp_group.is_last_rank: - for i, c in enumerate(chunks): - c.latents = state.latents[i:i + 1] + if pp_group.is_last_rank: + for i, c in enumerate(chunks): + c.latents = state.latents[i : i + 1] - for c in chunks: - c.step_index += 1 + prev_task = assignment[pp_group.group_prev_rank] + if prev_task.chunk_indices: + self.pipeline.prefetch_tensors(state, batch_size=len(prev_task.chunk_indices)) + self._update_state_after(state, layout, finished=finished) return RunnerOutput( req_id=state.req_id, finished=finished, @@ -639,52 +615,49 @@ def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ chunks[chunk_idx] = chunk return chunk, True - def _rank0_assemble_input( - self, state: DiffusionRequestState, scheduler_output: DiffusionSchedulerOutput, - ) -> tuple[DiffusionOutput | None, bool]: - """Build rank 0's batched forward input.""" - - layouts = scheduler_output.rank0_layouts - layout = layouts.get(state.req_id) if layouts else None - if layout is None: - return None, False - - prev_latents = state.latents + def _prepare_chunk_latents(self, state: DiffusionRequestState, layout: Layout, is_first_rank: bool): pieces: list[torch.Tensor] = [] n_finished = len(layout.finished_idxs) - if n_finished > 0 and prev_latents is not None: - saved = state.latents - state.latents = prev_latents[: n_finished] - decoded = self.pipeline.post_decode(state) - state.latents = saved - state.extra.setdefault("decoded_chunks", []).append(decoded) - state.extra["num_chunks_decoded"] = ( - state.extra.get("num_chunks_decoded", 0) + n_finished - ) - for idx in layout.finished_idxs: - state.extra.get("chunks", {}).pop(idx, None) - if layout.n_circulating > 0 and prev_latents is not None: - pieces.append( - prev_latents[n_finished : n_finished + layout.n_circulating] - ) + for i, idx in enumerate(layout.circulating_idxs): + chunk, _ = self._get_or_create_chunk(state, idx) + if is_first_rank: + chunk.latents = state.latents[n_finished + i : n_finished + i + 1] + pieces.append(chunk.latents) if layout.new_idxs: encoded = self.pipeline.encode_chunk_inputs(state, layout.new_idxs) - for latent, idx in zip(encoded, layout.new_idxs): + for i, idx in enumerate(layout.new_idxs): chunk, _ = self._get_or_create_chunk(state, idx) - chunk.latents = latent - pieces.append(latent) + chunk.latents = encoded[i : i + 1] + pieces.append(encoded) state.latents = torch.cat(pieces, dim=0) if pieces else None + def _update_decoded_chunks(self, state: DiffusionRequestState, layout: Layout) -> DiffusionOutput | None: + n_finished = len(layout.finished_idxs) + if n_finished > 0: + decoded = self.pipeline.post_decode(state.latents[: n_finished]) + state.extra.setdefault("decoded_chunks", []).append(decoded) + state.extra["num_chunks_decoded"] = ( + state.extra.get("num_chunks_decoded", 0) + n_finished + ) + output_chunks_target = state.sampling.num_frames // state.sampling.chunk_frames + if state.extra.get("num_chunks_decoded", 0) >= output_chunks_target: + return self._merge_chunk_outputs(state.extra["decoded_chunks"]) + + return None + + def _update_state_after(self, state: DiffusionRequestState, layout: Layout, finished: bool = False): + for idx in layout.finished_idxs: + state.extra.get("chunks", {}).pop(idx, None) + + if finished: self.state_cache.pop(state.req_id, None) - return self._merge_chunk_outputs(state.extra["decoded_chunks"]), True - return None, False - + @staticmethod def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: """Merge decoded chunk outputs into a single video tensor. From 0109533f18d64f2ea5a6059951f53f4a32e58953 Mon Sep 17 00:00:00 2001 From: Miguel Vieira Pereira Date: Tue, 26 May 2026 09:36:36 +0000 Subject: [PATCH 46/53] Implement tests for Lingbot World Fast Signed-off-by: Miguel Vieira Pereira --- .../lingbot_world_fast/end2end.py | 2 +- .../lingbot_world_fast/openai_client.py | 30 +- .../models/lingbot_world_fast/__init__.py | 2 + .../models/lingbot_world_fast/conftest.py | 173 +++++++++ .../test_protocol_validation.py | 363 ++++++++++++++++++ .../lingbot_world_fast/test_schedule.py | 106 +++++ .../lingbot_world_fast/test_session_state.py | 219 +++++++++++ .../test_lingbot_world_fast.py | 225 +++++++++++ .../test_lingbot_world_fast_expansion.py | 153 ++++++++ .../online_serving/test_lingbot_world_fast.py | 252 ++++++++++++ .../test_lingbot_world_fast_expansion.py | 344 +++++++++++++++++ .../openai_api/test_realtime_world_camera.py | 84 ++++ tests/helpers/lingbot_world_fast.py | 243 ++++++++++++ tests/worker/test_omni_connector_mixin.py | 2 +- .../pipeline_lingbot_world_fast.py | 15 +- .../diffusion/models/lingbot_world_fast/t5.py | 2 +- .../realtime/world/camera_connection.py | 19 +- .../openai/realtime/world/camera_serving.py | 43 ++- 18 files changed, 2222 insertions(+), 55 deletions(-) create mode 100644 tests/diffusion/models/lingbot_world_fast/__init__.py create mode 100644 tests/diffusion/models/lingbot_world_fast/conftest.py create mode 100644 tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py create mode 100644 tests/diffusion/models/lingbot_world_fast/test_schedule.py create mode 100644 tests/diffusion/models/lingbot_world_fast/test_session_state.py create mode 100644 tests/e2e/offline_inference/test_lingbot_world_fast.py create mode 100644 tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py create mode 100644 tests/e2e/online_serving/test_lingbot_world_fast.py create mode 100644 tests/e2e/online_serving/test_lingbot_world_fast_expansion.py create mode 100644 tests/entrypoints/openai_api/test_realtime_world_camera.py create mode 100644 tests/helpers/lingbot_world_fast.py diff --git a/examples/offline_inference/lingbot_world_fast/end2end.py b/examples/offline_inference/lingbot_world_fast/end2end.py index 565cd99a285..452fd10b349 100644 --- a/examples/offline_inference/lingbot_world_fast/end2end.py +++ b/examples/offline_inference/lingbot_world_fast/end2end.py @@ -109,7 +109,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {num_inference_steps}") print(f" Frames: {args.num_frames}") - print(f" Video size: {args.width}x{args.height}") + print(f" Video size: {width}x{height}") print(f"{'=' * 60}\n") generation_start = time.perf_counter() diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py index 9731308bdb0..e7123af2b59 100644 --- a/examples/online_serving/lingbot_world_fast/openai_client.py +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -6,11 +6,11 @@ ``vllm serve --omni`` when the loaded pipeline is ``LingbotWorldFastPipeline``. The endpoint speaks the OpenPI policy protocol on the wire: - 1. Connect -> server sends msgpack(PolicyServerConfig) - 2. Client send msgpack(obs) + 1. Connect -> server sends msgpack(CameraServerConfig) + 2. Client send msgpack(request) 3. Server send msgpack(ndarray) # generated frames -The ``obs`` payload sent here contains: +The ``request`` payload sent here contains: - "image": numpy array, the input image - "prompt": str, the text prompt describing the desired motion - "camera": {"poses": ndarray, "intrinsics": ndarray} @@ -58,8 +58,8 @@ def _load_camera(camera_dir: str) -> dict: return {"poses": poses, "intrinsics": intrinsics} -def generate_video(args: Namespace) -> np.ndarray: - """Send a single inference request and return the generated frames.""" +def generate_video(args: Namespace) -> list[np.ndarray]: + """Send inference requests and return the generated frames.""" image = _load_image(args.image) full_camera = _load_camera(args.camera_path) @@ -80,16 +80,11 @@ def generate_video(args: Namespace) -> np.ndarray: "intrinsics": full_camera["intrinsics"][starting_frame : starting_frame + args.num_frames], } + request: dict = {"prompt": args.prompt, "camera": camera, "extra_body": extra_body} if i == 0: - extra_body["num_frames"] = (args.num_frames // 4) * 4 + 1 - else: - extra_body["num_frames"] = (args.num_frames // 4) * 4 + request["image"] = image - obs: dict = {"prompt": args.prompt, "camera": camera, "extra_body": extra_body} - if i == 0: - obs["image"] = image - - obs["session_id"] = args.session_id + request["session_id"] = args.session_id endpoint = f"{args.server.rstrip('/')}/v1/realtime/world/camera" print(f"Connecting to {endpoint} ...") @@ -98,12 +93,12 @@ def generate_video(args: Namespace) -> np.ndarray: # 1. Server sends CameraServerConfig on connect. _unpack(ws.recv()) - # 2. Send obs. + # 2. Send request. print( - f"Sending obs image= ({str(image.shape) if obs.get('image', None) is not None else 'None'}, " + f"Sending request image= ({str(image.shape) if request.get('image', None) is not None else 'None'}, " f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." ) - ws.send(_pack(obs)) + ws.send(_pack(request)) # 3. Receive generated frames. chunks: list[np.ndarray] = [] @@ -121,7 +116,7 @@ def generate_video(args: Namespace) -> np.ndarray: clip = np.concatenate(chunks, axis=0) # The first chunk of frames returned was used to condition the video continuation but they are not useful if i != 0: - clip = clip[3:] + clip = clip[args.num_skip_frames :] for frame in clip: video.append(frame) @@ -163,6 +158,7 @@ def main(): parser.add_argument("--fps", type=int, default=16) parser.add_argument("--num-frames", type=int, default=24) parser.add_argument("--num-calls", type=int, default=2) + parser.add_argument("--num-skip-frames", type=int, default=4) args = parser.parse_args() frames = generate_video(args) diff --git a/tests/diffusion/models/lingbot_world_fast/__init__.py b/tests/diffusion/models/lingbot_world_fast/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/models/lingbot_world_fast/conftest.py b/tests/diffusion/models/lingbot_world_fast/conftest.py new file mode 100644 index 00000000000..0e130d2e678 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/conftest.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared stubs and dummy-input helpers for Lingbot World Fast L1 tests. + +The real pipeline pulls in T5-XXL, the Wan VAE and a 5B-parameter transformer +on construction. Tests exercise only the state container, msgpack protocol and +scheduler, so these stubs replace the heavy dependencies with the smallest +implementations that match the call sites in +``vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py``. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image +from torch import nn + +if TYPE_CHECKING: + from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + LingbotWorldFastPipeline, + ) + + +class StubT5Encoder: + """Minimal stand-in for ``T5EncoderModel``. + + The pipeline calls ``self.text_encoder([prompt], device)`` and expects a + list of token-embedding tensors, one per prompt. + """ + + def __init__(self, text_len: int = 512, dim: int = 32, dtype: torch.dtype = torch.float32) -> None: + self.text_len = text_len + self.dim = dim + self.dtype = dtype + + def __call__(self, prompts: list[str], device: torch.device) -> list[torch.Tensor]: + return [torch.zeros(self.text_len, self.dim, dtype=self.dtype, device=device) for _ in prompts] + + +class StubVAE: + """Stand-in for ``Wan2_1_VAE``. + + ``encode([pixels])`` returns a list with one latent tensor shaped + ``[16, F_lat, lat_h, lat_w]`` where ``F_lat = (F + 3) // 4`` so the + pipeline's masking / slicing math is exercised normally. + ``decode([latents])`` returns the latents unchanged (caller indexes [0]). + """ + + vae_stride = (4, 8, 8) + + def encode(self, video_list: list[torch.Tensor]) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for v in video_list: + # v: [C, F, H, W] + _, f, h, w = v.shape + lat_f = (f + self.vae_stride[0] - 1) // self.vae_stride[0] + lat_h = h // self.vae_stride[1] + lat_w = w // self.vae_stride[2] + out.append(torch.zeros(16, lat_f, lat_h, lat_w, dtype=v.dtype, device=v.device)) + return out + + def decode(self, latents_list: list[torch.Tensor]) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for latents in latents_list: + # latents: [16, F_lat, lat_h, lat_w]; produce pixels at the inverse stride. + _, f_lat, lat_h, lat_w = latents.shape + f = f_lat * self.vae_stride[0] + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + out.append(torch.zeros(3, f, h, w, dtype=latents.dtype, device=latents.device)) + return out + + +class StubWanModelFast(nn.Module): + """Stand-in for ``WanModelFast``. + + Returns zeros shaped like the input latent, and bumps the local/global + index tensors so chunk-boundary arithmetic is exercised by the pipeline. + """ + + def __init__(self, *, dim: int = 16, num_heads: int = 4, num_layers: int = 2) -> None: + super().__init__() + self.config = SimpleNamespace( + dim=dim, + num_heads=num_heads, + num_layers=num_layers, + local_attn_size=-1, + ) + + def forward(self, *, x, t, **kwargs): # noqa: ARG002 — matches pipeline call site + del t, kwargs + return [torch.zeros_like(x[0])] + + @classmethod + def from_pretrained(cls, *args, **kwargs): # noqa: ARG003 + return cls() + + +def make_dummy_camera_inputs(num_frames: int) -> dict[str, np.ndarray]: + """Camera payload matching the shape the pipeline expects.""" + intrinsics = np.eye(4, dtype=np.float32) + poses = np.tile(np.eye(4, dtype=np.float32), (num_frames, 1, 1)) + return {"intrinsics": intrinsics, "poses": poses} + + +def make_dummy_image(width: int = 64, height: int = 64) -> Image.Image: + return Image.new("RGB", (width, height), color=(128, 128, 128)) + + +def make_stubbed_pipeline( + *, + device: torch.device | None = None, + dim: int = 16, + num_heads: int = 4, + num_layers: int = 2, + target_dtype: torch.dtype = torch.float32, +) -> LingbotWorldFastPipeline: + """Build a ``LingbotWorldFastPipeline`` backed by the conftest stubs. + + Skips the real ``__init__`` (which loads umt5-xxl, Wan VAE and a 5B + transformer) via ``object.__new__`` and assigns the stubs directly, + mirroring ``_make_i2v_pipeline`` in ``tests/diffusion/models/wan2_2``. + The returned pipeline is suitable for driving ``.forward(req)`` end-to-end + against ``LingbotWorldFastState`` without touching real weights. + """ + from vllm_omni.diffusion.models.lingbot_world_fast.fm_solvers_unipc import ( + FlowUniPCMultistepScheduler, + ) + from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + LingbotWorldFastPipeline, + ) + from vllm_omni.diffusion.models.lingbot_world_fast.state_lingbot_world_fast import ( + LingbotWorldFastState, + ) + + if device is None: + device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu", 0) + + parallel_config = SimpleNamespace(world_size=1) + od_config = SimpleNamespace( + model="stub/Lingbot-World-Fast", + parallel_config=parallel_config, + dtype=target_dtype, + ) + + pipeline = object.__new__(LingbotWorldFastPipeline) + nn.Module.__init__(pipeline) + pipeline.od_config = od_config + pipeline.parallel_config = parallel_config + pipeline.device = device + pipeline.target_dtype = target_dtype + pipeline.control_type = "cam" + pipeline.num_train_timesteps = CONFIG["num_train_timesteps"] + pipeline.sp_size = parallel_config.world_size + pipeline.state = LingbotWorldFastState() + pipeline.text_encoder = StubT5Encoder(dim=dim, dtype=target_dtype) + pipeline.vae = StubVAE() + pipeline.vae_stride = CONFIG["vae_stride"] + pipeline.patch_size = CONFIG["patch_size"] + pipeline.model = StubWanModelFast(dim=dim, num_heads=num_heads, num_layers=num_layers).to(device) + pipeline.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=CONFIG["num_train_timesteps"], + shift=1, + use_dynamic_shifting=False, + ) + pipeline.sample_neg_prompt = CONFIG["negative_prompt_sample"] + return pipeline diff --git a/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py b/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py new file mode 100644 index 00000000000..902a4afe711 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 protocol-validation tests for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +import asyncio +import contextlib +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +from tests.diffusion.models.lingbot_world_fast.conftest import make_dummy_camera_inputs +from vllm_omni.entrypoints.openai.realtime.world.camera_connection import CHUNK_FRAMES, WorldCameraRealtimeConnection +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ( + CameraServerConfig, + ServingRealtimeWorldCamera, +) + +# The endpoint's wire codec is provided by the optional openpi-client dep. +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Mock infrastructure +# --------------------------------------------------------------------------- + + +class MockWebSocket: + def __init__(self, incoming: Iterable[dict[str, Any]] | None = None) -> None: + self._incoming: list[dict[str, Any]] = list(incoming or []) + self._idx = 0 + self.sent_bytes: list[bytes] = [] + self.sent_text: list[str] = [] + self.accepted = False + self.closed = False + + async def accept(self) -> None: + self.accepted = True + + async def receive(self) -> dict[str, Any]: + if self._idx >= len(self._incoming): + return {"type": "websocket.disconnect"} + msg = self._incoming[self._idx] + self._idx += 1 + return msg + + async def send_bytes(self, data: bytes) -> None: + self.sent_bytes.append(data) + + async def send_text(self, data: str) -> None: + self.sent_text.append(data) + + async def close(self) -> None: + self.closed = True + + +def _bytes_frame(payload: Any) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": msgpack_numpy.packb(payload)} + + +def _raw_bytes_frame(data: bytes) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": data} + + +class _AsyncIter: + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class FakeResult: + def __init__(self, frames: np.ndarray) -> None: + self.images = [frames] + + +class FakeEngineClient: + """Stand-in for ``AsyncOmni`` engine client. + + Captures the ``generate(...)`` arguments and yields a single fake result + so the connection's framing logic can be exercised without a real engine. + """ + + def __init__(self, frames: np.ndarray | None = None) -> None: + if frames is None: + # Default: CHUNK_FRAMES*(1+1/2) RGB frames so we exercise the chunk split. + frames = np.zeros((CHUNK_FRAMES * 3 // 2, 16, 16, 3), dtype=np.uint8) + self._frames = frames + self.calls: list[dict[str, Any]] = [] + self.fail_with: Exception | None = None + # Attributes consulted by ``CameraServerConfig.from_model_config``. + self.model_config = {"pipeline": "lingbot_world_fast", "resolution": [480, 832], "fps": 16} + + def generate(self, *, prompt, request_id, sampling_params_list): + self.calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + if self.fail_with is not None: + raise self.fail_with + return _AsyncIter([FakeResult(self._frames)]) + + +def _make_serving(engine_client: FakeEngineClient | None = None) -> ServingRealtimeWorldCamera: + return ServingRealtimeWorldCamera(engine_client=engine_client or FakeEngineClient(), model_name="lingbot") + + +# --------------------------------------------------------------------------- +# CameraServerConfig +# --------------------------------------------------------------------------- + + +def test_camera_server_config_round_trip_through_dict() -> None: + cfg = CameraServerConfig.from_model_config({"pipeline": "lingbot", "fps": 16}) + out = cfg.to_dict() + assert isinstance(out, dict) + assert out["pipeline"] == "lingbot" + assert out["fps"] == 16 + + +# --------------------------------------------------------------------------- +# msgpack-numpy round-trip +# --------------------------------------------------------------------------- + + +def test_msgpack_camera_payload_round_trip() -> None: + camera = make_dummy_camera_inputs(num_frames=8) + payload = { + "image": np.random.randint(0, 255, size=(8, 8, 3), dtype=np.uint8), + "prompt": "walk forward", + "camera": camera, + "session_id": "sess-1", + "extra_body": {"height": 240, "width": 416, "num_frames": 25, "fps": 16}, + } + packed = msgpack_numpy.packb(payload) + decoded = msgpack_numpy.unpackb(packed) + + assert decoded["prompt"] == "walk forward" + assert decoded["session_id"] == "sess-1" + assert decoded["extra_body"] == payload["extra_body"] + + image_out = decoded["image"] + assert isinstance(image_out, np.ndarray) + assert image_out.shape == payload["image"].shape + assert image_out.dtype == payload["image"].dtype + np.testing.assert_array_equal(image_out, payload["image"]) + + for key in ("intrinsics", "poses"): + arr_in = camera[key] + arr_out = decoded["camera"][key] + assert arr_out.dtype == arr_in.dtype + assert arr_out.shape == arr_in.shape + np.testing.assert_array_equal(arr_out, arr_in) + + +# --------------------------------------------------------------------------- +# Connection-level framing +# --------------------------------------------------------------------------- + + +def test_handshake_sends_camera_server_config_on_connect() -> None: + serving = _make_serving() + ws = MockWebSocket(incoming=[]) # client disconnects immediately after handshake + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert ws.accepted is True + assert len(ws.sent_bytes) == 1 + handshake = msgpack_numpy.unpackb(ws.sent_bytes[0]) + assert isinstance(handshake, dict) + assert handshake["pipeline"] == "lingbot_world_fast" + + +def test_invalid_msgpack_returns_error_frame_and_keeps_connection_open() -> None: + serving = _make_serving() + ws = MockWebSocket( + incoming=[ + _raw_bytes_frame(b"\x99not-msgpack"), # malformed + _bytes_frame({"endpoint": "reset"}), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # First sent message is the handshake, next is the error frame, then the + # "reset successful" text reply — proving the connection stayed open. + assert len(ws.sent_bytes) >= 2 + error = msgpack_numpy.unpackb(ws.sent_bytes[1]) + assert error == {"type": "error", "message": "Invalid request payload"} + assert ws.sent_text == ["reset successful"] + + +def test_non_dict_payload_is_rejected_with_error_frame() -> None: + serving = _make_serving() + ws = MockWebSocket(incoming=[_bytes_frame([1, 2, 3])]) # list, not dict + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert len(ws.sent_bytes) >= 2 + error = msgpack_numpy.unpackb(ws.sent_bytes[1]) + assert error["type"] == "error" + + +def test_reset_endpoint_clears_session_and_returns_text_ack() -> None: + engine_client = FakeEngineClient() + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + # Pre-populate as if a prior session were active. + serving._current_session_id = "session-a" + + ws = MockWebSocket(incoming=[_bytes_frame({"endpoint": "reset"})]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert serving._current_session_id is None + assert ws.sent_text == ["reset successful"] + + +def test_infer_frames_are_chunked() -> None: + num_frames = CHUNK_FRAMES * 3 // 2 + + frames = np.arange(num_frames * 4 * 4 * 3, dtype=np.uint8).reshape(num_frames, 4, 4, 3) + engine_client = FakeEngineClient(frames=frames) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=6), + "session_id": "s1", + "extra_body": {"num_frames": 6, "height": 16, "width": 16, "fps": 16}, + "image": np.zeros((16, 16, 3), dtype=np.uint8), + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # Drop the handshake; the remaining sent_bytes are frame chunks. + chunks = [msgpack_numpy.unpackb(b) for b in ws.sent_bytes[1:]] + assert [c["type"] for c in chunks] == ["frame", "frame"] + assert [c["index"] for c in chunks] == [0, 1] + assert {c["total"] for c in chunks} == {2} + + assert len(chunks[0]["video"]) == CHUNK_FRAMES + assert len(chunks[1]["video"]) == num_frames - CHUNK_FRAMES + + for chunk in chunks: + assert chunk["video"][0].shape == (4, 4, 3) + + +def test_session_id_churn_flips_current_session_id() -> None: + engine_client = FakeEngineClient() + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + base_obs = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=4), + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + obs_a = {**base_obs, "session_id": "session-a"} + obs_b = {**base_obs, "session_id": "session-b"} + + ws = MockWebSocket(incoming=[_bytes_frame(obs_a), _bytes_frame(obs_b)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert serving._current_session_id == "session-b" + assert len(engine_client.calls) == 2 + # Each engine call observes the active session id via extra_args. + seen_session_ids = [call["sampling_params_list"][0].extra_args["session_id"] for call in engine_client.calls] + assert seen_session_ids == ["session-a", "session-b"] + + +def test_engine_failure_surfaces_as_error_frame_not_close() -> None: + engine_client = FakeEngineClient() + engine_client.fail_with = RuntimeError("kaboom") + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=4), + "session_id": "s1", + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + error = msgpack_numpy.unpackb(ws.sent_bytes[-1]) + assert error == {"type": "error", "message": "Internal inference error"} + + +# --------------------------------------------------------------------------- +# Required-field validation: a missing ``camera`` propagates to the pipeline +# layer's ValueError. At the serving layer we exercise this by giving the +# fake engine a side-effect that raises like the pipeline would, then assert +# the connection responds with an error frame and keeps running. +# --------------------------------------------------------------------------- + + +def test_missing_camera_surfaces_as_error_frame() -> None: + engine_client = FakeEngineClient() + # Pipeline's actual ValueError text — useful to keep this in sync. + engine_client.fail_with = ValueError("A path to camera positions must be passed to this model") + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "session_id": "s1", + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + error = msgpack_numpy.unpackb(ws.sent_bytes[-1]) + assert error["type"] == "error" + + +# --------------------------------------------------------------------------- +# Dtype/rank guard: the wire codec preserves bit patterns, so a malformed +# camera entry (float64 vs float32, wrong rank) passes through the connection +# unchanged. The pipeline-layer assertions are exercised in the L2 offline +# test; here we just confirm the codec doesn't silently coerce. +# --------------------------------------------------------------------------- + + +def test_msgpack_does_not_silently_coerce_camera_dtypes() -> None: + payload = { + "intrinsics": np.eye(3, dtype=np.float64), # wrong dtype on purpose + "poses": np.tile(np.eye(4, dtype=np.float32), (2, 1, 1))[None], # extra leading dim + } + decoded = msgpack_numpy.unpackb(msgpack_numpy.packb(payload)) + assert decoded["intrinsics"].dtype == np.float64 + assert decoded["poses"].shape == (1, 2, 4, 4) + + +# --------------------------------------------------------------------------- +# Suppress event-loop teardown noise in some environments +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _silence_runtime_warnings(recwarn): # noqa: PT004 + yield + with contextlib.suppress(Exception): + recwarn.clear() diff --git a/tests/diffusion/models/lingbot_world_fast/test_schedule.py b/tests/diffusion/models/lingbot_world_fast/test_schedule.py new file mode 100644 index 00000000000..0adc1589c91 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_schedule.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 tests for the Lingbot World Fast scheduler.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from vllm_omni.diffusion.models.lingbot_world_fast.fm_solvers_unipc import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + LingbotWorldFastPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def _make_scheduler() -> FlowUniPCMultistepScheduler: + # Same construction as ``LingbotWorldFastPipeline.__init__``. + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=CONFIG["num_train_timesteps"], + shift=1, + use_dynamic_shifting=False, + ) + scheduler.set_timesteps(CONFIG["num_train_timesteps"], shift=CONFIG["sample_shift"]) + return scheduler + + +def test_timesteps_index_selects_exactly_four_steps() -> None: + scheduler = _make_scheduler() + selected = scheduler.timesteps[CONFIG["timesteps_index"]] + + assert selected.shape == (4,) + # Monotonically decreasing — flow matching schedulers walk t from high to low. + diffs = selected[1:] - selected[:-1] + assert torch.all(diffs < 0), f"timesteps must be strictly decreasing, got {selected.tolist()}" + + +def test_timesteps_full_schedule_length_matches_num_train_timesteps() -> None: + scheduler = _make_scheduler() + assert scheduler.num_inference_steps == CONFIG["num_train_timesteps"] + assert scheduler.timesteps.shape == (CONFIG["num_train_timesteps"],) + + +def _convert(flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor, scheduler) -> torch.Tensor: + # Bind ``_convert_flow_pred_to_x0`` as an unbound method to avoid + # constructing the full pipeline (which loads T5 / VAE). + return LingbotWorldFastPipeline._convert_flow_pred_to_x0( + None, # type: ignore[arg-type] + flow_pred=flow_pred, + xt=xt, + timestep=timestep, + scheduler=scheduler, + ) + + +def test_convert_flow_pred_to_x0_passthrough_when_pred_is_zero() -> None: + scheduler = _make_scheduler() + xt = torch.randn(1, 4, 1, 4, 4, dtype=torch.float32) + timestep = scheduler.timesteps[0] + flow_pred = torch.zeros_like(xt) + + x0 = _convert(flow_pred, xt, timestep, scheduler) + assert torch.allclose(x0, xt, atol=1e-6) + + +def test_convert_flow_pred_to_x0_recovers_x0_from_synthesized_pair() -> None: + scheduler = _make_scheduler() + timestep = scheduler.timesteps[CONFIG["timesteps_index"][1]] + + sigmas = scheduler.sigmas + timesteps = scheduler.timesteps + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].item() + + x0 = torch.randn(1, 4, 1, 4, 4, dtype=torch.float32) + noise = torch.randn_like(x0) + flow_pred = noise - x0 + xt = (1.0 - sigma_t) * x0 + sigma_t * noise + + recovered = _convert(flow_pred, xt, timestep, scheduler) + + # The function does the math in float64 internally, casts back to the + # input dtype. float32 inputs ⇒ ~1e-5 absolute tolerance is plenty. + assert torch.allclose(recovered, x0, atol=1e-4) + + +def test_timesteps_index_is_within_schedule_bounds() -> None: + """Defensive guard: an out-of-range index would silently wrap.""" + assert isinstance(CONFIG["timesteps_index"], list) + assert len(CONFIG["timesteps_index"]) == 4 + for idx in CONFIG["timesteps_index"]: + assert 0 <= idx < CONFIG["num_train_timesteps"] + + +def test_sample_shift_constant_is_positive() -> None: + """``sample_shift`` controls the timestep curve; a non-positive value + would corrupt the flow-matching trajectory.""" + assert CONFIG["sample_shift"] > 0 + # Reasonable upper bound — Wan models use shift ~5–10. + assert math.isfinite(CONFIG["sample_shift"]) + assert CONFIG["sample_shift"] <= 100 diff --git a/tests/diffusion/models/lingbot_world_fast/test_session_state.py b/tests/diffusion/models/lingbot_world_fast/test_session_state.py new file mode 100644 index 00000000000..e10298332b4 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_session_state.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 unit tests for ``LingbotWorldFastState``. + +The state container is the load-bearing structure for chunk-streamed +generation: it owns the KV cache, the cross-attention cache, the +``current_lat_f`` cursor used to derive ``current_start`` RoPE offsets, +and the session-id that decides between fresh vs extension semantics. +""" + +from __future__ import annotations + +import pytest +import torch + +from vllm_omni.diffusion.models.lingbot_world_fast.state_lingbot_world_fast import ( + LingbotWorldFastState, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +BATCH_SIZE = 1 +NUM_LAYERS = 3 +NUM_HEADS = 4 +HEAD_DIM = 8 +KV_SIZE = 16 +DTYPE = torch.float32 +DEVICE = torch.device("cpu") + + +def _fresh_state_with_caches(kv_size: int = KV_SIZE) -> LingbotWorldFastState: + state = LingbotWorldFastState() + state.create_kv_caches( + batch_size=BATCH_SIZE, + dtype=DTYPE, + device=DEVICE, + kv_size=kv_size, + num_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + head_dim=HEAD_DIM, + ) + return state + + +def test_reset_initializes_all_fields() -> None: + state = LingbotWorldFastState() + assert state.kv_cache is None + assert state.crossattn_cache is None + assert state.current_start_frame == 0 + assert state.local_end_index is None + assert state.global_end_index is None + assert state.is_initialized is False + assert state.current_lat_f == 0 + assert state.session_id is None + assert state.batch_size is None + assert state.num_layers is None + assert state.num_heads is None + assert state.head_dim is None + assert state.h is None + assert state.w is None + assert state.lat_h is None + assert state.lat_w is None + assert state.frame_seqlen is None + assert state.last_decoded_latent is None + + +def test_create_kv_caches_allocates_expected_shapes() -> None: + state = _fresh_state_with_caches() + + assert state.is_initialized is True + assert state.batch_size == BATCH_SIZE + assert state.num_layers == NUM_LAYERS + assert state.num_heads == NUM_HEADS + assert state.head_dim == HEAD_DIM + + assert state.kv_cache is not None + assert len(state.kv_cache) == NUM_LAYERS + for layer in state.kv_cache: + assert layer.shape == (2, BATCH_SIZE, KV_SIZE, NUM_HEADS, HEAD_DIM) + assert layer.dtype == DTYPE + assert torch.all(layer == 0) + + assert state.local_end_index is not None and state.global_end_index is not None + for idx_list in (state.local_end_index, state.global_end_index): + assert len(idx_list) == NUM_LAYERS + for idx in idx_list: + assert idx.shape == (1,) + assert idx.dtype == torch.long + assert int(idx.item()) == 0 + + assert state.crossattn_cache is not None + assert len(state.crossattn_cache) == NUM_LAYERS + for entry in state.crossattn_cache: + assert entry == {"is_init": False, "k": None, "v": None} + + +def test_extend_kv_caches_grows_tensor_and_zeros_new_slots() -> None: + state = _fresh_state_with_caches() + extra = 7 + # Mark the existing slots so we can confirm they aren't disturbed. + for layer in state.kv_cache: + layer.fill_(1.0) + + state.extend_kv_caches(extra_kv_size=extra) + + for layer in state.kv_cache: + assert layer.shape == (2, BATCH_SIZE, KV_SIZE + extra, NUM_HEADS, HEAD_DIM) + assert torch.all(layer[:, :, :KV_SIZE] == 1.0) + # Newly grown trailing slice is fresh zeros. + assert torch.all(layer[:, :, KV_SIZE:] == 0.0) + + +def test_extend_kv_caches_requires_initialization() -> None: + state = LingbotWorldFastState() + with pytest.raises(AssertionError): + state.extend_kv_caches(extra_kv_size=4) + + +def test_get_accessors_require_initialization() -> None: + state = LingbotWorldFastState() + with pytest.raises(AssertionError): + state.get_kv_caches() + with pytest.raises(AssertionError): + state.get_crossattn_caches() + + +def test_get_kv_caches_returns_underlying_list() -> None: + state = _fresh_state_with_caches() + assert state.get_kv_caches() is state.kv_cache + + +def test_advance_moves_cursor_by_delta() -> None: + state = LingbotWorldFastState() + state.advance(3) + assert state.current_lat_f == 3 + state.advance(5) + assert state.current_lat_f == 8 + + +def test_reset_clears_all_session_state() -> None: + state = _fresh_state_with_caches() + state.session_id = "abc" + state.advance(4) + state.h, state.w, state.lat_h, state.lat_w, state.frame_seqlen = 480, 832, 60, 104, 1560 + state.last_decoded_latent = torch.zeros(16, 2, 60, 104) + + state.reset() + + assert state.kv_cache is None + assert state.crossattn_cache is None + assert state.local_end_index is None + assert state.global_end_index is None + assert state.is_initialized is False + assert state.current_lat_f == 0 + assert state.current_start_frame == 0 + assert state.session_id is None + assert state.h is None and state.w is None + assert state.lat_h is None and state.lat_w is None + assert state.frame_seqlen is None + assert state.last_decoded_latent is None + assert state.batch_size is None + assert state.num_layers is None + + +def test_reset_is_idempotent() -> None: + state = LingbotWorldFastState() + state.reset() + state.reset() + assert state.is_initialized is False + assert state.current_lat_f == 0 + + +# --------------------------------------------------------------------------- +# Reset is triggered only by session-id change, not prompt change. +# +# Mirrors the conditional in ``LingbotWorldFastPipeline.forward`` (pipeline +# file, around the ``if self.state.session_id is None or +# self.state.session_id != session_id`` block). We assert the contract on +# the state container so the test does not depend on instantiating the +# heavy pipeline. +# --------------------------------------------------------------------------- + + +def _should_reset(state: LingbotWorldFastState, incoming_session_id: str) -> bool: + """Replicates the pipeline's reset trigger.""" + return state.session_id is None or state.session_id != incoming_session_id + + +def test_first_call_with_any_session_id_triggers_reset() -> None: + state = LingbotWorldFastState() + assert _should_reset(state, "session-a") is True + + +def test_same_session_id_does_not_reset() -> None: + state = _fresh_state_with_caches() + state.session_id = "session-a" + state.advance(4) + + assert _should_reset(state, "session-a") is False + # ... and a prompt-only change must not trigger a reset either. + assert _should_reset(state, "session-a") is False + # Pipeline would proceed in extension mode → state still alive. + assert state.is_initialized is True + assert state.current_lat_f == 4 + + +def test_different_session_id_triggers_reset() -> None: + state = _fresh_state_with_caches() + state.session_id = "session-a" + state.advance(4) + + assert _should_reset(state, "session-b") is True + + state.reset() + assert state.session_id is None + assert state.current_lat_f == 0 + assert state.kv_cache is None diff --git a/tests/e2e/offline_inference/test_lingbot_world_fast.py b/tests/e2e/offline_inference/test_lingbot_world_fast.py new file mode 100644 index 00000000000..f80576d8bb9 --- /dev/null +++ b/tests/e2e/offline_inference/test_lingbot_world_fast.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L2 offline smoke for the Lingbot World Fast pipeline.""" + +from __future__ import annotations + +import pytest +import torch + +from tests.diffusion.models.lingbot_world_fast.conftest import ( + make_dummy_camera_inputs, + make_dummy_image, + make_stubbed_pipeline, +) +from tests.helpers.mark import hardware_test +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + get_lingbot_world_fast_post_process_func, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +# ``torch.amp.autocast("cuda", …)`` inside the pipeline requires CUDA at import +# time on hosts where PyTorch is compiled without CUDA support. +if not torch.cuda.is_available(): + pytest.skip( + 'Lingbot World Fast pipeline requires CUDA (torch.amp.autocast("cuda", …))', + allow_module_level=True, + ) + + +# Keep the spatial resolution tiny so the KV-cache stays small (frame_seqlen +# is derived from ``lat_h * lat_w``); the stub pipeline still exercises every +# size-related code path with full fidelity. +_TINY_MAX_AREA = 64 * 64 + +# Default ``num_frames`` argument; the pipeline floors to ``25`` internally on +# a fresh call (smallest length that maps to a non-empty latent). +_FRESH_NUM_FRAMES = 25 +_EXTENSION_NUM_FRAMES = 24 + +_DIM = 16 +_NUM_HEADS = 4 +_NUM_LAYERS = 2 +_HEAD_DIM = _DIM // _NUM_HEADS + + +def _build_request( + *, + image, + camera, + session_id: str, + num_frames: int, + prompt: str = "walk forward", +) -> OmniDiffusionRequest: + multi_modal_data: dict = {"camera": camera} + if image is not None: + multi_modal_data["image"] = image + return OmniDiffusionRequest( + prompts=[{"prompt": prompt, "multi_modal_data": multi_modal_data}], + sampling_params=OmniDiffusionSamplingParams( + height=None, + width=None, + num_frames=num_frames, + seed=42, + extra_args={"session_id": session_id}, + ), + request_ids=[f"req-{session_id}"], + ) + + +@pytest.fixture +def stubbed_pipeline(monkeypatch): + """Build a stub-backed pipeline and shrink CONFIG['max_area'] for speed.""" + monkeypatch.setitem(CONFIG, "max_area", _TINY_MAX_AREA) + pipeline = make_stubbed_pipeline( + dim=_DIM, + num_heads=_NUM_HEADS, + num_layers=_NUM_LAYERS, + target_dtype=torch.float32, + ) + yield pipeline + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_session_lifecycle_fresh_then_extension(stubbed_pipeline) -> None: + """Drive a fresh + extension pair through the pipeline and assert that + ``LingbotWorldFastState`` advances exactly as the chunk arithmetic prescribes.""" + pipeline = stubbed_pipeline + session_id = "session-l2-offline" + + # --- Fresh call --------------------------------------------------------- + camera_fresh = make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES) + image = make_dummy_image() + req_fresh = _build_request( + image=image, + camera=camera_fresh, + session_id=session_id, + num_frames=_FRESH_NUM_FRAMES, + ) + + out_fresh = pipeline.forward(req_fresh) + + assert isinstance(out_fresh, DiffusionOutput) + assert out_fresh.output is not None + assert torch.isfinite(out_fresh.output).all(), "Fresh-call video contains NaN/Inf." + + state = pipeline.state + assert state.is_initialized is True + assert state.session_id == session_id + assert state.current_lat_f > 0 + assert state.kv_cache is not None + assert state.crossattn_cache is not None + assert state.last_decoded_latent is not None + + fresh_lat_f = state.current_lat_f + fresh_kv_size = state.kv_cache[0].shape[2] + frame_seqlen = state.frame_seqlen + assert frame_seqlen == state.lat_h * state.lat_w // 4 + assert fresh_kv_size == frame_seqlen * fresh_lat_f + + # The spatial dims must come from the input image on the fresh call so the + # extension branch can later reuse them — make sure they were captured. + assert state.h is not None and state.w is not None + assert state.lat_h is not None and state.lat_w is not None + + # --- Extension call ----------------------------------------------------- + camera_ext = make_dummy_camera_inputs(num_frames=_EXTENSION_NUM_FRAMES) + req_ext = _build_request( + image=None, + camera=camera_ext, + session_id=session_id, + num_frames=_EXTENSION_NUM_FRAMES, + ) + + out_ext = pipeline.forward(req_ext) + + assert isinstance(out_ext, DiffusionOutput) + assert out_ext.output is not None + assert torch.isfinite(out_ext.output).all(), "Extension-call video contains NaN/Inf." + + assert state.session_id == session_id, "Same session_id must not trigger a reset." + assert state.is_initialized is True + assert state.current_lat_f > fresh_lat_f, "current_lat_f must advance on extension." + ext_lat_f = state.current_lat_f - fresh_lat_f + assert ext_lat_f > 0 + + # ``extend_kv_caches`` allocates a fresh tensor of size old + frame_seqlen * + # new_lat_f and concatenates; assert the trailing slice grew by exactly + # ``frame_seqlen * ext_lat_f`` for every layer. + for layer_idx, layer in enumerate(state.kv_cache): + assert layer.shape == ( + 2, + 1, + fresh_kv_size + frame_seqlen * ext_lat_f, + _NUM_HEADS, + _HEAD_DIM, + ), f"layer {layer_idx} KV cache did not grow by exactly frame_seqlen * ext_lat_f" + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_different_session_id_resets_state(stubbed_pipeline) -> None: + """A new ``session_id`` must hard-reset the cached spatial dims and KV cache""" + pipeline = stubbed_pipeline + image = make_dummy_image() + + req_a = _build_request( + image=image, + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-a", + num_frames=_FRESH_NUM_FRAMES, + ) + pipeline.forward(req_a) + + assert pipeline.state.session_id == "session-a" + lat_f_after_a = pipeline.state.current_lat_f + assert lat_f_after_a > 0 + + # A different session_id on the next call must drop the prior KV cache — + # ``current_lat_f`` resets to ``new_lat_f`` of the second call, not to + # ``lat_f_after_a + new_lat_f``. + req_b = _build_request( + image=make_dummy_image(), + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-b", + num_frames=_FRESH_NUM_FRAMES, + ) + out_b = pipeline.forward(req_b) + assert torch.isfinite(out_b.output).all() + + assert pipeline.state.session_id == "session-b" + assert pipeline.state.current_lat_f == lat_f_after_a, ( + "Stub fresh-call advances by the same fresh new_lat_f, so the reset must " + "have wiped the prior cumulative count rather than added to it." + ) + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_post_process_shapes_videos_for_external_output(stubbed_pipeline) -> None: + """The model-specific post-process flips ``[C, F, H, W]`` to ``[F, H, W, C]``; + that's what diffusion engine + serving code downstream expects.""" + pipeline = stubbed_pipeline + req = _build_request( + image=make_dummy_image(), + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-postprocess", + num_frames=_FRESH_NUM_FRAMES, + ) + + out = pipeline.forward(req) + post = get_lingbot_world_fast_post_process_func(pipeline.od_config) + framed = post(out.output) + + # [C, F, H, W] → [F, H, W, C] + assert framed.ndim == 4 + assert framed.shape[-1] == out.output.shape[0] + assert framed.shape[0] == out.output.shape[1] diff --git a/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py b/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py new file mode 100644 index 00000000000..98a932f5680 --- /dev/null +++ b/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L3 real-weight offline expansion for Lingbot World Fast.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import PIL.Image +import pytest +import torch + +from tests.helpers.lingbot_world_fast import ( + FPS, + GREAT_WALL_PROMPT, + HEIGHT, + LONG_NUM_FRAMES, + SEED, + SHORT_NUM_FRAMES, + SSIM_THRESHOLD, + WIDTH, + find_lingbot_world_fast_assets, + frame_ssim, + golden_frames_dir, + load_camera_trajectory, + normalize_to_uint8_rgb, +) +from tests.helpers.mark import hardware_test +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +pytestmark = [ + pytest.mark.advanced_model, + pytest.mark.core_model, + pytest.mark.diffusion, +] + + +def _extract_frames_from_output(output: Any) -> np.ndarray: + """Pull a ``[N, H, W, 3]`` numpy array out of an ``OmniRequestOutput``.""" + if isinstance(output, list) and output: + output = output[0] + if isinstance(output, OmniRequestOutput): + if output.is_pipeline_output and output.request_output is not None: + inner = output.request_output + if isinstance(inner, OmniRequestOutput): + output = inner + if isinstance(output, OmniRequestOutput) and output.images: + entry = output.images[0] + if isinstance(entry, tuple) and len(entry) >= 1: + output = entry[0] + elif isinstance(entry, dict): + output = entry.get("frames") or entry.get("video") + else: + output = entry + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + if not isinstance(output, np.ndarray): + raise AssertionError(f"Could not extract frames from output: {type(output)}") + return normalize_to_uint8_rgb(output) + + +@pytest.fixture(scope="module") +def lingbot_world_fast_assets(): + assets = find_lingbot_world_fast_assets() + if assets is None: + pytest.skip( + "Lingbot-World-Fast L3 assets not available. Set LINGBOT_WORLD_FAST_PATH " + "(model dir) + LINGBOT_WORLD_FAST_CAMERA_PATH (poses.npy/intrinsics.npy) " + "+ LINGBOT_WORLD_FAST_IMAGE (input image) to enable.", + ) + return assets + + +@pytest.fixture(scope="module") +def lingbot_world_fast_omni(lingbot_world_fast_assets): + omni = Omni( + model=str(lingbot_world_fast_assets.weights_path), + parallel_config=None, + model_class_name="LingbotWorldFastPipeline", + stage_init_timeout=6000, + init_timeout=6000, + ) + try: + yield omni + finally: + if hasattr(omni, "close"): + omni.close() + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +@pytest.mark.parametrize("num_frames, length", [(SHORT_NUM_FRAMES, "short"), (LONG_NUM_FRAMES, "long")]) +def test_lingbot_world_offline_video( + num_frames, + length, + lingbot_world_fast_assets, + lingbot_world_fast_omni, +): + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + poses = poses[:num_frames] + intrinsics = intrinsics[:num_frames] + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(SEED) + sampling = OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + generator=generator, + num_frames=num_frames, + frame_rate=FPS, + extra_args={"session_id": f"SESSION_ID-{length}"}, + ) + + multi_modal_data: dict = {"image": image, "camera": {"poses": poses, "intrinsics": intrinsics}} + + prompt = { + "prompt": GREAT_WALL_PROMPT, + "negative_prompt": "", + "multi_modal_data": multi_modal_data, + } + + output = lingbot_world_fast_omni.generate(prompt, sampling) + + video = _extract_frames_from_output(output) + + first_frame = video[0] + last_frame = video[-1] + + first_path = golden_frames_dir() / f"golden_frame_{length}_first.npy" + last_path = golden_frames_dir() / f"golden_frame_{length}_last.npy" + + first_golden = np.load(first_path) + last_golden = np.load(last_path) + + ssim_first = frame_ssim(first_frame, first_golden) + ssim_last = frame_ssim(last_frame, last_golden) + print( + f"[lingbot-world-fast L3] SSIM(first)={ssim_first:.4f} SSIM(last)={ssim_last:.4f} (threshold {SSIM_THRESHOLD})" + ) + assert ssim_first >= SSIM_THRESHOLD, ( + f"First-frame SSIM {ssim_first:.4f} below {SSIM_THRESHOLD}: regression in first-call path." + ) + assert ssim_last >= SSIM_THRESHOLD, ( + f"Last-frame SSIM {ssim_last:.4f} below {SSIM_THRESHOLD}: regression in last-call path." + ) diff --git a/tests/e2e/online_serving/test_lingbot_world_fast.py b/tests/e2e/online_serving/test_lingbot_world_fast.py new file mode 100644 index 00000000000..a51cb2f542f --- /dev/null +++ b/tests/e2e/online_serving/test_lingbot_world_fast.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L2 online smoke for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +from vllm_omni.entrypoints.openai.realtime.world.camera_connection import ( + CHUNK_FRAMES, + WorldCameraRealtimeConnection, +) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera + +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Test plumbing +# --------------------------------------------------------------------------- + + +class _MockWebSocket: + """ASGI-shaped mock matching ``WorldCameraRealtimeConnection``'s call sites. + + ``receive`` is the lowest-level ASGI hook the connection uses (it pulls + ``{"type": "websocket.receive", "bytes": ...}`` dicts directly, not the + higher-level ``receive_bytes``). After the scripted messages are + exhausted, ``receive`` returns a disconnect frame so the connection's + main loop exits cleanly instead of timing out. + """ + + def __init__(self, incoming: Iterable[dict[str, Any]] | None = None) -> None: + self._incoming: list[dict[str, Any]] = list(incoming or []) + self._idx = 0 + self.sent_bytes: list[bytes] = [] + self.sent_text: list[str] = [] + self.accepted = False + self.closed = False + + async def accept(self) -> None: + self.accepted = True + + async def receive(self) -> dict[str, Any]: + if self._idx >= len(self._incoming): + return {"type": "websocket.disconnect"} + msg = self._incoming[self._idx] + self._idx += 1 + return msg + + async def send_bytes(self, data: bytes) -> None: + self.sent_bytes.append(data) + + async def send_text(self, data: str) -> None: + self.sent_text.append(data) + + async def close(self) -> None: + self.closed = True + + +class _FakeAsyncIter: + """Async iterable for the canned engine output.""" + + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class _FakeResult: + """Stand-in for ``OmniRequestOutput`` — only ``.images`` is consulted.""" + + def __init__(self, frames: np.ndarray) -> None: + self.images = [frames] + + +class _FakeEngineClient: + """Stand-in for ``AsyncOmni``: records calls, returns a per-call frame buffer. + + The connection's framing logic calls ``generate(...)`` once per ``infer`` + request. Tests pre-load ``self.queued_frames`` with one buffer per + expected call, in order. + """ + + def __init__(self, queued_frames: list[np.ndarray]) -> None: + self.queued_frames = list(queued_frames) + self.calls: list[dict[str, Any]] = [] + # Attributes consulted by ``CameraServerConfig.from_model_config``. + self.model_config = {"pipeline": "lingbot_world_fast", "resolution": [480, 832], "fps": 16} + + def generate(self, *, prompt, request_id, sampling_params_list): + self.calls.append( + { + "prompt": prompt, + "request_id": request_id, + "session_id": sampling_params_list[0].extra_args.get("session_id"), + } + ) + if not self.queued_frames: + raise AssertionError("FakeEngineClient ran out of queued frames") + frames = self.queued_frames.pop(0) + return _FakeAsyncIter([_FakeResult(frames)]) + + +def _pack_frame(payload: Any) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": msgpack_numpy.packb(payload)} + + +def _camera_payload(num_frames: int) -> dict[str, np.ndarray]: + return { + "intrinsics": np.eye(3, dtype=np.float32), + "poses": np.tile(np.eye(4, dtype=np.float32), (num_frames, 1, 1)), + } + + +def _infer_req(*, session_id: str, num_frames: int, include_image: bool) -> dict[str, Any]: + req: dict[str, Any] = { + "prompt": "walk along the Great Wall of China", + "camera": _camera_payload(num_frames), + "session_id": session_id, + "extra_body": {"num_frames": num_frames, "height": 480, "width": 832, "fps": 16}, + } + if include_image: + req["image"] = np.zeros((480, 832, 3), dtype=np.uint8) + return req + + +# --------------------------------------------------------------------------- +# Lifecycle smoke test +# --------------------------------------------------------------------------- + + +def test_camera_session_lifecycle_handshake_infer_reset_infer() -> None: + """End-to-end client session: handshake → infer → reset → infer (new + session_id). Mirrors what ``examples/online_serving/lingbot_world_fast/openai_client.py`` + does on the wire, minus the actual model.""" + # Distinct, non-divisible-by-CHUNK_FRAMES buffer sizes so both calls + # exercise the boundary case (final chunk shorter than CHUNK_FRAMES) and + # the fill-value lets us prove chunks aren't swapped between requests. + first_frames = np.full((CHUNK_FRAMES * 2 + 1, 8, 8, 3), 3, dtype=np.uint8) + second_frames = np.full((CHUNK_FRAMES + 1, 8, 8, 3), 7, dtype=np.uint8) + + engine_client = _FakeEngineClient(queued_frames=[first_frames, second_frames]) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + # First infer uses ``4N+1`` to mimic the openai_client's fresh-call shape; + # second infer uses ``4N`` for the extension shape (modelled on the + # client's branch at ``openai_client.py:84``). + fresh_req = _infer_req(session_id="session-1", num_frames=25, include_image=True) + ext_req = _infer_req(session_id="session-2", num_frames=24, include_image=False) + + ws = _MockWebSocket( + incoming=[ + _pack_frame(fresh_req), + _pack_frame({"endpoint": "reset"}), + _pack_frame(ext_req), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # ---------------- Handshake ------------------ + assert ws.accepted is True, "Connection must accept() before sending the handshake." + assert len(ws.sent_bytes) >= 1 + handshake = msgpack_numpy.unpackb(ws.sent_bytes[0]) + assert isinstance(handshake, dict) and handshake, "Handshake must be a non-empty msgpack dict." + assert handshake.get("pipeline") == "lingbot_world_fast" + + # ---------------- Frame chunks --------------- + decoded = [msgpack_numpy.unpackb(b) for b in ws.sent_bytes[1:]] + frame_chunks = [d for d in decoded if isinstance(d, dict) and d.get("type") == "frame"] + + first_total = (len(first_frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES + second_total = (len(second_frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES + + assert len(frame_chunks) == first_total + second_total, ( + f"Expected {first_total} + {second_total} frame chunks, got {len(frame_chunks)}." + ) + + for chunk in frame_chunks: + assert chunk.keys() >= {"type", "index", "total", "video"} + video = chunk["video"] + for frame in video: + assert frame.dtype == np.float32 + assert (frame >= 0).all() and (frame <= 1).all() + assert frame.ndim == 3 # [h, w, 3] + assert frame.shape[-1] == 3 + + # Send order is first-call chunks, then second-call chunks. Index runs + # 0..total-1 per request, and the per-chunk fill-value proves the chunker + # didn't leak state between requests. + for i, chunk in enumerate(frame_chunks[:first_total]): + assert chunk["index"] == i + assert chunk["total"] == first_total + for i, chunk in enumerate(frame_chunks[first_total:]): + assert chunk["index"] == i + assert chunk["total"] == second_total + + # ---------------- reset ---------------------- + assert "reset successful" in ws.sent_text, "Reset endpoint must reply with a text ack." + # ``ServingRealtimeWorldCamera.reset`` wipes ``_current_session_id``; + # the third request then sets it to ``session-2``. + assert serving._current_session_id == "session-2" + + # ---------------- Engine call accounting ----- + # Exactly two ``generate`` invocations; the reset request must not call + # the engine. + assert len(engine_client.calls) == 2 + assert [c["session_id"] for c in engine_client.calls] == ["session-1", "session-2"] + + +def test_camera_session_handshake_does_not_repeat_within_connection() -> None: + """The handshake is sent **once** at connect, even if many infer/reset + operations follow. Other diffusion clients depend on this invariant to + avoid double-initialising their config.""" + first_frames = np.full((CHUNK_FRAMES + 1, 4, 4, 3), 1, dtype=np.uint8) + second_frames = np.full((CHUNK_FRAMES + 2, 4, 4, 3), 2, dtype=np.uint8) + engine_client = _FakeEngineClient(queued_frames=[first_frames, second_frames]) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + ws = _MockWebSocket( + incoming=[ + _pack_frame(_infer_req(session_id="s", num_frames=25, include_image=True)), + _pack_frame({"endpoint": "reset"}), + _pack_frame(_infer_req(session_id="s", num_frames=24, include_image=False)), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # Exactly one msgpack-encoded non-frame, non-error dict in the entire + # outbound stream: the handshake. + handshakes = [] + for b in ws.sent_bytes: + decoded = msgpack_numpy.unpackb(b) + if isinstance(decoded, dict) and decoded.get("type") not in ("frame", "error"): + handshakes.append(decoded) + assert len(handshakes) == 1, f"Expected exactly one handshake, got {len(handshakes)}: {handshakes}" diff --git a/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py b/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py new file mode 100644 index 00000000000..3de37072f55 --- /dev/null +++ b/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L3 real-weight online expansion for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import PIL.Image +import pytest + +from tests.helpers.lingbot_world_fast import ( + FPS, + GREAT_WALL_PROMPT, + HEIGHT, + LONG_NUM_FRAMES, + SEED, + SHORT_NUM_FRAMES, + SSIM_THRESHOLD, + WIDTH, + find_lingbot_world_fast_assets, + frame_ssim, + golden_frames_dir, + load_camera_trajectory, + slice_camera_chunk, +) +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer + +# Optional protocol deps mirror what the connection itself imports lazily. +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") +ws_sync = pytest.importorskip("websockets.sync.client") + +pytestmark = [ + pytest.mark.advanced_model, + pytest.mark.core_model, + pytest.mark.diffusion, +] + +_CONNECT_KWARGS = {"max_size": None, "ping_interval": None, "ping_timeout": None} + + +# --------------------------------------------------------------------------- +# Asset / golden fixtures (module-scoped to amortize file IO) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def lingbot_world_fast_assets(): + assets = find_lingbot_world_fast_assets() + if assets is None: + pytest.skip( + "Lingbot-World-Fast L3 assets not available. Set LINGBOT_WORLD_FAST_PATH, " + "LINGBOT_WORLD_FAST_CAMERA_PATH and LINGBOT_WORLD_FAST_IMAGE.", + ) + return assets + + +_LINGBOT_SERVER_ARGS = [ + "--model-class-name", + "LingbotWorldFastPipeline", + "--ws-max-size", + "16777216", # 16 MiB — matches run_server.sh; large enough for a 480×832 image + "--ws", + "wsproto", + "--stage-init-timeout", + "6000", + "--init-timeout", + "6000", +] + + +@pytest.fixture(scope="module") +def lingbot_world_fast_server(lingbot_world_fast_assets): + """Module-scoped real-weight server; amortizes the multi-minute cold load + across the four protocol tests in this file.""" + with OmniServer( + str(lingbot_world_fast_assets.weights_path), + list(_LINGBOT_SERVER_ARGS), + use_omni=True, + ) as server: + yield server + + +def _ws_url(server: OmniServer) -> str: + return f"ws://{server.host}:{server.port}/v1/realtime/world/camera" + + +# --------------------------------------------------------------------------- +# WebSocket helpers +# --------------------------------------------------------------------------- + + +def _drain_handshake(ws) -> dict[str, Any]: + handshake = msgpack_numpy.unpackb(ws.recv()) + return handshake + + +def _send_request(ws, req: dict[str, Any]) -> None: + ws.send(msgpack_numpy.packb(req)) + + +def _drain_frames_or_error(ws) -> tuple[list[np.ndarray] | None, dict[str, Any] | None, str | None]: + """Return ``(frames, error, text)``. Exactly one of the three is non-None. + + * ``frames``: list of per-chunk uint8 arrays once ``total`` frames arrive. + * ``error``: parsed ``{"type": "error", "message": ...}`` payload. + * ``text``: text-frame reply (e.g. ``"reset successful"``). + """ + chunks: list[np.ndarray] = [] + total: int | None = None + while total is None or len(chunks) < total: + msg = ws.recv() + if isinstance(msg, str): + return None, None, msg + decoded = msgpack_numpy.unpackb(msg) + if isinstance(decoded, dict) and decoded.get("type") == "error": + return None, decoded, None + if not isinstance(decoded, dict) or decoded.get("type") != "frame": + continue # ignore unknown + total = decoded["total"] + chunks.append(np.asarray(decoded["video"])) + return chunks, None, None + + +def _build_request( + *, + session_id: str, + image: np.ndarray | None, + camera_chunk: dict[str, np.ndarray], + num_frames: int, +) -> dict[str, Any]: + req: dict[str, Any] = { + "prompt": GREAT_WALL_PROMPT, + "camera": camera_chunk, + "session_id": session_id, + "extra_body": { + "num_frames": num_frames, + "height": HEIGHT, + "width": WIDTH, + "fps": FPS, + "session_id": session_id, + "seed": SEED, + }, + } + if image is not None: + req["image"] = image + return req + + +# --------------------------------------------------------------------------- +# Test 1: Single session generation +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +@pytest.mark.parametrize("num_frames, length", [(SHORT_NUM_FRAMES, "short"), (LONG_NUM_FRAMES, "long")]) +def test_lingbot_world_online_video( + num_frames, + length, + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + poses = poses[:num_frames] + intrinsics = intrinsics[:num_frames] + + camera = {"poses": poses, "intrinsics": intrinsics} + + req = _build_request( + session_id=f"SESSION-ID-{length}", + image=image, + camera_chunk=camera, + num_frames=num_frames, + ) + + _send_request(ws, req) + + chunks, error, text = _drain_frames_or_error(ws) + assert error is None and text is None, f"Got unexpected control reply: error={error} text={text}" + assert chunks is not None and chunks, "Returned no frames" + + reassembled = np.concatenate(chunks, axis=0) + + assert reassembled.ndim == 4 and reassembled.shape[0] >= 2, ( + f"Reassembled video has too few frames: {reassembled.shape}" + ) + + first_frame = (reassembled[0] * 255.0).round().astype(np.uint8) + last_frame = (reassembled[-1] * 255.0).round().astype(np.uint8) + + first_path = golden_frames_dir() / f"golden_frame_{length}_first.npy" + last_path = golden_frames_dir() / f"golden_frame_{length}_last.npy" + + first_golden = np.load(first_path) + last_golden = np.load(last_path) + + ssim_first = frame_ssim(first_frame, first_golden) + ssim_last = frame_ssim(last_frame, last_golden) + print( + f"[lingbot-world-fast L3 online] SSIM(first)={ssim_first:.4f} " + f"SSIM(last)={ssim_last:.4f} (threshold {SSIM_THRESHOLD})" + ) + assert ssim_first >= SSIM_THRESHOLD, ( + f"First-frame SSIM {ssim_first:.4f} below {SSIM_THRESHOLD}: regression in fresh-call path." + ) + assert ssim_last >= SSIM_THRESHOLD, ( + f"Last-frame SSIM {ssim_last:.4f} below {SSIM_THRESHOLD}: regression in extension-call path." + ) + + +# --------------------------------------------------------------------------- +# Test 2: Session-id churn mid-stream +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +def test_websocket_session_id_churn_resets_state( + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + """A new ``session_id`` mid-stream resets pipeline state. The next ``infer`` + that omits the image (i.e. an "extension-style" payload) must error + because the new session is fresh.""" + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + camera_a = slice_camera_chunk(poses, intrinsics, call_index=0) + camera_b = slice_camera_chunk(poses, intrinsics, call_index=1) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + _send_request( + ws, + _build_request( + session_id="churn-session-a", + image=image, + camera_chunk=camera_a, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks, error, text = _drain_frames_or_error(ws) + assert chunks is not None and not error and not text, ( + f"First infer on session-a should succeed; got error={error} text={text}" + ) + + # Switch session_id mid-stream WITHOUT sending an image. The pipeline + # treats this as a fresh call (new session) and rejects. + _send_request( + ws, + _build_request( + session_id="churn-session-b", + image=None, + camera_chunk=camera_b, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks2, error2, text2 = _drain_frames_or_error(ws) + assert error2 is not None, ( + "Server must reject a fresh session that omits ``image``; got " + f"frames={None if chunks2 is None else len(chunks2)} text={text2}" + ) + assert error2.get("type") == "error" + + +# --------------------------------------------------------------------------- +# Test 3: Mid-session ``reset`` RPC re-initializes +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +def test_websocket_mid_session_reset_reinitializes( + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + """After a ``reset`` text ack, the next ``infer`` with the same ``session_id`` + is a brand-new fresh call. We verify this by asserting that the + follow-up ``infer`` *without* an image errors (same logic as the + session-id churn test).""" + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + camera_a = slice_camera_chunk(poses, intrinsics, call_index=0) + camera_b = slice_camera_chunk(poses, intrinsics, call_index=1) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + _send_request( + ws, + _build_request( + session_id="reset-session", + image=image, + camera_chunk=camera_a, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks, error, text = _drain_frames_or_error(ws) + assert chunks is not None and not error and not text, ( + f"Initial infer must succeed; got error={error} text={text}" + ) + + # Mid-session reset RPC. + ws.send(msgpack_numpy.packb({"endpoint": "reset"})) + _, _, reset_text = _drain_frames_or_error(ws) + assert reset_text == "reset successful", f"Expected 'reset successful' text frame, got {reset_text!r}" + + # Same session_id, no image → fresh-call branch → server error. + _send_request( + ws, + _build_request( + session_id="reset-session", + image=None, + camera_chunk=camera_b, + num_frames=SHORT_NUM_FRAMES, + ), + ) + _, post_reset_error, post_reset_text = _drain_frames_or_error(ws) + assert post_reset_error is not None, ( + "After mid-session reset the server must treat the next infer as a fresh call; " + f"missing-image payload should error. Got text={post_reset_text!r}" + ) diff --git a/tests/entrypoints/openai_api/test_realtime_world_camera.py b/tests/entrypoints/openai_api/test_realtime_world_camera.py new file mode 100644 index 00000000000..4b600e9bfae --- /dev/null +++ b/tests/entrypoints/openai_api/test_realtime_world_camera.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import msgspec +import pytest +from fastapi import FastAPI, WebSocket +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect + +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import CameraServerConfig + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +mock_model_config = {"model_name": "/m/foo", "image_width": 832, "image_height": 480, "extra_args": {"foo": "bar"}} + + +def test_from_model_config_loads_correctly(): + cfg = CameraServerConfig.from_model_config(mock_model_config) + + assert cfg.to_dict() == mock_model_config + + +def test_msgpack_roundtrip(): + cfg = CameraServerConfig.from_model_config(mock_model_config) + encoded = msgspec.msgpack.encode(cfg) + decoded = msgspec.msgpack.decode(encoded, type=CameraServerConfig) + assert decoded == cfg + + +def _build_camera_app(*, supports: bool, cfg: CameraServerConfig | None): + """Build a minimal FastAPI app that mirrors the api_server.py handler.""" + app = FastAPI() + + @app.websocket("/v1/realtime/world/camera") + async def realtime_world_camera(websocket: WebSocket): + await websocket.accept() + if cfg is None or not supports: + await websocket.send_json( + {"type": "error", "error": "Camera realtime API is not available", "code": "unsupported"} + ) + await websocket.close() + return + await websocket.send_bytes(msgspec.msgpack.encode(cfg)) + try: + while True: + msg = await websocket.receive() + if msg.get("type") == "websocket.disconnect": + break + except WebSocketDisconnect: + return + + return app + + +class TestRealtimeWorldCameraEndpoint: + def test_sends_msgpack_config_on_connect(self): + cfg = CameraServerConfig.from_model_config(mock_model_config) + app = _build_camera_app(supports=True, cfg=cfg) + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/world/camera") as ws: + payload = ws.receive_bytes() + decoded = msgspec.msgpack.decode(payload, type=CameraServerConfig) + assert decoded == cfg + + def test_keeps_socket_open_after_initial_send(self): + cfg = CameraServerConfig.from_model_config(mock_model_config) + app = _build_camera_app(supports=True, cfg=cfg) + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/world/camera") as ws: + ws.receive_bytes() + # Client-initiated messages are accepted (currently ignored). + ws.send_text("ping") + # Closing from the client side must not raise on the server. + + def test_unsupported_path_sends_error_and_closes(self): + app = _build_camera_app(supports=False, cfg=None) + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/world/camera") as ws: + err = ws.receive_json() + assert err["type"] == "error" + assert err["code"] == "unsupported" + with pytest.raises(WebSocketDisconnect): + ws.receive_bytes() diff --git a/tests/helpers/lingbot_world_fast.py b/tests/helpers/lingbot_world_fast.py new file mode 100644 index 00000000000..da5e9c270d0 --- /dev/null +++ b/tests/helpers/lingbot_world_fast.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared L3-fixture helpers for Lingbot World Fast expansion tests.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + +# --------------------------------------------------------------------------- +# Constants — single source of truth across the two expansion tests +# --------------------------------------------------------------------------- + +# Mirrors ``examples/offline_inference/lingbot_world_fast/end2end.py`` and +# ``examples/online_serving/lingbot_world_fast/openai_client.py``. +GREAT_WALL_PROMPT = "A sweeping cinematic journey along the Great Wall of China, winding through golden autumn hills under a brilliant blue sky — stone pathways stretch into the distance, watchtowers stand sentinel, and vibrant foliage blankets the mountainsides as the camera glides smoothly forward, capturing the grandeur and timeless majesty of this ancient wonder." +SEED = 42 +WIDTH = 832 +HEIGHT = 480 +FPS = 16 + +SHORT_NUM_FRAMES = 25 +LONG_NUM_FRAMES = 81 + +EXTENSION_WARMUP_DROP = 4 + +SSIM_THRESHOLD = 0.95 + + +@dataclass(frozen=True) +class LingbotWorldFastAssets: + """All external assets a real-weight Lingbot World Fast test needs.""" + + weights_path: Path + camera_dir: Path + image_path: Path + + +# --------------------------------------------------------------------------- +# Resolution helpers (no pytest imports here — callers decide whether to skip) +# --------------------------------------------------------------------------- + + +def _hf_cache_root() -> Path: + return Path(os.environ.get("HF_HOME", str(Path.home() / ".cache" / "huggingface"))) + + +def _hf_model_snapshot_dirs(repo_id: str) -> list[Path]: + snapshots = _hf_cache_root() / "hub" / f"models--{repo_id.replace('/', '--')}" / "snapshots" + if not snapshots.exists(): + return [] + return sorted( + (p for p in snapshots.iterdir() if p.is_dir()), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + + +def _repo_root() -> Path: + # tests/helpers/lingbot_world_fast.py → repo root + return Path(__file__).resolve().parents[2] + + +def _example_checkpoint_root() -> Path: + return ( + _repo_root() + / "examples" + / "offline_inference" + / "lingbot_world_fast" + / "lingbot_world" + / "lingbot-world-base-cam" + / "Lingbot-World-Fast" + ) + + +def _example_camera_root_candidates() -> list[Path]: + base = ( + _repo_root() + / "examples" + / "offline_inference" + / "lingbot_world_fast" + / "lingbot_world" + / "lingbot-world-base-cam" + ) + return [base / "examples" / "04", base / "04", base] + + +def find_lingbot_world_fast_weights() -> Path | None: + """Return a path to the Lingbot World Fast model directory, or ``None``. + + Resolution order: ``LINGBOT_WORLD_FAST_PATH`` env var → committed example + checkpoint → HF cache snapshot of ``robbyant/lingbot-world-base-cam``. + A *path is real* only when the required ``config.json`` plus at least + one ``model-*.safetensors`` shard are present, so a half-pulled snapshot + doesn't masquerade as a usable checkpoint. + """ + override = os.environ.get("LINGBOT_WORLD_FAST_PATH") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + + candidates.append(_example_checkpoint_root()) + + for snapshot in _hf_model_snapshot_dirs("robbyant/lingbot-world-base-cam"): + candidates.append(snapshot / "Lingbot-World-Fast") + + for path in candidates: + if not path.exists() or not path.is_dir(): + continue + if not (path / "config.json").exists(): + continue + if not any(path.glob("model-*.safetensors")): + continue + return path + return None + + +def find_lingbot_world_fast_camera_dir() -> Path | None: + """Locate a directory with ``poses.npy`` + ``intrinsics.npy``.""" + override = os.environ.get("LINGBOT_WORLD_FAST_CAMERA_PATH") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + candidates.extend(_example_camera_root_candidates()) + for path in candidates: + if path.exists() and (path / "poses.npy").exists() and (path / "intrinsics.npy").exists(): + return path + return None + + +def find_lingbot_world_fast_image() -> Path | None: + """Locate the example input image used by ``run_fast.sh`` case 2.""" + override = os.environ.get("LINGBOT_WORLD_FAST_IMAGE") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + for camera_root in _example_camera_root_candidates(): + for name in ("image.jpg", "image.png", "input.jpg", "input.png"): + candidates.append(camera_root / name) + for path in candidates: + if path.exists() and path.is_file(): + return path + return None + + +def find_lingbot_world_fast_assets() -> LingbotWorldFastAssets | None: + """Resolve weights + camera trajectory + image in one call, or return None + if any of the three is missing — callers can then ``pytest.skip`` with a + single specific reason.""" + weights = find_lingbot_world_fast_weights() + camera = find_lingbot_world_fast_camera_dir() + image = find_lingbot_world_fast_image() + with open("output.txt", "w+") as f: + f.write(f"{weights is None} {camera is None} {image is None}") + if not (weights and camera and image): + return None + return LingbotWorldFastAssets(weights_path=weights, camera_dir=camera, image_path=image) + + +def golden_frames_dir() -> Path: + return _repo_root() / "tests" / "data" / "lingbot_world_fast" + + +# --------------------------------------------------------------------------- +# Per-call payload helpers +# --------------------------------------------------------------------------- + + +def load_camera_trajectory(camera_dir: Path) -> tuple[np.ndarray, np.ndarray]: + poses = np.load(camera_dir / "poses.npy") + intrinsics = np.load(camera_dir / "intrinsics.npy") + return poses, intrinsics + + +def slice_camera_chunk( + poses: np.ndarray, + intrinsics: np.ndarray, + *, + call_index: int, + chunk_stride: int = SHORT_NUM_FRAMES, +) -> dict[str, np.ndarray]: + """Mirrors the slicing in ``examples/online_serving/lingbot_world_fast/openai_client.py``. + + Each call consumes ``chunk_stride`` poses. The model will floor internally if the slice has fewer + poses than requested. + """ + start = call_index * chunk_stride + end = start + chunk_stride + poses_slice = poses[start:end] + intrinsics_slice = intrinsics[start:end] if intrinsics.ndim > 2 else intrinsics + return {"poses": poses_slice, "intrinsics": intrinsics_slice} + + +# --------------------------------------------------------------------------- +# Frame post-processing helpers +# --------------------------------------------------------------------------- + + +def reassemble_chunked_video( + per_call_frames: list[np.ndarray], + *, + drop_warmup: int = EXTENSION_WARMUP_DROP, +) -> np.ndarray: + """Concatenate per-call frame chunks, dropping ``drop_warmup`` leading + frames on every extension call (call index >= 1).""" + assembled: list[np.ndarray] = [] + for i, frames in enumerate(per_call_frames): + clip = frames[drop_warmup:] if i > 0 else frames + assembled.append(clip) + return np.concatenate(assembled, axis=0) + + +def normalize_to_uint8_rgb(frames: np.ndarray) -> np.ndarray: + """Coerce a generated frames tensor to ``[N, H, W, 3]`` ``uint8``. + + The diffusion engine emits either an unsigned-int chunk (pre-encoded by + ``_normalize_frames``) or a float tensor in ``[-1, 1]``. We accept both + so the SSIM helper sees a single canonical shape. + """ + arr = frames + if arr.dtype.kind == "f": + arr = np.clip((arr + 1.0) * 0.5, 0.0, 1.0) + arr = (arr * 255.0).round().astype(np.uint8) + if arr.ndim == 5 and arr.shape[0] == 1: + arr = arr[0] + return arr + + +def frame_ssim(prediction: np.ndarray, reference: np.ndarray) -> float: + """Per-frame SSIM with ``data_range=1``. Uses ``torchmetrics`` (already a + transitive dep) and accepts ``[H, W, 3]`` uint8 arrays. + """ + import torch + from torchmetrics.image import StructuralSimilarityIndexMeasure + + pred_t = (torch.from_numpy(prediction.astype(np.float32) / 255.0)).permute(2, 0, 1).unsqueeze(0) + ref_t = (torch.from_numpy(reference.astype(np.float32) / 255.0)).permute(2, 0, 1).unsqueeze(0) + metric = StructuralSimilarityIndexMeasure(data_range=1.0) + return float(metric(pred_t, ref_t).item()) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index b8eb899fa32..a11ac588a35 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -797,7 +797,7 @@ def mock_process(transfer_manager, pooling_output, request): class TestLocalPayloadCacheLifecycle(unittest.TestCase): - """Unit tests for the local payload cache API (RFC §2.4).""" + """Unit tests for the local payload cache API""" def _make_host(self) -> MixinHost: host = MixinHost() diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py index fb697fc88c8..73c0b803126 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -168,9 +168,12 @@ def forward( multi_modal_data = req.prompts[0].get("multi_modal_data", {}) session_id = str(req.sampling_params.extra_args.get("session_id") or None) + + force_reset = req.sampling_params.extra_args.get("force_reset") or False + extension = True - if self.state.session_id is None or self.state.session_id != session_id: + if force_reset or self.state.session_id is None or self.state.session_id != session_id: self.state.reset() self.state.session_id = session_id extension = False @@ -236,9 +239,13 @@ def forward( new_lat_f = max(new_lat_f, 1) max_seq_len = chunk_size * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size - seed = random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) + seed_g = req.sampling_params.generator + if seed_g is None: + seed = req.sampling_params.seed + if seed is None: + seed = random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) noise = torch.randn(16, new_lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) # Fresh: msk[0] = 1 (anchor) and the rest = 0, replicated into 4 channels grouped diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/t5.py b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py index ccdb8160d6d..0863d121b4b 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/t5.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py @@ -417,7 +417,7 @@ def __init__( self, text_len, dtype=torch.bfloat16, - device=torch.accelerator.current_device_idx(), + device=torch.accelerator.current_device_index(), checkpoint_path=None, tokenizer_path=None, shard_fn=None, diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py index 7d4a500399c..01c99d67a03 100644 --- a/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py @@ -5,7 +5,7 @@ Protocol (compatible with DreamZero test_client_AR.py): Connect -> server sends msgpack(PolicyServerConfig fields) - Infer -> client sends msgpack(obs), server sends msgpack(ndarray) + Infer -> client sends msgpack(req), server sends msgpack(ndarray) Reset -> client sends msgpack({endpoint:reset}), server sends "reset successful" """ @@ -24,6 +24,7 @@ logger = init_logger(__name__) _DEFAULT_IDLE_TIMEOUT = 30.0 +CHUNK_FRAMES = 4 def _get_msgpack_numpy() -> Any: @@ -63,10 +64,10 @@ async def _send_error(self, message: str) -> None: await self.websocket.send_bytes(_pack({"type": "error", "message": message})) def _unpack_request(self, data: bytes) -> dict[str, Any]: - obs = _unpack(data) - if not isinstance(obs, dict): + req = _unpack(data) + if not isinstance(req, dict): raise ValueError("Invalid request payload") - return obs + return req async def handle_connection(self) -> None: """Main loop.""" @@ -99,7 +100,7 @@ async def handle_connection(self) -> None: continue try: - obs = self._unpack_request(msg["bytes"]) + req = self._unpack_request(msg["bytes"]) except Exception: logger.exception("Invalid world model OpenPI request payload") try: @@ -109,13 +110,13 @@ async def handle_connection(self) -> None: continue try: - endpoint = obs.pop("endpoint", "infer") + endpoint = req.pop("endpoint", "infer") if endpoint == "reset": - self.serving.reset(obs) + self.serving.reset(req) await self.websocket.send_text("reset successful") else: - result = await self.serving.infer(obs) + result = await self.serving.infer(req) if ( len(result.images) == 1 @@ -136,8 +137,6 @@ async def handle_connection(self) -> None: frames = _normalize_frames(frames) - CHUNK_FRAMES = 4 - total = (len(frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES for i in range(total): chunk = frames[i * CHUNK_FRAMES : (i + 1) * CHUNK_FRAMES] diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py index 6c7be761762..4ac8a6c4a1f 100644 --- a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py @@ -30,12 +30,7 @@ def _to_builtin_container(value: Any) -> Any: @dataclass(frozen=True) class CameraServerConfig: - """Static server-side camera/pipeline parameters sent to a client on connect. - - Fields are seeded from the Lingbot World Fast pipeline constants — the only - camera-capable pipeline today. When additional camera pipelines are added, - ``from_model_config`` should branch on the model identifier. - """ + """Static server-side camera/pipeline parameters sent to a client on connect.""" values: dict[str, Any] @@ -63,6 +58,7 @@ def __init__( self._current_session_id: str | None = None self._call_count = 0 self.policy_server_config = self._get_policy_server_config(engine_client) + self._force_reset = False @classmethod def create_policy_server( @@ -101,21 +97,23 @@ def _get_policy_server_config(engine_client: Any) -> CameraServerConfig: if model_config is None: model_config = getattr(engine_client, "model_config", None) + return CameraServerConfig.from_model_config(model_config) - def reset(self, obs: dict) -> None: + def reset(self, req: dict) -> None: """Reset serving state. Engine-side Lingbot state is reset on the next inference request via `extra_args["reset"]`, not by an immediate websocket-side RPC. """ self._current_session_id = None + self._force_reset = True - async def infer(self, obs: dict) -> np.ndarray: - """raw obs → engine → video.""" + async def infer(self, req: dict) -> np.ndarray: + """raw req → engine → video.""" # Session tracking - session_id = obs.get("session_id") + session_id = req.get("session_id") if session_id is not None and session_id != self._current_session_id: if self._current_session_id is not None: logger.info("Session changed %s → %s", self._current_session_id, session_id) @@ -125,7 +123,11 @@ async def infer(self, obs: dict) -> np.ndarray: self._call_count += 1 # Build request, run inference through AsyncOmni - request = self._build_request(obs) + request = self._build_request(req) + + # After an inference call we reset the _force_reset argument + self._force_reset = False + result = None # OpenPI policy serving is one request -> one action reply. AsyncOmni # exposes an async iterator, so consume it to completion and use the @@ -141,8 +143,8 @@ async def infer(self, obs: dict) -> np.ndarray: return result - def _build_request(self, obs: dict) -> Any: - """Build engine request from raw robot obs. + def _build_request(self, req: dict) -> Any: + """Build engine request from raw robot req. Returns an `OmniDiffusionRequest` payload consumed by `AsyncOmni.generate()` and routed to the diffusion stage. @@ -150,28 +152,27 @@ def _build_request(self, obs: dict) -> Any: from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams - extra_args = { - "session_id": self._current_session_id or "default", - } + extra_args = {"session_id": self._current_session_id or "default", "force_reset": self._force_reset} - camera = obs.get("camera", None) + camera = req.get("camera", None) multi_modal_data = { - "image": obs.get("image", None), + "image": req.get("image", None), "camera": camera, } - prompt = obs.get("prompt", "") + prompt = req.get("prompt", "") - extra_body = obs.get("extra_body", {}) + extra_body = req.get("extra_body", {}) height = extra_body.get("height", None) width = extra_body.get("width", None) num_frames = extra_body.get("num_frames", None) fps = extra_body.get("fps", None) + seed = extra_body.get("seed", None) sampling_params = OmniDiffusionSamplingParams( - height=height, width=width, num_frames=num_frames, frame_rate=fps, extra_args=extra_args, seed=42 + height=height, width=width, num_frames=num_frames, frame_rate=fps, extra_args=extra_args, seed=seed ) return OmniDiffusionRequest( prompts=[ From 778e2ff02034ee013e4f502fcc73d4fd36d5c56e Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Tue, 26 May 2026 13:41:05 +0200 Subject: [PATCH 47/53] bugfixes Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/distributed/pipeline_parallel.py | 2 +- vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | 12 +++++++----- vllm_omni/diffusion/worker/diffusion_model_runner.py | 6 +++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 1cfb187b89a..576ea22082c 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -240,7 +240,7 @@ def predict_noise_maybe_with_cfg( if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. - for kwargs, it in zip(all_kwargs, its): + for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) self._pp_send_work.extend( pp_group.pipeline_isend_tensor_dict( diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 6766d376e9b..cf52e55c4ec 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -1072,6 +1072,7 @@ def _resolve_generation_params( height = sampling.height or 480 width = sampling.width or 832 num_frames = sampling.num_frames or 81 + chunk_frames = sampling.chunk_frames or 8 num_steps = sampling.num_inference_steps or 40 patch_size = self.transformer_config.patch_size @@ -1083,7 +1084,7 @@ def _resolve_generation_params( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - guidance_scale = sampling.guidance_scale if sampling.guidance_scale_provided else 4.0 + guidance_scale = sampling.guidance_scale if sampling.guidance_scale_provided else 1.0 guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] guidance_high = ( sampling.guidance_scale_2 @@ -1115,6 +1116,7 @@ def _resolve_generation_params( "height": height, "width": width, "num_frames": num_frames, + "chunk_frames": chunk_frames, "num_steps": num_steps, "guidance_low": guidance_low, "guidance_high": guidance_high, @@ -1242,12 +1244,12 @@ def prepare_encode( height=height, width=width, num_frames=chunk_frames, - dtype=torch.float32, - device="meta", # NOTE: stream mode; latents will be prepared per-chunk in `encode_chunk_inputs` - generator=generator, + dtype=self.vae.dtype, + device="meta", # NOTE: stream mode; latents are prepared per-chunk in `encode_chunk_inputs` + generator=None, latents=state.sampling.latents, ) - if self.expand_timesteps and raw_image is not None: + elif self.expand_timesteps and raw_image is not None: # I2V mode from diffusers.video_processor import VideoProcessor diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 42ee51f91fc..9a70b6024c3 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -638,7 +638,11 @@ def _prepare_chunk_latents(self, state: DiffusionRequestState, layout: Layout, i def _update_decoded_chunks(self, state: DiffusionRequestState, layout: Layout) -> DiffusionOutput | None: n_finished = len(layout.finished_idxs) if n_finished > 0: - decoded = self.pipeline.post_decode(state.latents[: n_finished]) + saved = state.latents + state.latents = state.latents[:n_finished] + decoded = self.pipeline.post_decode(state) + state.latents = saved + state.extra.setdefault("decoded_chunks", []).append(decoded) state.extra["num_chunks_decoded"] = ( state.extra.get("num_chunks_decoded", 0) + n_finished From b68064ec0da0cfd624d910ad07800b3d6aa6d847 Mon Sep 17 00:00:00 2001 From: Miguel Vieira Pereira Date: Wed, 27 May 2026 12:23:08 +0000 Subject: [PATCH 48/53] Update documentation Signed-off-by: Miguel Vieira Pereira --- docs/models/supported_models.md | 1 + .../offline_inference/lingbot_world_fast.md | 49 ++++++++++++++++++ .../online_serving/lingbot_world_fast.md | 51 +++++++++++++++++++ .../lingbot_world_fast/end2end.py | 2 +- .../lingbot_world_fast/openai_client.py | 4 +- .../state_lingbot_world_fast.py | 10 ++++ 6 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 docs/user_guide/examples/offline_inference/lingbot_world_fast.md create mode 100644 docs/user_guide/examples/online_serving/lingbot_world_fast.md diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index d26a579edfd..168b91d7343 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -73,5 +73,6 @@ th { |`DyninOmniForConditionalGeneration` | Dynin-Omni | `snu-aidas/Dynin-Omni` | ✅︎ | | | | | `ErnieImagePipeline` | ERNIE-Image | `baidu/ERNIE-Image`, `baidu/ERNIE-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | |`HiDreamImagePipeline` | HiDream-I1-Full | `HiDream-ai/HiDream-I1-Full` | ✅︎ | ✅︎ | | | +|`LingbotWorldFastPipeline`| Lingbot-World-Fast | `robbyant/lingbot-world-fast`|✅︎ | | | | ✅︎ indicates the model is supported on that backend. Empty cells mean not listed as supported on that backend. diff --git a/docs/user_guide/examples/offline_inference/lingbot_world_fast.md b/docs/user_guide/examples/offline_inference/lingbot_world_fast.md new file mode 100644 index 00000000000..2669f6e6ecd --- /dev/null +++ b/docs/user_guide/examples/offline_inference/lingbot_world_fast.md @@ -0,0 +1,49 @@ +# Lingbot World Fast Offline Inference + +Lingbot World Fast is an autoregressive diffusion model that uses a reference image, a text prompt and a set of camera positions to generate a video. + +## Video Generation + +First, download the model weights using `examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py`. + +The simplest way to run offline generation is to use the script on `examples/offline_inference/lingbot_world_fast/end2end.py`. The core of this script is done by: + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", model_class_name="LingbotWorldFastPipeline") + outputs = omni.generate( + { + "prompt": "A journey along the Great Wall of China", + "multi_modal_data": { + "image": "input.png", + "camera": { + "poses": np.load("path/to/poses.npy") + "intrinsics": np.load("path/to/intrinsics.npy") + } + }, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_frames=num_frames, + frame_rate=fps, + ), + ) + export_to_video(outputs[0], "output.png") +``` + +## Generation Parameters + +| Parameter | Type | Default | Description | +| --------------------- | ----- | ------- | ----------------------------------- | +| `height` | int | None (computed from image) | Image height in pixels | +| `width` | int | None (computed from image) | Image width in pixels | +| `num_frames` | int | 81 | Number of frames to generate | +| `fps` | int | 16 | Frames per second | +| `seed` | int | 42 | Optional random seed | +| `prompt` | str | "" | Text prompt | +| `negative_prompt` | str | None | Negative prompt | +| `image` | str | Required | Path to reference image | +|`camera-path` | str | Required | Path to folder with `poses.npy` and `intrinsics.npy`| diff --git a/docs/user_guide/examples/online_serving/lingbot_world_fast.md b/docs/user_guide/examples/online_serving/lingbot_world_fast.md new file mode 100644 index 00000000000..106f76a6a05 --- /dev/null +++ b/docs/user_guide/examples/online_serving/lingbot_world_fast.md @@ -0,0 +1,51 @@ +# Lingbot World Fast Offline Inference + +Lingbot World Fast is an autoregressive diffusion model that uses a reference image, a text prompt and a set of camera positions to generate a video. The online serving model of this model adds a feature that is not implemented in the original model: video extension. + +## Quickstart + +The easiest way to launch a server running the Lingbot World Fast model is by using the script `examples/online_serving/lingbot_world_fast/run_server.sh`. + +Once the server is launched, the client can send requests to its websocket at `/v1/realtime/world/camera`. The easiest way to interact with the server is using the script `examples/online_serving/lingbot_world_fast/openai_client.py`. Its command line options are described below. + +| Parameter | Type | Default | Description | +| --------------------- | ----- | ------- | ----------------------------------- | +| `height` | int | None (computed from image) | Image height in pixels | +| `width` | int | None (computed from image) | Image width in pixels | +| `num_frames` | int | 81 | Number of frames to generate | +| `fps` | int | 16 | Frames per second | +| `seed` | int | 42 | Optional random seed | +| `prompt` | str | "" | Text prompt | +| `negative_prompt` | str | None | Negative prompt | +| `image` | str | Required | Path to reference image | +|`camera-path` | str | Required | Path to folder with `poses.npy` and `intrinsics.npy`| +| `num-calls` | int | 1 | Makes an additional `num-calls - 1` video extension calls with `num_frames` frames | +| `num-skip-frames` | int | 4 | Extension calls have artifacts on the first couple frames. Discard them. | +| `session-id` | str | None | Session id to control whether to trigger a video extension call | + +## Video Extension + +The idea of video extension is to allow the user to generate further frames for the same video efficiently. This is done by the vllm-omni implementation by storing the KV-cache of the generated video by default. This way, if the next request uses the same session-id, the pipeline will enter extension mode. So, the newly generated frames will use the previously generated frames as context. This is done by storing the KV-cache as mentioned above. No frame information, whether in latent space or RGB values, is kept in the server. + +This feature is limited by the fact that the model has not been trained to perform this task. So, the steering capacity of the user is limited. Namely, the reference image and changes to the text prompt are ignored. The best tool the user has is to provide camera positions. In the end, video extension is more of a demonstration of the power and features of VLLM-Omni than of Lingbot World in itself. + +## API + +The server uses a websocket endpoint located at `/v1/realtime/world/camera`. It makes available two tasks: `infer` and `reset` which can be controlled by the "endpoint" key of the request. + +By default, the server uses the `infer` task, which checks the `session-id` field and compares it to the one used on the last infer call. If they are the same, it triggers an extension call at the pipeline level. Note that only the KV-cache of the last request is stored to mitigate Out of Memory problems at the GPU level. Otherwise, it generates the video from scratch. Notice that when doing an extension task, no reference image should be provided (it would be ignored anyway). + +The `reset` endpoint does not immediately evict the KV cache in the GPU, but instead it forces a reset on the next `infer` call independently of the value of `session-id`. + +The endpoint sends the resulting frames in groups of 4 to mitigate package loss problems. It is the client's role to concatenate the different frames to obtain the final video. + +## Example materials + +??? abstract "run_server.sh" + ``````sh + --8<-- "examples/online_serving/lingbot_world_fast/run_server.sh" + `````` +??? abstract "openai_client.py" + ``````sh + --8<-- "examples/online_serving/lingbot_world_fast/openai_client.py" + `````` diff --git a/examples/offline_inference/lingbot_world_fast/end2end.py b/examples/offline_inference/lingbot_world_fast/end2end.py index 452fd10b349..44d3df40c04 100644 --- a/examples/offline_inference/lingbot_world_fast/end2end.py +++ b/examples/offline_inference/lingbot_world_fast/end2end.py @@ -112,7 +112,6 @@ def main(): print(f" Video size: {width}x{height}") print(f"{'=' * 60}\n") - generation_start = time.perf_counter() # omni.generate() returns Generator[OmniRequestOutput, None, None] multi_modal_data = {"image": image} @@ -123,6 +122,7 @@ def main(): multi_modal_data["camera"] = {"poses": poses, "intrinsics": intrinsics} + generation_start = time.perf_counter() frames = omni.generate( { "prompt": args.prompt, diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py index e7123af2b59..81beeec2d54 100644 --- a/examples/online_serving/lingbot_world_fast/openai_client.py +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -69,6 +69,7 @@ def generate_video(args: Namespace) -> list[np.ndarray]: "num_frames": args.num_frames, "fps": args.fps, "session_id": args.session_id, + "seed": args.seed, } video = [] @@ -157,8 +158,9 @@ def main(): parser.add_argument("--height", type=int, default=480) parser.add_argument("--fps", type=int, default=16) parser.add_argument("--num-frames", type=int, default=24) - parser.add_argument("--num-calls", type=int, default=2) + parser.add_argument("--num-calls", type=int, default=1) parser.add_argument("--num-skip-frames", type=int, default=4) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") args = parser.parse_args() frames = generate_video(args) diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py index e2ae2174e6a..35e00fc03fc 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py @@ -28,6 +28,7 @@ class LingbotWorldFastState: """ def __init__(self) -> None: + self.is_initialized = False self.reset() # ------------------------------------------------------------------ @@ -36,6 +37,15 @@ def __init__(self) -> None: def reset(self) -> None: """Clear all state.""" + + if self.is_initialized: + for cache in self.kv_cache: + del cache + for cache in self.crossattn_cache: + if isinstance(cache["k"], torch.Tensor): + del cache["k"] + del cache["v"] + self.kv_cache: list[torch.Tensor] | None = None self.crossattn_cache: list[dict[str, bool | torch.Tensor | None]] | None = None self.current_start_frame: int = 0 From c7c2cf839b2320b2ed51857e5c68404348b7ecb9 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Sat, 30 May 2026 01:06:08 +0200 Subject: [PATCH 49/53] revert wan changes Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../scheduling_flow_unipc_multistep.py | 12 +- vllm_omni/diffusion/models/wan2_2/__init__.py | 3 +- .../models/wan2_2/pipeline_wan2_2.py | 564 +----------------- .../models/wan2_2/scheduling_wan_euler.py | 9 +- 4 files changed, 13 insertions(+), 575 deletions(-) diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py index 3cc6fe94833..fbf32b5c9f1 100644 --- a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py +++ b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py @@ -161,7 +161,6 @@ def set_timesteps( sigmas: list[float] | None = None, mu: float | None = None, shift: float | None = None, - sigma_start: float = 1.0, ) -> None: """ Sets the discrete timesteps used for the diffusion chain (run before inference). @@ -177,16 +176,10 @@ def set_timesteps( Parameter for dynamic shifting. shift (`float`, *optional*): Override shift parameter. - sigma_start (`float`, defaults to 1.0): - Scales the post-shift sigmas so step 0 - lands at ``sigma_start`` instead of 1.0. """ if self.config.use_dynamic_shifting and mu is None: raise ValueError("Must pass a value for `mu` when `use_dynamic_shifting` is True") - if not 0.0 < sigma_start <= 1.0: - raise ValueError(f"sigma_start must be in (0, 1], got {sigma_start}") - if sigmas is None: assert num_inference_steps is not None sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] @@ -200,9 +193,6 @@ def set_timesteps( assert isinstance(sigmas, np.ndarray) sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - if sigma_start != 1.0: - sigmas = sigmas * float(sigma_start) - if self.config.final_sigmas_type == "sigma_min": sigma_last = self.sigma_min elif self.config.final_sigmas_type == "zero": @@ -748,4 +738,4 @@ def add_noise( return noisy_samples def __len__(self) -> int: - return self.config.num_train_timesteps + return self.config.num_train_timesteps \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py index bb2ec9b9d68..2b8c37b47d0 100644 --- a/vllm_omni/diffusion/models/wan2_2/__init__.py +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -4,6 +4,7 @@ from .patch_diffusers import patch_wan_rms_norm from .pipeline_wan2_2 import ( Wan22Pipeline, + WanT2VDMD2Pipeline, create_transformer_from_config, get_wan22_post_process_func, get_wan22_pre_process_func, @@ -52,4 +53,4 @@ "WanVACETransformer3DModel", ] -patch_wan_rms_norm() +patch_wan_rms_norm() \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 4cea061986d..0ee9e20fc8b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -3,13 +3,12 @@ from __future__ import annotations -import copy import json import logging import os import time from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import Any, cast import PIL.Image import torch @@ -39,9 +38,6 @@ from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform -if TYPE_CHECKING: - from vllm_omni.diffusion.worker.utils import DiffusionRequestState - logger = logging.getLogger(__name__) DEBUG_PERF = False WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} @@ -263,9 +259,6 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: class Wan22Pipeline( nn.Module, PipelineParallelMixin, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin ): - supports_step_execution: ClassVar[bool] = True - supports_micro_step_execution: ClassVar[bool] = True - def __init__( self, *, @@ -357,10 +350,6 @@ def __init__( model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) - z_dim = self.vae.config.z_dim - self._latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, z_dim, 1, 1, 1) - self._latents_inv_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, z_dim, 1, 1, 1) - # Initialize transformers with correct config (weights loaded via load_weights) if load_transformer: transformer_config = load_transformer_config(model, "transformer", local_files_only) @@ -996,550 +985,15 @@ def check_inputs( if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") - # ── Micro-step execution (SupportsMicroStepExecution) Wan2.1 model (single-stage) ── - - def _extract_prompts( - self, - state: DiffusionRequestState, - ) -> tuple[str | None, str | None]: - """Extract prompt and negative prompt from *state*.""" - prompt: str | None = None - negative_prompt: str | None = None - if state.prompts: - p = state.prompts[0] - if isinstance(p, str): - prompt = p - else: - prompt = p.get("prompt") - negative_prompt = p.get("negative_prompt") - return prompt, negative_prompt - - def _resolve_generation_params( - self, - state: DiffusionRequestState, - ) -> dict[str, Any]: - """Extract and validate generation parameters from *state*. - - Returns a dict with resolved height, width, num_frames, guidance_low, - guidance_high, boundary_ratio, boundary_timestep, device, dtype, and - generator. - """ - sampling = state.sampling - - height = sampling.height or 480 - width = sampling.width or 832 - num_frames = sampling.num_frames or 81 - chunk_frames = sampling.chunk_frames or 8 - num_steps = sampling.num_inference_steps or 40 - - patch_size = self.transformer_config.patch_size - mod_value = self.vae_scale_factor_spatial * patch_size[1] - height = (height // mod_value) * mod_value - width = (width // mod_value) * mod_value - - if num_frames % self.vae_scale_factor_temporal != 1: - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) - - guidance_scale = sampling.guidance_scale if sampling.guidance_scale_provided else 1.0 - guidance_low = guidance_scale if isinstance(guidance_scale, (int, float)) else guidance_scale[0] - guidance_high = ( - sampling.guidance_scale_2 - if sampling.guidance_scale_2 is not None - else ( - guidance_scale[1] - if isinstance(guidance_scale, (list, tuple)) and len(guidance_scale) > 1 - else guidance_low - ) - ) - - boundary_ratio = self.boundary_ratio if self.boundary_ratio is not None else sampling.boundary_ratio - if boundary_ratio is None: - boundary_ratio = 0.875 - - device = self.device - if self.transformer is not None: - dtype = self.transformer.dtype - elif self.transformer_2 is not None: - dtype = self.transformer_2.dtype - else: - dtype = self.text_encoder.dtype - - generator = sampling.generator - if generator is None and sampling.seed is not None: - generator = torch.Generator(device=device).manual_seed(sampling.seed) - - return { - "height": height, - "width": width, - "num_frames": num_frames, - "chunk_frames": chunk_frames, - "num_steps": num_steps, - "guidance_low": guidance_low, - "guidance_high": guidance_high, - "boundary_ratio": boundary_ratio, - "boundary_timestep": boundary_ratio * 1000, # num_train_timesteps - "device": device, - "dtype": dtype, - "generator": generator, - } - - def _prepare_latent_input( - self, - state: DiffusionRequestState, - t: torch.Tensor, - dtype: torch.dtype, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare latent_model_input and timestep tensor for one denoise step. - Handles both T2V (passthrough) and I2V (condition blending + per-patch - timestep expansion) modes. - - Returns: - (latent_model_input, timestep_tensor) - """ - latent_condition = state.extra.get("latent_condition") - first_frame_mask = state.extra.get("first_frame_mask") - expand_timesteps = state.extra.get("expand_timesteps", False) - - if expand_timesteps and latent_condition is not None: - # I2V mode: blend condition with latents, expand timesteps per patch - latent_model_input = ((1 - first_frame_mask) * latent_condition + first_frame_mask * state.latents).to( - dtype - ) - patch_size = self.transformer_config.patch_size - patch_height = state.latents.shape[3] // patch_size[1] - patch_width = state.latents.shape[4] // patch_size[2] - patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]] - patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] - temp_ts = (patch_mask[0][0] * t).flatten() - timestep_tensor = temp_ts.unsqueeze(0).expand(state.latents.shape[0], -1) - else: - # T2V mode - latent_model_input = state.latents.to(dtype) - timestep_tensor = t.expand(state.latents.shape[0]) if t.ndim == 0 else t - - return latent_model_input, timestep_tensor - - def _select_model_for_timestep( - self, - t: torch.Tensor, - boundary_timestep: float | None, - ) -> tuple[nn.Module, float]: - """Return (transformer, guidance_scale) for a given timestep.""" - guidance_low = self._guidance_scale or 4.0 - guidance_high = self._guidance_scale_2 or guidance_low - if boundary_timestep is not None and t < boundary_timestep: - model = self.transformer_2 if self.transformer_2 is not None else self.transformer - return model, guidance_high - model = self.transformer if self.transformer is not None else self.transformer_2 - return model, guidance_low - - def prepare_encode( - self, - state: DiffusionRequestState, - **kwargs: Any, - ) -> DiffusionRequestState: - """One-time request setup: encode prompt, prepare latents, init scheduler.""" - # Extract prompts - prompt, negative_prompt = self._extract_prompts(state) - - params = self._resolve_generation_params(state) - height = params["height"] - width = params["width"] - num_frames = params["num_frames"] - chunk_frames = params["chunk_frames"] - device = params["device"] - dtype = params["dtype"] - generator = params["generator"] - guidance_low = params["guidance_low"] - guidance_high = params["guidance_high"] - - # Store guidance for properties and for denoise_step model selection - self._guidance_scale = guidance_low - self._guidance_scale_2 = guidance_high - - # Encode prompt - do_cfg = guidance_low > 1.0 or guidance_high > 1.0 - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=do_cfg, - num_videos_per_prompt=state.sampling.num_outputs_per_prompt or 1, - max_sequence_length=state.sampling.max_sequence_length or 512, - device=device, - dtype=dtype, - ) - - # Multi-modal inputs (I2V image, V2V video) - multi_modal_data = ( - state.prompts[0].get("multi_modal_data", {}) - if state.prompts and not isinstance(state.prompts[0], str) - else None - ) - raw_image = multi_modal_data.get("image", None) if multi_modal_data else None - if isinstance(raw_image, list): - raw_image = raw_image[0] - - v2v_video = multi_modal_data.get("video", None) if multi_modal_data else None - noise_scale = float((state.sampling.extra_args or {}).get("noise_scale", 0.8)) - sigma_start = noise_scale if (v2v_video is not None and 0.0 < noise_scale < 1.0) else 1.0 - - # Scheduler - self.scheduler.set_timesteps(params["num_steps"], device=device, sigma_start=sigma_start) - req_scheduler = copy.deepcopy(self.scheduler) - - latent_condition = None - first_frame_mask = None - - if v2v_video is not None: - # V2V mode - num_channels_latents = self.transformer_config.in_channels - latents = self.prepare_latents( - batch_size=prompt_embeds.shape[0], - num_channels_latents=num_channels_latents, - height=height, - width=width, - num_frames=chunk_frames, - dtype=self.vae.dtype, - device="meta", # NOTE: stream mode; latents are prepared per-chunk in `encode_chunk_inputs` - generator=None, - latents=state.sampling.latents, - ) - elif self.expand_timesteps and raw_image is not None: - # I2V mode - from diffusers.video_processor import VideoProcessor - - video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - if isinstance(raw_image, str): - image = PIL.Image.open(raw_image) - else: - image = cast(PIL.Image.Image | torch.Tensor, raw_image) - - if isinstance(image, PIL.Image.Image): - image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) - image_tensor = video_processor.preprocess(image, height=height, width=width) - else: - image_tensor = image +# --------------------------------------------------------------------------- +# DMD2-distilled variant +# --------------------------------------------------------------------------- - num_channels_latents = self.transformer_config.out_channels - batch_size = prompt_embeds.shape[0] - - latents = self.prepare_latents( - batch_size=batch_size, - num_channels_latents=num_channels_latents, - height=height, - width=width, - num_frames=num_frames, - dtype=torch.float32, - device=device, - generator=generator, - latents=state.sampling.latents, - ) - - image_tensor = image_tensor.unsqueeze(2).to(device=device, dtype=self.vae.dtype) - latent_condition = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latent_condition.device, latent_condition.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latent_condition.device, latent_condition.dtype - ) - latent_condition = ((latent_condition - latents_mean) * latents_std).to(torch.float32) - - num_latent_frames = latents.shape[2] - latent_height = latents.shape[3] - latent_width = latents.shape[4] - first_frame_mask = torch.ones( - 1, 1, num_latent_frames, latent_height, latent_width, dtype=torch.float32, device=device - ) - first_frame_mask[:, :, 0] = 0 - else: - # T2V mode - num_channels_latents = self.transformer_config.in_channels - latents = self.prepare_latents( - batch_size=prompt_embeds.shape[0], - num_channels_latents=num_channels_latents, - height=height, - width=width, - num_frames=num_frames, - dtype=torch.float32, - device=device, - generator=generator, - latents=state.sampling.latents, - ) - - # Populate state - state.prompt_embeds = prompt_embeds - state.negative_prompt_embeds = negative_prompt_embeds - state.latents = latents - state.timesteps = req_scheduler.timesteps - state.step_index = 0 - state.scheduler = req_scheduler - state.do_true_cfg = do_cfg and negative_prompt_embeds is not None - state.guidance = torch.tensor([guidance_low], device=device) - - state.extra["guidance_low"] = guidance_low - state.extra["guidance_high"] = guidance_high - state.extra["boundary_timestep"] = params["boundary_timestep"] - state.extra["expand_timesteps"] = self.expand_timesteps - state.extra["latent_condition"] = latent_condition - state.extra["first_frame_mask"] = first_frame_mask - state.extra["height"] = height - state.extra["width"] = width - - return state - - def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: - from vllm_omni.diffusion.distributed.parallel_state import get_pp_group - - pp_group = get_pp_group() - if pp_group.world_size == 1: - return - - slo_fps = getattr(state.sampling, "slo_fps", None) - slo_max_batch = getattr(state.sampling, "slo_max_batch", 1) - slo_max_batch = max(1, slo_max_batch if slo_fps else 1) - num_inference_steps = state.sampling.num_inference_steps or 4 - - _, channels, latent_chunk_frames, height, width = state.latents.shape - p_t, p_h, p_w = self.transformer_config.patch_size - inner_dim = self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim - it_dtype = (self.transformer or self.transformer_2).dtype - latents_dtype = state.latents.dtype - - seq_len = (latent_chunk_frames // p_t) * (height // p_h) * (width // p_w) - - cfg_branches = 2 if state.do_true_cfg else 1 - - for batch_size in range(1, slo_max_batch * num_inference_steps + 1): - latents_template = { - "latents": torch.empty(batch_size, channels, latent_chunk_frames, height, width, dtype=latents_dtype, device="meta") - } - it_template = { - "hidden_states": torch.empty(batch_size, seq_len, inner_dim, dtype=it_dtype, device="meta") - } - pp_group.set_recv_dict_buffer("latents", -1, latents_template, batch_size=batch_size) - for seg in range(cfg_branches): - pp_group.set_recv_dict_buffer("intermediate", seg, it_template, batch_size=batch_size) - - def denoise_step( - self, - state: DiffusionRequestState, - batch_size: int = 1, - **kwargs: Any, - ) -> torch.Tensor | None: - """Run one denoising iteration. - - When ``state.batched_timesteps`` is set (stream-batch path with - multiple chunks at different ``step_index`` fused along dim 0), - it overrides ``current_timestep`` and ``_prepare_latent_input`` - forwards the per-row timesteps directly to the transformer. - Model selection uses the per-row max so a batch straddling the - high/low noise boundary picks the high-noise transformer (NOTE correct - for single-stage Wan2.1). - """ - t = state.batched_timesteps if state.batched_timesteps is not None else state.current_timestep - self._current_timestep = t - - boundary_timestep = state.extra.get("boundary_timestep") - t_select = t.max() if t.ndim > 0 else t - current_model, current_guidance_scale = self._select_model_for_timestep(t_select, boundary_timestep) - - latent_model_input, timestep = self._prepare_latent_input(state, t, current_model.dtype) - - do_true_cfg = current_guidance_scale > 1.0 and state.negative_prompt_embeds is not None - - # Stream-batch path fuses B chunks along dim 0 - B = latent_model_input.shape[0] - encoder_pos = state.prompt_embeds - if encoder_pos is not None and encoder_pos.shape[0] == 1 and B > 1: - encoder_pos = encoder_pos.expand(B, *encoder_pos.shape[1:]).contiguous() - encoder_neg = state.negative_prompt_embeds - if encoder_neg is not None and encoder_neg.shape[0] == 1 and B > 1: - encoder_neg = encoder_neg.expand(B, *encoder_neg.shape[1:]).contiguous() - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": encoder_pos, - "attention_kwargs": {}, - "return_dict": False, - "current_model": current_model, - } - negative_kwargs = ( - { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": encoder_neg, - "attention_kwargs": {}, - "return_dict": False, - "current_model": current_model, - } - if do_true_cfg - else None - ) - preposted_its = state.extra.pop("preposted_its", None) - - return self.predict_noise_maybe_with_cfg( - do_true_cfg=do_true_cfg, - true_cfg_scale=current_guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - buf_idx=state.step_index % 2, - batch_size=batch_size, - preposted_its=preposted_its, - ) - - def prefetch_tensors(self, state: DiffusionRequestState, batch_size: int = 1) -> None: - """Prefetch next-step tensors: latents on first rank, ITs on others.""" - do_true_cfg = state.do_true_cfg - buf_idx = state.step_index % 2 - - preposted = self.prefetch_tensors_maybe_with_cfg( - do_true_cfg=do_true_cfg, - buf_idx=buf_idx, - batch_size=batch_size, - ) - - if isinstance(preposted, AsyncLatents): - state.latents = preposted - elif preposted is not None: - state.extra["preposted_its"] = preposted - - def step_scheduler( - self, - state: DiffusionRequestState, - noise_pred: torch.Tensor, - *, - per_request_scheduler: Any | list[Any] | None = None, - batch_size: int = 1, - **kwargs: Any, - ) -> None: - """Run one scheduler step: update latents and advance step_index. - - ``per_request_scheduler`` may be a single scheduler (B=1) or a list of - per-chunk schedulers (B>1, last rank loops one ``step()`` per row). - """ - t = state.batched_timesteps if state.batched_timesteps is not None else state.current_timestep - do_true_cfg = state.do_true_cfg - - if per_request_scheduler is None: - per_request_scheduler = state.scheduler - - state.latents = self.scheduler_step_maybe_with_cfg( - noise_pred, - t, - state.latents, - do_true_cfg, - per_request_scheduler=per_request_scheduler, - batch_size=batch_size, - ) - - state.step_index += 1 - - def post_decode( - self, - state: DiffusionRequestState, - **kwargs: Any, - ) -> DiffusionOutput: - """Decode final latents after denoising completes.""" - self._sync_pp_send() - self._current_timestep = None - - # I2V: blend final latents with condition - latent_condition = state.extra.get("latent_condition") - first_frame_mask = state.extra.get("first_frame_mask") - if state.extra.get("expand_timesteps") and latent_condition is not None: - state.latents = (1 - first_frame_mask) * latent_condition + first_frame_mask * state.latents - - latents = state.latents.to(self.vae.dtype) - - latents_mean = self._latents_mean.to(latents.device, latents.dtype) - latents_inv_std = self._latents_inv_std.to( - latents.device, latents.dtype - ) - - latents = latents / latents_inv_std + latents_mean - output = self.vae.decode(latents, return_dict=False)[0] - - return DiffusionOutput( - output=output, - stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, - ) - - def encode_chunk_inputs( - self, - state: DiffusionRequestState, - new_idxs: list[int], - ) -> torch.Tensor: - """Streaming V2V initial latents (StreamDiffusionV2-style). - - For each newly admitted chunk: VAE-encode the source frames at - ``chunk_idx * chunk_frames : (idx+1) * chunk_frames``, normalize, and - linearly blend with Gaussian noise via ``extra_args["noise_scale"]`` - (default 0.8). The transformer then runs the full schedule starting at - ``step_index=0`` on this partially-noised latent. - """ - batch_size = len(new_idxs) - noise_scale = float((state.sampling.extra_args or {}).get("noise_scale", 0.8)) - chunk_frames = state.sampling.chunk_frames - prompt = state.prompts[0] if state.prompts else None - video = None - if prompt is not None and not isinstance(prompt, str): - video = (prompt.get("multi_modal_data") or {}).get("video") - if video is None: - raise ValueError( - "encode_chunk_inputs requires V2V source frames in prompts[0]['multi_modal_data']['video']" - ) - - latents_mean = self._latents_mean.to(self.device) - latents_inv_std = self._latents_inv_std.to(self.device) - - controls: list[torch.Tensor] = [] - for idx in new_idxs: - start = idx * chunk_frames - end = start + chunk_frames - if isinstance(video, list): - frames = torch.stack(list(video[start:end]), dim=0) - else: - frames = video[start:end] - - if frames.shape[0] < chunk_frames: - pad = torch.zeros( - (chunk_frames - frames.shape[0], *frames.shape[1:]), - dtype=frames.dtype, device=frames.device, - ) - frames = torch.cat([frames, pad], dim=0) - - # [n, C, H, W] -> [1, C, n, H, W] - controls.append( - frames.permute(1, 0, 2, 3) - .unsqueeze(0) - .to(device=self.device, dtype=self.vae.dtype) - ) - control = torch.cat(controls, dim=0) - - latents = self.prepare_latents( - batch_size=batch_size, - num_channels_latents=self.transformer_config.in_channels, - height=state.extra["height"], - width=state.extra["width"], - num_frames=chunk_frames, - dtype=self.vae.dtype, - device=self.device, - generator=state.sampling.generator, - latents=None, - ) - latent_condition = retrieve_latents(self.vae.encode(control), sample_mode="argmax") - latent_condition = ( - (latent_condition.float() - latents_mean.to(latent_condition.dtype)) - * latents_inv_std.to(latent_condition.dtype) - ).to(self.vae.dtype) +class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): + """Wan 2.x T2V pipeline for FastGen DMD2-distilled models.""" - return latents * noise_scale + latent_condition * (1.0 - noise_scale) + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py index 0a7651113c7..4a285b8b12f 100644 --- a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py +++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py @@ -96,7 +96,6 @@ def set_timesteps( self, num_inference_steps: int, device: torch.device | str | int | None = None, - sigma_start: float = 1.0, **kwargs, # noqa: ARG002 - kept for scheduler API compatibility ) -> None: timesteps = _get_timesteps( @@ -108,12 +107,6 @@ def set_timesteps( device=device or self.device, ) self.set_shift(self._shift) - # scale shifted sigmas so step 0 lands at sigma_start - if sigma_start != 1.0: - if not 0.0 < sigma_start <= 1.0: - raise ValueError(f"sigma_start must be in (0, 1], got {sigma_start}") - self.sigmas = self.sigmas * float(sigma_start) - self.timesteps = self.sigmas[:-1] * self.num_train_timesteps self._step_index = None self._begin_index = None @@ -151,4 +144,4 @@ def step( return WanEulerSchedulerOutput(prev_sample=prev_sample) def __len__(self) -> int: - return self.num_train_timesteps + return self.num_train_timesteps \ No newline at end of file From 21a1cbd0327468e441997d4777b916ee35f831aa Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Sat, 30 May 2026 01:10:57 +0200 Subject: [PATCH 50/53] fixes Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- vllm_omni/diffusion/diffusion_engine.py | 8 ++ .../distributed/group_coordinator.py | 45 ++++----- .../distributed/pipeline_parallel.py | 31 +++++-- vllm_omni/diffusion/registry.py | 24 +++++ .../diffusion/sched/stream_batch_scheduler.py | 91 ++++++------------- .../worker/diffusion_model_runner.py | 33 ++++--- vllm_omni/diffusion/worker/utils.py | 1 + vllm_omni/inputs/data.py | 3 +- 8 files changed, 128 insertions(+), 108 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index b15b0d7bb29..075581b30f9 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -41,6 +41,7 @@ from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, RunnerOutput from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.diffusion.registry import apply_required_sampling_overrides logger = init_logger(__name__) @@ -599,6 +600,10 @@ def make_engine( return DiffusionEngine(config, scheduler=scheduler) def add_request(self, request: OmniDiffusionRequest) -> str: + apply_required_sampling_overrides( + request.sampling_params, self.od_config.model_class_name, + ) + with self._cv: if self._closed: raise RuntimeError("DiffusionEngine is closed.") @@ -627,6 +632,9 @@ async def async_add_req_and_wait_for_response(self, request: OmniDiffusionReques return await self.get_result(sched_req_id) def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> DiffusionOutput: + apply_required_sampling_overrides( + request.sampling_params, self.od_config.model_class_name, + ) with self._rpc_lock: if self._closed: raise RuntimeError("DiffusionEngine is closed.") diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 1bde474e6f4..2e79341ab8f 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -976,13 +976,10 @@ def _isend_dict_schema( send_size_tensor = torch.tensor( [payload_tensor.numel()], device=self.device, dtype=torch.int64 ) - # batch_isend_irecv (not plain isend) — plain P2P on size-2 PG - # triggers lazy sub-comm creation that requires the peer present. - ops = [ - torch.distributed.P2POp(torch.distributed.isend, send_size_tensor, self.next_rank, send_group), - torch.distributed.P2POp(torch.distributed.isend, payload_tensor, self.next_rank, send_group), + handles = [ + torch.distributed.isend(send_size_tensor, dst=self.next_rank, group=send_group), + torch.distributed.isend(payload_tensor, dst=self.next_rank, group=send_group), ] - handles = list(torch.distributed.batch_isend_irecv(ops)) return handles, [send_size_tensor, payload_tensor] def _recv_dict_schema(self) -> list[tuple[str, Any]]: @@ -993,15 +990,9 @@ def _recv_dict_schema(self) -> list[tuple[str, Any]]: self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group ) recv_size_tensor = torch.empty(1, device=self.device, dtype=torch.int64) - for req in torch.distributed.batch_isend_irecv( - [torch.distributed.P2POp(torch.distributed.irecv, recv_size_tensor, self.prev_rank, recv_group)] - ): - req.wait() + torch.distributed.recv(recv_size_tensor, src=self.prev_rank, group=recv_group) recv_payload = torch.empty(int(recv_size_tensor.item()), device=self.device, dtype=torch.uint8) - for req in torch.distributed.batch_isend_irecv( - [torch.distributed.P2POp(torch.distributed.irecv, recv_payload, self.prev_rank, recv_group)] - ): - req.wait() + torch.distributed.recv(recv_payload, src=self.prev_rank, group=recv_group) return pickle.loads(recv_payload.cpu().numpy().tobytes()) def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: @@ -1071,6 +1062,11 @@ def pipeline_isend_tensor_dict( compute_done = self._record_compute_event() comms = self.comms_stream + group = ( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ) with self._comms_stream_ctx(): if comms is not None and compute_done is not None: comms.wait_event(compute_done) @@ -1080,7 +1076,9 @@ def pipeline_isend_tensor_dict( tensor = tensor.contiguous() if tensor.is_cuda and comms is not None: tensor.record_stream(comms) - handles.append(self._pipeline_isend(tensor)) + handles.append( + torch.distributed.isend(tensor, dst=self.next_rank, group=group) + ) return handles def pipeline_irecv_tensor_dict( @@ -1118,6 +1116,11 @@ def pipeline_irecv_tensor_dict( tensor_dict: dict[str, Any] = {} handles: list[torch.distributed.Work] = [] + group = ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ) with self._comms_stream_ctx(): for k, value in metadata_list: if isinstance(value, TensorMetadata): @@ -1131,7 +1134,9 @@ def pipeline_irecv_tensor_dict( tensor = buffers[k] if tensor.is_cuda and comms is not None: tensor.record_stream(comms) - handles.append(self._pipeline_irecv(tensor)) + handles.append( + torch.distributed.irecv(tensor, src=self.prev_rank, group=group) + ) _update_nested_dict(tensor_dict, k, tensor) else: _update_nested_dict(tensor_dict, k, value) @@ -1158,16 +1163,12 @@ def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.T return self.recv_buffer[name][idx] def _pipeline_irecv(self, tensor: torch.tensor): - # batch_isend_irecv (not plain irecv) — plain P2P on size-2 PG - # triggers lazy sub-comm creation that requires the peer present. group = self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group - op = torch.distributed.P2POp(torch.distributed.irecv, tensor, self.prev_rank, group) - return torch.distributed.batch_isend_irecv([op])[0] + return torch.distributed.irecv(tensor, src=self.prev_rank, group=group) def _pipeline_isend(self, tensor: torch.tensor): group = self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group - op = torch.distributed.P2POp(torch.distributed.isend, tensor, self.next_rank, group) - return torch.distributed.batch_isend_irecv([op])[0] + return torch.distributed.isend(tensor, dst=self.next_rank, group=group) def set_skip_tensor_recv_buffer( self, diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index 576ea22082c..9f4407ef7f8 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -280,23 +280,41 @@ def scheduler_step_maybe_with_cfg( latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, per_request_scheduler: Any | list[Any] | None = None, + generator: torch.Generator | None = None, batch_size: int = 1, + receive_latents: bool = True, + buf_idx: int = 0, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Drop-in replacement for scheduler_step_maybe_with_cfg that also handles PP. - Only the last rank runs the scheduler (it already has noise_pred) and - sends the result back to rank 0. + Only the last rank runs the scheduler (it already has noise_pred); the result + is sent to rank 0 which needs it for the next forward pass. + + If `receive_latents` is True, returns a ``AsyncLatents`` on rank 0 that transparently defers + ``handle.wait()`` until the tensor is actually consumed (via attribute + access or a torch operation), keeping the rank non-blocking after the + ``irecv`` is posted. """ if get_pipeline_parallel_world_size() == 1: - return self._scheduler_step_local(noise_pred, t, latents, do_true_cfg, per_request_scheduler) + return self._scheduler_step_local( + noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator, + ) pp_group = get_pp_group() if pp_group.is_last_rank: - latents = self._scheduler_step_local(noise_pred, t, latents, do_true_cfg, per_request_scheduler) + latents = self._scheduler_step_local( + noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator, + ) self._pp_send_work = pp_group.pipeline_isend_tensor_dict( {"latents": latents}, name="latents", batch_size=batch_size, ) + if pp_group.is_first_rank and receive_latents: + latents = AsyncLatents( + *pp_group.pipeline_irecv_tensor_dict( + name="latents", buf_idx=buf_idx, batch_size=batch_size, + ) + ) return latents def _scheduler_step_local( @@ -306,18 +324,19 @@ def _scheduler_step_local( latents: torch.Tensor, do_true_cfg: bool, per_request_scheduler: Any | list[Any] | None, + generator: torch.Generator | None = None, ) -> torch.Tensor: """Run scheduler.step on this rank — single call or per-chunk loop.""" if not isinstance(per_request_scheduler, list): return super().scheduler_step_maybe_with_cfg( - noise_pred, t, latents, do_true_cfg, per_request_scheduler, + noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator=generator, ) new_rows: list[torch.Tensor] = [] for i, sched in enumerate(per_request_scheduler): t_i = t[i] if t.ndim > 0 else t new_rows.append( super().scheduler_step_maybe_with_cfg( - noise_pred[i:i + 1], t_i, latents[i:i + 1], do_true_cfg, sched, + noise_pred[i:i + 1], t_i, latents[i:i + 1], do_true_cfg, sched, generator=generator, ) ) return torch.cat(new_rows, dim=0) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 063a2105cd4..e268b0c40d1 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +from typing import Any import torch.nn as nn from vllm.logger import init_logger @@ -596,3 +597,26 @@ def get_diffusion_pre_process_func(od_config: OmniDiffusionConfig): return None # Return None if no pre-processing function is registered (for backward compatibility) func_name = _DIFFUSION_PRE_PROCESS_FUNCS[od_config.model_class_name] return _load_process_func(od_config, func_name) + + +_STREAM_BATCH_OVERRIDE_ATTRS: dict[str, str] = { + "chunk_frames": "STREAM_BATCH_CHUNK_FRAMES", + "num_inference_steps": "STREAM_BATCH_NUM_INFERENCE_STEPS", +} + +def apply_required_sampling_overrides(sampling: Any, model_class_name: str) -> None: + """Overwrite sampling-param fields that the model has hard requirements on.""" + pipeline_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if pipeline_cls is None: + return + for field, attr in _STREAM_BATCH_OVERRIDE_ATTRS.items(): + required = getattr(pipeline_cls, attr, None) + if required is None: + continue + current = getattr(sampling, field, None) + if current != required: + logger.warning( + "%s requires sampling.%s=%s, got %r. Overriding.", + model_class_name, field, required, current, + ) + setattr(sampling, field, required) diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py index 6c27ca39065..9970372c976 100644 --- a/vllm_omni/diffusion/sched/stream_batch_scheduler.py +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -1,11 +1,10 @@ """Temporal-pipeline-parallel scheduler for streaming chunked diffusion. Each ``schedule()`` call corresponds to one micro-step. The pipeline is modeled -as ``pp_size`` per-rank chunk queues plus a transient ``returning`` queue. -At each schedule(), chunks at rank N-1 drain (finished -> Layout finished -slice, otherwise -> returning), queues shift one rank, and rank 0 receives the -returning chunks plus B fresh admits drawn from the source video frames in -``prompts[0]["multi_modal_data"]["video"]``. +as ``pp_size`` per-rank chunk queues. At each schedule(), chunks at rank N-1 +drain (finished -> Layout finished slice, otherwise -> circulating back to +rank 0), queues shift one rank, and rank 0 receives the circulating chunks +plus B fresh admits up to the request's output chunk target. """ from __future__ import annotations @@ -14,7 +13,6 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -import torch from vllm.logger import init_logger from vllm_omni.diffusion.data import OmniDiffusionConfig @@ -33,26 +31,6 @@ logger = init_logger(__name__) -def _video_frame_count(request: OmniDiffusionRequest) -> int: - """Number of frames currently in ``prompts[0]["multi_modal_data"]["video"]``.""" - if not request.prompts: - return 0 - prompt = request.prompts[0] - if isinstance(prompt, str): - return 0 - multi_modal = prompt.get("multi_modal_data") or {} - video = multi_modal.get("video") - if video is None: - return 0 - if isinstance(video, torch.Tensor): - return int(video.shape[0]) - if isinstance(video, list): - return len(video) - raise TypeError( - f"multi_modal_data['video'] must be a Tensor or list of Tensors; got {type(video).__name__}." - ) - - @dataclass class _InFlightChunk: chunk_idx: int @@ -64,22 +42,17 @@ class _Progress: sched_req_id: str pp_size: int chunk_frames: int - num_frames: int + num_chunks: int num_steps: int - frames_committed: int = 0 next_chunk_idx: int = 0 batch_size: int = 0 - + # chunks that will be processed by rank r at the current micro-step chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) # rank r's layout — constructed at rank 0 and shifted forward each step layouts_at: list[Layout] = field(default_factory=list) - @property - def output_chunks_target(self) -> int: - return self.num_frames // self.chunk_frames - @dataclass class _SLOReqState: @@ -87,6 +60,7 @@ class _SLOReqState: max_batch: int chunk_frames: int batch_size: int = 1 + warmed_up: bool = False class _SLOController: @@ -116,9 +90,14 @@ def get_target(self, sched_req_id: str) -> int: st = self._reqs.get(sched_req_id) return st.batch_size if st is not None else 1 + def mark_warmed_up(self, sched_req_id: str) -> None: + st = self._reqs.get(sched_req_id) + if st is not None: + st.warmed_up = True + def observe(self, sched_req_id: str, latency_ns: int | None, b_current: int | None) -> None: st = self._reqs.get(sched_req_id) - if st is None or latency_ns is None or latency_ns <= 0 or b_current is None or b_current <= 0: + if st is None or not st.warmed_up or latency_ns is None or latency_ns <= 0 or b_current is None or b_current <= 0: return budget = (b_current * st.chunk_frames / st.slo_fps) * 1e9 @@ -146,12 +125,11 @@ class StreamBatchScheduler(_BaseScheduler): Per micro-step: 1. Promote waiting requests (handled by the base class). 2. Drain rank N-1: finished chunks -> finished slice in - Layout, otherwise -> returning queue. + Layout, otherwise -> circulating back to rank 0. 3. Shift per-rank queues by one (rank r <- rank r-1). - 4. Rank 0 = returning + B fresh admits, where - `B = min(B_target, queue_chunks_available, output_chunks_remaining)`. - 5. Emit per-rank assignment with Layout attached to every RankTask. Flip - req state RUNNING -> BLOCKED when admission is starved on input. + 4. Rank 0 = circulating + B fresh admits, where + `B = min(B_target, output_chunks_remaining)`. + 5. Emit per-rank assignment with Layout attached to every RankTask. """ def __init__(self) -> None: @@ -184,8 +162,8 @@ def add_request(self, request: OmniDiffusionRequest) -> str: raise ValueError( f"chunk_frames must be a positive int when stream_batch=True, got {sampling.chunk_frames}" ) - if sampling.num_frames is None or sampling.num_frames <= 0: - raise ValueError(f"num_frames must be a positive int, got {sampling.num_frames}") + if sampling.num_chunks is None or sampling.num_chunks <= 0: + raise ValueError(f"num_chunks must be a positive int, got {sampling.num_chunks}") if sampling.num_inference_steps is None or sampling.num_inference_steps <= 0: raise ValueError( f"num_inference_steps must be a positive int, got {sampling.num_inference_steps}" @@ -216,13 +194,13 @@ def schedule(self) -> DiffusionSchedulerOutput: def _init_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: sampling = req.sampling_params chunk_frames = sampling.chunk_frames - num_frames = sampling.num_frames + num_chunks = sampling.num_chunks num_steps = sampling.num_inference_steps self._progress[sched_req_id] = _Progress( sched_req_id=sched_req_id, chunk_frames=chunk_frames, - num_frames=num_frames, + num_chunks=num_chunks, num_steps=num_steps, pp_size=self.pp_size, chunks_at=[deque() for _ in range(self.pp_size)], @@ -241,8 +219,8 @@ def _init_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: logger.debug( "StreamBatchScheduler initialized progress for %s " - "(chunk_frames=%d, num_frames=%d, num_steps=%d, slo_fps=%s, pp_size=%d)", - sched_req_id, chunk_frames, num_frames, num_steps, sampling.slo_fps, self.pp_size, + "(chunk_frames=%d, num_chunks=%d, num_steps=%d, slo_fps=%s, pp_size=%d)", + sched_req_id, chunk_frames, num_chunks, num_steps, sampling.slo_fps, self.pp_size, ) def _advance_chunk_pipeline(self, progress: _Progress) -> None: @@ -272,18 +250,14 @@ def _advance_chunk_pipeline(self, progress: _Progress) -> None: for chunk in circulating: progress.chunks_at[0].append(chunk) - state = self.get_request_state(progress.sched_req_id) - available_frames = _video_frame_count(state.req) if state is not None else 0 - queue_chunks = max(0, (available_frames - progress.frames_committed) // progress.chunk_frames) - output_chunks_remaining = progress.output_chunks_target - progress.next_chunk_idx + output_chunks_remaining = progress.num_chunks - progress.next_chunk_idx b_target = self._slo.get_target(progress.sched_req_id) - batch_size = min(b_target, queue_chunks, output_chunks_remaining) + batch_size = min(b_target, output_chunks_remaining) new_idxs: list[int] = [] for _ in range(batch_size): chunk_idx = progress.next_chunk_idx progress.next_chunk_idx += 1 - progress.frames_committed += progress.chunk_frames progress.chunks_at[0].append(_InFlightChunk(chunk_idx=chunk_idx)) new_idxs.append(chunk_idx) progress.batch_size = batch_size @@ -295,19 +269,8 @@ def _advance_chunk_pipeline(self, progress: _Progress) -> None: new_idxs=new_idxs, ) - # 5. Flip RUNNING -> BLOCKED if input-starved and we still owe output. - if ( - batch_size == 0 - and output_chunks_remaining > 0 - and queue_chunks == 0 - and progress.sched_req_id in self._running - ): - self.block_request(progress.sched_req_id) - logger.debug( - "StreamBatchScheduler: %s BLOCKED on input " - "(committed_frames=%d, target_frames=%d, available_frames=%d)", - progress.sched_req_id, progress.frames_committed, progress.num_frames, available_frames, - ) + if finished_idxs: + self._slo.mark_warmed_up(progress.sched_req_id) def _build_assignment(self) -> list[RankTask]: assert len(self._progress) <= 1 #TODO: support multiple requests diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 9a70b6024c3..0219cd73a57 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -542,6 +542,7 @@ def execute_micro_step(self, sched_output: DiffusionSchedulerOutput) -> RunnerOu layout = task.layout if is_new_request: + state.extra.pop("chunks", None) pp_group.reset_buffer() self.pipeline.prepare_encode(state) self.pipeline.set_pp_recv_dict_buffers(state) @@ -558,6 +559,8 @@ def execute_micro_step(self, sched_output: DiffusionSchedulerOutput) -> RunnerOu self._prepare_chunk_latents(state, layout, is_first_rank=pp_group.is_first_rank) if chunk_idxs: + state.extra["current_chunk_idxs"] = chunk_idxs + chunks: list[ChunkState] = [ self._get_or_create_chunk(state, idx)[0] for idx in chunk_idxs ] @@ -627,10 +630,11 @@ def _prepare_chunk_latents(self, state: DiffusionRequestState, layout: Layout, i pieces.append(chunk.latents) if layout.new_idxs: + for idx in layout.new_idxs: + self._get_or_create_chunk(state, idx) encoded = self.pipeline.encode_chunk_inputs(state, layout.new_idxs) for i, idx in enumerate(layout.new_idxs): - chunk, _ = self._get_or_create_chunk(state, idx) - chunk.latents = encoded[i : i + 1] + state.extra["chunks"][idx].latents = encoded[i : i + 1] pieces.append(encoded) state.latents = torch.cat(pieces, dim=0) if pieces else None @@ -642,17 +646,15 @@ def _update_decoded_chunks(self, state: DiffusionRequestState, layout: Layout) - state.latents = state.latents[:n_finished] decoded = self.pipeline.post_decode(state) state.latents = saved - + state.extra.setdefault("decoded_chunks", []).append(decoded) - state.extra["num_chunks_decoded"] = ( - state.extra.get("num_chunks_decoded", 0) + n_finished + state.extra["chunks_decoded"] = ( + state.extra.get("chunks_decoded", 0) + n_finished ) - output_chunks_target = state.sampling.num_frames // state.sampling.chunk_frames - - if state.extra.get("num_chunks_decoded", 0) >= output_chunks_target: + if state.extra.get("chunks_decoded", 0) >= state.sampling.num_chunks: return self._merge_chunk_outputs(state.extra["decoded_chunks"]) - + return None def _update_state_after(self, state: DiffusionRequestState, layout: Layout, finished: bool = False): @@ -664,18 +666,19 @@ def _update_state_after(self, state: DiffusionRequestState, layout: Layout, fini @staticmethod def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: - """Merge decoded chunk outputs into a single video tensor. + """Merge decoded chunk outputs along the temporal axis. - Each entry's ``.output`` is ``[B_i, C, T, H, W]``. - Concat along the batch dim then unroll into the temporal axis ``[1, C, total_chunks * T, H, W]``. + Supports both: + - 5D ``[B, C, T, H, W]`` (Wan-style): time axis = dim 2. + - 4D ``[C, T, H, W]`` (lingbot-style): time axis = dim 1. NOTE: This is a temporary solution until streaming output is supported. """ try: - cat0 = torch.cat([c.output for c in chunks], dim=0) - B, C, T, H, W = cat0.shape - merged = cat0.permute(1, 0, 2, 3, 4).reshape(1, C, B * T, H, W) + outputs = [c.output for c in chunks] + time_dim = outputs[0].dim() - 3 + merged = torch.cat(outputs, dim=time_dim) except Exception as e: return DiffusionOutput(error=f"Failed to merge {len(chunks)} chunk outputs: {e}") return DiffusionOutput(output=merged) \ No newline at end of file diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index ccf13d01fef..a214a6aa6dd 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -125,6 +125,7 @@ class ChunkState: latents: torch.Tensor | None = None step_index: int = 0 scheduler: Any | None = None + extra: dict[str, Any] = field(default_factory=dict) class BaseRunnerOutput(ABC): diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 3d7a4e689d2..dd10d36a6b8 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -220,7 +220,8 @@ class OmniDiffusionSamplingParams: width_latents: list[int] | int | None = None num_frames: int = 1 # Default for image models num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus - chunk_frames: int = 8 # Used when stream_batch=True + chunk_frames: int = 5 # Used when stream_batch=True + num_chunks: int = 1 # SLO-adaptive stream batching. ``slo_fps=None`` keeps B_target fixed at 1. slo_fps: float | None = None From 234841c1daf95354cd365dd99ec0f41f0a0cb609 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Sat, 30 May 2026 01:12:44 +0200 Subject: [PATCH 51/53] add stream batch support to lingbot Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../lingbot_world_fast/flow_scheduler.py | 67 +++ .../pipeline_lingbot_world_fast.py | 516 +++++++++++++++++- .../models/lingbot_world_fast/stream_vae.py | 82 +++ .../models/lingbot_world_fast/wan_fast.py | 360 +++++++----- 4 files changed, 889 insertions(+), 136 deletions(-) create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py create mode 100644 vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py b/vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py new file mode 100644 index 00000000000..9f30b20876b --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import torch + +from .fm_solvers_unipc import FlowUniPCMultistepScheduler + + +class LingbotFlowScheduler: + def __init__( + self, + inner: FlowUniPCMultistepScheduler, + timesteps5: torch.Tensor, + ) -> None: + self._inner = inner + # Length-5 schedule: [t0, t1, t2, t3, 0]. + self.timesteps = timesteps5 + # Used by `_convert_flow_pred_to_x0` to look up sigma_t. + self.sigmas = inner.sigmas + self._full_timesteps = inner.timesteps + + def step( + self, + noise_pred: torch.Tensor, + t: torch.Tensor, + latents: torch.Tensor, + return_dict: bool = False, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor]: + # `t` is a per-row scalar (`_scheduler_step_local` loops per row). + if float(t.item()) == 0.0: + return (latents,) + + x0 = self._convert_flow_pred_to_x0(noise_pred, latents, t) + + ts_eq = (self.timesteps == t).nonzero(as_tuple=False) + chunk_step = int(ts_eq[0].item()) if ts_eq.numel() > 0 else 0 + + if chunk_step + 1 < self.timesteps.shape[0] - 1: + next_t = self.timesteps[chunk_step + 1] + noise = torch.randn( + x0.shape, generator=generator, device=x0.device, dtype=x0.dtype + ) + return (self._inner.add_noise(x0, noise, next_t),) + return (x0,) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + return self._inner.add_noise(original_samples, noise, timesteps) + + def _convert_flow_pred_to_x0( + self, + flow_pred: torch.Tensor, + xt: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + original_dtype = flow_pred.dtype + flow_pred, xt, sigmas, timesteps = map( + lambda x: x.double().to(flow_pred.device), + [flow_pred, xt, self.sigmas, self._full_timesteps], + ) + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + return (xt - sigma_t * flow_pred).to(original_dtype) \ No newline at end of file diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py index 73c0b803126..2637276dd3f 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -1,9 +1,11 @@ +import copy import logging import math import os import random import sys from contextlib import contextmanager +from typing import Any, ClassVar import numpy as np import torch @@ -12,12 +14,24 @@ from einops import rearrange from torch import nn from tqdm import tqdm +from vllm.sequence import IntermediateTensors from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + get_pipeline_parallel_world_size, + get_pp_group, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from vllm_omni.diffusion.distributed.pipeline_parallel import ( + AsyncLatents, + PipelineParallelMixin, +) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.models.interface import SupportCameraPosInput, SupportImageInput from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.worker.utils import DiffusionRequestState from .cam_utils import ( compute_relative_poses, @@ -25,14 +39,17 @@ get_plucker_embeddings, interpolate_camera_poses, ) +from .flow_scheduler import LingbotFlowScheduler from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .state_lingbot_world_fast import LingbotWorldFastState +from .stream_vae import StreamVAE from .t5 import T5EncoderModel from .vae2_1 import Wan2_1_VAE from .wan_fast import WanModelFast logger = logging.getLogger(__name__) + CONFIG = { "text_len": 512, "num_train_timesteps": 1000, @@ -70,7 +87,19 @@ def post_process_func( return post_process_func -class LingbotWorldFastPipeline(nn.Module, SupportImageInput, SupportCameraPosInput, CFGParallelMixin): +class LingbotWorldFastPipeline( + nn.Module, + SupportImageInput, + SupportCameraPosInput, + PipelineParallelMixin, + CFGParallelMixin, +): + supports_step_execution: ClassVar[bool] = True + supports_micro_step_execution: ClassVar[bool] = True + + STREAM_BATCH_CHUNK_FRAMES: ClassVar[int] = 12 + STREAM_BATCH_NUM_INFERENCE_STEPS: ClassVar[int] = 5 + def __init__(self, *, od_config: OmniDiffusionConfig): super().__init__() self.od_config = od_config @@ -83,7 +112,7 @@ def __init__(self, *, od_config: OmniDiffusionConfig): self.control_type = "cam" self.num_train_timesteps = CONFIG["num_train_timesteps"] - self.sp_size = od_config.parallel_config.world_size + self.sp_size = od_config.parallel_config.sequence_parallel_size self.state = LingbotWorldFastState() @@ -93,14 +122,17 @@ def __init__(self, *, od_config: OmniDiffusionConfig): self.text_encoder = T5EncoderModel( text_len=CONFIG["text_len"], dtype=self.target_dtype, - device=torch.device("cpu"), + device=self.device, checkpoint_path=os.path.join(checkpoint_path, CONFIG["t5_checkpoint"]), tokenizer_path=os.path.join(checkpoint_path, CONFIG["t5_tokenizer"]), ) self.vae_stride = CONFIG["vae_stride"] self.patch_size = CONFIG["patch_size"] - self.vae = Wan2_1_VAE(vae_pth=os.path.join(checkpoint_path, CONFIG["vae_checkpoint"]), device=self.device) + base_vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_path, CONFIG["vae_checkpoint"]), device=self.device + ) + self.vae = StreamVAE(base_vae) if od_config.stream_batch else base_vae logger.info(f"Creating WanModelFast from {checkpoint_path}") self.model = WanModelFast.from_pretrained( @@ -109,6 +141,8 @@ def __init__(self, *, od_config: OmniDiffusionConfig): torch_dtype=torch.bfloat16, control_type=self.control_type, ).to(self.device) + # Partition transformer across PP ranks (no-op at PP=1). + self.model.apply_pp_split() self.scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False @@ -265,8 +299,7 @@ def forward( self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) timesteps = self.scheduler.timesteps[CONFIG["timesteps_index"]] - context = self.text_encoder([prompt], torch.device("cpu")) - context = [t.to(self.device) for t in context] + context = self.text_encoder([prompt], self.device) dit_cond_dict = None Ks = torch.from_numpy(camera.get("intrinsics")) @@ -464,5 +497,476 @@ def noop_no_sync(): return DiffusionOutput(output=videos[0]) + # ------------------------------------------------------------------ + # micro-step execution + # ------------------------------------------------------------------ + + def predict_noise( + self, + intermediate_tensors: IntermediateTensors | None = None, + **kwargs: Any, + ) -> torch.Tensor | IntermediateTensors: + """Single transformer forward; returns IntermediateTensors on non-last PP stages.""" + with torch.amp.autocast("cuda", dtype=self.target_dtype): + result = self.model(**kwargs, intermediate_tensors=intermediate_tensors) + if isinstance(result, IntermediateTensors): + return result + # Last stage returns List[Tensor] (one per row); stack along dim 0. + return torch.stack(result, dim=0) + + def prepare_encode( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> DiffusionRequestState: + """One-time request setup mirroring forward()'s prep up to the chunk loop. + + Stashes per-chunk noise / conditioning / Plucker tensors in state.extra, + initializes (or extends) the model's persistent KV caches sized to this + rank's owned layer slice, and exposes state.timesteps as length 5 + (4 denoise + 1 t=0 KV-update). + """ + if not state.prompts or len(state.prompts) > 1: + raise ValueError("LingbotWorldFastPipeline only supports a single prompt.") + + sampling = state.sampling + if sampling.chunk_frames != 3*4: + logger.warning( + "LingbotWorldFastPipeline requires chunk_frames=3*4=12, got %s. Overriding.", + sampling.chunk_frames, + ) + sampling.chunk_frames = 12 + if sampling.num_inference_steps != 5: + logger.warning( + "LingbotWorldFastPipeline requires num_inference_steps=4+1=5, got %s. Overriding.", + sampling.num_inference_steps, + ) + sampling.num_inference_steps = 5 + + prompt = state.prompts[0].get("prompt") + multi_modal_data = state.prompts[0].get("multi_modal_data", {}) or {} + + extra_args = state.sampling.extra_args or {} + session_id = str(extra_args.get("session_id") or None) + force_reset = bool(extra_args.get("force_reset") or False) + + if force_reset or self.state.session_id is None or self.state.session_id != session_id: + self.state.reset() + self.state.session_id = session_id + extension = False + else: + extension = True + + camera = multi_modal_data.get("camera", None) + if camera is None: + raise ValueError("LingbotWorldFastPipeline requires camera poses in multi_modal_data['camera'].") + + if extension: + assert multi_modal_data.get("image") is None, ( + "image must not be provided on extension calls; it is only used on the first call of a session" + ) + assert self.model.config.local_attn_size == -1, ( + "video extension requires the model to be configured with local_attn_size == -1" + ) + + batch_size = 1 + c2ws = camera.get("poses") + chunk_size = CONFIG["chunk_size"] + max_area = CONFIG["max_area"] + + new_lat_f = max(sampling.num_chunks * chunk_size, 1) + if extension: + num_frames = new_lat_f * 4 + else: + num_frames = (new_lat_f - 1) * 4 + 1 + if len(c2ws) < num_frames: + raise ValueError( + f"camera trajectory has {len(c2ws)} poses; need >= {num_frames} " + f"for {sampling.num_chunks} chunks (chunk_size={chunk_size})." + ) + c2ws = c2ws[:num_frames] + + if not extension: + img = multi_modal_data.get("image") + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1] + ) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2] + ) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + else: + img = None + h, w, lat_h, lat_w = self.state.h, self.state.w, self.state.lat_h, self.state.lat_w + + max_seq_len = chunk_size * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = state.sampling.seed + if seed is None: + seed = random.randint(0, sys.maxsize) + # Two separate generators to keep noise consistent across PP ranks: + # - seed_g: chunk-initial noise. + # - seed_g_addnoise: scheduler.add_noise consumed on last rank only. + seed_g = torch.Generator(device=self.device).manual_seed(seed) + seed_g_addnoise = torch.Generator(device=self.device).manual_seed(seed + 1) + + # Sampler timesteps (4 denoise + 1 t=0 KV-update) + self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) + denoise_timesteps = self.scheduler.timesteps[CONFIG["timesteps_index"]].to(self.device) + timesteps5 = torch.cat([denoise_timesteps, denoise_timesteps.new_zeros(1)], dim=0) + + # Text + camera Plucker + context_list = self.text_encoder([prompt], self.device) + + Ks_raw = torch.from_numpy(camera.get("intrinsics")) + Ks_t = get_Ks_transformed( + Ks_raw, height_org=480, width_org=832, height_resize=h, width_resize=w, height_final=h, width_final=w + )[0] + len_c2ws_orig = len(c2ws) + tgt_indices_full = np.linspace(0, len_c2ws_orig - 1, new_lat_f) + c2ws_infer_full = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws_orig - 1, len_c2ws_orig), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=tgt_indices_full, + ) + c2ws_infer_full = compute_relative_poses(c2ws_infer_full, framewise=True) + c2ws_infer_full = c2ws_infer_full.to(self.device).to(torch.float32) + Ks_t = Ks_t.to(self.device).to(torch.float32) + + anchor_latent: torch.Tensor | None = None + if is_pipeline_first_stage(): + self.vae.reset() + if not extension: + anchor_pixels = ( + torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic") + .transpose(0, 1) + .to(self.device) + ) + anchor_latent = self.vae.init(anchor_pixels) # [16, 1, lat_h, lat_w] + else: + zero_frame = torch.zeros(3, 1, h, w, device=self.device) + self.vae.init(zero_frame) + + # KV cache sizing — per this rank's owned layer slice. + model_args = self.model.config + transformer_dtype = self.target_dtype + frame_seqlen = int(lat_h * lat_w // 4) + extra_kv_size = frame_seqlen * new_lat_f + head_dim = model_args.dim // model_args.num_heads + local_num_heads = model_args.num_heads // self.sp_size + owned_num_layers = self.model.end_layer - self.model.start_layer + + if not extension: + self.state.create_kv_caches( + batch_size, + transformer_dtype, + self.device, + extra_kv_size, + owned_num_layers, + local_num_heads, + head_dim, + ) + else: + self.state.extend_kv_caches(extra_kv_size) + + prev_lat_f = self.state.current_lat_f + total_kv_size = frame_seqlen * (prev_lat_f + new_lat_f) + start_token_offset = prev_lat_f * frame_seqlen + + # State population. + state.prompt_embeds = None # unused; lingbot keeps text as raw list[Tensor] + state.latents = None # per-chunk latents are stacked by encode_chunk_inputs + state.timesteps = timesteps5 + state.step_index = 0 + state.scheduler = LingbotFlowScheduler(self.scheduler, timesteps5) + state.do_true_cfg = False + + state.extra["context"] = context_list + state.extra["anchor_latent"] = anchor_latent + state.extra["start_token_offset"] = start_token_offset + state.extra["max_attention_size"] = total_kv_size + state.extra["frame_seqlen"] = frame_seqlen + state.extra["max_seq_len"] = max_seq_len + state.extra["chunk_size"] = chunk_size + state.extra["lat_h"] = lat_h + state.extra["lat_w"] = lat_w + state.extra["h"] = h + state.extra["w"] = w + state.extra["new_lat_f"] = new_lat_f + state.extra["extension"] = extension + state.extra["seed_g"] = seed_g + state.extra["seed_g_addnoise"] = seed_g_addnoise + + state.extra["c2ws_infer_full"] = c2ws_infer_full + state.extra["Ks_transformed"] = Ks_t + + return state + + def encode_chunk_inputs( + self, + state: DiffusionRequestState, + new_idxs: list[int], + ) -> torch.Tensor: + """Build per-chunk noise, plus VAE-encoded y and Plucker on first stage.""" + seed_g = state.extra["seed_g"] + chunk_size = state.extra["chunk_size"] + lat_h = state.extra["lat_h"] + lat_w = state.extra["lat_w"] + h = state.extra["h"] + w = state.extra["w"] + chunks = state.extra["chunks"] + B = len(new_idxs) + + # noise + noise = torch.randn( + B, 16, chunk_size, lat_h, lat_w, + dtype=torch.float32, generator=seed_g, device=self.device, + ) + + if not is_pipeline_first_stage(): + return noise + + c2ws_infer_full = state.extra["c2ws_infer_full"] + Ks_t = state.extra["Ks_transformed"] + anchor_latent: torch.Tensor | None = state.extra["anchor_latent"] + extension: bool = state.extra["extension"] + + # per-chunk stream-encode y + per-chunk msk + for idx in new_idxs: + is_anchor_chunk = (not extension) and idx == 0 + if is_anchor_chunk: + tail_frames = 4 * (chunk_size - 1) + if tail_frames > 0: + zeros = torch.zeros(3, tail_frames, h, w, device=self.device) + tail_lat = self.vae.encode(zeros) + assert anchor_latent is not None + vae_lat = torch.cat([anchor_latent, tail_lat], dim=1) + else: + assert anchor_latent is not None + vae_lat = anchor_latent + else: + zeros = torch.zeros(3, 4 * chunk_size, h, w, device=self.device) + vae_lat = self.vae.encode(zeros) + + msk_chunk = torch.zeros(4, chunk_size, lat_h, lat_w, device=self.device) + if is_anchor_chunk: + msk_chunk[:, 0] = 1 + chunks[idx].extra["y"] = torch.cat([msk_chunk, vae_lat], dim=0) + + # plucker + frame_indices = torch.tensor( + [ci * chunk_size + f for ci in new_idxs for f in range(chunk_size)], + device=c2ws_infer_full.device, dtype=torch.long, + ) + batched_c2ws = c2ws_infer_full[frame_indices] # [B*chunk_size, 3, 4] + batched_Ks = Ks_t.repeat(B * chunk_size, 1) # [B*chunk_size, 4] + batched_plucker = get_plucker_embeddings(batched_c2ws, batched_Ks, h, w, only_rays_d=False) + batched_plucker = rearrange( + batched_plucker, + "f (h c1) (w c2) c -> f h w (c c1 c2)", + c1=int(h // lat_h), + c2=int(w // lat_w), + ) + batched_plucker = batched_plucker.view(B, chunk_size, lat_h, lat_w, -1) + batched_plucker = batched_plucker.permute(0, 4, 1, 2, 3).contiguous().to(self.target_dtype) + + for i, idx in enumerate(new_idxs): + chunks[idx].extra["plucker"] = batched_plucker[i : i + 1] + + return noise + + def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: + if get_pipeline_parallel_world_size() == 1: + return + + pp_group = get_pp_group() + slo_fps = getattr(state.sampling, "slo_fps", None) + slo_max_batch = getattr(state.sampling, "slo_max_batch", 1) + slo_max_batch = max(1, slo_max_batch if slo_fps else 1) + + chunk_size = state.extra["chunk_size"] + lat_h = state.extra["lat_h"] + lat_w = state.extra["lat_w"] + max_seq_len = state.extra["max_seq_len"] + n_steps = int(state.timesteps.shape[0]) + + latents_dtype = torch.float32 + it_dtype = self.target_dtype + + for batch_size in range(1, slo_max_batch * n_steps + 1): + latents_template = { + "latents": torch.empty( + batch_size, 16, chunk_size, lat_h, lat_w, dtype=latents_dtype, device="meta" + ) + } + it_template = { + "hidden_states": torch.empty( + batch_size, max_seq_len, self.model.dim, dtype=it_dtype, device="meta" + ), + "grid_sizes": torch.empty(batch_size, 3, dtype=torch.long, device="meta"), + "seq_lens": torch.empty(batch_size, dtype=torch.long, device="meta"), + "c2ws_plucker_emb": torch.empty( + batch_size, max_seq_len, self.model.dim, dtype=it_dtype, device="meta" + ), + } + pp_group.set_recv_dict_buffer("latents", -1, latents_template, batch_size=batch_size) + pp_group.set_recv_dict_buffer("intermediate", 0, it_template, batch_size=batch_size) + + def denoise_step( + self, + state: DiffusionRequestState, + batch_size: int = 1, + **kwargs: Any, + ) -> torch.Tensor | None: + """Fused transformer forward for the batch of chunks. + + Each row's per-chunk metadata (``current_starts``, ``y``, ``c2ws_plucker_emb``) + is read from state.extra keyed by chunk index. Rows whose timestep is 0 + carry the KV-update payload (the chunk's saved x0) — their output is + ignored by ``step_scheduler``. + """ + chunk_idxs: list[int] = state.extra["current_chunk_idxs"] + assert len(chunk_idxs) == batch_size + + chunk_size = state.extra["chunk_size"] + frame_seqlen = state.extra["frame_seqlen"] + start_token_offset = state.extra["start_token_offset"] + chunks = state.extra["chunks"] + context_list = state.extra["context"] + + x_list, y_list, plucker_list = None, None, None + if is_pipeline_first_stage(): + x_list = [state.latents[i] for i in range(batch_size)] + y_list = [chunks[ci].extra["y"] for ci in chunk_idxs] + plucker_list = [chunks[ci].extra["plucker"] for ci in chunk_idxs] + + current_starts = [ + start_token_offset + ci * chunk_size * frame_seqlen for ci in chunk_idxs + ] + + positive_kwargs = { + "x": x_list, + "t": state.batched_timesteps, + "context": [context_list[0]] * batch_size, + "seq_len": state.extra["max_seq_len"], + "y": y_list, + "dit_cond_dict": {"c2ws_plucker_emb": plucker_list}, + "kv_cache": self.state.get_kv_caches(), + "local_end_index": self.state.local_end_index, + "global_end_index": self.state.global_end_index, + "crossattn_cache": self.state.get_crossattn_caches(), + "current_starts": current_starts, + "max_attention_size": state.extra["max_attention_size"], + } + + preposted_its = state.extra.pop("preposted_its", None) + return self.predict_noise_maybe_with_cfg( + do_true_cfg=False, + true_cfg_scale=1.0, + positive_kwargs=positive_kwargs, + negative_kwargs=None, + buf_idx=state.step_index % 2, + batch_size=batch_size, + preposted_its=preposted_its, + ) + + def step_scheduler( + self, + state: DiffusionRequestState, + noise_pred: torch.Tensor, + *, + per_request_scheduler: Any | list[Any] | None = None, + batch_size: int = 1, + **kwargs: Any, + ) -> None: + if per_request_scheduler is None: + per_request_scheduler = state.scheduler + + state.latents = self.scheduler_step_maybe_with_cfg( + noise_pred, + state.batched_timesteps, + state.latents, + do_true_cfg=False, + per_request_scheduler=per_request_scheduler, + generator=state.extra["seed_g_addnoise"], + batch_size=batch_size, + receive_latents=False, + ) + state.step_index += 1 + + def prefetch_tensors( + self, + state: DiffusionRequestState, + batch_size: int = 1, + **kwargs: Any, + ) -> None: + if get_pipeline_parallel_world_size() == 1: + return + buf_idx = state.step_index % 2 + preposted = self.prefetch_tensors_maybe_with_cfg( + do_true_cfg=False, buf_idx=buf_idx, batch_size=batch_size, + ) + if isinstance(preposted, AsyncLatents): + state.latents = preposted + elif preposted is not None: + state.extra["preposted_its"] = preposted + + def post_decode( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> DiffusionOutput: + """VAE-decode the finished chunks with prior-tail warmup. + + Mirrors forward()'s decode block: on extension calls prepend the prior + chunk's last 2 latents to prime the temporal-causal feat_map, then drop + ``4*k - 3`` leading pixels. After decoding, refresh ``last_decoded_latent`` + with the tail of the new latents so the next call's decode is warm. + """ + self._sync_pp_send() + pred_latent_chunks = state.latents.transpose(0, 1).reshape( + state.latents.shape[1], + state.latents.shape[0] * state.latents.shape[2], + state.latents.shape[3], + state.latents.shape[4], + ) + # pred_latent_chunks: [16, B*chunk_size, lat_h, lat_w] + + extension = state.extra["extension"] + if self.state.last_decoded_latent is not None: + warmup = self.state.last_decoded_latent.to( + pred_latent_chunks.device, pred_latent_chunks.dtype + ) + k = warmup.shape[1] + drop = 4 * k - 3 + to_decode = torch.cat([warmup, pred_latent_chunks], dim=1) + videos = self.vae.decode([to_decode]) + videos = [v[:, drop:] for v in videos] + else: + videos = self.vae.decode([pred_latent_chunks]) + + self.state.last_decoded_latent = pred_latent_chunks[:, -2:].detach().clone() + + sampling = state.sampling + chunks_so_far = state.extra.get("chunks_decoded", 0) + chunks_this_call = state.latents.shape[0] + is_final = chunks_so_far + chunks_this_call >= sampling.num_chunks + if is_final: + if not extension: + self.state.h = state.extra["h"] + self.state.w = state.extra["w"] + self.state.lat_h = state.extra["lat_h"] + self.state.lat_w = state.extra["lat_w"] + self.state.frame_seqlen = state.extra["frame_seqlen"] + self.state.advance(state.extra["new_lat_f"]) + + return DiffusionOutput(output=videos[0]) + def load_weights(self, weights): pass diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py b/vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py new file mode 100644 index 00000000000..8c6c7557a06 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py @@ -0,0 +1,82 @@ +"""Per-chunk streaming VAE encode wrapper around ``Wan2_1_VAE``.""" + +from __future__ import annotations + +import torch +import torch.cuda.amp as amp + +from .vae2_1 import Wan2_1_VAE + + +class StreamVAE: + def __init__(self, vae: Wan2_1_VAE) -> None: + self._vae = vae + self._model = vae.model + self._scale = vae.scale + self._dtype = vae.dtype + + def reset(self) -> None: + """Clear encoder feat_map cache. Call at the start of each new request.""" + self._model.clear_cache() + + def decode(self, zs): + return self._vae.decode(zs) + + @property + def dtype(self) -> torch.dtype: + return self._vae.dtype + + @torch.no_grad() + def init(self, frame: torch.Tensor) -> torch.Tensor: + """Encode the single init pixel frame and return its latent. + + Caller keeps the latent for fresh starts (anchor encoding) or + discards it for extension starts (just sets up the init bias). + + Args: + frame: ``[C, 1, H, W]`` or ``[B, C, 1, H, W]``. + Returns: + ``[z_dim, 1, H, W]`` latent. + """ + with amp.autocast(dtype=self._dtype): + pixels = frame.unsqueeze(0) if frame.dim() == 4 else frame + out = self._encode_group(pixels) + mu = self._apply_conv1_and_normalize(out) + return mu.float().squeeze(0) + + @torch.no_grad() + def encode(self, pixels: torch.Tensor) -> torch.Tensor: + """Encode ``4*N`` pixel frames using the preserved state. + + Args: + pixels: ``[C, 4N, H, W]`` or ``[B, C, 4N, H, W]``. + Returns: + ``[z_dim, N, H, W]`` latents. + """ + pixels = pixels.unsqueeze(0) if pixels.dim() == 4 else pixels + T = pixels.shape[2] + assert T % 4 == 0, f"StreamVAE.encode expects a multiple of 4 pixel frames, got {T}" + N = T // 4 + with amp.autocast(dtype=self._dtype): + outs = [self._encode_group(pixels[:, :, i * 4 : (i + 1) * 4]) for i in range(N)] + out = torch.cat(outs, dim=2) + mu = self._apply_conv1_and_normalize(out) + return mu.float().squeeze(0) + + # ── internals ────────────────────────────────────────────────────────── + + def _encode_group(self, pixels: torch.Tensor) -> torch.Tensor: + """One ``(1|4)``-frame encoder pass using the live ``_enc_feat_map``.""" + self._model._enc_conv_idx = [0] + return self._model.encoder( + pixels, feat_cache=self._model._enc_feat_map, feat_idx=self._model._enc_conv_idx + ) + + def _apply_conv1_and_normalize(self, out: torch.Tensor) -> torch.Tensor: + mu, _ = self._model.conv1(out).chunk(2, dim=1) + z = self._model.z_dim + if isinstance(self._scale[0], torch.Tensor): + mu = (mu - self._scale[0].view(1, z, 1, 1, 1)) * self._scale[1].view(1, z, 1, 1, 1) + else: + mu = (mu - self._scale[0]) * self._scale[1] + return mu \ No newline at end of file diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py index 1011c2bdf54..89b05b65fe9 100644 --- a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py @@ -8,30 +8,47 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from einops import rearrange +from vllm.model_executor.models.utils import PPMissingLayer +from vllm.sequence import IntermediateTensors from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.distributed.parallel_state import ( + get_pipeline_parallel_rank, + get_pipeline_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) from .state_lingbot_world_fast import CacheIndex from .wan_model import WanLayerNorm, WanRMSNorm, WanSelfAttention, rope_params, sinusoidal_embedding_1d -def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): +def causal_rope_apply(x, grid_sizes, freqs, start_frames=0): + """Apply causal rotary position embedding per batch row. + + start_frames: int or list[int] of per-row frame offsets. + An int broadcasts to all rows. + """ n, c = x.size(2), x.size(3) // 2 # split freqs freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + if isinstance(start_frames, int): + start_frames = [start_frames] * grid_sizes.shape[0] + # loop over samples output = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): + sf = start_frames[i] seq_len = f * h * w # precompute multipliers x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) freqs_i = torch.cat( [ - freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[0][sf : sf + f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), ], @@ -84,7 +101,7 @@ def forward( kv_cache=None, local_end_index=None, global_end_index=None, - current_start=0, + current_starts=0, max_attention_size=1_000_000, ): r""" @@ -92,10 +109,15 @@ def forward( x(Tensor): Shape [B, L, num_heads, C / num_heads] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] - block_mask (BlockMask) + current_starts(int | list[int]): per-row absolute token offset; an int + broadcasts to all rows. """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + if isinstance(current_starts, int): + current_starts = [current_starts] * b + assert len(current_starts) == b + # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)).view(b, s, n, d) @@ -106,54 +128,74 @@ def qkv_fn(x): q, k, v = qkv_fn(x) frame_seqlen = math.prod(grid_sizes[0][1:]).item() - current_start_frame = current_start // frame_seqlen - roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) - roped_key = causal_rope_apply(k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) - current_end = current_start + roped_query.shape[1] - sink_tokens = self.sink_size * frame_seqlen - # If we are using local attention and the current KV cache size is larger than the local attention size, - # then we need to truncate the KV cache - kv_cache_size = kv_cache[CacheIndex.K].shape[1] + start_frames = [cs // frame_seqlen for cs in current_starts] + roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frames=start_frames).type_as(v) + roped_key = causal_rope_apply(k, grid_sizes, freqs, start_frames=start_frames).type_as(v) num_new_tokens = roped_query.shape[1] - if ( - self.local_attn_size != -1 - and (current_end > global_end_index.item()) - and (num_new_tokens + local_end_index.item() > kv_cache_size) - ): - # Calculate the number of new tokens added in this step - # Shift existing cache content left to discard oldest tokens - # Clone the source slice to avoid overlapping memory error - num_evicted_tokens = num_new_tokens + local_end_index.item() - kv_cache_size - num_rolled_tokens = local_end_index.item() - num_evicted_tokens - sink_tokens - kv_cache[CacheIndex.K][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.K][ - :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens - ].clone() - kv_cache[CacheIndex.V][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.V][ - :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens - ].clone() - # Insert the new keys/values at the end - new_local_end_index = local_end_index.item() + current_end - global_end_index.item() - num_evicted_tokens - local_start_index = new_local_end_index - num_new_tokens - kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key - kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v - else: - # Assign new keys/values directly up to current_end - new_local_end_index = local_end_index.item() + current_end - global_end_index.item() + + if self.local_attn_size != -1: + # Cache-rolling path only supports single-row processing. + assert b == 1, "local_attn_size != -1 requires batch_size=1" + current_start = current_starts[0] + current_end = current_start + num_new_tokens + sink_tokens = self.sink_size * frame_seqlen + kv_cache_size = kv_cache[CacheIndex.K].shape[1] + + if (current_end > global_end_index.item()) and ( + num_new_tokens + local_end_index.item() > kv_cache_size + ): + num_evicted_tokens = num_new_tokens + local_end_index.item() - kv_cache_size + num_rolled_tokens = local_end_index.item() - num_evicted_tokens - sink_tokens + kv_cache[CacheIndex.K][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.K][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + kv_cache[CacheIndex.V][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.V][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + new_local_end_index = ( + local_end_index.item() + current_end - global_end_index.item() - num_evicted_tokens + ) + else: + new_local_end_index = local_end_index.item() + current_end - global_end_index.item() + local_start_index = new_local_end_index - num_new_tokens kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v - k_cache = kv_cache[CacheIndex.K][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] - v_cache = kv_cache[CacheIndex.V][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] - x = self.attn(roped_query, k_cache, v_cache) + k_cache = kv_cache[CacheIndex.K][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + v_cache = kv_cache[CacheIndex.V][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + out = self.attn(roped_query, k_cache, v_cache) - global_end_index.fill_(current_end) - local_end_index.fill_(new_local_end_index) + global_end_index.fill_(current_end) + local_end_index.fill_(new_local_end_index) + else: + # local_attn_size == -1: per-row writes to non-overlapping cache slots, + # per-row attention reads sized by max_attention_size. Loops once per + # batch row inside attention to avoid needing a key-padding mask. + outs = [] + max_end = 0 + for i in range(b): + cs_i = current_starts[i] + ce_i = cs_i + num_new_tokens + kv_cache[CacheIndex.K][:, cs_i:ce_i] = roped_key[i : i + 1] + kv_cache[CacheIndex.V][:, cs_i:ce_i] = v[i : i + 1] + + kv_start_i = max(0, ce_i - max_attention_size) + k_cache_i = kv_cache[CacheIndex.K][:, kv_start_i:ce_i] + v_cache_i = kv_cache[CacheIndex.V][:, kv_start_i:ce_i] + + outs.append(self.attn(roped_query[i : i + 1], k_cache_i, v_cache_i)) + if ce_i > max_end: + max_end = ce_i + + out = torch.cat(outs, dim=0) + global_end_index.fill_(max_end) + local_end_index.fill_(max_end) # output - x = x.flatten(2) - x = self.o(x) - return x + out = out.flatten(2) + out = self.o(out) + return out class WanCrossAttention(WanSelfAttention): @@ -183,13 +225,18 @@ def forward(self, x, context, context_lens, crossattn_cache=None): if crossattn_cache is not None: if not crossattn_cache.get("is_init", False): crossattn_cache["is_init"] = True - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) + # Cache at B=1 (text context is shared across chunks in a batch); + # expand on retrieval to match q's batch size for variable-B calls. + k = self.norm_k(self.k(context[:1])).view(1, -1, n, d) + v = self.v(context[:1]).view(1, -1, n, d) crossattn_cache[CacheIndex.K] = k crossattn_cache[CacheIndex.V] = v else: k = crossattn_cache[CacheIndex.K] v = crossattn_cache[CacheIndex.V] + if k.shape[0] != b: + k = k.expand(b, *k.shape[1:]) + v = v.expand(b, *v.shape[1:]) else: k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) @@ -248,7 +295,7 @@ def forward( local_end_index=None, global_end_index=None, crossattn_cache=None, - current_start=0, + current_starts=0, max_attention_size=1_000_000, ): r""" @@ -257,6 +304,7 @@ def forward( e(Tensor): Shape [B, F, 6, C] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + current_starts(int | list[int]): per-row absolute token offset; int broadcasts. """ assert e.dtype == torch.float32 with torch.amp.autocast("cuda", dtype=torch.float32): @@ -271,7 +319,7 @@ def forward( kv_cache, local_end_index, global_end_index, - current_start, + current_starts, max_attention_size, ) with torch.amp.autocast("cuda", dtype=torch.float32): @@ -449,6 +497,10 @@ def __init__( # head self.head = CausalHead(dim, out_dim, patch_size, eps) + # PP layout — defaults to single-stage; apply_pp_split() refines after loading. + self.start_layer = 0 + self.end_layer = num_layers + # buffers (don't use register_buffer otherwise dtype will be changed in to()) assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 d = dim // num_heads @@ -460,6 +512,40 @@ def __init__( # initialize weights self.init_weights() + def apply_pp_split(self) -> None: + """Partition the model across PP ranks. Called after weight loading. + + After this returns, blocks outside this rank's [start_layer, end_layer) + slice are replaced with PPMissingLayer(); embeddings/head are kept only + on the first/last stage. KV-cache sizing (in the pipeline state) reads + end_layer - start_layer to allocate just for the owned slice. + """ + pp_world = get_pipeline_parallel_world_size() + if pp_world <= 1: + self.start_layer = 0 + self.end_layer = self.num_layers + return + + rank = get_pipeline_parallel_rank() + per_rank = self.num_layers // pp_world + rem = self.num_layers % pp_world + # Even split: extra layers go to the first `rem` ranks. + self.start_layer = rank * per_rank + min(rank, rem) + self.end_layer = self.start_layer + per_rank + (1 if rank < rem else 0) + + for i in range(self.num_layers): + if not (self.start_layer <= i < self.end_layer): + self.blocks[i] = PPMissingLayer() + + if not is_pipeline_first_stage(): + self.patch_embedding = PPMissingLayer() + self.patch_embedding_wancamctrl = PPMissingLayer() + self.c2ws_hidden_states_layer1 = PPMissingLayer() + self.c2ws_hidden_states_layer2 = PPMissingLayer() + + if not is_pipeline_last_stage(): + self.head = PPMissingLayer() + def forward( self, x, @@ -472,106 +558,110 @@ def forward( local_end_index=None, global_end_index=None, crossattn_cache=None, - current_start=0, + current_starts=0, max_attention_size=1_000_000, + intermediate_tensors: IntermediateTensors | None = None, ): r""" Run the diffusion model with kv caching. - See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. - This function will be run for num_frame times. - Process the latent frames one by one (1560 tokens each) + + On the first PP stage, ``x``/``y``/``dit_cond_dict`` are consumed to build + the token sequence; non-first stages take ``hidden_states`` (and the + camera-conditioned ``c2ws_plucker_emb`` if used) from ``intermediate_tensors``. + Non-last stages return an ``IntermediateTensors`` carrying ``hidden_states`` + (plus ``c2ws_plucker_emb`` so downstream stages can do cam injection). Args: - x (List[Tensor]): - List of input video tensors, each with shape [C_in, F, H, W] - t (Tensor): - Diffusion timesteps tensor of shape [B] - context (List[Tensor]): - List of text embeddings each with shape [L, C] - seq_len (`int`): - Maximum sequence length for positional encoding - y (List[Tensor], *optional*): - Conditional video inputs for image-to-video mode, same shape as x - dit_cond_dict (`dict`, *optional*, defaults to None): - Dictionary of conditioning signals. May contain key ``c2ws_plucker_emb`` - with camera Plucker embeddings of shape [B, C, F, H, W] for camera control. - kv_cache (`list[dict]`, *optional*, defaults to None): - Per-layer self-attention KV cache. Each dict contains keys ``k``, ``v`` - (Tensor of shape [B, kv_size, num_heads, head_dim]), ``global_end_index``, - and ``local_end_index`` (scalar Tensors tracking cache position). - crossattn_cache (`list[dict]`, *optional*, defaults to None): - Per-layer cross-attention KV cache. Each dict contains keys ``k``, ``v`` - (Tensor of shape [B, text_len, num_heads, head_dim]) and ``is_init`` (bool). - current_start (`int`, *optional*, defaults to 0): - Token offset of the current chunk in the full sequence. Used to index - into the KV cache and compute positional embeddings correctly. - max_attention_size (`int`, *optional*, defaults to 1_000_000): - Maximum number of KV tokens each query can attend to. Limits the - effective context window of self-attention to control memory usage. + current_starts (int | list[int]): per-row absolute token offset. + int broadcasts to all rows. + intermediate_tensors: per-stage hidden state from the previous PP rank. Returns: - List[Tensor]: - List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + list[Tensor] on last PP stage; IntermediateTensors elsewhere. """ - if self.model_type == "i2v": + if self.model_type == "i2v" and is_pipeline_first_stage(): assert y is not None # params - device = self.patch_embedding.weight.device + first_stage = is_pipeline_first_stage() + last_stage = is_pipeline_last_stage() + # `freqs` lives as a plain attribute (not a buffer) — move it to the + # device of the first parameter we can find on this stage. + first_param = next(self.parameters()) + device = first_param.device if self.freqs.device != device: self.freqs = self.freqs.to(device) - if y is not None: - x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + if first_stage: + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long, device=device) for u in x] + ) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device) + assert seq_lens.max() <= seq_len + x = torch.cat(x) + else: + assert intermediate_tensors is not None, "non-first PP stage requires intermediate_tensors" + x = intermediate_tensors["hidden_states"] + grid_sizes = intermediate_tensors["grid_sizes"] + seq_lens = intermediate_tensors["seq_lens"] + + B = x.shape[0] + s = x.shape[1] - # embeddings - x = [self.patch_embedding(u.unsqueeze(0)) for u in x] - grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) - x = [u.flatten(2).transpose(1, 2) for u in x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - assert seq_lens.max() <= seq_len - x = torch.cat(x) - - # time embeddings - if t.dim() == 1: - t = t.expand(t.size(0), seq_lens) + # Per-row time embeddings: same timestep replicated across this row's tokens. with torch.amp.autocast("cuda", dtype=torch.float32): - bt = t.size(0) - t = t.flatten() - e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_lens)).float()) + if t.dim() == 1: + t_full = t.unsqueeze(1).expand(B, s).contiguous() + else: + t_full = t + bt, btn = t_full.shape + t_flat = t_full.flatten() + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t_flat).unflatten(0, (bt, btn)).float() + ) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 - # context + # context — text embedding runs on every stage (each block has cross-attn). context_lens = None context = self.text_embedding( torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) ) - # cam - if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: - c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] - c2ws_plucker_emb = [ - rearrange( - i, - "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", - c1=self.patch_size[0], - c2=self.patch_size[1], - c3=self.patch_size[2], + # cam Plucker — processed on first stage, then forwarded via intermediate_tensors + # so downstream stages re-use the same embedding for in-block cam injection. + if first_stage: + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_plucker_emb = [ + rearrange( + i, + "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", + c1=self.patch_size[0], + c2=self.patch_size[1], + c3=self.patch_size[2], + ) + for i in c2ws_plucker_emb + ] + c2ws_plucker_emb = torch.cat(c2ws_plucker_emb, dim=0) + + c2ws_plucker_emb = self.patch_embedding_wancamctrl(c2ws_plucker_emb) + c2ws_hidden_states = self.c2ws_hidden_states_layer2( + torch_F.silu(self.c2ws_hidden_states_layer1(c2ws_plucker_emb)) ) - for i in c2ws_plucker_emb - ] - c2ws_plucker_emb = torch.cat(c2ws_plucker_emb, dim=1) # [1, (L1+...+Ln), C] - - c2ws_plucker_emb = self.patch_embedding_wancamctrl(c2ws_plucker_emb) - c2ws_hidden_states = self.c2ws_hidden_states_layer2( - torch_F.silu(self.c2ws_hidden_states_layer1(c2ws_plucker_emb)) - ) - dit_cond_dict = dict(dit_cond_dict) - dit_cond_dict["c2ws_plucker_emb"] = c2ws_plucker_emb + c2ws_hidden_states + dit_cond_dict = dict(dit_cond_dict) + dit_cond_dict["c2ws_plucker_emb"] = c2ws_plucker_emb + c2ws_hidden_states + else: + if "c2ws_plucker_emb" in intermediate_tensors.tensors: + dit_cond_dict = {"c2ws_plucker_emb": intermediate_tensors["c2ws_plucker_emb"]} + else: + dit_cond_dict = None - # arguments kwargs = dict( e=e0, seq_lens=seq_lens, @@ -583,24 +673,34 @@ def forward( max_attention_size=max_attention_size, ) - for block_index, block in enumerate(self.blocks): + # Iterate this rank's blocks. kv_cache / crossattn_cache / *_end_index are + # sized to (end_layer - start_layer) — index locally. + for local_idx, block in enumerate(self.blocks[self.start_layer : self.end_layer]): kwargs.update( { - "kv_cache": kv_cache[block_index], - "crossattn_cache": crossattn_cache[block_index], - "local_end_index": local_end_index[block_index], - "global_end_index": global_end_index[block_index], - "current_start": current_start, + "kv_cache": kv_cache[local_idx], + "crossattn_cache": crossattn_cache[local_idx], + "local_end_index": local_end_index[local_idx], + "global_end_index": global_end_index[local_idx], + "current_starts": current_starts, } ) x = block(x, **kwargs) - # head + if not last_stage: + model_dtype = next(self.parameters()).dtype + it = { + "hidden_states": x.to(model_dtype), + "grid_sizes": grid_sizes, + "seq_lens": seq_lens, + } + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + it["c2ws_plucker_emb"] = dit_cond_dict["c2ws_plucker_emb"].to(model_dtype) + return IntermediateTensors(it) + + # head + unpatchify only on the last PP stage x = self.head(x, e) - - # unpatchify x = self.unpatchify(x, grid_sizes) - return [u.float() for u in x] def unpatchify(self, x, grid_sizes): From 10a5e4084615d276cc83391d8339748cd0275720 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Sat, 30 May 2026 01:13:19 +0200 Subject: [PATCH 52/53] update tests Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../wan2_2/test_wan22_step_execution.py | 505 ------------------ .../test_diffusion_micro_step_pipeline.py | 155 ++---- tests/diffusion/test_diffusion_scheduler.py | 16 +- 3 files changed, 67 insertions(+), 609 deletions(-) delete mode 100644 tests/diffusion/models/wan2_2/test_wan22_step_execution.py diff --git a/tests/diffusion/models/wan2_2/test_wan22_step_execution.py b/tests/diffusion/models/wan2_2/test_wan22_step_execution.py deleted file mode 100644 index 66409c982a6..00000000000 --- a/tests/diffusion/models/wan2_2/test_wan22_step_execution.py +++ /dev/null @@ -1,505 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -"""Unit tests for Wan2.2 SupportsStepExecution protocol implementation. - -Tests use lightweight mocks (no real model weights) to verify: -- Protocol compliance (class flag, method presence) -- Helper correctness (_resolve_generation_params, _select_model_for_timestep) -- Step execution decomposition matches monolithic forward() -- I2V mode latent input preparation -""" - -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import torch - -from vllm_omni.diffusion.worker.utils import DiffusionRequestState - -# --------------------------------------------------------------------------- -# Shared test utilities -# --------------------------------------------------------------------------- - - -def _make_sampling(**overrides): - """Create a mock sampling params object.""" - sampling = MagicMock() - sampling.height = overrides.get("height", 480) - sampling.width = overrides.get("width", 832) - sampling.num_frames = overrides.get("num_frames", 81) - sampling.num_inference_steps = overrides.get("num_inference_steps", 4) - sampling.guidance_scale = overrides.get("guidance_scale", 1.0) - sampling.guidance_scale_provided = overrides.get("guidance_scale_provided", True) - sampling.guidance_scale_2 = overrides.get("guidance_scale_2", None) - sampling.boundary_ratio = overrides.get("boundary_ratio", None) - sampling.num_outputs_per_prompt = overrides.get("num_outputs_per_prompt", 1) - sampling.max_sequence_length = overrides.get("max_sequence_length", 512) - sampling.seed = overrides.get("seed", 42) - sampling.generator = None - sampling.latents = None - return sampling - - -def _make_state(**overrides): - """Create a DiffusionRequestState with mock sampling.""" - return DiffusionRequestState( - req_id="test", - sampling=_make_sampling(**overrides), - prompts=overrides.get("prompts", ["test prompt"]), - ) - - -def _make_pipeline_stub(): - """Create a minimal Wan22Pipeline without __init__ (no model weights).""" - from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline - - pipeline = object.__new__(Wan22Pipeline) - torch.nn.Module.__init__(pipeline) - pipeline.vae_scale_factor_spatial = 8 - pipeline.vae_scale_factor_temporal = 4 - pipeline.boundary_ratio = 0.875 - pipeline.expand_timesteps = False - pipeline._guidance_scale = None - pipeline._guidance_scale_2 = None - pipeline._current_timestep = None - pipeline._pp_send_work_list = [] - - config = MagicMock() - config.patch_size = (1, 2, 2) - config.in_channels = 16 - config.out_channels = 16 - pipeline.transformer_config = config - - pipeline.device = torch.device("cpu") - - return pipeline - - -# --------------------------------------------------------------------------- -# 1. Protocol compliance -# --------------------------------------------------------------------------- - - -class TestWan22SupportsStepExecution: - """Verify the class-level protocol flag and method signatures.""" - - def test_class_var_is_true(self): - from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline - - assert hasattr(Wan22Pipeline, "supports_step_execution") - assert Wan22Pipeline.supports_step_execution is True - - def test_has_required_methods(self): - from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline - - for method_name in ("prepare_encode", "denoise_step", "step_scheduler", "post_decode"): - assert hasattr(Wan22Pipeline, method_name), f"Missing method: {method_name}" - -# --------------------------------------------------------------------------- -# 2. _resolve_generation_params helper -# --------------------------------------------------------------------------- - - -class TestResolveGenerationParams: - """Verify parameter resolution and alignment logic.""" - - def test_dimensions_aligned_to_mod_value(self): - pipeline = _make_pipeline_stub() - pipeline.transformer = MagicMock(dtype=torch.bfloat16) - pipeline.transformer_2 = None - - for height in range(400, 600, 50): - for width in range(700, 900, 50): - state = _make_state(height=height, width=width) - params = pipeline._resolve_generation_params(state) - # mod_value = 8 * 2 = 16 - assert params["height"] % 16 == 0 - assert params["width"] % 16 == 0 - - def test_num_frames_aligned_to_vae_temporal(self): - pipeline = _make_pipeline_stub() - pipeline.transformer = MagicMock(dtype=torch.bfloat16) - pipeline.transformer_2 = None - - for num_frames in range(1, 200): - state = _make_state(num_frames=num_frames) - params = pipeline._resolve_generation_params(state) - assert params["num_frames"] % pipeline.vae_scale_factor_temporal == 1 or params["num_frames"] == 1 - - - -# --------------------------------------------------------------------------- -# 3. _select_model_for_timestep helper -# --------------------------------------------------------------------------- - - -class TestSelectModelForTimestep: - def _make(self): - pipeline = _make_pipeline_stub() - pipeline.transformer = MagicMock(name="transformer") - pipeline.transformer_2 = MagicMock(name="transformer_2") - pipeline._guidance_scale = 4.0 - pipeline._guidance_scale_2 = 7.0 - return pipeline - - def test_high_noise_uses_transformer(self): - pipeline = self._make() - model, scale = pipeline._select_model_for_timestep(torch.tensor(800.0), boundary_timestep=500.0) - assert model is pipeline.transformer - assert scale == 4.0 - - def test_low_noise_uses_transformer_2(self): - pipeline = self._make() - model, scale = pipeline._select_model_for_timestep(torch.tensor(300.0), boundary_timestep=500.0) - assert model is pipeline.transformer_2 - assert scale == 7.0 - - def test_no_boundary_uses_transformer(self): - pipeline = self._make() - model, _ = pipeline._select_model_for_timestep(torch.tensor(300.0), boundary_timestep=None) - assert model is pipeline.transformer - - def test_fallback_when_transformer_none(self): - pipeline = self._make() - pipeline.transformer = None - model, _ = pipeline._select_model_for_timestep(torch.tensor(800.0), boundary_timestep=500.0) - assert model is pipeline.transformer_2 - - -# --------------------------------------------------------------------------- -# 4.1 Step execution decomposition matches forward() -# --------------------------------------------------------------------------- - - -class _FakeScheduler: - """Minimal scheduler that applies a deterministic update: latents -= 0.1 * noise_pred.""" - - def __init__(self, timesteps: torch.Tensor): - self.timesteps = timesteps - self._step_index = 0 - self.config = MagicMock() - self.config.num_train_timesteps = 1000 - - def set_timesteps(self, _num_steps, device=None): - pass # timesteps already set - - def step(self, noise_pred, _t, latents, return_dict=False): - self._step_index += 1 - return (latents - 0.1 * noise_pred,) - - -class _FakeTransformer(torch.nn.Module): - """Deterministic transformer: output = input * 0.5 (applied to hidden_states).""" - - def __init__(self): - super().__init__() - self._dummy = torch.nn.Parameter(torch.tensor(1.0)) - - @property - def dtype(self): - return torch.float32 - - def forward(self, hidden_states, timestep, encoder_hidden_states, intermediate_tensors=None, **kwargs): - # Simulate noise prediction: scale down hidden_states - noise_pred = hidden_states * 0.5 - return (noise_pred,) - - -def _patch_parallel_state(): - """Context manager that patches PP and CFG parallel state to single-GPU (world_size=1).""" - from contextlib import ExitStack - - stack = ExitStack() - stack.enter_context( - patch("vllm_omni.diffusion.distributed.pp_parallel.get_pipeline_parallel_world_size", return_value=1) - ) - stack.enter_context( - patch("vllm_omni.diffusion.distributed.cfg_parallel.get_classifier_free_guidance_world_size", return_value=1) - ) - return stack - - -class TestDenoiseStepCorrectness: - """Verify prepare_encode → denoise_step x N → step_scheduler x N → post_decode - produces the same latent trajectory as running the equivalent loop manually.""" - - def _make_pipeline(self): - pipeline = _make_pipeline_stub() - pipeline.transformer = _FakeTransformer() - pipeline.transformer_2 = None - pipeline.expand_timesteps = False - - timesteps = torch.tensor([900.0, 700.0, 500.0, 300.0]) - pipeline.scheduler = _FakeScheduler(timesteps) - - # Mock encode_prompt to return fixed embeddings - prompt_embeds = torch.randn(1, 10, 64) - pipeline.encode_prompt = MagicMock(return_value=(prompt_embeds, None)) - - # Mock VAE decode (identity) - vae = MagicMock() - vae.dtype = torch.float32 - vae.config.latents_mean = [0.0] * 16 - vae.config.latents_std = [1.0] * 16 - vae.config.z_dim = 16 - vae.decode = MagicMock(side_effect=lambda x, **kw: (x,)) - pipeline.vae = vae - - # Mock prepare_latents to return seeded noise - torch.manual_seed(123) - fixed_latents = torch.randn(1, 16, 21, 30, 52) - pipeline.prepare_latents = MagicMock(return_value=fixed_latents.clone()) - - return pipeline, fixed_latents.clone(), prompt_embeds - - def test_latent_trajectory_matches(self): - """Step-by-step execution produces the same final latents as a manual loop.""" - pipeline, initial_latents, prompt_embeds = self._make_pipeline() - timesteps = pipeline.scheduler.timesteps - - # ── Manual baseline loop ── - latents = initial_latents.clone() - for t in timesteps: - latent_input = latents.to(torch.float32) - noise_pred = pipeline.transformer( - hidden_states=latent_input, - timestep=t.expand(1), - encoder_hidden_states=prompt_embeds, - )[0] - latents = latents - 0.1 * noise_pred - baseline_latents = latents - - # ── Step execution path ── - state = _make_state(num_inference_steps=4) - state = pipeline.prepare_encode(state) - - with _patch_parallel_state(): - while not state.denoise_completed: - noise_pred = pipeline.denoise_step(state) - pipeline.step_scheduler(state, noise_pred) - - assert state.step_index == len(timesteps) - torch.testing.assert_close( - state.latents, - baseline_latents, - rtol=1e-5, - atol=1e-5, - msg="Step execution latents diverged from manual baseline", - ) - - def test_post_decode_calls_vae(self): - """post_decode invokes VAE decode and returns DiffusionOutput.""" - from vllm_omni.diffusion.data import DiffusionOutput - - pipeline, _, _ = self._make_pipeline() - - state = _make_state(num_inference_steps=4) - state = pipeline.prepare_encode(state) - with _patch_parallel_state(): - while not state.denoise_completed: - noise_pred = pipeline.denoise_step(state) - pipeline.step_scheduler(state, noise_pred) - - mock_platform = MagicMock() - mock_platform.is_available.return_value = False - with ( - patch.object(type(pipeline), "sync_pp_send"), - patch("vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2.current_omni_platform", mock_platform), - ): - result = pipeline.post_decode(state) - - assert isinstance(result, DiffusionOutput) - assert result.output is not None - pipeline.vae.decode.assert_called_once() - - def test_step_count_matches_timesteps(self): - """Exactly len(timesteps) steps are executed.""" - pipeline, _, _ = self._make_pipeline() - state = _make_state(num_inference_steps=4) - state = pipeline.prepare_encode(state) - - step_count = 0 - with _patch_parallel_state(): - while not state.denoise_completed: - noise_pred = pipeline.denoise_step(state) - pipeline.step_scheduler(state, noise_pred) - step_count += 1 - - assert step_count == 4 - - def test_scheduler_is_deepcopied(self): - """Each request gets its own scheduler copy, not a shared reference.""" - pipeline, _, _ = self._make_pipeline() - original_scheduler = pipeline.scheduler - state = _make_state(num_inference_steps=4) - state = pipeline.prepare_encode(state) - assert state.scheduler is not original_scheduler - - -# --------------------------------------------------------------------------- -# 4.3 _prepare_latent_input I2V mode -# --------------------------------------------------------------------------- - - -class TestPrepareLatentInputI2V: - """Verify I2V mode latent blending and timestep expansion.""" - - def _make_pipeline_and_state(self): - pipeline = _make_pipeline_stub() - pipeline.transformer = MagicMock(dtype=torch.float32) - pipeline.transformer_2 = None - - # Latents: [B=1, C=16, T=5, H=8, W=10] - latents = torch.randn(1, 16, 5, 8, 10) - # Condition: same shape, different values - latent_condition = torch.randn(1, 16, 5, 8, 10) - # Mask: 0 for first frame, 1 for rest - first_frame_mask = torch.ones(1, 1, 5, 8, 10) - first_frame_mask[:, :, 0] = 0 - - state = _make_state() - state.latents = latents - state.extra["expand_timesteps"] = True - state.extra["latent_condition"] = latent_condition - state.extra["first_frame_mask"] = first_frame_mask - - return pipeline, state, latents, latent_condition, first_frame_mask - - def test_i2v_blends_condition_with_latents(self): - """First frame uses condition, remaining frames use latents.""" - pipeline, state, latents, condition, mask = self._make_pipeline_and_state() - t = torch.tensor(500.0) - - latent_input, _ = pipeline._prepare_latent_input(state, t, torch.float32) - - # First frame (mask=0): should be condition - expected_first = condition[:, :, 0] - torch.testing.assert_close(latent_input[:, :, 0], expected_first, rtol=1e-5, atol=1e-5) - - # Remaining frames (mask=1): should be latents - expected_rest = latents[:, :, 1:] - torch.testing.assert_close(latent_input[:, :, 1:], expected_rest, rtol=1e-5, atol=1e-5) - - def test_i2v_timestep_expansion(self): - """Timestep is expanded per-patch: 0 for condition patches, t for noise patches.""" - pipeline, state, _, _, mask = self._make_pipeline_and_state() - t = torch.tensor(500.0) - - _, timestep_tensor = pipeline._prepare_latent_input(state, t, torch.float32) - - # patch_size = (1, 2, 2) → patch dims: T=5, H=4, W=5 - # Sequence length = 5 * 4 * 5 = 100 - assert timestep_tensor.shape[0] == 1 # batch - assert timestep_tensor.shape[1] == 5 * 4 * 5 # flattened patch sequence - - # First frame patches (first 4*5=20) should have timestep 0 - first_frame_patches = timestep_tensor[0, : 4 * 5] - assert (first_frame_patches == 0).all(), "First frame patches should have timestep 0" - - # Remaining patches should have timestep = 500 - rest_patches = timestep_tensor[0, 4 * 5 :] - assert (rest_patches == 500.0).all(), "Non-first-frame patches should have timestep t" - - def test_t2v_mode_passthrough(self): - """T2V mode: latents pass through unchanged, timestep is broadcast.""" - pipeline = _make_pipeline_stub() - pipeline.transformer = MagicMock(dtype=torch.float32) - pipeline.transformer_2 = None - - latents = torch.randn(2, 16, 5, 8, 10) - state = _make_state() - state.latents = latents - # No I2V extras → T2V mode - - t = torch.tensor(500.0) - latent_input, timestep_tensor = pipeline._prepare_latent_input(state, t, torch.float32) - - torch.testing.assert_close(latent_input, latents) - assert timestep_tensor.shape == (2,) - assert (timestep_tensor == 500.0).all() - - -# --------------------------------------------------------------------------- -# 5. denoise_step vs Wan22Pipeline.forward() -# --------------------------------------------------------------------------- - - -def _make_req(num_inference_steps: int = 4): - """Create a minimal T2V OmniDiffusionRequest for forward() comparison tests.""" - from vllm_omni.diffusion.request import OmniDiffusionRequest - from vllm_omni.inputs.data import OmniDiffusionSamplingParams - - sampling = OmniDiffusionSamplingParams( - height=480, - width=832, - num_frames=81, - num_inference_steps=num_inference_steps, - guidance_scale=1.0, - max_sequence_length=512, - num_outputs_per_prompt=1, - seed=42, - ) - return OmniDiffusionRequest(prompts=["test prompt"], sampling_params=sampling) - - -class TestDenoiseStepMatchesForward: - """Verify denoise_step matches Wan22Pipeline.forward().""" - - def _make_pipeline(self): - pipeline = _make_pipeline_stub() - pipeline.transformer = _FakeTransformer() - pipeline.transformer_2 = None - pipeline.expand_timesteps = False - - timesteps = torch.tensor([900.0, 700.0, 500.0, 300.0]) - pipeline.scheduler = _FakeScheduler(timesteps) - - torch.manual_seed(17) - prompt_embeds = torch.randn(1, 10, 64) - pipeline.encode_prompt = MagicMock(return_value=(prompt_embeds, None)) - - vae = MagicMock() - vae.dtype = torch.float32 - vae.config.latents_mean = [0.0] * 16 - vae.config.latents_std = [1.0] * 16 - vae.config.z_dim = 16 - pipeline.vae = vae - - torch.manual_seed(13) - fixed_latents = torch.randn(1, 16, 21, 30, 52) - pipeline.prepare_latents = MagicMock(return_value=fixed_latents.clone()) - - return pipeline - - def _run_forward(self, pipeline): - """Run pipeline.forward() in T2V mode and return the final latents.""" - req = _make_req(num_inference_steps=4) - mock_platform = MagicMock() - mock_platform.is_available.return_value = False - with ( - patch("vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2.current_omni_platform", mock_platform), - _patch_parallel_state(), - ): - result = pipeline.forward(req, output_type="latent") - return result.output - - def test_denoise_step_matches_forward(self): - """Full step-execution loop (prepare_encode → denoise_step x N → step_scheduler x N) - produces the same final latents as Wan22Pipeline.forward().""" - pipeline = self._make_pipeline() - - # Reference: monolithic forward() - fwd_latents = self._run_forward(pipeline) - - # Step-execution path using the same pipeline (same latent / prompt_embed mocks) - state = _make_state(num_inference_steps=4) - state = pipeline.prepare_encode(state) - with _patch_parallel_state(): - while not state.denoise_completed: - noise_pred = pipeline.denoise_step(state) - pipeline.step_scheduler(state, noise_pred) - - torch.testing.assert_close(state.latents, fwd_latents, rtol=1e-5, atol=1e-5) - diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py index 718c02b7f49..ebaaa6afa1e 100644 --- a/tests/diffusion/test_diffusion_micro_step_pipeline.py +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -1,4 +1,4 @@ -# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for micro-step level diffusion execution across runner / worker / executor / engine.""" @@ -15,8 +15,8 @@ from vllm_omni.diffusion.sched.interface import ( CachedRequestData, DiffusionSchedulerOutput, + Layout, NewRequestData, - Rank0Layout, RankTask, ) from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner @@ -40,6 +40,8 @@ def __init__(self, rank_in_group: int = 0, world_size: int = 1): self.is_last_rank = rank_in_group == world_size - 1 self.prev_rank = (rank_in_group - 1) % world_size self.next_rank = (rank_in_group + 1) % world_size + self.group_prev_rank = (rank_in_group - 1) % world_size + self.group_next_rank = (rank_in_group + 1) % world_size self.reset_calls = 0 def reset_buffer(self) -> None: @@ -59,7 +61,6 @@ def __init__(self, num_steps: int = 1): self.decode_calls = 0 self.prefetch_calls = 0 self.encode_calls = 0 - self.is_buffer_setup = False def prepare_encode(self, state, **kwargs): del kwargs @@ -72,12 +73,11 @@ def prepare_encode(self, state, **kwargs): def encode_chunk_inputs(self, state, new_idxs): del state self.encode_calls += 1 - return [torch.zeros((1, 1, 1, 1, 1)) for _ in new_idxs] + return torch.zeros((len(new_idxs), 1, 1, 1, 1)) def set_pp_recv_dict_buffers(self, state, **kwargs): del state, kwargs self.set_buffer_calls += 1 - self.is_buffer_setup = True def denoise_step(self, state, **kwargs): del state, kwargs @@ -97,7 +97,7 @@ def post_decode(self, state, **kwargs): b = state.latents.shape[0] if state.latents.ndim > 0 else 1 return DiffusionOutput(output=torch.ones(b, 1, 1, 1, 1, dtype=torch.float32)) - def prefetch_its(self, state, **kwargs): + def prefetch_tensors(self, state, **kwargs): del state, kwargs self.prefetch_calls += 1 @@ -123,7 +123,7 @@ def _make_micro_request( req_id: str = "req-1", *, num_inference_steps: int = 1, - num_frames: int = 1, + num_chunks: int = 1, chunk_frames: int = 1, ): return SimpleNamespace( @@ -135,7 +135,8 @@ def _make_micro_request( generator_device=None, num_inference_steps=num_inference_steps, chunk_frames=chunk_frames, - num_frames=num_frames, + num_chunks=num_chunks, + num_frames=num_chunks * chunk_frames, slo_fps=None, slo_max_batch=8, lora_request=None, @@ -162,12 +163,12 @@ def _make_runner(pp_size: int = 1, num_steps: int = 1): def _make_layout( *, - n_circulating: int = 0, + circulating_idxs: list[int] | None = None, finished_idxs: list[int] | None = None, new_idxs: list[int] | None = None, -) -> Rank0Layout: - return Rank0Layout( - n_circulating=n_circulating, +) -> Layout: + return Layout( + circulating_idxs=circulating_idxs or [], finished_idxs=finished_idxs or [], new_idxs=new_idxs or [], ) @@ -178,14 +179,16 @@ def _make_micro_scheduler_output( req=None, sched_req_id: str = "req-1", step_id: int = 0, - assignment=None, + chunk_indices: list[int] | None = None, is_new: bool = True, finished_req_ids=None, - rank0_layout: Rank0Layout | None = None, + layout: Layout | None = None, ): - if assignment is None: - assignment = [RankTask(sched_req_id=sched_req_id, chunk_indices=[0])] - rank0_layouts = {sched_req_id: rank0_layout} if rank0_layout is not None else None + if layout is None: + layout = _make_layout() + if chunk_indices is None: + chunk_indices = [0] + assignment = [RankTask(sched_req_id=sched_req_id, chunk_indices=chunk_indices, layout=layout)] if is_new and req is not None: new_reqs = [NewRequestData(sched_req_id=sched_req_id, req=req)] cached_reqs = CachedRequestData.make_empty() @@ -200,7 +203,6 @@ def _make_micro_scheduler_output( num_running_reqs=1, num_waiting_reqs=0, assignment=assignment, - rank0_layouts=rank0_layouts, ) @@ -220,14 +222,14 @@ class TestRunner: def test_completes_single_chunk_request(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=1) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) out0 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(new_idxs=[0]), + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), ), ) assert out0.req_id == "req-1" @@ -238,8 +240,8 @@ def test_completes_single_chunk_request(self, monkeypatch): runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, - assignment=[None], is_new=False, - rank0_layout=_make_layout(finished_idxs=[0]), + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[0]), ), ) assert out1.finished is True @@ -255,23 +257,23 @@ def test_completes_single_chunk_request(self, monkeypatch): def test_completes_multi_chunk_request(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=2) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(new_idxs=[0]), + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), ), ) out1 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[1])], + chunk_indices=[1], is_new=False, - rank0_layout=_make_layout(finished_idxs=[0], new_idxs=[1]), + layout=_make_layout(finished_idxs=[0], new_idxs=[1]), ), ) assert out1.finished is False @@ -280,8 +282,8 @@ def test_completes_multi_chunk_request(self, monkeypatch): runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=2, - assignment=[None], is_new=False, - rank0_layout=_make_layout(finished_idxs=[1]), + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[1]), ), ) assert out2.finished is True @@ -295,14 +297,14 @@ def test_completes_multi_chunk_request(self, monkeypatch): def test_re_admits_circulating_chunk(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=2) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=2, num_frames=1) + req = _make_micro_request(num_inference_steps=2, num_chunks=1) out0 = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(new_idxs=[0]), + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), ), ) assert out0.finished is False @@ -311,9 +313,9 @@ def test_re_admits_circulating_chunk(self, monkeypatch): runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], + chunk_indices=[0], is_new=False, - rank0_layout=_make_layout(n_circulating=1), + layout=_make_layout(circulating_idxs=[0]), ), ) assert out1.finished is False @@ -323,8 +325,8 @@ def test_re_admits_circulating_chunk(self, monkeypatch): runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=2, - assignment=[None], is_new=False, - rank0_layout=_make_layout(finished_idxs=[0]), + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[0]), ), ) assert out2.finished is True @@ -333,14 +335,14 @@ def test_re_admits_circulating_chunk(self, monkeypatch): def test_empty_layout_is_a_no_op(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=1) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(new_idxs=[0]), + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), ), ) denoise_calls_before = runner.pipeline.denoise_calls @@ -349,8 +351,8 @@ def test_empty_layout_is_a_no_op(self, monkeypatch): runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, - assignment=[None], is_new=False, - rank0_layout=_make_layout(), + chunk_indices=[], is_new=False, + layout=_make_layout(), ), ) assert out.req_id == "req-1" @@ -362,14 +364,14 @@ def test_interrupt_marks_request_as_aborted(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) runner.pipeline = _InterruptingMicroStepPipeline(num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=1) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(new_idxs=[0]), + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), ), ) assert out.req_id == "req-1" @@ -402,14 +404,14 @@ def test_rejects_cache_backend(self): def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=1) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0])], - rank0_layout=_make_layout(new_idxs=[0]), + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), ), ) assert out.micro_step_wall_ns is not None @@ -418,14 +420,14 @@ def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): def test_batch_two_runs_one_fused_forward(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=2) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], - rank0_layout=_make_layout(new_idxs=[0, 1]), + chunk_indices=[0, 1], + layout=_make_layout(new_idxs=[0, 1]), ), ) @@ -438,22 +440,22 @@ def test_batch_two_runs_one_fused_forward(self, monkeypatch): def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): runner = _make_runner(pp_size=1, num_steps=1) _patch_runtime(monkeypatch, runner) - req = _make_micro_request(num_inference_steps=1, num_frames=2) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( req=req, step_id=0, - assignment=[RankTask(sched_req_id="req-1", chunk_indices=[0, 1])], - rank0_layout=_make_layout(new_idxs=[0, 1]), + chunk_indices=[0, 1], + layout=_make_layout(new_idxs=[0, 1]), ), ) out = DiffusionModelRunner.execute_micro_step( runner, _make_micro_scheduler_output( sched_req_id="req-1", step_id=1, - assignment=[None], is_new=False, - rank0_layout=_make_layout(finished_idxs=[0, 1]), + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[0, 1]), ), ) @@ -526,45 +528,6 @@ def test_rejects_lora_requests(self): # Executor # --------------------------------------------------------------------------- - -class TestSupportedPipelines: - """Micro-step protocol membership checks.""" - - def test_stub_pipeline_satisfies_protocol(self): - from vllm_omni.diffusion.models.interface import ( - SupportsMicroStepExecution, - SupportsStepExecution, - supports_micro_step_execution, - supports_step_execution, - ) - - pipeline = _MicroStepPipeline() - assert isinstance(pipeline, SupportsMicroStepExecution) is True - assert supports_micro_step_execution(pipeline) is True - # Micro-step protocol extends step protocol. - assert isinstance(pipeline, SupportsStepExecution) is True - assert supports_step_execution(pipeline) is True - - def test_wan22_supports_micro_step_execution(self): - from vllm_omni.diffusion.models.interface import ( - SupportsMicroStepExecution, - supports_micro_step_execution, - ) - - try: - from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline - except (RuntimeError, ImportError) as exc: - pytest.skip(f"Wan22Pipeline import not available on this platform: {exc}") - - # Avoid loading weights; protocol membership is a class-contract check. - pipeline = object.__new__(Wan22Pipeline) - - assert pipeline.supports_step_execution is True - assert pipeline.supports_micro_step_execution is True - assert supports_micro_step_execution(pipeline) is True - assert isinstance(pipeline, SupportsMicroStepExecution) is True - - class TestExecutor: """MultiprocDiffusionExecutor.execute_micro_step collects rank-0's reply.""" @@ -590,4 +553,4 @@ def test_rejects_unexpected_reply_type(self, mocker: MockerFixture): sched_output = _make_micro_scheduler_output(req=_make_micro_request()) with pytest.raises(RuntimeError, match="Unexpected response type"): - MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) \ No newline at end of file + MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index 554ba92bf4d..52dbe7bb560 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -975,7 +975,6 @@ def _make_stream_request( num_chunks: int = 1, chunk_frames: int = 1, ) -> OmniDiffusionRequest: - """``num_chunks`` is a test-helper shorthand for ``num_frames / chunk_frames``.""" num_frames = num_chunks * chunk_frames video = [torch.zeros(3, 8, 8) for _ in range(num_frames)] return OmniDiffusionRequest( @@ -986,6 +985,7 @@ def _make_stream_request( sampling_params=OmniDiffusionSamplingParams( num_inference_steps=num_inference_steps, chunk_frames=chunk_frames, + num_chunks=num_chunks, num_frames=num_frames, ), request_ids=[req_id], @@ -1016,18 +1016,19 @@ def _ranks(sched_output) -> list[tuple[str, list[int]] | None]: if sched_output.assignment is None: return [] return [ - (t.sched_req_id, list(t.chunk_indices)) if t is not None else None + (t.sched_req_id, list(t.chunk_indices)) if t.chunk_indices else None for t in sched_output.assignment ] def _layout(sched_output, sched_req_id: str) -> tuple[int, int, int] | None: - if sched_output.rank0_layouts is None: - return None - layout = sched_output.rank0_layouts.get(sched_req_id) - if layout is None: + if sched_output.assignment is None: return None - return (len(layout.finished_idxs), layout.n_circulating, len(layout.new_idxs)) + for task in sched_output.assignment: + if task.sched_req_id == sched_req_id: + layout = task.layout + return (len(layout.finished_idxs), len(layout.circulating_idxs), len(layout.new_idxs)) + return None class TestStreamBatchScheduler: @@ -1212,7 +1213,6 @@ def test_schedule_with_no_requests_emits_no_assignment(self) -> None: scheduler = self._make_scheduler(pp_size=2) out = scheduler.schedule() assert out.assignment is None - assert out.rank0_layouts is None assert out.scheduled_req_ids == [] def test_fifo_two_requests(self) -> None: From 7762df5f34a19da44835767ecef2f0f266cf4fc3 Mon Sep 17 00:00:00 2001 From: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> Date: Sat, 30 May 2026 01:59:59 +0200 Subject: [PATCH 53/53] add docs Signed-off-by: Mahdi Nasser <94046147+mnasser02@users.noreply.github.com> --- .../model/adding_diffusion_model.md | 35 +++ .../feature/diffusion_micro_step_execution.md | 239 ++++++++++++++++++ .../diffusion/micro_step_execution.md | 105 ++++++++ 3 files changed, 379 insertions(+) create mode 100644 docs/design/feature/diffusion_micro_step_execution.md create mode 100644 docs/user_guide/diffusion/micro_step_execution.md diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md index 35ff6dae202..63575168f6d 100644 --- a/docs/contributing/model/adding_diffusion_model.md +++ b/docs/contributing/model/adding_diffusion_model.md @@ -776,6 +776,41 @@ turning them on. For Qwen-Image-style serving examples, document `--step-execution` as the feature gate and `--max-num-seqs N` as the companion batching knob. +### Micro-Step Execution + +See detailed design guide: [How to add micro-step execution support](../../design/feature/diffusion_micro_step_execution.md) + +Use this only when your pipeline is built for *streaming chunked* output +(e.g. video chunks) and you want stream batch — +at each tick every PP rank denoises a different chunk at a different +timestep, then chunks shift one rank downstream. + +Micro-step is a superset of step execution. On top of the four +step-execution methods, the pipeline must also implement: + +1. `set_pp_recv_dict_buffers()` to pre-register PP recv buffers the request will use. +2. `encode_chunk_inputs()` to build per-chunk initial latents and any + per-chunk conditioning. +3. `prefetch_tensors()` to pre-post the next-step recv (latents on the + first rank, intermediate tensors elsewhere) so it overlaps with compute. + +`denoise_step()` and `step_scheduler()` are also redesigned to operate on a +row-batched mix of chunks at different denoising step indices. +`post_decode()` becomes incremental — it runs on rank 0 every tick that has +freshly finished chunks, not just once at the end. + +Prerequisites: + +- The transformer is PP-partitioned (`make_layers`, `PPMissingLayer`) — see + [Pipeline Parallel](../../design/feature/pipeline_parallel.md). +- The pipeline inherits `PipelineParallelMixin` and `CFGParallelMixin`r. +- The pipeline declares `supports_micro_step_execution: ClassVar[bool] = + True`. +- Each request sets `chunk_frames`, `num_chunks`, and + `num_inference_steps` in `OmniDiffusionSamplingParams`. + +Reference implementation: `LingbotWorldFastPipeline` + ### Cache Acceleration #### TeaCache diff --git a/docs/design/feature/diffusion_micro_step_execution.md b/docs/design/feature/diffusion_micro_step_execution.md new file mode 100644 index 00000000000..f04a641b6c1 --- /dev/null +++ b/docs/design/feature/diffusion_micro_step_execution.md @@ -0,0 +1,239 @@ +# Adding Micro-Step Execution Support for Diffusion Pipelines + +This guide documents vLLM-Omni's micro-step diffusion contract for model +authors and contributors implementing `stream_batch=True` support for a +diffusion pipeline. + +For end-user enablement, supported models, and current limitations, see +[Micro-Step Execution](../../user_guide/diffusion/micro_step_execution.md). + +This document describes the micro-step execution contract only. It builds on +the request-/step-level contract in +[Step Execution](diffusion_step_execution.md) and the PP partitioning rules in +[Pipeline Parallel](pipeline_parallel.md). Read those first. + +## Current Support Scope + +`stream_batch` is **not** a generic diffusion toggle. It only works for +pipelines that implement the segmented stateful contract in +[`vllm_omni/diffusion/models/interface.py`](gh-file:vllm_omni/diffusion/models/interface.py) +as `SupportsMicroStepExecution`. + +This page is intentionally author-facing. Treat runtime enablement +(`stream_batch=True` when constructing `Omni`) as an opt-in user knob layered +on top of the implementation contract below. + +Current in-tree support: + +| Pipeline | Example models | Micro-step execution | +|----------|----------------|----------------------| +| `LingbotWorldFastPipeline` | `lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast` | Yes | +| All other diffusion pipelines | — | No | + +Current engine/runtime limitations: + +- `max_num_seqs == 1` — exactly one in-flight request per engine. +- `cache_backend` is not supported. +- Unsupported pipelines fail early during model loading instead of + failing on the first request. + +## Execution Contract + +Micro-step mode is driven by seven pipeline methods plus the shared mutable +request state object: + +- `prepare_encode(state)`: one-time request preparation (inherited from + step execution). +- `set_pp_recv_dict_buffers(state)`: register PP recv buffers and schema + cache for every `(name, segment_idx, batch_size)` this request will use. +- `encode_chunk_inputs(state, new_idxs)`: per-chunk latent initialization. + Returns a tensor stacked along dim 0 over `new_idxs`; the runner stitches + it onto `state.latents` and into each chunk's `chunk.latents`. +- `denoise_step(state, batch_size)`: row-batched noise prediction over + `batch_size` chunks at different denoising step indices. +- `step_scheduler(state, noise_pred, per_request_scheduler, batch_size)`: + per-row scheduler update on the last rank; sends the updated latents + back to rank 0 via the ring (rank 0 picks them up via `prefetch_tensors`, + not inside this call). Every rank increments `state.step_index`. +- `prefetch_tensors(state, batch_size)`: pre-post the next-step recv on the + comms stream so it overlaps with this rank's compute. +- `post_decode(state)`: incremental decode of one or more freshly-finished + chunks (called whenever the previous tick produced `finished_idxs`). + +The state lives in +[`vllm_omni/diffusion/worker/utils.py`](gh-file:vllm_omni/diffusion/worker/utils.py) +as `DiffusionRequestState` plus per-chunk `ChunkState` entries under +`state.extra["chunks"]`. + +The worker-side micro-step loop lives in +[`vllm_omni/diffusion/worker/diffusion_model_runner.py`](gh-file:vllm_omni/diffusion/worker/diffusion_model_runner.py) +under `execute_micro_step`: + +1. `prepare_encode()` runs once for a new request. +2. `set_pp_recv_dict_buffers()` runs immediately after, before any P2P. +3. Each micro-step: + - Rank 0 calls `post_decode()` for any chunks the previous tick + reported as finished, and accumulates the decoded output. + - Rank 0 and rank N-1 call `encode_chunk_inputs()` for their layout's + `new_idxs`. On rank 0 those are chunks freshly admitted this tick; + on rank N-1 they are the same chunks arriving at the back of the + ring N-1 ticks later — both ranks must produce identical initial + noise so the scheduler step on the last rank starts from the same + latents the first rank started from. + - All ranks with `chunk_indices` non-empty call `denoise_step()` then + `step_scheduler()`. The last rank also snapshots + `chunk.latents = state.latents[i:i+1]` per row so the next time those + chunks reach the last rank they can resume. + - `prefetch_tensors()` runs sized to the previous rank's load so the + next recv is posted before the next micro-step's compute. + +## Per-Rank Chunk Layout + +`StreamBatchScheduler` builds one `RankTask` per PP rank per micro-step: + +| Field | Meaning | +|-------|---------| +| `chunk_indices` | Chunks this rank will denoise this tick | +| `layout.circulating_idxs` | Chunks that drained from rank N-1 last tick still needing more steps, looping back to rank 0 | +| `layout.finished_idxs` | Chunks that completed `num_inference_steps` at rank N-1 last tick, ready for decode | +| `layout.new_idxs` | Chunks freshly admitted at rank 0 (up to SLO `B_target`, capped by `num_chunks - admitted_so_far`) | + +Layouts travel with their chunks: at rank R the current layout was built at +rank 0 R ticks ago, so `new_idxs` at rank R names the chunks admitted R ticks +ago and now reaching this rank for the first time on their first lap. + +The runner uses rank 0's layout to assemble `state.latents` along dim 0 from +the circulating snapshot + fresh-noise rows for `new_idxs`, and to +incrementally decode `finished_idxs`. The last rank does the same assembly +when it owns `new_idxs` so step_scheduler has the matching initial latents. + + +## Recommended Split + +| Request-level phase | Micro-step method | What belongs there | +|---------------------|-------------------|--------------------| +| Input validation, prompt encoding, timestep prep, per-request scheduler | `prepare_encode()` | Anything that should happen once per request | +| PP recv buffer / schema registration for every `(name, segment_idx, B)` | `set_pp_recv_dict_buffers()` | Iterate `1..slo_max_batch * num_inference_steps` | +| Per-chunk latent init (fresh randn, V2V VAE encode, anchor latents, plucker, etc.) | `encode_chunk_inputs()` | Build per-chunk initial latents (RNG must match across rank 0 and rank N-1); write per-chunk conditioning into `state.extra["chunks"][idx].extra` only on the rank that will read it | +| Row-batched transformer forward | `denoise_step()` | Row-aware kwargs, `predict_noise_maybe_with_cfg(buf_idx=step_index % 2, batch_size=B, preposted_its=...)` | +| Per-row `scheduler.step` and `state.step_index += 1` | `step_scheduler()` | `scheduler_step_maybe_with_cfg(..., receive_latents=False, batch_size=B)` | +| Pre-post next-step recv | `prefetch_tensors()` | `prefetch_tensors_maybe_with_cfg(buf_idx=step_index % 2, batch_size=B)` and stash on state | +| Per-chunk VAE decode | `post_decode()` | Decode the leading `len(finished_idxs)` rows of `state.latents` (runner narrows the slice for you) | + +Keep the micro-step path reusing the same helpers as the request-level path +whenever possible. Reimplementing the denoise loop from scratch is the easiest +way to introduce behavioral drift. + +## PP Communication + +`PipelineGroupCoordinator` provides three primitives the micro-step path +leans on: + +| Primitive | Purpose | +|-----------|---------| +| `set_recv_dict_buffer(name, segment_idx, template_dict, batch_size)` | Register the schema and pre-allocate a double-buffer pair (slots 0 and 1) for one logical channel | +| `pipeline_isend_tensor_dict(...)` | Async send of an arbitrary dict to the next rank | +| `pipeline_irecv_tensor_dict(..., buf_idx)` | Posts async recv into the pre-allocated buffer slot; returns an `AsyncIntermediateTensors`/`AsyncLatents` that defers `.wait()` until consumed | + +[`PipelineParallelMixin`](gh-file:vllm_omni/diffusion/distributed/pipeline_parallel.py) +already wraps these in `predict_noise_maybe_with_cfg`, +`scheduler_step_maybe_with_cfg`, and `prefetch_tensors_maybe_with_cfg`. +Pipelines should call those, not the coordinator primitives directly. + +### Why schemas must be pre-registered + +The first call to `pipeline_isend_tensor_dict` on a previously unseen +`(name, segment_idx, batch_size)` triggers a blocking schema exchange. +`set_pp_recv_dict_buffers` populates the cache identically on all ranks so the +schema path is never entered during the data loop. + +Enumerate every `B` the request can hit. For SLO-driven admission the upper +bound is `slo_max_batch * num_inference_steps`. + +### Double buffering + +Caller picks `buf_idx = state.step_index % 2` consistently across +`denoise_step`, `step_scheduler`, and `prefetch_tensors` on the same +micro-step. Alternating slots keeps the previous result readable while the +next recv is in flight. + +## Row-Batched Computation + +`state.batched_timesteps` is a 1-D tensor of length `B`; row `i` carries +`state.timesteps[chunks[i].step_index]`. Inside `denoise_step` and +`step_scheduler`, treat the leading dim as a mix of independent chunks at +*different* progress points. + +## Lingbot Reference + +[`pipeline_lingbot_world_fast.py`](gh-file:vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py) +is the reference for the *self-forcing* pattern and is split +correctly for the current contract: + +- `prepare_encode()` wraps `self.scheduler` in `LingbotFlowScheduler` so the + last denoise step returns the cached x0 and intermediate steps re-noise to + the next `t`. Two `torch.Generator`s are created on every rank: `seed_g` + for chunk noise (consumed identically on every rank that calls + `encode_chunk_inputs`) and `seed_g_addnoise` for the re-noise step + (consumed only on the last rank). +- `set_pp_recv_dict_buffers()` registers `("latents", -1, B)` and + `("intermediate", 0, B)` templates for every B in + `1..slo_max_batch * num_inference_steps`. +- `encode_chunk_inputs()` builds per-chunk noise on every rank using + `seed_g`. Only rank 0 (first stage) additionally stream-encodes per-chunk + `y` (with anchor-frame handling on the first chunk) and computes Plucker + embeddings, stashing both into `state.extra["chunks"][idx].extra` for + `denoise_step` to read. +- `denoise_step()` slices per-row `current_starts`, `y`, and + `c2ws_plucker_emb` from `state.extra["chunks"][idx]` keyed by the current + micro-step's `chunk_idxs`, then calls + `predict_noise_maybe_with_cfg(...)`. The per-chunk conditioning is only + read on the first stage; the last stage receives processed hidden states + via intermediate tensors. +- `step_scheduler()` rides the shared `scheduler_step_maybe_with_cfg(..., + receive_latents=False, batch_size=B, generator=state.extra["seed_g_addnoise"])` + and bumps `state.step_index`. +- `prefetch_tensors()` calls + `prefetch_tensors_maybe_with_cfg(buf_idx=state.step_index % 2, + batch_size=B)` and stashes results into `state.latents` (rank 0) or + `state.extra["preposted_its"]` (others). + +That decomposition is the target pattern for future micro-step models. + +## Rules For New Pipelines + +- Inherit `PipelineParallelMixin` and `CFGParallelMixin`. +- Declare `supports_micro_step_execution: ClassVar[bool] = True` on the + pipeline class. +- Pre-populate every `(name, segment_idx, batch_size)` in + `set_pp_recv_dict_buffers`. Skipping a `B` triggers the blocking schema + path and risks PP deadlock. +- Use `state.extra["chunks"][idx]` (a `ChunkState`) for per-chunk persistent + state: latents snapshot at the last rank, per-chunk scheduler, conditioning + slices. +- Do not put request-scoped scheduler state on `self.scheduler`. Deep-copy + it into `state.scheduler` during `prepare_encode` (the runner then + deep-copies that into each new `ChunkState.scheduler` on admission). +- Do not mutate `state.step_index` inside `denoise_step`. Only + `step_scheduler` should advance it. +- Use `buf_idx = state.step_index % 2` across `denoise_step`, + `step_scheduler`, and `prefetch_tensors`. + +## Validation Checklist + +Before marking a pipeline `supports_micro_step_execution = True`, verify: + +- `pipeline_parallel_size=2` and `pipeline_parallel_size>=3` both complete. +- `B=1` and `B>1` outputs match — verifies per-row scheduler / cache / + conditioning slicing. +- CFG-parallel and non-CFG paths both work if the pipeline supports them. + +## Related Files + +- Contract: [`vllm_omni/diffusion/models/interface.py`](gh-file:vllm_omni/diffusion/models/interface.py) +- State: [`vllm_omni/diffusion/worker/utils.py`](gh-file:vllm_omni/diffusion/worker/utils.py) +- Runner loop: [`vllm_omni/diffusion/worker/diffusion_model_runner.py`](gh-file:vllm_omni/diffusion/worker/diffusion_model_runner.py) +- Scheduler: [`vllm_omni/diffusion/sched/stream_batch_scheduler.py`](gh-file:vllm_omni/diffusion/sched/stream_batch_scheduler.py) +- PP coordinator: [`vllm_omni/diffusion/distributed/group_coordinator.py`](gh-file:vllm_omni/diffusion/distributed/group_coordinator.py) +- PP mixin: [`vllm_omni/diffusion/distributed/pipeline_parallel.py`](gh-file:vllm_omni/diffusion/distributed/pipeline_parallel.py) +- Reference pipeline: [`vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py`](gh-file:vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py) diff --git a/docs/user_guide/diffusion/micro_step_execution.md b/docs/user_guide/diffusion/micro_step_execution.md new file mode 100644 index 00000000000..273a3964c63 --- /dev/null +++ b/docs/user_guide/diffusion/micro_step_execution.md @@ -0,0 +1,105 @@ +# Micro-Step Execution + +Micro-step execution is an opt-in diffusion execution mode enabled with +`stream_batch=True` when constructing `Omni`. It runs *temporal pipeline +parallelism* on streaming chunked diffusion: at each tick every PP rank +denoises a different chunk at a different timestep, then chunks shift one +rank downstream. One tick = one micro-step. + +It is not a generic diffusion toggle for every pipeline. Only pipelines that +implement the micro-step contract support it today. + +## Quick Start + +```python +import PIL.Image +import numpy as np + +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", + model_class_name="LingbotWorldFastPipeline", + stream_batch=True, + parallel_config=DiffusionParallelConfig(pipeline_parallel_size=4), + enforce_eager=True, +) + +outputs = omni.generate( + { + "prompt": "A sweeping cinematic journey along the Great Wall of China", + "multi_modal_data": { + "image": PIL.Image.open("anchor.jpg"), + "camera": { + "poses": np.load("poses.npy"), + "intrinsics": np.load("intrinsics.npy"), + }, + }, + }, + OmniDiffusionSamplingParams( + height=480, + width=832, + num_chunks=20, + chunk_frames=12, + num_inference_steps=5, + slo_fps=16.0, + slo_max_batch=4, + extra_args={"session_id": "demo"}, + ), +) +``` + +## Sampling Parameters + +| Parameter | Required | Description | +|-----------|----------|-------------| +| `chunk_frames` | yes | Pixel frames produced per chunk | +| `num_chunks` | yes | Total number of chunks per request. Output frames = `num_chunks * chunk_frames` after VAE decode | +| `num_inference_steps` | yes | Denoising steps per chunk | +| `slo_fps` | no | Frames-per-second target. Enables SLO-adaptive batching that grows or shrinks per-step admission `B` to meet the budget | +| `slo_max_batch` | no, default 8 | Upper bound on per-step admission `B` | + +When `slo_fps` is set, the scheduler observes the wall-clock latency of each +micro-step and adjusts `B_target` for the next admission tick. If latency +exceeds the budget, `B` decreases; if it is comfortably under, `B` grows up +to `slo_max_batch`. + +## Supported Pipelines + +| Pipeline | Example models | Micro-step execution | +|----------|----------------|----------------------| +| `LingbotWorldFastPipeline` | `lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast` | Yes | +| All other diffusion pipelines | — | No | + +## Current Limitations + +- `max_num_seqs == 1` — exactly one in-flight request per engine. +- `cache_backend` is not supported together with `stream_batch`. +- Unsupported pipelines fail early during model loading. + +## When To Use It + +Use micro-step execution when: + +- The pipeline is built for streaming chunked output (video chunks, audio + segments) and you want temporal PP to overlap per-chunk denoising across + ranks. +- You want SLO-aware admission control to keep up with a real-time + frame-rate budget under variable load. + +For single-request stepwise execution without temporal PP, use +[Step Execution](step_execution.md) instead. + +For non-streaming PP (memory scaling on a normal diffusion pipeline), see +[Pipeline Parallelism Guide](parallelism/pipeline_parallel.md). + +## For Model Authors + +If you want to add micro-step execution support to a new diffusion pipeline, +see the implementation guide: +[Diffusion Micro-Step Execution Design](../../design/feature/diffusion_micro_step_execution.md). + +The pipeline must already support PP partitioning. See +[Pipeline Parallel Design](../../design/feature/pipeline_parallel.md).