diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md index 35ff6dae202..63575168f6d 100644 --- a/docs/contributing/model/adding_diffusion_model.md +++ b/docs/contributing/model/adding_diffusion_model.md @@ -776,6 +776,41 @@ turning them on. For Qwen-Image-style serving examples, document `--step-execution` as the feature gate and `--max-num-seqs N` as the companion batching knob. +### Micro-Step Execution + +See detailed design guide: [How to add micro-step execution support](../../design/feature/diffusion_micro_step_execution.md) + +Use this only when your pipeline is built for *streaming chunked* output +(e.g. video chunks) and you want stream batch — +at each tick every PP rank denoises a different chunk at a different +timestep, then chunks shift one rank downstream. + +Micro-step is a superset of step execution. On top of the four +step-execution methods, the pipeline must also implement: + +1. `set_pp_recv_dict_buffers()` to pre-register PP recv buffers the request will use. +2. `encode_chunk_inputs()` to build per-chunk initial latents and any + per-chunk conditioning. +3. `prefetch_tensors()` to pre-post the next-step recv (latents on the + first rank, intermediate tensors elsewhere) so it overlaps with compute. + +`denoise_step()` and `step_scheduler()` are also redesigned to operate on a +row-batched mix of chunks at different denoising step indices. +`post_decode()` becomes incremental — it runs on rank 0 every tick that has +freshly finished chunks, not just once at the end. + +Prerequisites: + +- The transformer is PP-partitioned (`make_layers`, `PPMissingLayer`) — see + [Pipeline Parallel](../../design/feature/pipeline_parallel.md). +- The pipeline inherits `PipelineParallelMixin` and `CFGParallelMixin`r. +- The pipeline declares `supports_micro_step_execution: ClassVar[bool] = + True`. +- Each request sets `chunk_frames`, `num_chunks`, and + `num_inference_steps` in `OmniDiffusionSamplingParams`. + +Reference implementation: `LingbotWorldFastPipeline` + ### Cache Acceleration #### TeaCache diff --git a/docs/design/feature/diffusion_micro_step_execution.md b/docs/design/feature/diffusion_micro_step_execution.md new file mode 100644 index 00000000000..f04a641b6c1 --- /dev/null +++ b/docs/design/feature/diffusion_micro_step_execution.md @@ -0,0 +1,239 @@ +# Adding Micro-Step Execution Support for Diffusion Pipelines + +This guide documents vLLM-Omni's micro-step diffusion contract for model +authors and contributors implementing `stream_batch=True` support for a +diffusion pipeline. + +For end-user enablement, supported models, and current limitations, see +[Micro-Step Execution](../../user_guide/diffusion/micro_step_execution.md). + +This document describes the micro-step execution contract only. It builds on +the request-/step-level contract in +[Step Execution](diffusion_step_execution.md) and the PP partitioning rules in +[Pipeline Parallel](pipeline_parallel.md). Read those first. + +## Current Support Scope + +`stream_batch` is **not** a generic diffusion toggle. It only works for +pipelines that implement the segmented stateful contract in +[`vllm_omni/diffusion/models/interface.py`](gh-file:vllm_omni/diffusion/models/interface.py) +as `SupportsMicroStepExecution`. + +This page is intentionally author-facing. Treat runtime enablement +(`stream_batch=True` when constructing `Omni`) as an opt-in user knob layered +on top of the implementation contract below. + +Current in-tree support: + +| Pipeline | Example models | Micro-step execution | +|----------|----------------|----------------------| +| `LingbotWorldFastPipeline` | `lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast` | Yes | +| All other diffusion pipelines | — | No | + +Current engine/runtime limitations: + +- `max_num_seqs == 1` — exactly one in-flight request per engine. +- `cache_backend` is not supported. +- Unsupported pipelines fail early during model loading instead of + failing on the first request. + +## Execution Contract + +Micro-step mode is driven by seven pipeline methods plus the shared mutable +request state object: + +- `prepare_encode(state)`: one-time request preparation (inherited from + step execution). +- `set_pp_recv_dict_buffers(state)`: register PP recv buffers and schema + cache for every `(name, segment_idx, batch_size)` this request will use. +- `encode_chunk_inputs(state, new_idxs)`: per-chunk latent initialization. + Returns a tensor stacked along dim 0 over `new_idxs`; the runner stitches + it onto `state.latents` and into each chunk's `chunk.latents`. +- `denoise_step(state, batch_size)`: row-batched noise prediction over + `batch_size` chunks at different denoising step indices. +- `step_scheduler(state, noise_pred, per_request_scheduler, batch_size)`: + per-row scheduler update on the last rank; sends the updated latents + back to rank 0 via the ring (rank 0 picks them up via `prefetch_tensors`, + not inside this call). Every rank increments `state.step_index`. +- `prefetch_tensors(state, batch_size)`: pre-post the next-step recv on the + comms stream so it overlaps with this rank's compute. +- `post_decode(state)`: incremental decode of one or more freshly-finished + chunks (called whenever the previous tick produced `finished_idxs`). + +The state lives in +[`vllm_omni/diffusion/worker/utils.py`](gh-file:vllm_omni/diffusion/worker/utils.py) +as `DiffusionRequestState` plus per-chunk `ChunkState` entries under +`state.extra["chunks"]`. + +The worker-side micro-step loop lives in +[`vllm_omni/diffusion/worker/diffusion_model_runner.py`](gh-file:vllm_omni/diffusion/worker/diffusion_model_runner.py) +under `execute_micro_step`: + +1. `prepare_encode()` runs once for a new request. +2. `set_pp_recv_dict_buffers()` runs immediately after, before any P2P. +3. Each micro-step: + - Rank 0 calls `post_decode()` for any chunks the previous tick + reported as finished, and accumulates the decoded output. + - Rank 0 and rank N-1 call `encode_chunk_inputs()` for their layout's + `new_idxs`. On rank 0 those are chunks freshly admitted this tick; + on rank N-1 they are the same chunks arriving at the back of the + ring N-1 ticks later — both ranks must produce identical initial + noise so the scheduler step on the last rank starts from the same + latents the first rank started from. + - All ranks with `chunk_indices` non-empty call `denoise_step()` then + `step_scheduler()`. The last rank also snapshots + `chunk.latents = state.latents[i:i+1]` per row so the next time those + chunks reach the last rank they can resume. + - `prefetch_tensors()` runs sized to the previous rank's load so the + next recv is posted before the next micro-step's compute. + +## Per-Rank Chunk Layout + +`StreamBatchScheduler` builds one `RankTask` per PP rank per micro-step: + +| Field | Meaning | +|-------|---------| +| `chunk_indices` | Chunks this rank will denoise this tick | +| `layout.circulating_idxs` | Chunks that drained from rank N-1 last tick still needing more steps, looping back to rank 0 | +| `layout.finished_idxs` | Chunks that completed `num_inference_steps` at rank N-1 last tick, ready for decode | +| `layout.new_idxs` | Chunks freshly admitted at rank 0 (up to SLO `B_target`, capped by `num_chunks - admitted_so_far`) | + +Layouts travel with their chunks: at rank R the current layout was built at +rank 0 R ticks ago, so `new_idxs` at rank R names the chunks admitted R ticks +ago and now reaching this rank for the first time on their first lap. + +The runner uses rank 0's layout to assemble `state.latents` along dim 0 from +the circulating snapshot + fresh-noise rows for `new_idxs`, and to +incrementally decode `finished_idxs`. The last rank does the same assembly +when it owns `new_idxs` so step_scheduler has the matching initial latents. + + +## Recommended Split + +| Request-level phase | Micro-step method | What belongs there | +|---------------------|-------------------|--------------------| +| Input validation, prompt encoding, timestep prep, per-request scheduler | `prepare_encode()` | Anything that should happen once per request | +| PP recv buffer / schema registration for every `(name, segment_idx, B)` | `set_pp_recv_dict_buffers()` | Iterate `1..slo_max_batch * num_inference_steps` | +| Per-chunk latent init (fresh randn, V2V VAE encode, anchor latents, plucker, etc.) | `encode_chunk_inputs()` | Build per-chunk initial latents (RNG must match across rank 0 and rank N-1); write per-chunk conditioning into `state.extra["chunks"][idx].extra` only on the rank that will read it | +| Row-batched transformer forward | `denoise_step()` | Row-aware kwargs, `predict_noise_maybe_with_cfg(buf_idx=step_index % 2, batch_size=B, preposted_its=...)` | +| Per-row `scheduler.step` and `state.step_index += 1` | `step_scheduler()` | `scheduler_step_maybe_with_cfg(..., receive_latents=False, batch_size=B)` | +| Pre-post next-step recv | `prefetch_tensors()` | `prefetch_tensors_maybe_with_cfg(buf_idx=step_index % 2, batch_size=B)` and stash on state | +| Per-chunk VAE decode | `post_decode()` | Decode the leading `len(finished_idxs)` rows of `state.latents` (runner narrows the slice for you) | + +Keep the micro-step path reusing the same helpers as the request-level path +whenever possible. Reimplementing the denoise loop from scratch is the easiest +way to introduce behavioral drift. + +## PP Communication + +`PipelineGroupCoordinator` provides three primitives the micro-step path +leans on: + +| Primitive | Purpose | +|-----------|---------| +| `set_recv_dict_buffer(name, segment_idx, template_dict, batch_size)` | Register the schema and pre-allocate a double-buffer pair (slots 0 and 1) for one logical channel | +| `pipeline_isend_tensor_dict(...)` | Async send of an arbitrary dict to the next rank | +| `pipeline_irecv_tensor_dict(..., buf_idx)` | Posts async recv into the pre-allocated buffer slot; returns an `AsyncIntermediateTensors`/`AsyncLatents` that defers `.wait()` until consumed | + +[`PipelineParallelMixin`](gh-file:vllm_omni/diffusion/distributed/pipeline_parallel.py) +already wraps these in `predict_noise_maybe_with_cfg`, +`scheduler_step_maybe_with_cfg`, and `prefetch_tensors_maybe_with_cfg`. +Pipelines should call those, not the coordinator primitives directly. + +### Why schemas must be pre-registered + +The first call to `pipeline_isend_tensor_dict` on a previously unseen +`(name, segment_idx, batch_size)` triggers a blocking schema exchange. +`set_pp_recv_dict_buffers` populates the cache identically on all ranks so the +schema path is never entered during the data loop. + +Enumerate every `B` the request can hit. For SLO-driven admission the upper +bound is `slo_max_batch * num_inference_steps`. + +### Double buffering + +Caller picks `buf_idx = state.step_index % 2` consistently across +`denoise_step`, `step_scheduler`, and `prefetch_tensors` on the same +micro-step. Alternating slots keeps the previous result readable while the +next recv is in flight. + +## Row-Batched Computation + +`state.batched_timesteps` is a 1-D tensor of length `B`; row `i` carries +`state.timesteps[chunks[i].step_index]`. Inside `denoise_step` and +`step_scheduler`, treat the leading dim as a mix of independent chunks at +*different* progress points. + +## Lingbot Reference + +[`pipeline_lingbot_world_fast.py`](gh-file:vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py) +is the reference for the *self-forcing* pattern and is split +correctly for the current contract: + +- `prepare_encode()` wraps `self.scheduler` in `LingbotFlowScheduler` so the + last denoise step returns the cached x0 and intermediate steps re-noise to + the next `t`. Two `torch.Generator`s are created on every rank: `seed_g` + for chunk noise (consumed identically on every rank that calls + `encode_chunk_inputs`) and `seed_g_addnoise` for the re-noise step + (consumed only on the last rank). +- `set_pp_recv_dict_buffers()` registers `("latents", -1, B)` and + `("intermediate", 0, B)` templates for every B in + `1..slo_max_batch * num_inference_steps`. +- `encode_chunk_inputs()` builds per-chunk noise on every rank using + `seed_g`. Only rank 0 (first stage) additionally stream-encodes per-chunk + `y` (with anchor-frame handling on the first chunk) and computes Plucker + embeddings, stashing both into `state.extra["chunks"][idx].extra` for + `denoise_step` to read. +- `denoise_step()` slices per-row `current_starts`, `y`, and + `c2ws_plucker_emb` from `state.extra["chunks"][idx]` keyed by the current + micro-step's `chunk_idxs`, then calls + `predict_noise_maybe_with_cfg(...)`. The per-chunk conditioning is only + read on the first stage; the last stage receives processed hidden states + via intermediate tensors. +- `step_scheduler()` rides the shared `scheduler_step_maybe_with_cfg(..., + receive_latents=False, batch_size=B, generator=state.extra["seed_g_addnoise"])` + and bumps `state.step_index`. +- `prefetch_tensors()` calls + `prefetch_tensors_maybe_with_cfg(buf_idx=state.step_index % 2, + batch_size=B)` and stashes results into `state.latents` (rank 0) or + `state.extra["preposted_its"]` (others). + +That decomposition is the target pattern for future micro-step models. + +## Rules For New Pipelines + +- Inherit `PipelineParallelMixin` and `CFGParallelMixin`. +- Declare `supports_micro_step_execution: ClassVar[bool] = True` on the + pipeline class. +- Pre-populate every `(name, segment_idx, batch_size)` in + `set_pp_recv_dict_buffers`. Skipping a `B` triggers the blocking schema + path and risks PP deadlock. +- Use `state.extra["chunks"][idx]` (a `ChunkState`) for per-chunk persistent + state: latents snapshot at the last rank, per-chunk scheduler, conditioning + slices. +- Do not put request-scoped scheduler state on `self.scheduler`. Deep-copy + it into `state.scheduler` during `prepare_encode` (the runner then + deep-copies that into each new `ChunkState.scheduler` on admission). +- Do not mutate `state.step_index` inside `denoise_step`. Only + `step_scheduler` should advance it. +- Use `buf_idx = state.step_index % 2` across `denoise_step`, + `step_scheduler`, and `prefetch_tensors`. + +## Validation Checklist + +Before marking a pipeline `supports_micro_step_execution = True`, verify: + +- `pipeline_parallel_size=2` and `pipeline_parallel_size>=3` both complete. +- `B=1` and `B>1` outputs match — verifies per-row scheduler / cache / + conditioning slicing. +- CFG-parallel and non-CFG paths both work if the pipeline supports them. + +## Related Files + +- Contract: [`vllm_omni/diffusion/models/interface.py`](gh-file:vllm_omni/diffusion/models/interface.py) +- State: [`vllm_omni/diffusion/worker/utils.py`](gh-file:vllm_omni/diffusion/worker/utils.py) +- Runner loop: [`vllm_omni/diffusion/worker/diffusion_model_runner.py`](gh-file:vllm_omni/diffusion/worker/diffusion_model_runner.py) +- Scheduler: [`vllm_omni/diffusion/sched/stream_batch_scheduler.py`](gh-file:vllm_omni/diffusion/sched/stream_batch_scheduler.py) +- PP coordinator: [`vllm_omni/diffusion/distributed/group_coordinator.py`](gh-file:vllm_omni/diffusion/distributed/group_coordinator.py) +- PP mixin: [`vllm_omni/diffusion/distributed/pipeline_parallel.py`](gh-file:vllm_omni/diffusion/distributed/pipeline_parallel.py) +- Reference pipeline: [`vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py`](gh-file:vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index d26a579edfd..168b91d7343 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -73,5 +73,6 @@ th { |`DyninOmniForConditionalGeneration` | Dynin-Omni | `snu-aidas/Dynin-Omni` | ✅︎ | | | | | `ErnieImagePipeline` | ERNIE-Image | `baidu/ERNIE-Image`, `baidu/ERNIE-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | |`HiDreamImagePipeline` | HiDream-I1-Full | `HiDream-ai/HiDream-I1-Full` | ✅︎ | ✅︎ | | | +|`LingbotWorldFastPipeline`| Lingbot-World-Fast | `robbyant/lingbot-world-fast`|✅︎ | | | | ✅︎ indicates the model is supported on that backend. Empty cells mean not listed as supported on that backend. diff --git a/docs/user_guide/diffusion/micro_step_execution.md b/docs/user_guide/diffusion/micro_step_execution.md new file mode 100644 index 00000000000..273a3964c63 --- /dev/null +++ b/docs/user_guide/diffusion/micro_step_execution.md @@ -0,0 +1,105 @@ +# Micro-Step Execution + +Micro-step execution is an opt-in diffusion execution mode enabled with +`stream_batch=True` when constructing `Omni`. It runs *temporal pipeline +parallelism* on streaming chunked diffusion: at each tick every PP rank +denoises a different chunk at a different timestep, then chunks shift one +rank downstream. One tick = one micro-step. + +It is not a generic diffusion toggle for every pipeline. Only pipelines that +implement the micro-step contract support it today. + +## Quick Start + +```python +import PIL.Image +import numpy as np + +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", + model_class_name="LingbotWorldFastPipeline", + stream_batch=True, + parallel_config=DiffusionParallelConfig(pipeline_parallel_size=4), + enforce_eager=True, +) + +outputs = omni.generate( + { + "prompt": "A sweeping cinematic journey along the Great Wall of China", + "multi_modal_data": { + "image": PIL.Image.open("anchor.jpg"), + "camera": { + "poses": np.load("poses.npy"), + "intrinsics": np.load("intrinsics.npy"), + }, + }, + }, + OmniDiffusionSamplingParams( + height=480, + width=832, + num_chunks=20, + chunk_frames=12, + num_inference_steps=5, + slo_fps=16.0, + slo_max_batch=4, + extra_args={"session_id": "demo"}, + ), +) +``` + +## Sampling Parameters + +| Parameter | Required | Description | +|-----------|----------|-------------| +| `chunk_frames` | yes | Pixel frames produced per chunk | +| `num_chunks` | yes | Total number of chunks per request. Output frames = `num_chunks * chunk_frames` after VAE decode | +| `num_inference_steps` | yes | Denoising steps per chunk | +| `slo_fps` | no | Frames-per-second target. Enables SLO-adaptive batching that grows or shrinks per-step admission `B` to meet the budget | +| `slo_max_batch` | no, default 8 | Upper bound on per-step admission `B` | + +When `slo_fps` is set, the scheduler observes the wall-clock latency of each +micro-step and adjusts `B_target` for the next admission tick. If latency +exceeds the budget, `B` decreases; if it is comfortably under, `B` grows up +to `slo_max_batch`. + +## Supported Pipelines + +| Pipeline | Example models | Micro-step execution | +|----------|----------------|----------------------| +| `LingbotWorldFastPipeline` | `lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast` | Yes | +| All other diffusion pipelines | — | No | + +## Current Limitations + +- `max_num_seqs == 1` — exactly one in-flight request per engine. +- `cache_backend` is not supported together with `stream_batch`. +- Unsupported pipelines fail early during model loading. + +## When To Use It + +Use micro-step execution when: + +- The pipeline is built for streaming chunked output (video chunks, audio + segments) and you want temporal PP to overlap per-chunk denoising across + ranks. +- You want SLO-aware admission control to keep up with a real-time + frame-rate budget under variable load. + +For single-request stepwise execution without temporal PP, use +[Step Execution](step_execution.md) instead. + +For non-streaming PP (memory scaling on a normal diffusion pipeline), see +[Pipeline Parallelism Guide](parallelism/pipeline_parallel.md). + +## For Model Authors + +If you want to add micro-step execution support to a new diffusion pipeline, +see the implementation guide: +[Diffusion Micro-Step Execution Design](../../design/feature/diffusion_micro_step_execution.md). + +The pipeline must already support PP partitioning. See +[Pipeline Parallel Design](../../design/feature/pipeline_parallel.md). diff --git a/docs/user_guide/examples/offline_inference/lingbot_world_fast.md b/docs/user_guide/examples/offline_inference/lingbot_world_fast.md new file mode 100644 index 00000000000..2669f6e6ecd --- /dev/null +++ b/docs/user_guide/examples/offline_inference/lingbot_world_fast.md @@ -0,0 +1,49 @@ +# Lingbot World Fast Offline Inference + +Lingbot World Fast is an autoregressive diffusion model that uses a reference image, a text prompt and a set of camera positions to generate a video. + +## Video Generation + +First, download the model weights using `examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py`. + +The simplest way to run offline generation is to use the script on `examples/offline_inference/lingbot_world_fast/end2end.py`. The core of this script is done by: + +```python +from vllm_omni.entrypoints.omni import Omni + +if __name__ == "__main__": + omni = Omni(model="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", model_class_name="LingbotWorldFastPipeline") + outputs = omni.generate( + { + "prompt": "A journey along the Great Wall of China", + "multi_modal_data": { + "image": "input.png", + "camera": { + "poses": np.load("path/to/poses.npy") + "intrinsics": np.load("path/to/intrinsics.npy") + } + }, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_frames=num_frames, + frame_rate=fps, + ), + ) + export_to_video(outputs[0], "output.png") +``` + +## Generation Parameters + +| Parameter | Type | Default | Description | +| --------------------- | ----- | ------- | ----------------------------------- | +| `height` | int | None (computed from image) | Image height in pixels | +| `width` | int | None (computed from image) | Image width in pixels | +| `num_frames` | int | 81 | Number of frames to generate | +| `fps` | int | 16 | Frames per second | +| `seed` | int | 42 | Optional random seed | +| `prompt` | str | "" | Text prompt | +| `negative_prompt` | str | None | Negative prompt | +| `image` | str | Required | Path to reference image | +|`camera-path` | str | Required | Path to folder with `poses.npy` and `intrinsics.npy`| diff --git a/docs/user_guide/examples/online_serving/lingbot_world_fast.md b/docs/user_guide/examples/online_serving/lingbot_world_fast.md new file mode 100644 index 00000000000..106f76a6a05 --- /dev/null +++ b/docs/user_guide/examples/online_serving/lingbot_world_fast.md @@ -0,0 +1,51 @@ +# Lingbot World Fast Offline Inference + +Lingbot World Fast is an autoregressive diffusion model that uses a reference image, a text prompt and a set of camera positions to generate a video. The online serving model of this model adds a feature that is not implemented in the original model: video extension. + +## Quickstart + +The easiest way to launch a server running the Lingbot World Fast model is by using the script `examples/online_serving/lingbot_world_fast/run_server.sh`. + +Once the server is launched, the client can send requests to its websocket at `/v1/realtime/world/camera`. The easiest way to interact with the server is using the script `examples/online_serving/lingbot_world_fast/openai_client.py`. Its command line options are described below. + +| Parameter | Type | Default | Description | +| --------------------- | ----- | ------- | ----------------------------------- | +| `height` | int | None (computed from image) | Image height in pixels | +| `width` | int | None (computed from image) | Image width in pixels | +| `num_frames` | int | 81 | Number of frames to generate | +| `fps` | int | 16 | Frames per second | +| `seed` | int | 42 | Optional random seed | +| `prompt` | str | "" | Text prompt | +| `negative_prompt` | str | None | Negative prompt | +| `image` | str | Required | Path to reference image | +|`camera-path` | str | Required | Path to folder with `poses.npy` and `intrinsics.npy`| +| `num-calls` | int | 1 | Makes an additional `num-calls - 1` video extension calls with `num_frames` frames | +| `num-skip-frames` | int | 4 | Extension calls have artifacts on the first couple frames. Discard them. | +| `session-id` | str | None | Session id to control whether to trigger a video extension call | + +## Video Extension + +The idea of video extension is to allow the user to generate further frames for the same video efficiently. This is done by the vllm-omni implementation by storing the KV-cache of the generated video by default. This way, if the next request uses the same session-id, the pipeline will enter extension mode. So, the newly generated frames will use the previously generated frames as context. This is done by storing the KV-cache as mentioned above. No frame information, whether in latent space or RGB values, is kept in the server. + +This feature is limited by the fact that the model has not been trained to perform this task. So, the steering capacity of the user is limited. Namely, the reference image and changes to the text prompt are ignored. The best tool the user has is to provide camera positions. In the end, video extension is more of a demonstration of the power and features of VLLM-Omni than of Lingbot World in itself. + +## API + +The server uses a websocket endpoint located at `/v1/realtime/world/camera`. It makes available two tasks: `infer` and `reset` which can be controlled by the "endpoint" key of the request. + +By default, the server uses the `infer` task, which checks the `session-id` field and compares it to the one used on the last infer call. If they are the same, it triggers an extension call at the pipeline level. Note that only the KV-cache of the last request is stored to mitigate Out of Memory problems at the GPU level. Otherwise, it generates the video from scratch. Notice that when doing an extension task, no reference image should be provided (it would be ignored anyway). + +The `reset` endpoint does not immediately evict the KV cache in the GPU, but instead it forces a reset on the next `infer` call independently of the value of `session-id`. + +The endpoint sends the resulting frames in groups of 4 to mitigate package loss problems. It is the client's role to concatenate the different frames to obtain the final video. + +## Example materials + +??? abstract "run_server.sh" + ``````sh + --8<-- "examples/online_serving/lingbot_world_fast/run_server.sh" + `````` +??? abstract "openai_client.py" + ``````sh + --8<-- "examples/online_serving/lingbot_world_fast/openai_client.py" + `````` diff --git a/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py new file mode 100644 index 00000000000..6db8ebb5c6b --- /dev/null +++ b/examples/offline_inference/lingbot_world_fast/download_lingbot_world_fast.py @@ -0,0 +1,93 @@ +import argparse +import json +import os +import site +import tempfile +import time +from pathlib import Path + +from huggingface_hub import snapshot_download + +DEPENDENCY_REPO = "https://github.com/Robbyant/lingbot-world" +DEPENDENCY_BRANCH = "main" +CACHE_DIR = Path(tempfile.gettempdir()) / "vllm-omni-dependency" +LOCK_FILE = CACHE_DIR / ".install.lock" +DEPENDENCY_DIR = CACHE_DIR / "Lingbot-World" + + +def download_dependency(): + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + # write .pth to site-packages + site_packages = Path(site.getsitepackages()[0]) + pth_file = site_packages / "vllm_omni_dependency.pth" + pth_file.write_text(str(DEPENDENCY_DIR)) + print(f"Added {DEPENDENCY_DIR} to site-packages via {pth_file}") + + +def timed_download(repo_id: str, local_dir: str, allow_patterns: list | None = None): + """Download files from HF repo and log time + destination.""" + if os.path.exists(local_dir): + print(f"Directory {local_dir} already exists. Skipping download.") + return + print(f"Starting download from {repo_id} into {local_dir}") + start_time = time.time() + + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + local_dir_use_symlinks=False, + allow_patterns=allow_patterns, + ) + + elapsed = time.time() - start_time + print(f"✅ Finished downloading {repo_id} in {elapsed:.2f} seconds. Files saved at: {local_dir}") + + +def main(output_dir: str): + lingbot_base_dir = os.path.join(output_dir, "lingbot-world-base-cam") + + # Base Model + timed_download( + repo_id="robbyant/lingbot-world-base-cam", + local_dir=lingbot_base_dir, + allow_patterns=["google/*", "models_t5_umt5-xxl-enc-bf16.pth", "Wan2.1_VAE.pth"], + ) + + lingbot_fast_dir = os.path.join(lingbot_base_dir, "Lingbot-World-Fast") + + timed_download(repo_id="robbyant/lingbot-world-fast", local_dir=lingbot_fast_dir) + + # Lingbot World does not come with config.json which is required by diffusers + config = { + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + } + + with open( + os.path.join(output_dir, "lingbot-world-base-cam", "Lingbot-World-Fast", "config.json"), "w", encoding="utf-8" + ) as f: + json.dump(config, f, indent=2) + + print(f"model_index.json created at {os.path.join(output_dir, 'model_index.json')}") + + download_dependency() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download models from Hugging Face") + parser.add_argument( + "--output-dir", type=str, default="./lingbot_world", help="Base directory to save downloaded models" + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/examples/offline_inference/lingbot_world_fast/end2end.py b/examples/offline_inference/lingbot_world_fast/end2end.py new file mode 100644 index 00000000000..44d3df40c04 --- /dev/null +++ b/examples/offline_inference/lingbot_world_fast/end2end.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Image-Camera to Video generation example using Lingbot World Fast + +Usage example: + python end2end.py --model path/to/lingbot-fast --image path/to/image --camera-path path/to/camera + --width 832 --height 480 --prompt "Walk in the Great Wall of China" --output output.mp4 +""" + +import argparse +import os +import time +from pathlib import Path + +import numpy as np +import PIL.Image +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate a video from an image (Wan2.2, LTX2, HunyuanVideo-1.5).") + parser.add_argument( + "--model", + default="lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast", + help="Diffusers I2V model ID or local path (Wan2.2 or HunyuanVideo-1.5).", + ) + parser.add_argument("--image", required=True, help="Path to input image.") + parser.add_argument("--camera-path", default=None, help="Path to input camera positions") + parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.") + parser.add_argument("--negative-prompt", default="", help="Negative prompt.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument( + "--height", type=int, default=None, help="Video height (auto-calculated from image if not set)." + ) + parser.add_argument("--width", type=int, default=None, help="Video width (auto-calculated from image if not set).") + parser.add_argument("--num-frames", type=int, default=81, help="Number of frames.") + parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") + parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") + parser.add_argument( + "--enable-diffusion-pipeline-profiler", + action="store_true", + help="Enable diffusion pipeline profiler to display stage durations.", + ) + return parser.parse_args() + + +def calculate_dimensions( + image: PIL.Image.Image, + max_area: int = 480 * 832, + mod_value: int = 16, +) -> tuple[int, int]: + """Calculate output dimensions maintaining aspect ratio.""" + aspect_ratio = image.height / image.width + + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + return height, width + + +def main(): + args = parse_args() + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + model_class_name = "LingbotWorldFastPipeline" + + # Load input image + image = PIL.Image.open(args.image).convert("RGB") + + num_inference_steps = 40 + + # Calculate dimensions if not provided + height = args.height + width = args.width + if height is None or width is None: + max_area = 480 * 832 + mod_value = 16 + calc_height, calc_width = calculate_dimensions(image, max_area=max_area, mod_value=mod_value) + height = height or calc_height + width = width or calc_width + + # Resize image to target dimensions + image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + # Check if profiling is requested via environment variable + profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + + omni = Omni( + model=args.model, + parallel_config=None, + model_class_name=model_class_name, + stage_init_timeout=6000, + init_timeout=6000, + ) + + if profiler_enabled: + print("[Profiler] Starting profiling...") + omni.start_profile() + + # Print generation configuration + print(f"\n{'=' * 60}") + print("Generation Configuration:") + print(f" Model: {args.model}") + print(f" Inference steps: {num_inference_steps}") + print(f" Frames: {args.num_frames}") + print(f" Video size: {width}x{height}") + print(f"{'=' * 60}\n") + + # omni.generate() returns Generator[OmniRequestOutput, None, None] + + multi_modal_data = {"image": image} + + if args.camera_path is not None: + poses = np.load(os.path.join(args.camera_path, "poses.npy")) + intrinsics = np.load(os.path.join(args.camera_path, "intrinsics.npy")) + + multi_modal_data["camera"] = {"poses": poses, "intrinsics": intrinsics} + + generation_start = time.perf_counter() + frames = omni.generate( + { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "multi_modal_data": multi_modal_data, + }, + OmniDiffusionSamplingParams( + height=height, + width=width, + generator=generator, + num_frames=args.num_frames, + frame_rate=args.fps, + extra_args={"session_id": "offline_generation"}, + ), + ) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + if isinstance(frames, list): + frames = frames[0] if frames else None + + if isinstance(frames, OmniRequestOutput): + if frames.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation." + ) + if frames.is_pipeline_output and frames.request_output is not None: + inner_output = frames.request_output + if isinstance(inner_output, OmniRequestOutput): + frames = inner_output + if isinstance(frames, OmniRequestOutput): + if frames.images: + if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: + frames = frames.images[0] + elif len(frames.images) == 1 and isinstance(frames.images[0], dict): + frames = frames.images[0].get("frames") or frames.images[0].get("video") + else: + frames = frames.images + else: + raise ValueError("No video frames found in OmniRequestOutput.") + + if isinstance(frames, list) and frames: + first_item = frames[0] + if isinstance(first_item, tuple) and len(first_item) == 2: + frames = first_item + elif isinstance(first_item, dict): + frames = first_item.get("frames") or first_item.get("video") + elif isinstance(first_item, list): + frames = first_item + + if isinstance(frames, tuple) and len(frames) == 2: + frames = frames + elif isinstance(frames, dict): + frames = frames.get("frames") or frames.get("video") + + if frames is None: + raise ValueError("No video frames found in output.") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + try: + from diffusers.utils import export_to_video + except ImportError: + raise ImportError("diffusers is required for export_to_video.") + + def _normalize_frame(frame): + if isinstance(frame, torch.Tensor): + frame_tensor = frame.detach().cpu() + if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor[0] + if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): + frame_tensor = frame_tensor.permute(1, 2, 0) + if frame_tensor.is_floating_point(): + frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 + return frame_tensor.float().numpy() + if isinstance(frame, np.ndarray): + frame_array = frame + if frame_array.ndim == 4 and frame_array.shape[0] == 1: + frame_array = frame_array[0] + if np.issubdtype(frame_array.dtype, np.integer): + frame_array = frame_array.astype(np.float32) / 255.0 + return frame_array + try: + from PIL import Image + except ImportError: + Image = None + if Image is not None and isinstance(frame, Image.Image): + return np.asarray(frame).astype(np.float32) / 255.0 + return frame + + def _ensure_frame_list(video_array): + if isinstance(video_array, list): + if len(video_array) == 0: + return video_array + first_item = video_array[0] + if isinstance(first_item, np.ndarray): + if first_item.ndim == 5: + return list(first_item[0]) + if first_item.ndim == 4: + if len(video_array) == 1: + return list(first_item) + return list(first_item) + if first_item.ndim == 3: + return video_array + return video_array + if isinstance(video_array, np.ndarray): + if video_array.ndim == 5: + return list(video_array[0]) + if video_array.ndim == 4: + return list(video_array) + if video_array.ndim == 3: + return [video_array] + return video_array + + # frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images + # export_to_video expects a list of frames with values in [0, 1] + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + if video_tensor.dim() == 5: + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + elif isinstance(frames, np.ndarray): + video_array = frames + if video_array.ndim == 5: + video_array = video_array[0] + if np.issubdtype(video_array.dtype, np.integer): + video_array = video_array.astype(np.float32) / 255.0 + elif isinstance(frames, list): + if len(frames) == 0: + raise ValueError("No video frames found in output.") + video_array = [_normalize_frame(frame) for frame in frames] + else: + video_array = frames + + video_array = _ensure_frame_list(video_array) + + export_to_video(video_array, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + if profile_results and isinstance(profile_results, dict): + traces = profile_results.get("traces", []) + print("\n" + "=" * 60) + print("PROFILING RESULTS:") + for rank, trace in enumerate(traces): + print(f"\nRank {rank}:") + if trace: + print(f" • Trace: {trace}") + if not traces: + print(" No traces collected.") + print("=" * 60) + else: + print("[Profiler] No valid profiling data returned.") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot_world_fast/openai_client.py b/examples/online_serving/lingbot_world_fast/openai_client.py new file mode 100644 index 00000000000..81beeec2d54 --- /dev/null +++ b/examples/online_serving/lingbot_world_fast/openai_client.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Lingbot World Fast realtime camera client. + +Talks to the WebSocket endpoint ``/v1/realtime/world/camera`` exposed by +``vllm serve --omni`` when the loaded pipeline is ``LingbotWorldFastPipeline``. + +The endpoint speaks the OpenPI policy protocol on the wire: + 1. Connect -> server sends msgpack(CameraServerConfig) + 2. Client send msgpack(request) + 3. Server send msgpack(ndarray) # generated frames + +The ``request`` payload sent here contains: + - "image": numpy array, the input image + - "prompt": str, the text prompt describing the desired motion + - "camera": {"poses": ndarray, "intrinsics": ndarray} + +Usage: + python openai_chat_client.py \\ + --image path/to/image.png \\ + --camera-path path/to/camera_dir \\ + --prompt "Walk along the Great Wall of China" \\ + --output frames.npy +""" + +import argparse +from argparse import Namespace +from pathlib import Path + +import numpy as np +import PIL.Image +import websockets.sync.client as ws_sync +from diffusers.utils import export_to_video + +try: + from openpi_client import msgpack_numpy +except ImportError as exc: + raise SystemExit("This example requires `openpi-client`. Install it with `pip install openpi-client`.") from exc + + +def _pack(obj): + return msgpack_numpy.packb(obj) + + +def _unpack(data): + return msgpack_numpy.unpackb(data) + + +def _load_image(path: str | None) -> np.ndarray | None: + image = PIL.Image.open(path).convert("RGB") + return np.asarray(image) + + +def _load_camera(camera_dir: str) -> dict: + camera_path = Path(camera_dir) + poses = np.load(camera_path / "poses.npy") + intrinsics = np.load(camera_path / "intrinsics.npy") + return {"poses": poses, "intrinsics": intrinsics} + + +def generate_video(args: Namespace) -> list[np.ndarray]: + """Send inference requests and return the generated frames.""" + image = _load_image(args.image) + full_camera = _load_camera(args.camera_path) + + extra_body = { + "height": args.height, + "width": args.width, + "num_frames": args.num_frames, + "fps": args.fps, + "session_id": args.session_id, + "seed": args.seed, + } + + video = [] + starting_frame = 0 + + for i in range(args.num_calls): + camera = { + "poses": full_camera["poses"][starting_frame : starting_frame + args.num_frames], + "intrinsics": full_camera["intrinsics"][starting_frame : starting_frame + args.num_frames], + } + + request: dict = {"prompt": args.prompt, "camera": camera, "extra_body": extra_body} + if i == 0: + request["image"] = image + + request["session_id"] = args.session_id + + endpoint = f"{args.server.rstrip('/')}/v1/realtime/world/camera" + print(f"Connecting to {endpoint} ...") + + with ws_sync.connect(endpoint, max_size=None, ping_interval=None, ping_timeout=None) as ws: + # 1. Server sends CameraServerConfig on connect. + _unpack(ws.recv()) + + # 2. Send request. + print( + f"Sending request image= ({str(image.shape) if request.get('image', None) is not None else 'None'}, " + f"poses={camera['poses'].shape}, intrinsics={camera['intrinsics'].shape})..." + ) + ws.send(_pack(request)) + + # 3. Receive generated frames. + chunks: list[np.ndarray] = [] + total = None + while total is None or len(chunks) < total: + msg = _unpack(ws.recv()) + if isinstance(msg, dict) and msg.get("type") == "error": + raise RuntimeError(f"Server error: {msg.get('message')}") + if not isinstance(msg, dict) or msg.get("type") != "frame": + continue # ignore anything unexpected + total = msg["total"] + chunks.append(msg["video"]) + print(f" received chunk {msg['index'] + 1}/{total}") + + clip = np.concatenate(chunks, axis=0) + # The first chunk of frames returned was used to condition the video continuation but they are not useful + if i != 0: + clip = clip[args.num_skip_frames :] + for frame in clip: + video.append(frame) + + starting_frame += args.num_frames + + return video + + +def main(): + parser = argparse.ArgumentParser(description="Lingbot World Fast realtime camera client") + parser.add_argument("--image", "-i", required=True, help="Path to input image.") + parser.add_argument( + "--camera-path", + "-c", + required=True, + help="Directory containing poses.npy and intrinsics.npy.", + ) + parser.add_argument( + "--prompt", + "-p", + default="Walk along the Great Wall of China", + help="Text prompt describing the desired motion.", + ) + parser.add_argument( + "--server", + "-s", + default="ws://localhost:8091", + help="WebSocket server URL (ws:// or wss://).", + ) + parser.add_argument("--session-id", default=None, help="Optional session id.") + parser.add_argument( + "--output", + "-o", + default="lingbot-video.mp4", + help="Path to save the returned frames (npy).", + ) + parser.add_argument("--width", type=int, default=832) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--fps", type=int, default=16) + parser.add_argument("--num-frames", type=int, default=24) + parser.add_argument("--num-calls", type=int, default=1) + parser.add_argument("--num-skip-frames", type=int, default=4) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + args = parser.parse_args() + + frames = generate_video(args) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + export_to_video(frames, str(output_path), fps=args.fps) + print(f"Saved generated video to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lingbot_world_fast/run_server.sh b/examples/online_serving/lingbot_world_fast/run_server.sh new file mode 100755 index 00000000000..ef3e2caab93 --- /dev/null +++ b/examples/online_serving/lingbot_world_fast/run_server.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Bagel online serving startup script + +MODEL="${MODEL:-../../offline_inference/lingbot_world_fast/lingbot_world/lingbot-world-base-cam/Lingbot-World-Fast}" +PORT="${PORT:-8091}" + +echo "Starting Lingbot World server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +vllm serve "$MODEL" --omni \ + --port "$PORT" \ + --model-class-name LingbotWorldFastPipeline \ + --stage-init-timeout 6000 \ + --init-timeout 6000 \ + --ws-max-size 268435456 \ + --ws wsproto diff --git a/tests/diffusion/models/lingbot_world_fast/__init__.py b/tests/diffusion/models/lingbot_world_fast/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/models/lingbot_world_fast/conftest.py b/tests/diffusion/models/lingbot_world_fast/conftest.py new file mode 100644 index 00000000000..0e130d2e678 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/conftest.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared stubs and dummy-input helpers for Lingbot World Fast L1 tests. + +The real pipeline pulls in T5-XXL, the Wan VAE and a 5B-parameter transformer +on construction. Tests exercise only the state container, msgpack protocol and +scheduler, so these stubs replace the heavy dependencies with the smallest +implementations that match the call sites in +``vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py``. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image +from torch import nn + +if TYPE_CHECKING: + from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + LingbotWorldFastPipeline, + ) + + +class StubT5Encoder: + """Minimal stand-in for ``T5EncoderModel``. + + The pipeline calls ``self.text_encoder([prompt], device)`` and expects a + list of token-embedding tensors, one per prompt. + """ + + def __init__(self, text_len: int = 512, dim: int = 32, dtype: torch.dtype = torch.float32) -> None: + self.text_len = text_len + self.dim = dim + self.dtype = dtype + + def __call__(self, prompts: list[str], device: torch.device) -> list[torch.Tensor]: + return [torch.zeros(self.text_len, self.dim, dtype=self.dtype, device=device) for _ in prompts] + + +class StubVAE: + """Stand-in for ``Wan2_1_VAE``. + + ``encode([pixels])`` returns a list with one latent tensor shaped + ``[16, F_lat, lat_h, lat_w]`` where ``F_lat = (F + 3) // 4`` so the + pipeline's masking / slicing math is exercised normally. + ``decode([latents])`` returns the latents unchanged (caller indexes [0]). + """ + + vae_stride = (4, 8, 8) + + def encode(self, video_list: list[torch.Tensor]) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for v in video_list: + # v: [C, F, H, W] + _, f, h, w = v.shape + lat_f = (f + self.vae_stride[0] - 1) // self.vae_stride[0] + lat_h = h // self.vae_stride[1] + lat_w = w // self.vae_stride[2] + out.append(torch.zeros(16, lat_f, lat_h, lat_w, dtype=v.dtype, device=v.device)) + return out + + def decode(self, latents_list: list[torch.Tensor]) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for latents in latents_list: + # latents: [16, F_lat, lat_h, lat_w]; produce pixels at the inverse stride. + _, f_lat, lat_h, lat_w = latents.shape + f = f_lat * self.vae_stride[0] + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + out.append(torch.zeros(3, f, h, w, dtype=latents.dtype, device=latents.device)) + return out + + +class StubWanModelFast(nn.Module): + """Stand-in for ``WanModelFast``. + + Returns zeros shaped like the input latent, and bumps the local/global + index tensors so chunk-boundary arithmetic is exercised by the pipeline. + """ + + def __init__(self, *, dim: int = 16, num_heads: int = 4, num_layers: int = 2) -> None: + super().__init__() + self.config = SimpleNamespace( + dim=dim, + num_heads=num_heads, + num_layers=num_layers, + local_attn_size=-1, + ) + + def forward(self, *, x, t, **kwargs): # noqa: ARG002 — matches pipeline call site + del t, kwargs + return [torch.zeros_like(x[0])] + + @classmethod + def from_pretrained(cls, *args, **kwargs): # noqa: ARG003 + return cls() + + +def make_dummy_camera_inputs(num_frames: int) -> dict[str, np.ndarray]: + """Camera payload matching the shape the pipeline expects.""" + intrinsics = np.eye(4, dtype=np.float32) + poses = np.tile(np.eye(4, dtype=np.float32), (num_frames, 1, 1)) + return {"intrinsics": intrinsics, "poses": poses} + + +def make_dummy_image(width: int = 64, height: int = 64) -> Image.Image: + return Image.new("RGB", (width, height), color=(128, 128, 128)) + + +def make_stubbed_pipeline( + *, + device: torch.device | None = None, + dim: int = 16, + num_heads: int = 4, + num_layers: int = 2, + target_dtype: torch.dtype = torch.float32, +) -> LingbotWorldFastPipeline: + """Build a ``LingbotWorldFastPipeline`` backed by the conftest stubs. + + Skips the real ``__init__`` (which loads umt5-xxl, Wan VAE and a 5B + transformer) via ``object.__new__`` and assigns the stubs directly, + mirroring ``_make_i2v_pipeline`` in ``tests/diffusion/models/wan2_2``. + The returned pipeline is suitable for driving ``.forward(req)`` end-to-end + against ``LingbotWorldFastState`` without touching real weights. + """ + from vllm_omni.diffusion.models.lingbot_world_fast.fm_solvers_unipc import ( + FlowUniPCMultistepScheduler, + ) + from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + LingbotWorldFastPipeline, + ) + from vllm_omni.diffusion.models.lingbot_world_fast.state_lingbot_world_fast import ( + LingbotWorldFastState, + ) + + if device is None: + device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu", 0) + + parallel_config = SimpleNamespace(world_size=1) + od_config = SimpleNamespace( + model="stub/Lingbot-World-Fast", + parallel_config=parallel_config, + dtype=target_dtype, + ) + + pipeline = object.__new__(LingbotWorldFastPipeline) + nn.Module.__init__(pipeline) + pipeline.od_config = od_config + pipeline.parallel_config = parallel_config + pipeline.device = device + pipeline.target_dtype = target_dtype + pipeline.control_type = "cam" + pipeline.num_train_timesteps = CONFIG["num_train_timesteps"] + pipeline.sp_size = parallel_config.world_size + pipeline.state = LingbotWorldFastState() + pipeline.text_encoder = StubT5Encoder(dim=dim, dtype=target_dtype) + pipeline.vae = StubVAE() + pipeline.vae_stride = CONFIG["vae_stride"] + pipeline.patch_size = CONFIG["patch_size"] + pipeline.model = StubWanModelFast(dim=dim, num_heads=num_heads, num_layers=num_layers).to(device) + pipeline.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=CONFIG["num_train_timesteps"], + shift=1, + use_dynamic_shifting=False, + ) + pipeline.sample_neg_prompt = CONFIG["negative_prompt_sample"] + return pipeline diff --git a/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py b/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py new file mode 100644 index 00000000000..902a4afe711 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_protocol_validation.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 protocol-validation tests for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +import asyncio +import contextlib +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +from tests.diffusion.models.lingbot_world_fast.conftest import make_dummy_camera_inputs +from vllm_omni.entrypoints.openai.realtime.world.camera_connection import CHUNK_FRAMES, WorldCameraRealtimeConnection +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ( + CameraServerConfig, + ServingRealtimeWorldCamera, +) + +# The endpoint's wire codec is provided by the optional openpi-client dep. +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Mock infrastructure +# --------------------------------------------------------------------------- + + +class MockWebSocket: + def __init__(self, incoming: Iterable[dict[str, Any]] | None = None) -> None: + self._incoming: list[dict[str, Any]] = list(incoming or []) + self._idx = 0 + self.sent_bytes: list[bytes] = [] + self.sent_text: list[str] = [] + self.accepted = False + self.closed = False + + async def accept(self) -> None: + self.accepted = True + + async def receive(self) -> dict[str, Any]: + if self._idx >= len(self._incoming): + return {"type": "websocket.disconnect"} + msg = self._incoming[self._idx] + self._idx += 1 + return msg + + async def send_bytes(self, data: bytes) -> None: + self.sent_bytes.append(data) + + async def send_text(self, data: str) -> None: + self.sent_text.append(data) + + async def close(self) -> None: + self.closed = True + + +def _bytes_frame(payload: Any) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": msgpack_numpy.packb(payload)} + + +def _raw_bytes_frame(data: bytes) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": data} + + +class _AsyncIter: + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class FakeResult: + def __init__(self, frames: np.ndarray) -> None: + self.images = [frames] + + +class FakeEngineClient: + """Stand-in for ``AsyncOmni`` engine client. + + Captures the ``generate(...)`` arguments and yields a single fake result + so the connection's framing logic can be exercised without a real engine. + """ + + def __init__(self, frames: np.ndarray | None = None) -> None: + if frames is None: + # Default: CHUNK_FRAMES*(1+1/2) RGB frames so we exercise the chunk split. + frames = np.zeros((CHUNK_FRAMES * 3 // 2, 16, 16, 3), dtype=np.uint8) + self._frames = frames + self.calls: list[dict[str, Any]] = [] + self.fail_with: Exception | None = None + # Attributes consulted by ``CameraServerConfig.from_model_config``. + self.model_config = {"pipeline": "lingbot_world_fast", "resolution": [480, 832], "fps": 16} + + def generate(self, *, prompt, request_id, sampling_params_list): + self.calls.append( + { + "prompt": prompt, + "request_id": request_id, + "sampling_params_list": sampling_params_list, + } + ) + if self.fail_with is not None: + raise self.fail_with + return _AsyncIter([FakeResult(self._frames)]) + + +def _make_serving(engine_client: FakeEngineClient | None = None) -> ServingRealtimeWorldCamera: + return ServingRealtimeWorldCamera(engine_client=engine_client or FakeEngineClient(), model_name="lingbot") + + +# --------------------------------------------------------------------------- +# CameraServerConfig +# --------------------------------------------------------------------------- + + +def test_camera_server_config_round_trip_through_dict() -> None: + cfg = CameraServerConfig.from_model_config({"pipeline": "lingbot", "fps": 16}) + out = cfg.to_dict() + assert isinstance(out, dict) + assert out["pipeline"] == "lingbot" + assert out["fps"] == 16 + + +# --------------------------------------------------------------------------- +# msgpack-numpy round-trip +# --------------------------------------------------------------------------- + + +def test_msgpack_camera_payload_round_trip() -> None: + camera = make_dummy_camera_inputs(num_frames=8) + payload = { + "image": np.random.randint(0, 255, size=(8, 8, 3), dtype=np.uint8), + "prompt": "walk forward", + "camera": camera, + "session_id": "sess-1", + "extra_body": {"height": 240, "width": 416, "num_frames": 25, "fps": 16}, + } + packed = msgpack_numpy.packb(payload) + decoded = msgpack_numpy.unpackb(packed) + + assert decoded["prompt"] == "walk forward" + assert decoded["session_id"] == "sess-1" + assert decoded["extra_body"] == payload["extra_body"] + + image_out = decoded["image"] + assert isinstance(image_out, np.ndarray) + assert image_out.shape == payload["image"].shape + assert image_out.dtype == payload["image"].dtype + np.testing.assert_array_equal(image_out, payload["image"]) + + for key in ("intrinsics", "poses"): + arr_in = camera[key] + arr_out = decoded["camera"][key] + assert arr_out.dtype == arr_in.dtype + assert arr_out.shape == arr_in.shape + np.testing.assert_array_equal(arr_out, arr_in) + + +# --------------------------------------------------------------------------- +# Connection-level framing +# --------------------------------------------------------------------------- + + +def test_handshake_sends_camera_server_config_on_connect() -> None: + serving = _make_serving() + ws = MockWebSocket(incoming=[]) # client disconnects immediately after handshake + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert ws.accepted is True + assert len(ws.sent_bytes) == 1 + handshake = msgpack_numpy.unpackb(ws.sent_bytes[0]) + assert isinstance(handshake, dict) + assert handshake["pipeline"] == "lingbot_world_fast" + + +def test_invalid_msgpack_returns_error_frame_and_keeps_connection_open() -> None: + serving = _make_serving() + ws = MockWebSocket( + incoming=[ + _raw_bytes_frame(b"\x99not-msgpack"), # malformed + _bytes_frame({"endpoint": "reset"}), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # First sent message is the handshake, next is the error frame, then the + # "reset successful" text reply — proving the connection stayed open. + assert len(ws.sent_bytes) >= 2 + error = msgpack_numpy.unpackb(ws.sent_bytes[1]) + assert error == {"type": "error", "message": "Invalid request payload"} + assert ws.sent_text == ["reset successful"] + + +def test_non_dict_payload_is_rejected_with_error_frame() -> None: + serving = _make_serving() + ws = MockWebSocket(incoming=[_bytes_frame([1, 2, 3])]) # list, not dict + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert len(ws.sent_bytes) >= 2 + error = msgpack_numpy.unpackb(ws.sent_bytes[1]) + assert error["type"] == "error" + + +def test_reset_endpoint_clears_session_and_returns_text_ack() -> None: + engine_client = FakeEngineClient() + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + # Pre-populate as if a prior session were active. + serving._current_session_id = "session-a" + + ws = MockWebSocket(incoming=[_bytes_frame({"endpoint": "reset"})]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert serving._current_session_id is None + assert ws.sent_text == ["reset successful"] + + +def test_infer_frames_are_chunked() -> None: + num_frames = CHUNK_FRAMES * 3 // 2 + + frames = np.arange(num_frames * 4 * 4 * 3, dtype=np.uint8).reshape(num_frames, 4, 4, 3) + engine_client = FakeEngineClient(frames=frames) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=6), + "session_id": "s1", + "extra_body": {"num_frames": 6, "height": 16, "width": 16, "fps": 16}, + "image": np.zeros((16, 16, 3), dtype=np.uint8), + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # Drop the handshake; the remaining sent_bytes are frame chunks. + chunks = [msgpack_numpy.unpackb(b) for b in ws.sent_bytes[1:]] + assert [c["type"] for c in chunks] == ["frame", "frame"] + assert [c["index"] for c in chunks] == [0, 1] + assert {c["total"] for c in chunks} == {2} + + assert len(chunks[0]["video"]) == CHUNK_FRAMES + assert len(chunks[1]["video"]) == num_frames - CHUNK_FRAMES + + for chunk in chunks: + assert chunk["video"][0].shape == (4, 4, 3) + + +def test_session_id_churn_flips_current_session_id() -> None: + engine_client = FakeEngineClient() + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + base_obs = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=4), + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + obs_a = {**base_obs, "session_id": "session-a"} + obs_b = {**base_obs, "session_id": "session-b"} + + ws = MockWebSocket(incoming=[_bytes_frame(obs_a), _bytes_frame(obs_b)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + assert serving._current_session_id == "session-b" + assert len(engine_client.calls) == 2 + # Each engine call observes the active session id via extra_args. + seen_session_ids = [call["sampling_params_list"][0].extra_args["session_id"] for call in engine_client.calls] + assert seen_session_ids == ["session-a", "session-b"] + + +def test_engine_failure_surfaces_as_error_frame_not_close() -> None: + engine_client = FakeEngineClient() + engine_client.fail_with = RuntimeError("kaboom") + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "camera": make_dummy_camera_inputs(num_frames=4), + "session_id": "s1", + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + error = msgpack_numpy.unpackb(ws.sent_bytes[-1]) + assert error == {"type": "error", "message": "Internal inference error"} + + +# --------------------------------------------------------------------------- +# Required-field validation: a missing ``camera`` propagates to the pipeline +# layer's ValueError. At the serving layer we exercise this by giving the +# fake engine a side-effect that raises like the pipeline would, then assert +# the connection responds with an error frame and keeps running. +# --------------------------------------------------------------------------- + + +def test_missing_camera_surfaces_as_error_frame() -> None: + engine_client = FakeEngineClient() + # Pipeline's actual ValueError text — useful to keep this in sync. + engine_client.fail_with = ValueError("A path to camera positions must be passed to this model") + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + request = { + "prompt": "p", + "session_id": "s1", + "image": np.zeros((16, 16, 3), dtype=np.uint8), + "extra_body": {"num_frames": 4, "height": 16, "width": 16, "fps": 16}, + } + ws = MockWebSocket(incoming=[_bytes_frame(request)]) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + error = msgpack_numpy.unpackb(ws.sent_bytes[-1]) + assert error["type"] == "error" + + +# --------------------------------------------------------------------------- +# Dtype/rank guard: the wire codec preserves bit patterns, so a malformed +# camera entry (float64 vs float32, wrong rank) passes through the connection +# unchanged. The pipeline-layer assertions are exercised in the L2 offline +# test; here we just confirm the codec doesn't silently coerce. +# --------------------------------------------------------------------------- + + +def test_msgpack_does_not_silently_coerce_camera_dtypes() -> None: + payload = { + "intrinsics": np.eye(3, dtype=np.float64), # wrong dtype on purpose + "poses": np.tile(np.eye(4, dtype=np.float32), (2, 1, 1))[None], # extra leading dim + } + decoded = msgpack_numpy.unpackb(msgpack_numpy.packb(payload)) + assert decoded["intrinsics"].dtype == np.float64 + assert decoded["poses"].shape == (1, 2, 4, 4) + + +# --------------------------------------------------------------------------- +# Suppress event-loop teardown noise in some environments +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _silence_runtime_warnings(recwarn): # noqa: PT004 + yield + with contextlib.suppress(Exception): + recwarn.clear() diff --git a/tests/diffusion/models/lingbot_world_fast/test_schedule.py b/tests/diffusion/models/lingbot_world_fast/test_schedule.py new file mode 100644 index 00000000000..0adc1589c91 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_schedule.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 tests for the Lingbot World Fast scheduler.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from vllm_omni.diffusion.models.lingbot_world_fast.fm_solvers_unipc import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + LingbotWorldFastPipeline, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def _make_scheduler() -> FlowUniPCMultistepScheduler: + # Same construction as ``LingbotWorldFastPipeline.__init__``. + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=CONFIG["num_train_timesteps"], + shift=1, + use_dynamic_shifting=False, + ) + scheduler.set_timesteps(CONFIG["num_train_timesteps"], shift=CONFIG["sample_shift"]) + return scheduler + + +def test_timesteps_index_selects_exactly_four_steps() -> None: + scheduler = _make_scheduler() + selected = scheduler.timesteps[CONFIG["timesteps_index"]] + + assert selected.shape == (4,) + # Monotonically decreasing — flow matching schedulers walk t from high to low. + diffs = selected[1:] - selected[:-1] + assert torch.all(diffs < 0), f"timesteps must be strictly decreasing, got {selected.tolist()}" + + +def test_timesteps_full_schedule_length_matches_num_train_timesteps() -> None: + scheduler = _make_scheduler() + assert scheduler.num_inference_steps == CONFIG["num_train_timesteps"] + assert scheduler.timesteps.shape == (CONFIG["num_train_timesteps"],) + + +def _convert(flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor, scheduler) -> torch.Tensor: + # Bind ``_convert_flow_pred_to_x0`` as an unbound method to avoid + # constructing the full pipeline (which loads T5 / VAE). + return LingbotWorldFastPipeline._convert_flow_pred_to_x0( + None, # type: ignore[arg-type] + flow_pred=flow_pred, + xt=xt, + timestep=timestep, + scheduler=scheduler, + ) + + +def test_convert_flow_pred_to_x0_passthrough_when_pred_is_zero() -> None: + scheduler = _make_scheduler() + xt = torch.randn(1, 4, 1, 4, 4, dtype=torch.float32) + timestep = scheduler.timesteps[0] + flow_pred = torch.zeros_like(xt) + + x0 = _convert(flow_pred, xt, timestep, scheduler) + assert torch.allclose(x0, xt, atol=1e-6) + + +def test_convert_flow_pred_to_x0_recovers_x0_from_synthesized_pair() -> None: + scheduler = _make_scheduler() + timestep = scheduler.timesteps[CONFIG["timesteps_index"][1]] + + sigmas = scheduler.sigmas + timesteps = scheduler.timesteps + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].item() + + x0 = torch.randn(1, 4, 1, 4, 4, dtype=torch.float32) + noise = torch.randn_like(x0) + flow_pred = noise - x0 + xt = (1.0 - sigma_t) * x0 + sigma_t * noise + + recovered = _convert(flow_pred, xt, timestep, scheduler) + + # The function does the math in float64 internally, casts back to the + # input dtype. float32 inputs ⇒ ~1e-5 absolute tolerance is plenty. + assert torch.allclose(recovered, x0, atol=1e-4) + + +def test_timesteps_index_is_within_schedule_bounds() -> None: + """Defensive guard: an out-of-range index would silently wrap.""" + assert isinstance(CONFIG["timesteps_index"], list) + assert len(CONFIG["timesteps_index"]) == 4 + for idx in CONFIG["timesteps_index"]: + assert 0 <= idx < CONFIG["num_train_timesteps"] + + +def test_sample_shift_constant_is_positive() -> None: + """``sample_shift`` controls the timestep curve; a non-positive value + would corrupt the flow-matching trajectory.""" + assert CONFIG["sample_shift"] > 0 + # Reasonable upper bound — Wan models use shift ~5–10. + assert math.isfinite(CONFIG["sample_shift"]) + assert CONFIG["sample_shift"] <= 100 diff --git a/tests/diffusion/models/lingbot_world_fast/test_session_state.py b/tests/diffusion/models/lingbot_world_fast/test_session_state.py new file mode 100644 index 00000000000..e10298332b4 --- /dev/null +++ b/tests/diffusion/models/lingbot_world_fast/test_session_state.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L1 unit tests for ``LingbotWorldFastState``. + +The state container is the load-bearing structure for chunk-streamed +generation: it owns the KV cache, the cross-attention cache, the +``current_lat_f`` cursor used to derive ``current_start`` RoPE offsets, +and the session-id that decides between fresh vs extension semantics. +""" + +from __future__ import annotations + +import pytest +import torch + +from vllm_omni.diffusion.models.lingbot_world_fast.state_lingbot_world_fast import ( + LingbotWorldFastState, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +BATCH_SIZE = 1 +NUM_LAYERS = 3 +NUM_HEADS = 4 +HEAD_DIM = 8 +KV_SIZE = 16 +DTYPE = torch.float32 +DEVICE = torch.device("cpu") + + +def _fresh_state_with_caches(kv_size: int = KV_SIZE) -> LingbotWorldFastState: + state = LingbotWorldFastState() + state.create_kv_caches( + batch_size=BATCH_SIZE, + dtype=DTYPE, + device=DEVICE, + kv_size=kv_size, + num_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + head_dim=HEAD_DIM, + ) + return state + + +def test_reset_initializes_all_fields() -> None: + state = LingbotWorldFastState() + assert state.kv_cache is None + assert state.crossattn_cache is None + assert state.current_start_frame == 0 + assert state.local_end_index is None + assert state.global_end_index is None + assert state.is_initialized is False + assert state.current_lat_f == 0 + assert state.session_id is None + assert state.batch_size is None + assert state.num_layers is None + assert state.num_heads is None + assert state.head_dim is None + assert state.h is None + assert state.w is None + assert state.lat_h is None + assert state.lat_w is None + assert state.frame_seqlen is None + assert state.last_decoded_latent is None + + +def test_create_kv_caches_allocates_expected_shapes() -> None: + state = _fresh_state_with_caches() + + assert state.is_initialized is True + assert state.batch_size == BATCH_SIZE + assert state.num_layers == NUM_LAYERS + assert state.num_heads == NUM_HEADS + assert state.head_dim == HEAD_DIM + + assert state.kv_cache is not None + assert len(state.kv_cache) == NUM_LAYERS + for layer in state.kv_cache: + assert layer.shape == (2, BATCH_SIZE, KV_SIZE, NUM_HEADS, HEAD_DIM) + assert layer.dtype == DTYPE + assert torch.all(layer == 0) + + assert state.local_end_index is not None and state.global_end_index is not None + for idx_list in (state.local_end_index, state.global_end_index): + assert len(idx_list) == NUM_LAYERS + for idx in idx_list: + assert idx.shape == (1,) + assert idx.dtype == torch.long + assert int(idx.item()) == 0 + + assert state.crossattn_cache is not None + assert len(state.crossattn_cache) == NUM_LAYERS + for entry in state.crossattn_cache: + assert entry == {"is_init": False, "k": None, "v": None} + + +def test_extend_kv_caches_grows_tensor_and_zeros_new_slots() -> None: + state = _fresh_state_with_caches() + extra = 7 + # Mark the existing slots so we can confirm they aren't disturbed. + for layer in state.kv_cache: + layer.fill_(1.0) + + state.extend_kv_caches(extra_kv_size=extra) + + for layer in state.kv_cache: + assert layer.shape == (2, BATCH_SIZE, KV_SIZE + extra, NUM_HEADS, HEAD_DIM) + assert torch.all(layer[:, :, :KV_SIZE] == 1.0) + # Newly grown trailing slice is fresh zeros. + assert torch.all(layer[:, :, KV_SIZE:] == 0.0) + + +def test_extend_kv_caches_requires_initialization() -> None: + state = LingbotWorldFastState() + with pytest.raises(AssertionError): + state.extend_kv_caches(extra_kv_size=4) + + +def test_get_accessors_require_initialization() -> None: + state = LingbotWorldFastState() + with pytest.raises(AssertionError): + state.get_kv_caches() + with pytest.raises(AssertionError): + state.get_crossattn_caches() + + +def test_get_kv_caches_returns_underlying_list() -> None: + state = _fresh_state_with_caches() + assert state.get_kv_caches() is state.kv_cache + + +def test_advance_moves_cursor_by_delta() -> None: + state = LingbotWorldFastState() + state.advance(3) + assert state.current_lat_f == 3 + state.advance(5) + assert state.current_lat_f == 8 + + +def test_reset_clears_all_session_state() -> None: + state = _fresh_state_with_caches() + state.session_id = "abc" + state.advance(4) + state.h, state.w, state.lat_h, state.lat_w, state.frame_seqlen = 480, 832, 60, 104, 1560 + state.last_decoded_latent = torch.zeros(16, 2, 60, 104) + + state.reset() + + assert state.kv_cache is None + assert state.crossattn_cache is None + assert state.local_end_index is None + assert state.global_end_index is None + assert state.is_initialized is False + assert state.current_lat_f == 0 + assert state.current_start_frame == 0 + assert state.session_id is None + assert state.h is None and state.w is None + assert state.lat_h is None and state.lat_w is None + assert state.frame_seqlen is None + assert state.last_decoded_latent is None + assert state.batch_size is None + assert state.num_layers is None + + +def test_reset_is_idempotent() -> None: + state = LingbotWorldFastState() + state.reset() + state.reset() + assert state.is_initialized is False + assert state.current_lat_f == 0 + + +# --------------------------------------------------------------------------- +# Reset is triggered only by session-id change, not prompt change. +# +# Mirrors the conditional in ``LingbotWorldFastPipeline.forward`` (pipeline +# file, around the ``if self.state.session_id is None or +# self.state.session_id != session_id`` block). We assert the contract on +# the state container so the test does not depend on instantiating the +# heavy pipeline. +# --------------------------------------------------------------------------- + + +def _should_reset(state: LingbotWorldFastState, incoming_session_id: str) -> bool: + """Replicates the pipeline's reset trigger.""" + return state.session_id is None or state.session_id != incoming_session_id + + +def test_first_call_with_any_session_id_triggers_reset() -> None: + state = LingbotWorldFastState() + assert _should_reset(state, "session-a") is True + + +def test_same_session_id_does_not_reset() -> None: + state = _fresh_state_with_caches() + state.session_id = "session-a" + state.advance(4) + + assert _should_reset(state, "session-a") is False + # ... and a prompt-only change must not trigger a reset either. + assert _should_reset(state, "session-a") is False + # Pipeline would proceed in extension mode → state still alive. + assert state.is_initialized is True + assert state.current_lat_f == 4 + + +def test_different_session_id_triggers_reset() -> None: + state = _fresh_state_with_caches() + state.session_id = "session-a" + state.advance(4) + + assert _should_reset(state, "session-b") is True + + state.reset() + assert state.session_id is None + assert state.current_lat_f == 0 + assert state.kv_cache is None diff --git a/tests/diffusion/test_diffusion_micro_step_pipeline.py b/tests/diffusion/test_diffusion_micro_step_pipeline.py new file mode 100644 index 00000000000..ebaaa6afa1e --- /dev/null +++ b/tests/diffusion/test_diffusion_micro_step_pipeline.py @@ -0,0 +1,556 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for micro-step level diffusion execution across runner / worker / executor / engine.""" + +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest +import torch +from pytest_mock import MockerFixture + +import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor +from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, + DiffusionSchedulerOutput, + Layout, + NewRequestData, + RankTask, +) +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker +from vllm_omni.diffusion.worker.utils import RunnerOutput + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + del args, kwargs + yield + + +class _FakePPGroup: + def __init__(self, rank_in_group: int = 0, world_size: int = 1): + self.rank_in_group = rank_in_group + self.world_size = world_size + self.is_first_rank = rank_in_group == 0 + self.is_last_rank = rank_in_group == world_size - 1 + self.prev_rank = (rank_in_group - 1) % world_size + self.next_rank = (rank_in_group + 1) % world_size + self.group_prev_rank = (rank_in_group - 1) % world_size + self.group_next_rank = (rank_in_group + 1) % world_size + self.reset_calls = 0 + + def reset_buffer(self) -> None: + self.reset_calls += 1 + + +class _MicroStepPipeline: + supports_step_execution = True + supports_micro_step_execution = True + + def __init__(self, num_steps: int = 1): + self.num_steps = num_steps + self.prepare_calls = 0 + self.set_buffer_calls = 0 + self.denoise_calls = 0 + self.scheduler_calls = 0 + self.decode_calls = 0 + self.prefetch_calls = 0 + self.encode_calls = 0 + + def prepare_encode(self, state, **kwargs): + del kwargs + self.prepare_calls += 1 + state.timesteps = [torch.tensor(float(i)) for i in range(self.num_steps)] + state.latents = torch.zeros((1, 1, 1, 1, 1)) + state.scheduler = SimpleNamespace(timesteps=list(state.timesteps)) + return state + + def encode_chunk_inputs(self, state, new_idxs): + del state + self.encode_calls += 1 + return torch.zeros((len(new_idxs), 1, 1, 1, 1)) + + def set_pp_recv_dict_buffers(self, state, **kwargs): + del state, kwargs + self.set_buffer_calls += 1 + + def denoise_step(self, state, **kwargs): + del state, kwargs + self.denoise_calls += 1 + return torch.tensor([1.0]) + + def step_scheduler(self, state, noise_pred, **kwargs): + del noise_pred, kwargs + self.scheduler_calls += 1 + state.step_index += 1 + + def post_decode(self, state, **kwargs): + del kwargs + self.decode_calls += 1 + # One batched decode covers all rows on this rank; output keeps the + # per-row layout so _merge_chunk_outputs can stitch the temporal axis. + b = state.latents.shape[0] if state.latents.ndim > 0 else 1 + return DiffusionOutput(output=torch.ones(b, 1, 1, 1, 1, dtype=torch.float32)) + + def prefetch_tensors(self, state, **kwargs): + del state, kwargs + self.prefetch_calls += 1 + + +class _InterruptingMicroStepPipeline(_MicroStepPipeline): + interrupt = True + + def denoise_step(self, state, **kwargs): + del state, kwargs + self.denoise_calls += 1 + return None + + def step_scheduler(self, state, noise_pred, **kwargs): + del state, noise_pred, kwargs + raise AssertionError("step_scheduler should not run after interrupt") + + def post_decode(self, state, **kwargs): + del state, kwargs + raise AssertionError("post_decode should not run after interrupt") + + +def _make_micro_request( + req_id: str = "req-1", + *, + num_inference_steps: int = 1, + num_chunks: int = 1, + chunk_frames: int = 1, +): + return SimpleNamespace( + prompts=["a prompt"], + request_ids=[req_id], + sampling_params=SimpleNamespace( + generator=None, + seed=None, + generator_device=None, + num_inference_steps=num_inference_steps, + chunk_frames=chunk_frames, + num_chunks=num_chunks, + num_frames=num_chunks * chunk_frames, + slo_fps=None, + slo_max_batch=8, + lora_request=None, + ), + ) + + +def _make_runner(pp_size: int = 1, num_steps: int = 1): + runner = object.__new__(DiffusionModelRunner) + runner.vllm_config = object() + runner.od_config = SimpleNamespace( + cache_backend=None, + parallel_config=SimpleNamespace(use_hsdp=False), + ) + runner.device = torch.device("cpu") + runner.pipeline = _MicroStepPipeline(num_steps=num_steps) + runner.cache_backend = None + runner.offload_backend = None + runner.state_cache = {} + runner.kv_transfer_manager = SimpleNamespace() + runner._fake_pp_group = _FakePPGroup(world_size=pp_size) + return runner + + +def _make_layout( + *, + circulating_idxs: list[int] | None = None, + finished_idxs: list[int] | None = None, + new_idxs: list[int] | None = None, +) -> Layout: + return Layout( + circulating_idxs=circulating_idxs or [], + finished_idxs=finished_idxs or [], + new_idxs=new_idxs or [], + ) + + +def _make_micro_scheduler_output( + *, + req=None, + sched_req_id: str = "req-1", + step_id: int = 0, + chunk_indices: list[int] | None = None, + is_new: bool = True, + finished_req_ids=None, + layout: Layout | None = None, +): + if layout is None: + layout = _make_layout() + if chunk_indices is None: + chunk_indices = [0] + assignment = [RankTask(sched_req_id=sched_req_id, chunk_indices=chunk_indices, layout=layout)] + if is_new and req is not None: + new_reqs = [NewRequestData(sched_req_id=sched_req_id, req=req)] + cached_reqs = CachedRequestData.make_empty() + else: + new_reqs = [] + cached_reqs = CachedRequestData(sched_req_ids=[sched_req_id]) + return DiffusionSchedulerOutput( + step_id=step_id, + scheduled_new_reqs=new_reqs, + scheduled_cached_reqs=cached_reqs, + finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), + num_running_reqs=1, + num_waiting_reqs=0, + assignment=assignment, + ) + + +def _patch_runtime(monkeypatch, runner) -> None: + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + monkeypatch.setattr(model_runner_module, "get_pp_group", lambda: runner._fake_pp_group) + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + + +class TestRunner: + """DiffusionModelRunner.execute_micro_step (PP=1).""" + + def test_completes_single_chunk_request(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + out0 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), + ), + ) + assert out0.req_id == "req-1" + assert out0.finished is False + assert "req-1" in runner.state_cache + + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[0]), + ), + ) + assert out1.finished is True + assert out1.result is not None + assert out1.result.output is not None + assert "req-1" not in runner.state_cache + + assert runner.pipeline.prepare_calls == 1 + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 1 + assert runner.pipeline.decode_calls == 1 + + def test_completes_multi_chunk_request(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) + + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), + ), + ) + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + chunk_indices=[1], + is_new=False, + layout=_make_layout(finished_idxs=[0], new_idxs=[1]), + ), + ) + assert out1.finished is False + + out2 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=2, + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[1]), + ), + ) + assert out2.finished is True + assert out2.result is not None + assert "req-1" not in runner.state_cache + + assert runner.pipeline.prepare_calls == 1 + assert runner.pipeline.denoise_calls == 2 + assert runner.pipeline.decode_calls == 2 + + def test_re_admits_circulating_chunk(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=2) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=2, num_chunks=1) + + out0 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), + ), + ) + assert out0.finished is False + + out1 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + chunk_indices=[0], + is_new=False, + layout=_make_layout(circulating_idxs=[0]), + ), + ) + assert out1.finished is False + assert runner.pipeline.denoise_calls == 2 + + out2 = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=2, + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[0]), + ), + ) + assert out2.finished is True + assert runner.pipeline.decode_calls == 1 + + def test_empty_layout_is_a_no_op(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), + ), + ) + denoise_calls_before = runner.pipeline.denoise_calls + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + chunk_indices=[], is_new=False, + layout=_make_layout(), + ), + ) + assert out.req_id == "req-1" + assert out.finished is False + assert runner.pipeline.denoise_calls == denoise_calls_before + assert runner.pipeline.decode_calls == 0 + + def test_interrupt_marks_request_as_aborted(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + runner.pipeline = _InterruptingMicroStepPipeline(num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), + ), + ) + assert out.req_id == "req-1" + assert out.result is not None + assert out.result.error == "micro-step denoise interrupted" + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 0 + assert runner.pipeline.decode_calls == 0 + + def test_rejects_missing_assignment(self): + runner = _make_runner(pp_size=1) + req = _make_micro_request() + sched_output = _make_micro_scheduler_output(req=req) + sched_output.assignment = None + + with pytest.raises(ValueError, match="assignment"): + DiffusionModelRunner.execute_micro_step(runner, sched_output) + + def test_rejects_cache_backend(self): + runner = _make_runner(pp_size=1) + runner.od_config = SimpleNamespace( + cache_backend="teacache", + parallel_config=SimpleNamespace(use_hsdp=False), + ) + req = _make_micro_request() + + with pytest.raises(ValueError, match="cache_backend"): + DiffusionModelRunner.execute_micro_step(runner, _make_micro_scheduler_output(req=req)) + + def test_stamps_micro_step_wall_ns_on_rank0(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=1) + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0], + layout=_make_layout(new_idxs=[0]), + ), + ) + assert out.micro_step_wall_ns is not None + assert out.micro_step_wall_ns >= 0 + + def test_batch_two_runs_one_fused_forward(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) + + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0, 1], + layout=_make_layout(new_idxs=[0, 1]), + ), + ) + + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 1 + assert out.req_id == "req-1" + assert out.finished is False + assert out.micro_step_wall_ns is not None + + def test_batch_two_decodes_both_chunks_when_drain_completes(self, monkeypatch): + runner = _make_runner(pp_size=1, num_steps=1) + _patch_runtime(monkeypatch, runner) + req = _make_micro_request(num_inference_steps=1, num_chunks=2) + + DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + req=req, step_id=0, + chunk_indices=[0, 1], + layout=_make_layout(new_idxs=[0, 1]), + ), + ) + out = DiffusionModelRunner.execute_micro_step( + runner, + _make_micro_scheduler_output( + sched_req_id="req-1", step_id=1, + chunk_indices=[], is_new=False, + layout=_make_layout(finished_idxs=[0, 1]), + ), + ) + + assert out.finished is True + assert runner.pipeline.decode_calls == 1 + assert "req-1" not in runner.state_cache + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + + +class TestWorker: + """DiffusionWorker.execute_micro_step""" + + def test_delegates_to_model_runner(self): + worker = object.__new__(DiffusionWorker) + expected = RunnerOutput(req_id="req-1") + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None))) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace( + execute_micro_step=lambda arg: expected if arg is scheduler_output else None + ) + worker._get_profiler = lambda: None + + output = DiffusionWorker.execute_micro_step(worker, scheduler_output) + assert output is expected + + def test_clears_active_lora(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=None))) + ] + ) + calls: list = [] + + class _FakeLoRAManager: + def set_active_adapter(self, adapter): + calls.append(adapter) + + worker.lora_manager = _FakeLoRAManager() + worker.model_runner = SimpleNamespace(execute_micro_step=lambda _: RunnerOutput(req_id="req-1")) + worker._get_profiler = lambda: None + + DiffusionWorker.execute_micro_step(worker, scheduler_output) + assert calls == [None] + + def test_rejects_lora_requests(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace(req=SimpleNamespace(sampling_params=SimpleNamespace(lora_request=object()))) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace(execute_micro_step=lambda _: RunnerOutput(req_id="req-1")) + worker._get_profiler = lambda: None + + with pytest.raises(ValueError, match="does not support LoRA"): + DiffusionWorker.execute_micro_step(worker, scheduler_output) + + +# --------------------------------------------------------------------------- +# Executor +# --------------------------------------------------------------------------- + +class TestExecutor: + """MultiprocDiffusionExecutor.execute_micro_step collects rank-0's reply.""" + + def test_passes_through_runner_output(self, mocker: MockerFixture): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._ensure_open = lambda: None + expected = RunnerOutput(req_id="req-1", finished=True) + rpc = mocker.Mock(return_value=expected) + executor.collective_rpc = rpc + + sched_output = _make_micro_scheduler_output(req=_make_micro_request()) + output = MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) + + assert output is expected + _, kwargs = rpc.call_args + assert kwargs.get("unique_reply_rank") == 0 + assert kwargs.get("exec_all_ranks") is True + + def test_rejects_unexpected_reply_type(self, mocker: MockerFixture): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._ensure_open = lambda: None + executor.collective_rpc = mocker.Mock(return_value="not a runner output") + + sched_output = _make_micro_scheduler_output(req=_make_micro_request()) + with pytest.raises(RuntimeError, match="Unexpected response type"): + MultiprocDiffusionExecutor.execute_micro_step(executor, sched_output) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index ca28b26294e..52dbe7bb560 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -19,6 +19,7 @@ Scheduler, SchedulerInterface, StepScheduler, + StreamBatchScheduler, ) from vllm_omni.diffusion.sched.interface import CachedRequestData, NewRequestData from vllm_omni.diffusion.worker.utils import RunnerOutput @@ -965,3 +966,352 @@ def test_rejects_invalid_initial_step_state(self, sampling_params: OmniDiffusion with pytest.raises(ValueError): self.scheduler.add_request(request) + + +def _make_stream_request( + req_id: str, + *, + num_inference_steps: int = 2, + num_chunks: int = 1, + chunk_frames: int = 1, +) -> OmniDiffusionRequest: + num_frames = num_chunks * chunk_frames + video = [torch.zeros(3, 8, 8) for _ in range(num_frames)] + return OmniDiffusionRequest( + prompts=[{ + "prompt": f"prompt_{req_id}", + "multi_modal_data": {"video": video}, + }], + sampling_params=OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + chunk_frames=chunk_frames, + num_chunks=num_chunks, + num_frames=num_frames, + ), + request_ids=[req_id], + ) + + +def _make_stream_output( + req_id: str, + *, + finished: bool = False, + error: str | None = None, + micro_step_wall_ns: int | None = None, +): + return SimpleNamespace( + req_id=req_id, + step_index=None, + finished=finished, + result=DiffusionOutput(output=None, error=error) if error is not None else None, + micro_step_wall_ns=micro_step_wall_ns, + ) + + +def _make_od_config(pp_size: int) -> SimpleNamespace: + return SimpleNamespace(parallel_config=SimpleNamespace(pipeline_parallel_size=pp_size)) + + +def _ranks(sched_output) -> list[tuple[str, list[int]] | None]: + if sched_output.assignment is None: + return [] + return [ + (t.sched_req_id, list(t.chunk_indices)) if t.chunk_indices else None + for t in sched_output.assignment + ] + + +def _layout(sched_output, sched_req_id: str) -> tuple[int, int, int] | None: + if sched_output.assignment is None: + return None + for task in sched_output.assignment: + if task.sched_req_id == sched_req_id: + layout = task.layout + return (len(layout.finished_idxs), len(layout.circulating_idxs), len(layout.new_idxs)) + return None + + +class TestStreamBatchScheduler: + def _make_scheduler(self, pp_size: int = 2) -> StreamBatchScheduler: + sched = StreamBatchScheduler() + sched.initialize(_make_od_config(pp_size)) + return sched + + def test_add_request_rejects_invalid_num_chunks(self) -> None: + scheduler = self._make_scheduler() + with pytest.raises(ValueError): + scheduler.add_request(_make_stream_request("bad-chunks", num_chunks=0)) + + def test_add_request_rejects_invalid_num_inference_steps(self) -> None: + scheduler = self._make_scheduler() + with pytest.raises(ValueError): + scheduler.add_request(_make_stream_request("bad-steps", num_inference_steps=0)) + + def test_pp1_single_chunk_single_step(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) + + out0 = scheduler.schedule() + assert _new_ids(out0) == [req_id] + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + assert scheduler.update_from_output(out0, _make_stream_output(req_id)) == set() + + out1 = scheduler.schedule() + assert _ranks(out1) == [None] + assert _layout(out1, req_id) == (1, 0, 0) + finished = scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) + assert finished == {req_id} + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert scheduler.has_requests() is False + + def test_pp1_single_chunk_multi_step_re_admits_same_chunk(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("multi", num_inference_steps=3, num_chunks=1)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [0])] + assert _layout(out1, req_id) == (0, 1, 0) + scheduler.update_from_output(out1, _make_stream_output(req_id)) + + out2 = scheduler.schedule() + assert _ranks(out2) == [(req_id, [0])] + assert _layout(out2, req_id) == (0, 1, 0) + scheduler.update_from_output(out2, _make_stream_output(req_id)) + + out3 = scheduler.schedule() + assert _ranks(out3) == [None] + assert _layout(out3, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out3, _make_stream_output(req_id, finished=True)) == {req_id} + + def test_pp1_multi_chunk_admits_in_order(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("multi", num_inference_steps=1, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [1])] + assert _layout(out1, req_id) == (1, 0, 1) + scheduler.update_from_output(out1, _make_stream_output(req_id)) + + out2 = scheduler.schedule() + assert _ranks(out2) == [None] + assert _layout(out2, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out2, _make_stream_output(req_id, finished=True)) == {req_id} + + def test_pp2_pipelined_chunks_advance_through_ranks(self) -> None: + scheduler = self._make_scheduler(pp_size=2) + req_id = scheduler.add_request(_make_stream_request("pp2", num_inference_steps=1, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0]), None] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [1]), (req_id, [0])] + assert _layout(out1, req_id) == (0, 0, 1) + scheduler.update_from_output(out1, _make_stream_output(req_id)) + + out2 = scheduler.schedule() + assert _ranks(out2) == [None, (req_id, [1])] + assert _layout(out2, req_id) == (1, 0, 0) + scheduler.update_from_output(out2, _make_stream_output(req_id)) + + out3 = scheduler.schedule() + assert _ranks(out3) == [None, None] + assert _layout(out3, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out3, _make_stream_output(req_id, finished=True)) == {req_id} + + def test_pp3_three_chunks_two_steps_each(self) -> None: + scheduler = self._make_scheduler(pp_size=3) + req_id = scheduler.add_request(_make_stream_request("pp3", num_inference_steps=2, num_chunks=3)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0]), None, None] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [1]), (req_id, [0]), None] + assert _layout(out1, req_id) == (0, 0, 1) + scheduler.update_from_output(out1, _make_stream_output(req_id)) + + out2 = scheduler.schedule() + assert _ranks(out2) == [(req_id, [2]), (req_id, [1]), (req_id, [0])] + assert _layout(out2, req_id) == (0, 0, 1) + scheduler.update_from_output(out2, _make_stream_output(req_id)) + + out3 = scheduler.schedule() + assert _ranks(out3) == [(req_id, [0]), (req_id, [2]), (req_id, [1])] + assert _layout(out3, req_id) == (0, 1, 0) + scheduler.update_from_output(out3, _make_stream_output(req_id)) + + out4 = scheduler.schedule() + assert _ranks(out4) == [(req_id, [1]), (req_id, [0]), (req_id, [2])] + assert _layout(out4, req_id) == (0, 1, 0) + scheduler.update_from_output(out4, _make_stream_output(req_id)) + + out5 = scheduler.schedule() + assert _ranks(out5) == [(req_id, [2]), (req_id, [1]), (req_id, [0])] + assert _layout(out5, req_id) == (0, 1, 0) + scheduler.update_from_output(out5, _make_stream_output(req_id)) + + out6 = scheduler.schedule() + assert _ranks(out6) == [None, (req_id, [2]), (req_id, [1])] + assert _layout(out6, req_id) == (1, 0, 0) + scheduler.update_from_output(out6, _make_stream_output(req_id)) + + out7 = scheduler.schedule() + assert _ranks(out7) == [None, None, (req_id, [2])] + assert _layout(out7, req_id) == (1, 0, 0) + scheduler.update_from_output(out7, _make_stream_output(req_id)) + + out8 = scheduler.schedule() + assert _ranks(out8) == [None, None, None] + assert _layout(out8, req_id) == (1, 0, 0) + assert scheduler.update_from_output(out8, _make_stream_output(req_id, finished=True)) == {req_id} + assert scheduler.has_requests() is False + + def test_returning_chunk_leads_fresh_admits_in_fifo(self) -> None: + # Re-admit prepends; new admits append. Order: [returning..., new...]. + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("prio", num_inference_steps=2, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0])] + assert _layout(out0, req_id) == (0, 0, 1) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [0, 1])] + assert _layout(out1, req_id) == (0, 1, 1) + + def test_chunk_progress_cleared_after_request_finishes(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("cleanup", num_inference_steps=1, num_chunks=1)) + + out0 = scheduler.schedule() + scheduler.update_from_output(out0, _make_stream_output(req_id)) + out1 = scheduler.schedule() + scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) + + scheduler.pop_request_state(req_id) + assert req_id not in scheduler._progress + assert scheduler.has_requests() is False + + def test_schedule_with_no_requests_emits_no_assignment(self) -> None: + scheduler = self._make_scheduler(pp_size=2) + out = scheduler.schedule() + assert out.assignment is None + assert out.scheduled_req_ids == [] + + def test_fifo_two_requests(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_a = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) + req_b = scheduler.add_request(_make_stream_request("b", num_inference_steps=1, num_chunks=1)) + + out0 = scheduler.schedule() + assert _new_ids(out0) == [req_a] + assert _ranks(out0) == [(req_a, [0])] + scheduler.update_from_output(out0, _make_stream_output(req_a)) + + out1 = scheduler.schedule() + assert _new_ids(out1) == [] + scheduler.update_from_output(out1, _make_stream_output(req_a, finished=True)) + scheduler.pop_request_state(req_a) + + out2 = scheduler.schedule() + assert _new_ids(out2) == [req_b] + assert _ranks(out2) == [(req_b, [0])] + + def test_has_requests_state_transition(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + assert scheduler.has_requests() is False + + req_id = scheduler.add_request(_make_stream_request("has", num_inference_steps=1, num_chunks=1)) + assert scheduler.has_requests() is True + + out0 = scheduler.schedule() + assert scheduler.has_requests() is True + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert scheduler.update_from_output(out1, _make_stream_output(req_id, finished=True)) == {req_id} + assert scheduler.has_requests() is False + + def test_abort_waiting_and_running_requests(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_a = scheduler.add_request(_make_stream_request("a", num_inference_steps=1, num_chunks=1)) + req_b = scheduler.add_request(_make_stream_request("b", num_inference_steps=1, num_chunks=1)) + + scheduler.finish_requests(req_b, DiffusionRequestStatus.FINISHED_ABORTED) + assert scheduler.get_request_state(req_b).status == DiffusionRequestStatus.FINISHED_ABORTED + + out = scheduler.schedule() + assert _new_ids(out) == [req_a] + + scheduler.finish_requests(req_a, DiffusionRequestStatus.FINISHED_ABORTED) + assert scheduler.get_request_state(req_a).status == DiffusionRequestStatus.FINISHED_ABORTED + assert scheduler.has_requests() is False + + def test_error_output_marks_finished_error(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("err", num_inference_steps=2, num_chunks=1)) + + out = scheduler.schedule() + finished = scheduler.update_from_output(out, _make_stream_output(req_id, error="worker failed")) + + assert finished == {req_id} + state = scheduler.get_request_state(req_id) + assert state.status == DiffusionRequestStatus.FINISHED_ERROR + assert state.error == "worker failed" + assert scheduler.has_requests() is False + + def test_preempt_request_preserves_chunk_progress(self) -> None: + scheduler = self._make_scheduler(pp_size=2) + req_id = scheduler.add_request(_make_stream_request("preempt", num_inference_steps=2, num_chunks=2)) + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0]), None] + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [1]), (req_id, [0])] + scheduler.update_from_output(out1, _make_stream_output(req_id)) + + before = scheduler._progress[req_id] + assert before.next_chunk_idx == 2 + snapshot = [[c.chunk_idx for c in q] for q in before.chunks_at] + assert snapshot == [[1], [0]] + + assert scheduler.preempt_request(req_id) is True + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.PREEMPTED + + after = scheduler._progress[req_id] + assert after.next_chunk_idx == 2 + assert [[c.chunk_idx for c in q] for q in after.chunks_at] == snapshot + + def test_b_admission(self) -> None: + scheduler = self._make_scheduler(pp_size=1) + req_id = scheduler.add_request(_make_stream_request("b", num_inference_steps=2, num_chunks=4)) + scheduler._slo.register(req_id, slo_fps=30.0, max_batch=4, chunk_frames=1) + scheduler._slo._reqs[req_id].batch_size = 2 + + out0 = scheduler.schedule() + assert _ranks(out0) == [(req_id, [0, 1])] + assert _layout(out0, req_id) == (0, 0, 2) + scheduler.update_from_output(out0, _make_stream_output(req_id)) + + out1 = scheduler.schedule() + assert _ranks(out1) == [(req_id, [0, 1, 2, 3])] + assert _layout(out1, req_id) == (0, 2, 2) diff --git a/tests/e2e/offline_inference/test_lingbot_world_fast.py b/tests/e2e/offline_inference/test_lingbot_world_fast.py new file mode 100644 index 00000000000..f80576d8bb9 --- /dev/null +++ b/tests/e2e/offline_inference/test_lingbot_world_fast.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L2 offline smoke for the Lingbot World Fast pipeline.""" + +from __future__ import annotations + +import pytest +import torch + +from tests.diffusion.models.lingbot_world_fast.conftest import ( + make_dummy_camera_inputs, + make_dummy_image, + make_stubbed_pipeline, +) +from tests.helpers.mark import hardware_test +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.models.lingbot_world_fast.pipeline_lingbot_world_fast import ( + CONFIG, + get_lingbot_world_fast_post_process_func, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +# ``torch.amp.autocast("cuda", …)`` inside the pipeline requires CUDA at import +# time on hosts where PyTorch is compiled without CUDA support. +if not torch.cuda.is_available(): + pytest.skip( + 'Lingbot World Fast pipeline requires CUDA (torch.amp.autocast("cuda", …))', + allow_module_level=True, + ) + + +# Keep the spatial resolution tiny so the KV-cache stays small (frame_seqlen +# is derived from ``lat_h * lat_w``); the stub pipeline still exercises every +# size-related code path with full fidelity. +_TINY_MAX_AREA = 64 * 64 + +# Default ``num_frames`` argument; the pipeline floors to ``25`` internally on +# a fresh call (smallest length that maps to a non-empty latent). +_FRESH_NUM_FRAMES = 25 +_EXTENSION_NUM_FRAMES = 24 + +_DIM = 16 +_NUM_HEADS = 4 +_NUM_LAYERS = 2 +_HEAD_DIM = _DIM // _NUM_HEADS + + +def _build_request( + *, + image, + camera, + session_id: str, + num_frames: int, + prompt: str = "walk forward", +) -> OmniDiffusionRequest: + multi_modal_data: dict = {"camera": camera} + if image is not None: + multi_modal_data["image"] = image + return OmniDiffusionRequest( + prompts=[{"prompt": prompt, "multi_modal_data": multi_modal_data}], + sampling_params=OmniDiffusionSamplingParams( + height=None, + width=None, + num_frames=num_frames, + seed=42, + extra_args={"session_id": session_id}, + ), + request_ids=[f"req-{session_id}"], + ) + + +@pytest.fixture +def stubbed_pipeline(monkeypatch): + """Build a stub-backed pipeline and shrink CONFIG['max_area'] for speed.""" + monkeypatch.setitem(CONFIG, "max_area", _TINY_MAX_AREA) + pipeline = make_stubbed_pipeline( + dim=_DIM, + num_heads=_NUM_HEADS, + num_layers=_NUM_LAYERS, + target_dtype=torch.float32, + ) + yield pipeline + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_session_lifecycle_fresh_then_extension(stubbed_pipeline) -> None: + """Drive a fresh + extension pair through the pipeline and assert that + ``LingbotWorldFastState`` advances exactly as the chunk arithmetic prescribes.""" + pipeline = stubbed_pipeline + session_id = "session-l2-offline" + + # --- Fresh call --------------------------------------------------------- + camera_fresh = make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES) + image = make_dummy_image() + req_fresh = _build_request( + image=image, + camera=camera_fresh, + session_id=session_id, + num_frames=_FRESH_NUM_FRAMES, + ) + + out_fresh = pipeline.forward(req_fresh) + + assert isinstance(out_fresh, DiffusionOutput) + assert out_fresh.output is not None + assert torch.isfinite(out_fresh.output).all(), "Fresh-call video contains NaN/Inf." + + state = pipeline.state + assert state.is_initialized is True + assert state.session_id == session_id + assert state.current_lat_f > 0 + assert state.kv_cache is not None + assert state.crossattn_cache is not None + assert state.last_decoded_latent is not None + + fresh_lat_f = state.current_lat_f + fresh_kv_size = state.kv_cache[0].shape[2] + frame_seqlen = state.frame_seqlen + assert frame_seqlen == state.lat_h * state.lat_w // 4 + assert fresh_kv_size == frame_seqlen * fresh_lat_f + + # The spatial dims must come from the input image on the fresh call so the + # extension branch can later reuse them — make sure they were captured. + assert state.h is not None and state.w is not None + assert state.lat_h is not None and state.lat_w is not None + + # --- Extension call ----------------------------------------------------- + camera_ext = make_dummy_camera_inputs(num_frames=_EXTENSION_NUM_FRAMES) + req_ext = _build_request( + image=None, + camera=camera_ext, + session_id=session_id, + num_frames=_EXTENSION_NUM_FRAMES, + ) + + out_ext = pipeline.forward(req_ext) + + assert isinstance(out_ext, DiffusionOutput) + assert out_ext.output is not None + assert torch.isfinite(out_ext.output).all(), "Extension-call video contains NaN/Inf." + + assert state.session_id == session_id, "Same session_id must not trigger a reset." + assert state.is_initialized is True + assert state.current_lat_f > fresh_lat_f, "current_lat_f must advance on extension." + ext_lat_f = state.current_lat_f - fresh_lat_f + assert ext_lat_f > 0 + + # ``extend_kv_caches`` allocates a fresh tensor of size old + frame_seqlen * + # new_lat_f and concatenates; assert the trailing slice grew by exactly + # ``frame_seqlen * ext_lat_f`` for every layer. + for layer_idx, layer in enumerate(state.kv_cache): + assert layer.shape == ( + 2, + 1, + fresh_kv_size + frame_seqlen * ext_lat_f, + _NUM_HEADS, + _HEAD_DIM, + ), f"layer {layer_idx} KV cache did not grow by exactly frame_seqlen * ext_lat_f" + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_different_session_id_resets_state(stubbed_pipeline) -> None: + """A new ``session_id`` must hard-reset the cached spatial dims and KV cache""" + pipeline = stubbed_pipeline + image = make_dummy_image() + + req_a = _build_request( + image=image, + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-a", + num_frames=_FRESH_NUM_FRAMES, + ) + pipeline.forward(req_a) + + assert pipeline.state.session_id == "session-a" + lat_f_after_a = pipeline.state.current_lat_f + assert lat_f_after_a > 0 + + # A different session_id on the next call must drop the prior KV cache — + # ``current_lat_f`` resets to ``new_lat_f`` of the second call, not to + # ``lat_f_after_a + new_lat_f``. + req_b = _build_request( + image=make_dummy_image(), + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-b", + num_frames=_FRESH_NUM_FRAMES, + ) + out_b = pipeline.forward(req_b) + assert torch.isfinite(out_b.output).all() + + assert pipeline.state.session_id == "session-b" + assert pipeline.state.current_lat_f == lat_f_after_a, ( + "Stub fresh-call advances by the same fresh new_lat_f, so the reset must " + "have wiped the prior cumulative count rather than added to it." + ) + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 1}) +def test_post_process_shapes_videos_for_external_output(stubbed_pipeline) -> None: + """The model-specific post-process flips ``[C, F, H, W]`` to ``[F, H, W, C]``; + that's what diffusion engine + serving code downstream expects.""" + pipeline = stubbed_pipeline + req = _build_request( + image=make_dummy_image(), + camera=make_dummy_camera_inputs(num_frames=_FRESH_NUM_FRAMES), + session_id="session-postprocess", + num_frames=_FRESH_NUM_FRAMES, + ) + + out = pipeline.forward(req) + post = get_lingbot_world_fast_post_process_func(pipeline.od_config) + framed = post(out.output) + + # [C, F, H, W] → [F, H, W, C] + assert framed.ndim == 4 + assert framed.shape[-1] == out.output.shape[0] + assert framed.shape[0] == out.output.shape[1] diff --git a/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py b/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py new file mode 100644 index 00000000000..98a932f5680 --- /dev/null +++ b/tests/e2e/offline_inference/test_lingbot_world_fast_expansion.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L3 real-weight offline expansion for Lingbot World Fast.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import PIL.Image +import pytest +import torch + +from tests.helpers.lingbot_world_fast import ( + FPS, + GREAT_WALL_PROMPT, + HEIGHT, + LONG_NUM_FRAMES, + SEED, + SHORT_NUM_FRAMES, + SSIM_THRESHOLD, + WIDTH, + find_lingbot_world_fast_assets, + frame_ssim, + golden_frames_dir, + load_camera_trajectory, + normalize_to_uint8_rgb, +) +from tests.helpers.mark import hardware_test +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +pytestmark = [ + pytest.mark.advanced_model, + pytest.mark.core_model, + pytest.mark.diffusion, +] + + +def _extract_frames_from_output(output: Any) -> np.ndarray: + """Pull a ``[N, H, W, 3]`` numpy array out of an ``OmniRequestOutput``.""" + if isinstance(output, list) and output: + output = output[0] + if isinstance(output, OmniRequestOutput): + if output.is_pipeline_output and output.request_output is not None: + inner = output.request_output + if isinstance(inner, OmniRequestOutput): + output = inner + if isinstance(output, OmniRequestOutput) and output.images: + entry = output.images[0] + if isinstance(entry, tuple) and len(entry) >= 1: + output = entry[0] + elif isinstance(entry, dict): + output = entry.get("frames") or entry.get("video") + else: + output = entry + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + if not isinstance(output, np.ndarray): + raise AssertionError(f"Could not extract frames from output: {type(output)}") + return normalize_to_uint8_rgb(output) + + +@pytest.fixture(scope="module") +def lingbot_world_fast_assets(): + assets = find_lingbot_world_fast_assets() + if assets is None: + pytest.skip( + "Lingbot-World-Fast L3 assets not available. Set LINGBOT_WORLD_FAST_PATH " + "(model dir) + LINGBOT_WORLD_FAST_CAMERA_PATH (poses.npy/intrinsics.npy) " + "+ LINGBOT_WORLD_FAST_IMAGE (input image) to enable.", + ) + return assets + + +@pytest.fixture(scope="module") +def lingbot_world_fast_omni(lingbot_world_fast_assets): + omni = Omni( + model=str(lingbot_world_fast_assets.weights_path), + parallel_config=None, + model_class_name="LingbotWorldFastPipeline", + stage_init_timeout=6000, + init_timeout=6000, + ) + try: + yield omni + finally: + if hasattr(omni, "close"): + omni.close() + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +@pytest.mark.parametrize("num_frames, length", [(SHORT_NUM_FRAMES, "short"), (LONG_NUM_FRAMES, "long")]) +def test_lingbot_world_offline_video( + num_frames, + length, + lingbot_world_fast_assets, + lingbot_world_fast_omni, +): + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + poses = poses[:num_frames] + intrinsics = intrinsics[:num_frames] + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(SEED) + sampling = OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + generator=generator, + num_frames=num_frames, + frame_rate=FPS, + extra_args={"session_id": f"SESSION_ID-{length}"}, + ) + + multi_modal_data: dict = {"image": image, "camera": {"poses": poses, "intrinsics": intrinsics}} + + prompt = { + "prompt": GREAT_WALL_PROMPT, + "negative_prompt": "", + "multi_modal_data": multi_modal_data, + } + + output = lingbot_world_fast_omni.generate(prompt, sampling) + + video = _extract_frames_from_output(output) + + first_frame = video[0] + last_frame = video[-1] + + first_path = golden_frames_dir() / f"golden_frame_{length}_first.npy" + last_path = golden_frames_dir() / f"golden_frame_{length}_last.npy" + + first_golden = np.load(first_path) + last_golden = np.load(last_path) + + ssim_first = frame_ssim(first_frame, first_golden) + ssim_last = frame_ssim(last_frame, last_golden) + print( + f"[lingbot-world-fast L3] SSIM(first)={ssim_first:.4f} SSIM(last)={ssim_last:.4f} (threshold {SSIM_THRESHOLD})" + ) + assert ssim_first >= SSIM_THRESHOLD, ( + f"First-frame SSIM {ssim_first:.4f} below {SSIM_THRESHOLD}: regression in first-call path." + ) + assert ssim_last >= SSIM_THRESHOLD, ( + f"Last-frame SSIM {ssim_last:.4f} below {SSIM_THRESHOLD}: regression in last-call path." + ) diff --git a/tests/e2e/online_serving/test_lingbot_world_fast.py b/tests/e2e/online_serving/test_lingbot_world_fast.py new file mode 100644 index 00000000000..a51cb2f542f --- /dev/null +++ b/tests/e2e/online_serving/test_lingbot_world_fast.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L2 online smoke for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Iterable +from typing import Any + +import numpy as np +import pytest + +from vllm_omni.entrypoints.openai.realtime.world.camera_connection import ( + CHUNK_FRAMES, + WorldCameraRealtimeConnection, +) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera + +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# Test plumbing +# --------------------------------------------------------------------------- + + +class _MockWebSocket: + """ASGI-shaped mock matching ``WorldCameraRealtimeConnection``'s call sites. + + ``receive`` is the lowest-level ASGI hook the connection uses (it pulls + ``{"type": "websocket.receive", "bytes": ...}`` dicts directly, not the + higher-level ``receive_bytes``). After the scripted messages are + exhausted, ``receive`` returns a disconnect frame so the connection's + main loop exits cleanly instead of timing out. + """ + + def __init__(self, incoming: Iterable[dict[str, Any]] | None = None) -> None: + self._incoming: list[dict[str, Any]] = list(incoming or []) + self._idx = 0 + self.sent_bytes: list[bytes] = [] + self.sent_text: list[str] = [] + self.accepted = False + self.closed = False + + async def accept(self) -> None: + self.accepted = True + + async def receive(self) -> dict[str, Any]: + if self._idx >= len(self._incoming): + return {"type": "websocket.disconnect"} + msg = self._incoming[self._idx] + self._idx += 1 + return msg + + async def send_bytes(self, data: bytes) -> None: + self.sent_bytes.append(data) + + async def send_text(self, data: str) -> None: + self.sent_text.append(data) + + async def close(self) -> None: + self.closed = True + + +class _FakeAsyncIter: + """Async iterable for the canned engine output.""" + + def __init__(self, items: list[Any]) -> None: + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class _FakeResult: + """Stand-in for ``OmniRequestOutput`` — only ``.images`` is consulted.""" + + def __init__(self, frames: np.ndarray) -> None: + self.images = [frames] + + +class _FakeEngineClient: + """Stand-in for ``AsyncOmni``: records calls, returns a per-call frame buffer. + + The connection's framing logic calls ``generate(...)`` once per ``infer`` + request. Tests pre-load ``self.queued_frames`` with one buffer per + expected call, in order. + """ + + def __init__(self, queued_frames: list[np.ndarray]) -> None: + self.queued_frames = list(queued_frames) + self.calls: list[dict[str, Any]] = [] + # Attributes consulted by ``CameraServerConfig.from_model_config``. + self.model_config = {"pipeline": "lingbot_world_fast", "resolution": [480, 832], "fps": 16} + + def generate(self, *, prompt, request_id, sampling_params_list): + self.calls.append( + { + "prompt": prompt, + "request_id": request_id, + "session_id": sampling_params_list[0].extra_args.get("session_id"), + } + ) + if not self.queued_frames: + raise AssertionError("FakeEngineClient ran out of queued frames") + frames = self.queued_frames.pop(0) + return _FakeAsyncIter([_FakeResult(frames)]) + + +def _pack_frame(payload: Any) -> dict[str, Any]: + return {"type": "websocket.receive", "bytes": msgpack_numpy.packb(payload)} + + +def _camera_payload(num_frames: int) -> dict[str, np.ndarray]: + return { + "intrinsics": np.eye(3, dtype=np.float32), + "poses": np.tile(np.eye(4, dtype=np.float32), (num_frames, 1, 1)), + } + + +def _infer_req(*, session_id: str, num_frames: int, include_image: bool) -> dict[str, Any]: + req: dict[str, Any] = { + "prompt": "walk along the Great Wall of China", + "camera": _camera_payload(num_frames), + "session_id": session_id, + "extra_body": {"num_frames": num_frames, "height": 480, "width": 832, "fps": 16}, + } + if include_image: + req["image"] = np.zeros((480, 832, 3), dtype=np.uint8) + return req + + +# --------------------------------------------------------------------------- +# Lifecycle smoke test +# --------------------------------------------------------------------------- + + +def test_camera_session_lifecycle_handshake_infer_reset_infer() -> None: + """End-to-end client session: handshake → infer → reset → infer (new + session_id). Mirrors what ``examples/online_serving/lingbot_world_fast/openai_client.py`` + does on the wire, minus the actual model.""" + # Distinct, non-divisible-by-CHUNK_FRAMES buffer sizes so both calls + # exercise the boundary case (final chunk shorter than CHUNK_FRAMES) and + # the fill-value lets us prove chunks aren't swapped between requests. + first_frames = np.full((CHUNK_FRAMES * 2 + 1, 8, 8, 3), 3, dtype=np.uint8) + second_frames = np.full((CHUNK_FRAMES + 1, 8, 8, 3), 7, dtype=np.uint8) + + engine_client = _FakeEngineClient(queued_frames=[first_frames, second_frames]) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + # First infer uses ``4N+1`` to mimic the openai_client's fresh-call shape; + # second infer uses ``4N`` for the extension shape (modelled on the + # client's branch at ``openai_client.py:84``). + fresh_req = _infer_req(session_id="session-1", num_frames=25, include_image=True) + ext_req = _infer_req(session_id="session-2", num_frames=24, include_image=False) + + ws = _MockWebSocket( + incoming=[ + _pack_frame(fresh_req), + _pack_frame({"endpoint": "reset"}), + _pack_frame(ext_req), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # ---------------- Handshake ------------------ + assert ws.accepted is True, "Connection must accept() before sending the handshake." + assert len(ws.sent_bytes) >= 1 + handshake = msgpack_numpy.unpackb(ws.sent_bytes[0]) + assert isinstance(handshake, dict) and handshake, "Handshake must be a non-empty msgpack dict." + assert handshake.get("pipeline") == "lingbot_world_fast" + + # ---------------- Frame chunks --------------- + decoded = [msgpack_numpy.unpackb(b) for b in ws.sent_bytes[1:]] + frame_chunks = [d for d in decoded if isinstance(d, dict) and d.get("type") == "frame"] + + first_total = (len(first_frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES + second_total = (len(second_frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES + + assert len(frame_chunks) == first_total + second_total, ( + f"Expected {first_total} + {second_total} frame chunks, got {len(frame_chunks)}." + ) + + for chunk in frame_chunks: + assert chunk.keys() >= {"type", "index", "total", "video"} + video = chunk["video"] + for frame in video: + assert frame.dtype == np.float32 + assert (frame >= 0).all() and (frame <= 1).all() + assert frame.ndim == 3 # [h, w, 3] + assert frame.shape[-1] == 3 + + # Send order is first-call chunks, then second-call chunks. Index runs + # 0..total-1 per request, and the per-chunk fill-value proves the chunker + # didn't leak state between requests. + for i, chunk in enumerate(frame_chunks[:first_total]): + assert chunk["index"] == i + assert chunk["total"] == first_total + for i, chunk in enumerate(frame_chunks[first_total:]): + assert chunk["index"] == i + assert chunk["total"] == second_total + + # ---------------- reset ---------------------- + assert "reset successful" in ws.sent_text, "Reset endpoint must reply with a text ack." + # ``ServingRealtimeWorldCamera.reset`` wipes ``_current_session_id``; + # the third request then sets it to ``session-2``. + assert serving._current_session_id == "session-2" + + # ---------------- Engine call accounting ----- + # Exactly two ``generate`` invocations; the reset request must not call + # the engine. + assert len(engine_client.calls) == 2 + assert [c["session_id"] for c in engine_client.calls] == ["session-1", "session-2"] + + +def test_camera_session_handshake_does_not_repeat_within_connection() -> None: + """The handshake is sent **once** at connect, even if many infer/reset + operations follow. Other diffusion clients depend on this invariant to + avoid double-initialising their config.""" + first_frames = np.full((CHUNK_FRAMES + 1, 4, 4, 3), 1, dtype=np.uint8) + second_frames = np.full((CHUNK_FRAMES + 2, 4, 4, 3), 2, dtype=np.uint8) + engine_client = _FakeEngineClient(queued_frames=[first_frames, second_frames]) + serving = ServingRealtimeWorldCamera(engine_client=engine_client, model_name="lingbot") + + ws = _MockWebSocket( + incoming=[ + _pack_frame(_infer_req(session_id="s", num_frames=25, include_image=True)), + _pack_frame({"endpoint": "reset"}), + _pack_frame(_infer_req(session_id="s", num_frames=24, include_image=False)), + ] + ) + conn = WorldCameraRealtimeConnection(ws, serving) + asyncio.run(conn.handle_connection()) + + # Exactly one msgpack-encoded non-frame, non-error dict in the entire + # outbound stream: the handshake. + handshakes = [] + for b in ws.sent_bytes: + decoded = msgpack_numpy.unpackb(b) + if isinstance(decoded, dict) and decoded.get("type") not in ("frame", "error"): + handshakes.append(decoded) + assert len(handshakes) == 1, f"Expected exactly one handshake, got {len(handshakes)}: {handshakes}" diff --git a/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py b/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py new file mode 100644 index 00000000000..3de37072f55 --- /dev/null +++ b/tests/e2e/online_serving/test_lingbot_world_fast_expansion.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""L3 real-weight online expansion for ``/v1/realtime/world/camera``.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import PIL.Image +import pytest + +from tests.helpers.lingbot_world_fast import ( + FPS, + GREAT_WALL_PROMPT, + HEIGHT, + LONG_NUM_FRAMES, + SEED, + SHORT_NUM_FRAMES, + SSIM_THRESHOLD, + WIDTH, + find_lingbot_world_fast_assets, + frame_ssim, + golden_frames_dir, + load_camera_trajectory, + slice_camera_chunk, +) +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer + +# Optional protocol deps mirror what the connection itself imports lazily. +msgpack_numpy = pytest.importorskip("openpi_client.msgpack_numpy") +ws_sync = pytest.importorskip("websockets.sync.client") + +pytestmark = [ + pytest.mark.advanced_model, + pytest.mark.core_model, + pytest.mark.diffusion, +] + +_CONNECT_KWARGS = {"max_size": None, "ping_interval": None, "ping_timeout": None} + + +# --------------------------------------------------------------------------- +# Asset / golden fixtures (module-scoped to amortize file IO) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def lingbot_world_fast_assets(): + assets = find_lingbot_world_fast_assets() + if assets is None: + pytest.skip( + "Lingbot-World-Fast L3 assets not available. Set LINGBOT_WORLD_FAST_PATH, " + "LINGBOT_WORLD_FAST_CAMERA_PATH and LINGBOT_WORLD_FAST_IMAGE.", + ) + return assets + + +_LINGBOT_SERVER_ARGS = [ + "--model-class-name", + "LingbotWorldFastPipeline", + "--ws-max-size", + "16777216", # 16 MiB — matches run_server.sh; large enough for a 480×832 image + "--ws", + "wsproto", + "--stage-init-timeout", + "6000", + "--init-timeout", + "6000", +] + + +@pytest.fixture(scope="module") +def lingbot_world_fast_server(lingbot_world_fast_assets): + """Module-scoped real-weight server; amortizes the multi-minute cold load + across the four protocol tests in this file.""" + with OmniServer( + str(lingbot_world_fast_assets.weights_path), + list(_LINGBOT_SERVER_ARGS), + use_omni=True, + ) as server: + yield server + + +def _ws_url(server: OmniServer) -> str: + return f"ws://{server.host}:{server.port}/v1/realtime/world/camera" + + +# --------------------------------------------------------------------------- +# WebSocket helpers +# --------------------------------------------------------------------------- + + +def _drain_handshake(ws) -> dict[str, Any]: + handshake = msgpack_numpy.unpackb(ws.recv()) + return handshake + + +def _send_request(ws, req: dict[str, Any]) -> None: + ws.send(msgpack_numpy.packb(req)) + + +def _drain_frames_or_error(ws) -> tuple[list[np.ndarray] | None, dict[str, Any] | None, str | None]: + """Return ``(frames, error, text)``. Exactly one of the three is non-None. + + * ``frames``: list of per-chunk uint8 arrays once ``total`` frames arrive. + * ``error``: parsed ``{"type": "error", "message": ...}`` payload. + * ``text``: text-frame reply (e.g. ``"reset successful"``). + """ + chunks: list[np.ndarray] = [] + total: int | None = None + while total is None or len(chunks) < total: + msg = ws.recv() + if isinstance(msg, str): + return None, None, msg + decoded = msgpack_numpy.unpackb(msg) + if isinstance(decoded, dict) and decoded.get("type") == "error": + return None, decoded, None + if not isinstance(decoded, dict) or decoded.get("type") != "frame": + continue # ignore unknown + total = decoded["total"] + chunks.append(np.asarray(decoded["video"])) + return chunks, None, None + + +def _build_request( + *, + session_id: str, + image: np.ndarray | None, + camera_chunk: dict[str, np.ndarray], + num_frames: int, +) -> dict[str, Any]: + req: dict[str, Any] = { + "prompt": GREAT_WALL_PROMPT, + "camera": camera_chunk, + "session_id": session_id, + "extra_body": { + "num_frames": num_frames, + "height": HEIGHT, + "width": WIDTH, + "fps": FPS, + "session_id": session_id, + "seed": SEED, + }, + } + if image is not None: + req["image"] = image + return req + + +# --------------------------------------------------------------------------- +# Test 1: Single session generation +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +@pytest.mark.parametrize("num_frames, length", [(SHORT_NUM_FRAMES, "short"), (LONG_NUM_FRAMES, "long")]) +def test_lingbot_world_online_video( + num_frames, + length, + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + poses = poses[:num_frames] + intrinsics = intrinsics[:num_frames] + + camera = {"poses": poses, "intrinsics": intrinsics} + + req = _build_request( + session_id=f"SESSION-ID-{length}", + image=image, + camera_chunk=camera, + num_frames=num_frames, + ) + + _send_request(ws, req) + + chunks, error, text = _drain_frames_or_error(ws) + assert error is None and text is None, f"Got unexpected control reply: error={error} text={text}" + assert chunks is not None and chunks, "Returned no frames" + + reassembled = np.concatenate(chunks, axis=0) + + assert reassembled.ndim == 4 and reassembled.shape[0] >= 2, ( + f"Reassembled video has too few frames: {reassembled.shape}" + ) + + first_frame = (reassembled[0] * 255.0).round().astype(np.uint8) + last_frame = (reassembled[-1] * 255.0).round().astype(np.uint8) + + first_path = golden_frames_dir() / f"golden_frame_{length}_first.npy" + last_path = golden_frames_dir() / f"golden_frame_{length}_last.npy" + + first_golden = np.load(first_path) + last_golden = np.load(last_path) + + ssim_first = frame_ssim(first_frame, first_golden) + ssim_last = frame_ssim(last_frame, last_golden) + print( + f"[lingbot-world-fast L3 online] SSIM(first)={ssim_first:.4f} " + f"SSIM(last)={ssim_last:.4f} (threshold {SSIM_THRESHOLD})" + ) + assert ssim_first >= SSIM_THRESHOLD, ( + f"First-frame SSIM {ssim_first:.4f} below {SSIM_THRESHOLD}: regression in fresh-call path." + ) + assert ssim_last >= SSIM_THRESHOLD, ( + f"Last-frame SSIM {ssim_last:.4f} below {SSIM_THRESHOLD}: regression in extension-call path." + ) + + +# --------------------------------------------------------------------------- +# Test 2: Session-id churn mid-stream +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +def test_websocket_session_id_churn_resets_state( + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + """A new ``session_id`` mid-stream resets pipeline state. The next ``infer`` + that omits the image (i.e. an "extension-style" payload) must error + because the new session is fresh.""" + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + camera_a = slice_camera_chunk(poses, intrinsics, call_index=0) + camera_b = slice_camera_chunk(poses, intrinsics, call_index=1) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + _send_request( + ws, + _build_request( + session_id="churn-session-a", + image=image, + camera_chunk=camera_a, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks, error, text = _drain_frames_or_error(ws) + assert chunks is not None and not error and not text, ( + f"First infer on session-a should succeed; got error={error} text={text}" + ) + + # Switch session_id mid-stream WITHOUT sending an image. The pipeline + # treats this as a fresh call (new session) and rejects. + _send_request( + ws, + _build_request( + session_id="churn-session-b", + image=None, + camera_chunk=camera_b, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks2, error2, text2 = _drain_frames_or_error(ws) + assert error2 is not None, ( + "Server must reject a fresh session that omits ``image``; got " + f"frames={None if chunks2 is None else len(chunks2)} text={text2}" + ) + assert error2.get("type") == "error" + + +# --------------------------------------------------------------------------- +# Test 3: Mid-session ``reset`` RPC re-initializes +# --------------------------------------------------------------------------- + + +@hardware_test(res={"cuda": "H100"}, num_cards={"cuda": 1}) +def test_websocket_mid_session_reset_reinitializes( + lingbot_world_fast_server, + lingbot_world_fast_assets, +): + """After a ``reset`` text ack, the next ``infer`` with the same ``session_id`` + is a brand-new fresh call. We verify this by asserting that the + follow-up ``infer`` *without* an image errors (same logic as the + session-id churn test).""" + poses, intrinsics = load_camera_trajectory(lingbot_world_fast_assets.camera_dir) + camera_a = slice_camera_chunk(poses, intrinsics, call_index=0) + camera_b = slice_camera_chunk(poses, intrinsics, call_index=1) + + image = ( + PIL.Image.open(lingbot_world_fast_assets.image_path) + .convert("RGB") + .resize((WIDTH, HEIGHT), PIL.Image.Resampling.LANCZOS) + ) + image = np.asarray(image) + + with ws_sync.connect(_ws_url(lingbot_world_fast_server), **_CONNECT_KWARGS) as ws: + _drain_handshake(ws) + + _send_request( + ws, + _build_request( + session_id="reset-session", + image=image, + camera_chunk=camera_a, + num_frames=SHORT_NUM_FRAMES, + ), + ) + chunks, error, text = _drain_frames_or_error(ws) + assert chunks is not None and not error and not text, ( + f"Initial infer must succeed; got error={error} text={text}" + ) + + # Mid-session reset RPC. + ws.send(msgpack_numpy.packb({"endpoint": "reset"})) + _, _, reset_text = _drain_frames_or_error(ws) + assert reset_text == "reset successful", f"Expected 'reset successful' text frame, got {reset_text!r}" + + # Same session_id, no image → fresh-call branch → server error. + _send_request( + ws, + _build_request( + session_id="reset-session", + image=None, + camera_chunk=camera_b, + num_frames=SHORT_NUM_FRAMES, + ), + ) + _, post_reset_error, post_reset_text = _drain_frames_or_error(ws) + assert post_reset_error is not None, ( + "After mid-session reset the server must treat the next infer as a fresh call; " + f"missing-image payload should error. Got text={post_reset_text!r}" + ) diff --git a/tests/entrypoints/openai_api/test_realtime_world_camera.py b/tests/entrypoints/openai_api/test_realtime_world_camera.py new file mode 100644 index 00000000000..4b600e9bfae --- /dev/null +++ b/tests/entrypoints/openai_api/test_realtime_world_camera.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import msgspec +import pytest +from fastapi import FastAPI, WebSocket +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect + +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import CameraServerConfig + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +mock_model_config = {"model_name": "/m/foo", "image_width": 832, "image_height": 480, "extra_args": {"foo": "bar"}} + + +def test_from_model_config_loads_correctly(): + cfg = CameraServerConfig.from_model_config(mock_model_config) + + assert cfg.to_dict() == mock_model_config + + +def test_msgpack_roundtrip(): + cfg = CameraServerConfig.from_model_config(mock_model_config) + encoded = msgspec.msgpack.encode(cfg) + decoded = msgspec.msgpack.decode(encoded, type=CameraServerConfig) + assert decoded == cfg + + +def _build_camera_app(*, supports: bool, cfg: CameraServerConfig | None): + """Build a minimal FastAPI app that mirrors the api_server.py handler.""" + app = FastAPI() + + @app.websocket("/v1/realtime/world/camera") + async def realtime_world_camera(websocket: WebSocket): + await websocket.accept() + if cfg is None or not supports: + await websocket.send_json( + {"type": "error", "error": "Camera realtime API is not available", "code": "unsupported"} + ) + await websocket.close() + return + await websocket.send_bytes(msgspec.msgpack.encode(cfg)) + try: + while True: + msg = await websocket.receive() + if msg.get("type") == "websocket.disconnect": + break + except WebSocketDisconnect: + return + + return app + + +class TestRealtimeWorldCameraEndpoint: + def test_sends_msgpack_config_on_connect(self): + cfg = CameraServerConfig.from_model_config(mock_model_config) + app = _build_camera_app(supports=True, cfg=cfg) + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/world/camera") as ws: + payload = ws.receive_bytes() + decoded = msgspec.msgpack.decode(payload, type=CameraServerConfig) + assert decoded == cfg + + def test_keeps_socket_open_after_initial_send(self): + cfg = CameraServerConfig.from_model_config(mock_model_config) + app = _build_camera_app(supports=True, cfg=cfg) + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/world/camera") as ws: + ws.receive_bytes() + # Client-initiated messages are accepted (currently ignored). + ws.send_text("ping") + # Closing from the client side must not raise on the server. + + def test_unsupported_path_sends_error_and_closes(self): + app = _build_camera_app(supports=False, cfg=None) + with TestClient(app) as client: + with client.websocket_connect("/v1/realtime/world/camera") as ws: + err = ws.receive_json() + assert err["type"] == "error" + assert err["code"] == "unsupported" + with pytest.raises(WebSocketDisconnect): + ws.receive_bytes() diff --git a/tests/helpers/lingbot_world_fast.py b/tests/helpers/lingbot_world_fast.py new file mode 100644 index 00000000000..da5e9c270d0 --- /dev/null +++ b/tests/helpers/lingbot_world_fast.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared L3-fixture helpers for Lingbot World Fast expansion tests.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + +# --------------------------------------------------------------------------- +# Constants — single source of truth across the two expansion tests +# --------------------------------------------------------------------------- + +# Mirrors ``examples/offline_inference/lingbot_world_fast/end2end.py`` and +# ``examples/online_serving/lingbot_world_fast/openai_client.py``. +GREAT_WALL_PROMPT = "A sweeping cinematic journey along the Great Wall of China, winding through golden autumn hills under a brilliant blue sky — stone pathways stretch into the distance, watchtowers stand sentinel, and vibrant foliage blankets the mountainsides as the camera glides smoothly forward, capturing the grandeur and timeless majesty of this ancient wonder." +SEED = 42 +WIDTH = 832 +HEIGHT = 480 +FPS = 16 + +SHORT_NUM_FRAMES = 25 +LONG_NUM_FRAMES = 81 + +EXTENSION_WARMUP_DROP = 4 + +SSIM_THRESHOLD = 0.95 + + +@dataclass(frozen=True) +class LingbotWorldFastAssets: + """All external assets a real-weight Lingbot World Fast test needs.""" + + weights_path: Path + camera_dir: Path + image_path: Path + + +# --------------------------------------------------------------------------- +# Resolution helpers (no pytest imports here — callers decide whether to skip) +# --------------------------------------------------------------------------- + + +def _hf_cache_root() -> Path: + return Path(os.environ.get("HF_HOME", str(Path.home() / ".cache" / "huggingface"))) + + +def _hf_model_snapshot_dirs(repo_id: str) -> list[Path]: + snapshots = _hf_cache_root() / "hub" / f"models--{repo_id.replace('/', '--')}" / "snapshots" + if not snapshots.exists(): + return [] + return sorted( + (p for p in snapshots.iterdir() if p.is_dir()), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + + +def _repo_root() -> Path: + # tests/helpers/lingbot_world_fast.py → repo root + return Path(__file__).resolve().parents[2] + + +def _example_checkpoint_root() -> Path: + return ( + _repo_root() + / "examples" + / "offline_inference" + / "lingbot_world_fast" + / "lingbot_world" + / "lingbot-world-base-cam" + / "Lingbot-World-Fast" + ) + + +def _example_camera_root_candidates() -> list[Path]: + base = ( + _repo_root() + / "examples" + / "offline_inference" + / "lingbot_world_fast" + / "lingbot_world" + / "lingbot-world-base-cam" + ) + return [base / "examples" / "04", base / "04", base] + + +def find_lingbot_world_fast_weights() -> Path | None: + """Return a path to the Lingbot World Fast model directory, or ``None``. + + Resolution order: ``LINGBOT_WORLD_FAST_PATH`` env var → committed example + checkpoint → HF cache snapshot of ``robbyant/lingbot-world-base-cam``. + A *path is real* only when the required ``config.json`` plus at least + one ``model-*.safetensors`` shard are present, so a half-pulled snapshot + doesn't masquerade as a usable checkpoint. + """ + override = os.environ.get("LINGBOT_WORLD_FAST_PATH") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + + candidates.append(_example_checkpoint_root()) + + for snapshot in _hf_model_snapshot_dirs("robbyant/lingbot-world-base-cam"): + candidates.append(snapshot / "Lingbot-World-Fast") + + for path in candidates: + if not path.exists() or not path.is_dir(): + continue + if not (path / "config.json").exists(): + continue + if not any(path.glob("model-*.safetensors")): + continue + return path + return None + + +def find_lingbot_world_fast_camera_dir() -> Path | None: + """Locate a directory with ``poses.npy`` + ``intrinsics.npy``.""" + override = os.environ.get("LINGBOT_WORLD_FAST_CAMERA_PATH") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + candidates.extend(_example_camera_root_candidates()) + for path in candidates: + if path.exists() and (path / "poses.npy").exists() and (path / "intrinsics.npy").exists(): + return path + return None + + +def find_lingbot_world_fast_image() -> Path | None: + """Locate the example input image used by ``run_fast.sh`` case 2.""" + override = os.environ.get("LINGBOT_WORLD_FAST_IMAGE") + candidates: list[Path] = [] + if override: + candidates.append(Path(override)) + for camera_root in _example_camera_root_candidates(): + for name in ("image.jpg", "image.png", "input.jpg", "input.png"): + candidates.append(camera_root / name) + for path in candidates: + if path.exists() and path.is_file(): + return path + return None + + +def find_lingbot_world_fast_assets() -> LingbotWorldFastAssets | None: + """Resolve weights + camera trajectory + image in one call, or return None + if any of the three is missing — callers can then ``pytest.skip`` with a + single specific reason.""" + weights = find_lingbot_world_fast_weights() + camera = find_lingbot_world_fast_camera_dir() + image = find_lingbot_world_fast_image() + with open("output.txt", "w+") as f: + f.write(f"{weights is None} {camera is None} {image is None}") + if not (weights and camera and image): + return None + return LingbotWorldFastAssets(weights_path=weights, camera_dir=camera, image_path=image) + + +def golden_frames_dir() -> Path: + return _repo_root() / "tests" / "data" / "lingbot_world_fast" + + +# --------------------------------------------------------------------------- +# Per-call payload helpers +# --------------------------------------------------------------------------- + + +def load_camera_trajectory(camera_dir: Path) -> tuple[np.ndarray, np.ndarray]: + poses = np.load(camera_dir / "poses.npy") + intrinsics = np.load(camera_dir / "intrinsics.npy") + return poses, intrinsics + + +def slice_camera_chunk( + poses: np.ndarray, + intrinsics: np.ndarray, + *, + call_index: int, + chunk_stride: int = SHORT_NUM_FRAMES, +) -> dict[str, np.ndarray]: + """Mirrors the slicing in ``examples/online_serving/lingbot_world_fast/openai_client.py``. + + Each call consumes ``chunk_stride`` poses. The model will floor internally if the slice has fewer + poses than requested. + """ + start = call_index * chunk_stride + end = start + chunk_stride + poses_slice = poses[start:end] + intrinsics_slice = intrinsics[start:end] if intrinsics.ndim > 2 else intrinsics + return {"poses": poses_slice, "intrinsics": intrinsics_slice} + + +# --------------------------------------------------------------------------- +# Frame post-processing helpers +# --------------------------------------------------------------------------- + + +def reassemble_chunked_video( + per_call_frames: list[np.ndarray], + *, + drop_warmup: int = EXTENSION_WARMUP_DROP, +) -> np.ndarray: + """Concatenate per-call frame chunks, dropping ``drop_warmup`` leading + frames on every extension call (call index >= 1).""" + assembled: list[np.ndarray] = [] + for i, frames in enumerate(per_call_frames): + clip = frames[drop_warmup:] if i > 0 else frames + assembled.append(clip) + return np.concatenate(assembled, axis=0) + + +def normalize_to_uint8_rgb(frames: np.ndarray) -> np.ndarray: + """Coerce a generated frames tensor to ``[N, H, W, 3]`` ``uint8``. + + The diffusion engine emits either an unsigned-int chunk (pre-encoded by + ``_normalize_frames``) or a float tensor in ``[-1, 1]``. We accept both + so the SSIM helper sees a single canonical shape. + """ + arr = frames + if arr.dtype.kind == "f": + arr = np.clip((arr + 1.0) * 0.5, 0.0, 1.0) + arr = (arr * 255.0).round().astype(np.uint8) + if arr.ndim == 5 and arr.shape[0] == 1: + arr = arr[0] + return arr + + +def frame_ssim(prediction: np.ndarray, reference: np.ndarray) -> float: + """Per-frame SSIM with ``data_range=1``. Uses ``torchmetrics`` (already a + transitive dep) and accepts ``[H, W, 3]`` uint8 arrays. + """ + import torch + from torchmetrics.image import StructuralSimilarityIndexMeasure + + pred_t = (torch.from_numpy(prediction.astype(np.float32) / 255.0)).permute(2, 0, 1).unsqueeze(0) + ref_t = (torch.from_numpy(reference.astype(np.float32) / 255.0)).permute(2, 0, 1).unsqueeze(0) + metric = StructuralSimilarityIndexMeasure(data_range=1.0) + return float(metric(pred_t, ref_t).item()) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index b8eb899fa32..a11ac588a35 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -797,7 +797,7 @@ def mock_process(transfer_manager, pooling_output, request): class TestLocalPayloadCacheLifecycle(unittest.TestCase): - """Unit tests for the local payload cache API (RFC §2.4).""" + """Unit tests for the local payload cache API""" def _make_host(self) -> MixinHost: host = MixinHost() diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index a6fe1e4e9c7..31660b2482e 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -589,6 +589,10 @@ class OmniDiffusionConfig: # sleep mode enable_sleep_mode: bool = False + # Temporal pipeline parallelism (StreamBatchScheduler-driven streaming chunks). + # When True, the engine uses ``StreamBatchScheduler`` and routes execution + # through ``executor.execute_micro_step``. + stream_batch: bool = False # Maximum number of sequences to generate in a batch max_num_seqs: int = 1 @@ -901,6 +905,9 @@ def enrich_config(self) -> None: self.model_class_name = "WanS2VPipeline" self.tf_model_config = TransformerConfig() self.update_multimodal_support() + elif self.model_class_name == "LingbotWorldFastPipeline": + self.tf_config_dict = get_hf_file_to_dict("config.json", self.model) + self.tf_model_config = TransformerConfig.from_dict(self.tf_config_dict) elif architectures and len(architectures) == 1: self.model_class_name = architectures[0] else: diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index c13bd3c0c37..075581b30f9 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -31,11 +31,17 @@ get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface, StepScheduler +from vllm_omni.diffusion.sched import ( + RequestScheduler, + SchedulerInterface, + StepScheduler, + StreamBatchScheduler, +) from vllm_omni.diffusion.sched.interface import DiffusionRequestStatus from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, RunnerOutput from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.diffusion.registry import apply_required_sampling_overrides logger = init_logger(__name__) @@ -71,6 +77,13 @@ def supports_multimodal_input(od_config: OmniDiffusionConfig) -> tuple[bool, boo return supports_image_input, supports_audio_input +def supports_camera_pos_input(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_camera_pos_input", False)) + + def image_color_format(model_class_name: str) -> str: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) return getattr(model_cls, "color_format", "RGB") @@ -138,9 +151,16 @@ def __init__( executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) self.step_execution = bool(getattr(od_config, "step_execution", False)) - self.scheduler: SchedulerInterface = scheduler or ( - StepScheduler() if self.step_execution else RequestScheduler() - ) + self.stream_batch = bool(getattr(od_config, "stream_batch", False)) + + if scheduler is not None: + self.scheduler: SchedulerInterface = scheduler + elif self.stream_batch: + self.scheduler = StreamBatchScheduler() + elif self.step_execution: + self.scheduler = StepScheduler() + else: + self.scheduler = RequestScheduler() self.scheduler.initialize(od_config) if self.scheduler.max_num_running_reqs > 1 and not self.step_execution: max_num_seqs = self.scheduler.max_num_running_reqs @@ -162,7 +182,12 @@ def __init__( self._shutdown_complete = False self.abort_queue: queue.Queue[str] = queue.Queue() self._rpc_queue: queue.Queue[_RpcTask] = queue.Queue() - self.execute_fn = self.executor.execute_step if self.step_execution else self.executor.execute_request + if self.stream_batch: + self.execute_fn = self.executor.execute_micro_step + elif self.step_execution: + self.execute_fn = self.executor.execute_step + else: + self.execute_fn = self.executor.execute_request try: self._dummy_run() @@ -575,6 +600,10 @@ def make_engine( return DiffusionEngine(config, scheduler=scheduler) def add_request(self, request: OmniDiffusionRequest) -> str: + apply_required_sampling_overrides( + request.sampling_params, self.od_config.model_class_name, + ) + with self._cv: if self._closed: raise RuntimeError("DiffusionEngine is closed.") @@ -603,6 +632,9 @@ async def async_add_req_and_wait_for_response(self, request: OmniDiffusionReques return await self.get_result(sched_req_id) def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> DiffusionOutput: + apply_required_sampling_overrides( + request.sampling_params, self.od_config.model_class_name, + ) with self._rpc_lock: if self._closed: raise RuntimeError("DiffusionEngine is closed.") @@ -691,6 +723,27 @@ def _dummy_run(self): dummy_audio = np.random.randn(audio_sr * 2).astype(np.float32) prompt.setdefault("multi_modal_data", {})["audio"] = dummy_audio + audio_duration_sec = 4 + audio_array = np.random.randn(audio_sr * audio_duration_sec).astype(np.float32) + dummy_audio = audio_array[audio_sr * 1 : audio_sr * 3] + else: + dummy_audio = None + + if supports_camera_pos_input(self.od_config.model_class_name): + camera_pos_len = 64 + # Shape [N x 4] + intrinsics = np.random.rand(camera_pos_len, 4) + # Shape [N x 4 x 4] + poses = np.array([np.identity(4) for _ in range(camera_pos_len)]) + + dummy_camera_pos = {"intrinsics": intrinsics, "poses": poses} + else: + dummy_camera_pos = None + + prompt: OmniTextPrompt = { + "prompt": "dummy run", + "multi_modal_data": {"image": dummy_image, "audio": dummy_audio, "camera": dummy_camera_pos}, + } req = OmniDiffusionRequest( prompts=[prompt], request_ids=["dummy_req_id"], @@ -932,4 +985,4 @@ def _finalize_finished_request( if runner_output is not None and runner_output.result is not None: return runner_output.result - return DiffusionOutput(error=missing_result_error) + return DiffusionOutput(error=missing_result_error) \ No newline at end of file diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 5c5b092ef21..2e79341ab8f 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -5,6 +5,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import pickle from collections import namedtuple +from contextlib import nullcontext from typing import Any import torch @@ -699,6 +700,16 @@ def __init__( self.send_shape: dict[str, dict[int, torch.Size]] = {} self.recv_buffer: dict[str, dict[int, torch.Size]] = {} + # Cached dict schema and pre-allocated recv buffers for + # `pipeline_isend_tensor_dict` / `pipeline_irecv_tensor_dict`. + # Keyed by (name, segment_idx). Recv buffer leaf is a length-2 list + # for double buffering. Caller picks the slot via buf_idx. + self.dict_schema_cache: dict[tuple[str, int], list[tuple[str, Any]]] = {} + self.dict_recv_buffer: dict[tuple[str, int], list[dict[str, torch.Tensor]]] = {} + self._comms_stream: Any = None # Dedicated comms stream for PP P2P. None on CPU. + + self.dict_schema_keepalive: list[torch.Tensor] = [] + self.skip_tensor_recv_buffer_set: bool = False self.recv_skip_tasks_queue: list[int | tuple[str, int]] = [] self.receiving_skip_tasks: list[tuple[torch.distributed.Work, str, int]] = [] @@ -710,6 +721,45 @@ def __init__( self.skip_device_group = skip_device_group assert self.skip_device_group is not None + self._warmup_nccl_comms() + + def _warmup_nccl_comms(self) -> None: + """Force eager ncclCommInit on every P2P group while all ranks are + synchronized at __init__. Otherwise the first real P2P op would + trigger a collective comm-init that blocks the early-arriving + rank — breaks temporal-PP where one rank deliberately runs ahead. + """ + if self.world_size == 1: + return + + dummy = torch.zeros(1, device=self.device, dtype=torch.uint8) + + if self.world_size == 2: + for group_idx in (0, 1): + group = self.device_groups[group_idx] + if self.rank_in_group == group_idx: + op = torch.distributed.P2POp(torch.distributed.isend, dummy, self.next_rank, group) + else: + op = torch.distributed.P2POp(torch.distributed.irecv, dummy, self.prev_rank, group) + for req in torch.distributed.batch_isend_irecv([op]): + req.wait() + else: + for req in torch.distributed.batch_isend_irecv( + [ + torch.distributed.P2POp(torch.distributed.isend, dummy, self.next_rank, self.device_group), + torch.distributed.P2POp(torch.distributed.irecv, dummy, self.prev_rank, self.device_group), + ] + ): + req.wait() + + for req in torch.distributed.batch_isend_irecv( + [ + torch.distributed.P2POp(torch.distributed.isend, dummy, self.skip_rank, self.skip_device_group), + torch.distributed.P2POp(torch.distributed.irecv, dummy, self.skip_rank, self.skip_device_group), + ] + ): + req.wait() + def reset_buffer(self): self.recv_tasks_queue = [] self.receiving_tasks = [] @@ -717,10 +767,40 @@ def reset_buffer(self): self.send_shape = {} self.recv_buffer = {} + self.dict_schema_cache = {} + self.dict_recv_buffer = {} + self.dict_schema_keepalive = [] + self.recv_skip_tasks_queue = [] self.receiving_skip_tasks = [] self.skip_tensor_recv_buffer = {} + @property + def comms_stream(self): + """Dedicated stream for PP P2P comms.""" + if self._comms_stream is None and self.device.type != "cpu": + mod = getattr(torch, self.device.type, None) + if mod is not None and hasattr(mod, "Stream"): + self._comms_stream = mod.Stream(device=self.device) + return self._comms_stream + + def _comms_stream_ctx(self): + """Context manager that makes ``comms_stream`` the current stream.""" + stream = self.comms_stream + if stream is None: + return nullcontext() + return getattr(torch, self.device.type).stream(stream) + + def _record_compute_event(self): + """Record an event on the default (compute) stream for later + ``comms_stream.wait_event``.""" + if self.comms_stream is None: + return None + mod = getattr(torch, self.device.type) + ev = mod.Event() + ev.record(mod.current_stream(self.device)) + return ev + def set_config(self, dtype: torch.dtype): self.dtype = dtype @@ -808,6 +888,12 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): recv_prev: boolean for whether tensor should be received from previous rank. """ + send_group = ( + self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group + ) + recv_group = ( + self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group + ) ops = [] if recv_prev: @@ -816,7 +902,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.irecv, recv_prev_dim_tensor, self.prev_rank, - self.device_group, + recv_group, ) ops.append(recv_prev_dim_op) @@ -826,7 +912,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.isend, send_next_dim_tensor, self.next_rank, - self.device_group, + send_group, ) ops.append(send_next_dim_op) @@ -849,7 +935,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.irecv, recv_prev_shape_tensor, self.prev_rank, - self.device_group, + recv_group, ) ops.append(recv_prev_shape_op) @@ -859,7 +945,7 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): torch.distributed.isend, send_next_shape_tensor, self.next_rank, - self.device_group, + send_group, ) ops.append(send_next_shape_op) @@ -875,15 +961,55 @@ def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): recv_prev_shape = recv_prev_shape_tensor return torch.Size(recv_prev_shape) + def _isend_dict_schema( + self, send_metadata: list[tuple[str, Any]] + ) -> tuple[list[torch.distributed.Work], list[torch.Tensor]]: + """Non-blocking schema send. Returns (handles, keepalive_tensors). + Caller must keep the tensors alive until the handles complete. + """ + send_group = ( + self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group + ) + payload_bytes = pickle.dumps(send_metadata) + payload_array = bytearray(payload_bytes) + payload_tensor = torch.frombuffer(payload_array, dtype=torch.uint8).to(self.device) + send_size_tensor = torch.tensor( + [payload_tensor.numel()], device=self.device, dtype=torch.int64 + ) + handles = [ + torch.distributed.isend(send_size_tensor, dst=self.next_rank, group=send_group), + torch.distributed.isend(payload_tensor, dst=self.next_rank, group=send_group), + ] + return handles, [send_size_tensor, payload_tensor] + + def _recv_dict_schema(self) -> list[tuple[str, Any]]: + """Blocking schema recv - must wait because the size value is + needed before allocating the payload buffer. + """ + recv_group = ( + self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group + ) + recv_size_tensor = torch.empty(1, device=self.device, dtype=torch.int64) + torch.distributed.recv(recv_size_tensor, src=self.prev_rank, group=recv_group) + recv_payload = torch.empty(int(recv_size_tensor.item()), device=self.device, dtype=torch.uint8) + torch.distributed.recv(recv_payload, src=self.prev_rank, group=recv_group) + return pickle.loads(recv_payload.cpu().numpy().tobytes()) + def pipeline_send(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: tensor = tensor.contiguous() self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) self._pipeline_isend(tensor).wait() - def pipeline_isend(self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1) -> None: + def pipeline_isend( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> torch.distributed.Work: tensor = tensor.contiguous() self._check_shape_and_buffer(tensor_send_to_next=tensor, name=name, segment_idx=segment_idx) - self._pipeline_isend(tensor) + handle = self._pipeline_isend(tensor) + if tensor.is_cuda: + # Keep allocator from reusing this CUDA buffer before the async send finishes. + tensor.record_stream(torch.cuda.current_stream(tensor.device)) + return handle def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: name = name or "latent" @@ -891,6 +1017,132 @@ def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: self._pipeline_irecv(self.recv_buffer[name][idx]).wait() return self.recv_buffer[name][idx] + def set_recv_dict_buffer( + self, + name: str, + segment_idx: int, + template_dict: dict[str, torch.Tensor | Any], + batch_size: int = 1, + ) -> None: + """Pre-populate schema cache + a double-buffer pair (indices 0/1) for + ``(name, segment_idx, batch_size)``. + """ + metadata_list, _ = _split_tensor_dict(template_dict) + key = (name, segment_idx, batch_size) + self.dict_schema_cache[key] = metadata_list + buffer_pair: list[dict[str, torch.Tensor]] = [] + for _ in range(2): + buffers: dict[str, torch.Tensor] = {} + for key_, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + continue + device = torch.device(value.device) if value.device == "cpu" else self.device + buffers[key_] = torch.empty(value.size, dtype=value.dtype, device=device) + buffer_pair.append(buffers) + self.dict_recv_buffer[key] = buffer_pair + + def pipeline_isend_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + name: str = "dict", + segment_idx: int = -1, + batch_size: int = 1, + ) -> list[torch.distributed.Work]: + """Non-blocking dict send keyed by ``(name, segment_idx, batch_size)``.""" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + + key = (name, segment_idx, batch_size) + handles: list[torch.distributed.Work] = [] + if key not in self.dict_schema_cache: + schema_handles, keepalive = self._isend_dict_schema(metadata_list) + handles.extend(schema_handles) + self.dict_schema_keepalive.extend(keepalive) + self.dict_schema_cache[key] = metadata_list + + compute_done = self._record_compute_event() + comms = self.comms_stream + group = ( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ) + with self._comms_stream_ctx(): + if comms is not None and compute_done is not None: + comms.wait_event(compute_done) + for tensor in tensor_list: + if tensor.numel() == 0: + continue + tensor = tensor.contiguous() + if tensor.is_cuda and comms is not None: + tensor.record_stream(comms) + handles.append( + torch.distributed.isend(tensor, dst=self.next_rank, group=group) + ) + return handles + + def pipeline_irecv_tensor_dict( + self, + name: str = "dict", + segment_idx: int = -1, + buf_idx: int = 0, + batch_size: int = 1, + ) -> tuple[dict[str, torch.Tensor | Any], list[torch.distributed.Work], list]: + """Async tensor-dict recv into the ``buf_idx`` slot (0 or 1) of the + double-buffer pair for ``(name, segment_idx, batch_size)``. Caller picks + the slot — typically ``micro_step % 2`` — so consecutive recvs alternate + and the previous result stays readable until its consumer is done. + Posts irecvs on ``comms_stream``. + """ + key = (name, segment_idx, batch_size) + if key not in self.dict_schema_cache: + metadata_list = self._recv_dict_schema() + self.dict_schema_cache[key] = metadata_list + buffer_pair: list[dict[str, torch.Tensor]] = [] + for _ in range(2): + buffers: dict[str, torch.Tensor] = {} + for k, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + continue + device = torch.device(value.device) if value.device == "cpu" else self.device + buffers[k] = torch.empty(value.size, dtype=value.dtype, device=device) + buffer_pair.append(buffers) + self.dict_recv_buffer[key] = buffer_pair + + metadata_list = self.dict_schema_cache[key] + buffers = self.dict_recv_buffer[key][buf_idx] + comms = self.comms_stream + + tensor_dict: dict[str, Any] = {} + handles: list[torch.distributed.Work] = [] + group = ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ) + with self._comms_stream_ctx(): + for k, value in metadata_list: + if isinstance(value, TensorMetadata): + if torch.Size(value.size).numel() == 0: + _update_nested_dict( + tensor_dict, + k, + torch.empty(value.size, dtype=value.dtype, device=self.device), + ) + continue + tensor = buffers[k] + if tensor.is_cuda and comms is not None: + tensor.record_stream(comms) + handles.append( + torch.distributed.irecv(tensor, src=self.prev_rank, group=group) + ) + _update_nested_dict(tensor_dict, k, tensor) + else: + _update_nested_dict(tensor_dict, k, value) + + return tensor_dict, handles, [] + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): name = name or "latent" self.recv_tasks_queue.append((name, idx)) @@ -911,18 +1163,12 @@ def get_pipeline_recv_data(self, idx: int = -1, name: str = "latent") -> torch.T return self.recv_buffer[name][idx] def _pipeline_irecv(self, tensor: torch.tensor): - return torch.distributed.irecv( - tensor, - src=self.prev_rank, - group=(self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group), - ) + group = self.device_groups[(self.rank_in_group + 1) % 2] if self.world_size == 2 else self.device_group + return torch.distributed.irecv(tensor, src=self.prev_rank, group=group) def _pipeline_isend(self, tensor: torch.tensor): - return torch.distributed.isend( - tensor, - dst=self.next_rank, - group=(self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group), - ) + group = self.device_groups[self.rank_in_group % 2] if self.world_size == 2 else self.device_group + return torch.distributed.isend(tensor, dst=self.next_rank, group=group) def set_skip_tensor_recv_buffer( self, diff --git a/vllm_omni/diffusion/distributed/pipeline_parallel.py b/vllm_omni/diffusion/distributed/pipeline_parallel.py index b39d8052e48..9f4407ef7f8 100644 --- a/vllm_omni/diffusion/distributed/pipeline_parallel.py +++ b/vllm_omni/diffusion/distributed/pipeline_parallel.py @@ -55,6 +55,9 @@ def _resolve(self) -> torch.Tensor: def __getattr__(self, name: str): return getattr(self._resolve(), name) + def __getitem__(self, key): + return self._resolve()[key] + # Torch function protocol: any torch op involving an AsyncLatents resolves it first. @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -184,6 +187,9 @@ def predict_noise_maybe_with_cfg( negative_kwargs: dict[str, Any] | None, cfg_normalize: bool = True, output_slice: int | None = None, + buf_idx: int = 0, + batch_size: int = 1, + preposted_its: list[AsyncIntermediateTensors] | None = None, ) -> torch.Tensor | tuple[torch.Tensor, ...] | None: """ Drop-in replacement for predict_noise_maybe_with_cfg that also handles PP. @@ -217,18 +223,30 @@ def predict_noise_maybe_with_cfg( # Sequential CFG (or no CFG): this PP pipeline handles all branches. all_kwargs = [positive_kwargs] + ([negative_kwargs] if do_true_cfg else []) - # Non-first ranks receive intermediate tensors asynchronously n = len(all_kwargs) its: list[AsyncIntermediateTensors | None] = [None] * n if not pp_group.is_first_rank: - for i in range(n): - its[i] = AsyncIntermediateTensors(*pp_group.irecv_tensor_dict()) + # Use recvs pre-posted by the previous step's scheduler_step. + # Caller owns the lifecycle in state.extra; we just consume here. + if preposted_its is not None and len(preposted_its) == n: + its = list(preposted_its) + else: + for i in range(n): + its[i] = AsyncIntermediateTensors( + *pp_group.pipeline_irecv_tensor_dict( + name="intermediate", segment_idx=i, buf_idx=buf_idx, batch_size=batch_size, + ) + ) if not pp_group.is_last_rank: # First / middle rank: run partial forwards and propagate ITs downstream. - for kwargs, it in zip(all_kwargs, its): + for i, (kwargs, it) in enumerate(zip(all_kwargs, its)): result = self.predict_noise(**kwargs, intermediate_tensors=it) - self._pp_send_work.extend(pp_group.isend_tensor_dict(result.tensors)) + self._pp_send_work.extend( + pp_group.pipeline_isend_tensor_dict( + result.tensors, name="intermediate", segment_idx=i, batch_size=batch_size, + ) + ) return None # Last rank: run full forward @@ -261,31 +279,96 @@ def scheduler_step_maybe_with_cfg( t: torch.Tensor | tuple[torch.Tensor, ...], latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, - per_request_scheduler: Any | None = None, + per_request_scheduler: Any | list[Any] | None = None, generator: torch.Generator | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, ...] | AsyncLatents: + batch_size: int = 1, + receive_latents: bool = True, + buf_idx: int = 0, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Drop-in replacement for scheduler_step_maybe_with_cfg that also handles PP. Only the last rank runs the scheduler (it already has noise_pred); the result is sent to rank 0 which needs it for the next forward pass. - Returns a ``AsyncLatents`` on rank 0 that transparently defers + If `receive_latents` is True, returns a ``AsyncLatents`` on rank 0 that transparently defers ``handle.wait()`` until the tensor is actually consumed (via attribute access or a torch operation), keeping the rank non-blocking after the ``irecv`` is posted. """ if get_pipeline_parallel_world_size() == 1: - return super().scheduler_step_maybe_with_cfg( - noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator + return self._scheduler_step_local( + noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator, ) pp_group = get_pp_group() if pp_group.is_last_rank: - latents = super().scheduler_step_maybe_with_cfg( - noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator + latents = self._scheduler_step_local( + noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator, + ) + self._pp_send_work = pp_group.pipeline_isend_tensor_dict( + {"latents": latents}, name="latents", batch_size=batch_size, + ) + if pp_group.is_first_rank and receive_latents: + latents = AsyncLatents( + *pp_group.pipeline_irecv_tensor_dict( + name="latents", buf_idx=buf_idx, batch_size=batch_size, + ) ) - self._pp_send_work = pp_group.isend_tensor_dict({"latents": latents}, dst=0) - elif pp_group.is_first_rank: - latents = AsyncLatents(*pp_group.irecv_tensor_dict(src=pp_group.world_size - 1)) return latents + + def _scheduler_step_local( + self, + noise_pred: torch.Tensor, + t: torch.Tensor, + latents: torch.Tensor, + do_true_cfg: bool, + per_request_scheduler: Any | list[Any] | None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """Run scheduler.step on this rank — single call or per-chunk loop.""" + if not isinstance(per_request_scheduler, list): + return super().scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg, per_request_scheduler, generator=generator, + ) + new_rows: list[torch.Tensor] = [] + for i, sched in enumerate(per_request_scheduler): + t_i = t[i] if t.ndim > 0 else t + new_rows.append( + super().scheduler_step_maybe_with_cfg( + noise_pred[i:i + 1], t_i, latents[i:i + 1], do_true_cfg, sched, generator=generator, + ) + ) + return torch.cat(new_rows, dim=0) + + def prefetch_tensors_maybe_with_cfg( + self, + do_true_cfg: bool, + buf_idx: int, + batch_size: int = 1, + ) -> list[AsyncIntermediateTensors] | AsyncLatents | None: + """Pre-post the next-step recv on this rank's comms stream. + + First rank pre-posts the latents irecv from the last rank. + Non-first ranks pre-post the intermediate-tensor irecv from the previous rank. + """ + if get_pipeline_parallel_world_size() == 1: + return None + pp_group = get_pp_group() + if pp_group.is_first_rank: + return AsyncLatents( + *pp_group.pipeline_irecv_tensor_dict( + name="latents", buf_idx=buf_idx, batch_size=batch_size, + ) + ) + + cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 + n = 1 if cfg_parallel_ready else (2 if do_true_cfg else 1) + return [ + AsyncIntermediateTensors( + *pp_group.pipeline_irecv_tensor_dict( + name="intermediate", segment_idx=i, buf_idx=buf_idx, batch_size=batch_size, + ) + ) + for i in range(n) + ] \ No newline at end of file diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py index c5abaf59cc7..5d51bbe50a3 100644 --- a/vllm_omni/diffusion/executor/abstract.py +++ b/vllm_omni/diffusion/executor/abstract.py @@ -84,6 +84,11 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner """Execute step-mode work from a scheduler output.""" pass + @abstractmethod + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: + """Execute one temporal-PP micro-step from a scheduler output.""" + pass + @abstractmethod def collective_rpc( self, diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index 6727122d462..043833c87fb 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -344,6 +344,36 @@ def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunner else: raise RuntimeError(f"Unexpected response type for execute_step: {type(result)!r}") + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: + """Forward a temporal-PP micro-step to worker ``execute_micro_step`` RPC. + + Assumes worker rank == PP rank (true for PP-only layouts; revisit when + introducing TP/DP combinations). + """ + from vllm_omni.diffusion.worker.utils import BaseRunnerOutput, RunnerOutput + + self._ensure_open() + + result = self.collective_rpc( + "execute_micro_step", + args=(scheduler_output,), + unique_reply_rank=0, + exec_all_ranks=True, + ) + + if isinstance(result, BaseRunnerOutput): + return result + if isinstance(result, DiffusionOutput): + req_id = scheduler_output.scheduled_req_ids[0] if scheduler_output.scheduled_req_ids else "" + return RunnerOutput( + req_id=req_id, + step_index=None, + finished=True, + result=result, + ) + else: + raise RuntimeError(f"Unexpected response type for execute_step: {type(result)!r}") + def collective_rpc( self, method: str, diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index f0ded9cdb0a..b12a5e8420d 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -28,6 +28,11 @@ class SupportAudioInput(Protocol): support_audio_input: ClassVar[bool] = True +@runtime_checkable +class SupportCameraPosInput(Protocol): + support_camera_pos_input: ClassVar[bool] = True + + @runtime_checkable class SupportAudioOutput(Protocol): support_audio_output: ClassVar[bool] = True @@ -87,3 +92,32 @@ def supports_step_execution(pipeline: object) -> bool: """Return whether `pipeline` implements :class:`SupportsStepExecution`.""" return isinstance(pipeline, SupportsStepExecution) + + +@runtime_checkable +class SupportsMicroStepExecution(SupportsStepExecution, Protocol): + """Temporal-PP micro-step execution protocol. + + Extends :class:`SupportsStepExecution` with the per-micro-step hooks + used by ``DiffusionModelRunner.execute_micro_step``: + + - ``set_pp_recv_dict_buffers`` pre-registers PPGC dict channels for + this request to skip the blocking first-call schema exchange. + - ``prefetch_tensors`` pre-posts the next-step recv on the comms stream + so it overlaps with the current micro-step's compute (latents on the + first PP rank, intermediate tensors on the others). + """ + + supports_micro_step_execution: ClassVar[bool] = True + + def set_pp_recv_dict_buffers(self, state: DiffusionRequestState, **kwargs: Any) -> None: + """Pre-register PP dict recv buffers and schema cache for this request.""" + + def prefetch_tensors(self, state: DiffusionRequestState, **kwargs: Any) -> None: + """Pre-post the next-step recv.""" + + +def supports_micro_step_execution(pipeline: object) -> bool: + """Return whether `pipeline` implements :class:`SupportsMicroStepExecution`.""" + + return isinstance(pipeline, SupportsMicroStepExecution) diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py b/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py new file mode 100644 index 00000000000..20513d34e60 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_lingbot_world_fast import LingbotWorldFastPipeline, get_lingbot_world_fast_post_process_func +from .wan_fast import WanModelFast + +__all__ = ["LingbotWorldFastPipeline", "get_lingbot_world_fast_post_process_func", "WanModelFast"] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py b/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py new file mode 100644 index 00000000000..cbc0da6889e --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/cam_utils.py @@ -0,0 +1,153 @@ +# Adapted from Lingbot-World/wan/utils/cam_utils.py +import numpy as np +import torch +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation, Slerp + + +def interpolate_camera_poses( + src_indices: np.ndarray, + src_rot_mat: np.ndarray, + src_trans_vec: np.ndarray, + tgt_indices: np.ndarray, +) -> torch.Tensor: + # interpolate translation + interp_func_trans = interp1d( + src_indices, + src_trans_vec, + axis=0, + kind="linear", + bounds_error=False, + fill_value="extrapolate", + ) + interpolated_trans_vec = interp_func_trans(tgt_indices) + + # interpolate rotation + src_quat_vec = Rotation.from_matrix(src_rot_mat) + # ensure there is no sudden change in qw + quats = src_quat_vec.as_quat().copy() # [N, 4] + for i in range(1, len(quats)): + if np.dot(quats[i], quats[i - 1]) < 0: + quats[i] = -quats[i] + src_quat_vec = Rotation.from_quat(quats) + slerp_func_rot = Slerp(src_indices, src_quat_vec) + interpolated_rot_quat = slerp_func_rot(tgt_indices) + interpolated_rot_mat = interpolated_rot_quat.as_matrix() + + poses = np.zeros((len(tgt_indices), 4, 4)) + poses[:, :3, :3] = interpolated_rot_mat + poses[:, :3, 3] = interpolated_trans_vec + poses[:, 3, 3] = 1.0 + return torch.from_numpy(poses).float() + + +def SE3_inverse(t: torch.Tensor) -> torch.Tensor: + Rot = t[:, :3, :3] # [B,3,3] + trans = t[:, :3, 3:] # [B,3,1] + R_inv = Rot.transpose(-1, -2) + t_inv = -torch.bmm(R_inv, trans) + T_inv = torch.eye(4, device=t.device, dtype=t.dtype)[None, :, :].repeat(t.shape[0], 1, 1) + T_inv[:, :3, :3] = R_inv + T_inv[:, :3, 3:] = t_inv + return T_inv + + +def compute_relative_poses( + c2ws_mat: torch.Tensor, + framewise: bool = False, + normalize_trans: bool = True, +) -> torch.Tensor: + ref_w2cs = SE3_inverse(c2ws_mat[0:1]) + relative_poses = torch.matmul(ref_w2cs, c2ws_mat) + # ensure identity matrix for 1st frame + relative_poses[0] = torch.eye(4, device=c2ws_mat.device, dtype=c2ws_mat.dtype) + if framewise: + # compute pose between i and i+1 + relative_poses_framewise = torch.bmm(SE3_inverse(relative_poses[:-1]), relative_poses[1:]) + relative_poses[1:] = relative_poses_framewise + if normalize_trans: + # scale the coordinate inputs to roughly 1 standard deviation to simplify model learning (camctrl2). + translations = relative_poses[:, :3, 3] # [f, 3] + max_norm = torch.norm(translations, dim=-1).max() + # only normalize when moving + if max_norm > 0: + relative_poses[:, :3, 3] = translations / max_norm + return relative_poses + + +@torch.no_grad() +def create_meshgrid( + n_frames: int, height: int, width: int, bias: float = 0.5, device="cuda", dtype=torch.float32 +) -> torch.Tensor: + x_range = torch.arange(width, device=device, dtype=dtype) + y_range = torch.arange(height, device=device, dtype=dtype) + grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") + grid_xy = torch.stack([grid_x, grid_y], dim=-1).view([-1, 2]) + bias # [h*w, 2] + grid_xy = grid_xy[None, ...].repeat(n_frames, 1, 1) # [f, h*w, 2] + return grid_xy + + +def get_plucker_embeddings( + c2ws_mat: torch.Tensor, + k: torch.Tensor, + height: int, + width: int, + only_rays_d: bool = False, +): + n_frames = c2ws_mat.shape[0] + grid_xy = create_meshgrid(n_frames, height, width, device=c2ws_mat.device, dtype=c2ws_mat.dtype) # [f, h*w, 2] + fx, fy, cx, cy = k.chunk(4, dim=-1) # [f, 1] + + i = grid_xy[..., 0] # [f, h*w] + j = grid_xy[..., 1] # [f, h*w] + zs = torch.ones_like(i) # [f, h*w] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + + directions = torch.stack([xs, ys, zs], dim=-1) # [f, h*w, 3] + directions = directions / directions.norm(dim=-1, keepdim=True) # [f, h*w, 3] + + rays_d = directions @ c2ws_mat[:, :3, :3].transpose(-1, -2) # [f, h*w, 3] + if only_rays_d: + plucker_embeddings = rays_d # [f, h*w, 3] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 3]) # [f*h*w, 3] + else: + rays_o = c2ws_mat[:, :3, 3] # [f, 3] + rays_o = rays_o[:, None, :].expand_as(rays_d) # [f, h*w, 3] + plucker_embeddings = torch.cat([rays_o, rays_d], dim=-1) # [f, h*w, 6] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 6]) # [f*h*w, 6] + return plucker_embeddings + + +def get_Ks_transformed( + k: torch.Tensor, + height_org: int, + width_org: int, + height_resize: int, + width_resize: int, + height_final: int, + width_final: int, +): + fx, fy, cx, cy = k.chunk(4, dim=-1) # [f, 1] + + scale_x = width_resize / width_org + scale_y = height_resize / height_org + + fx_resize = fx * scale_x + fy_resize = fy * scale_y + cx_resize = cx * scale_x + cy_resize = cy * scale_y + + crop_offset_x = (width_resize - width_final) / 2 + crop_offset_y = (height_resize - height_final) / 2 + + cx_final = cx_resize - crop_offset_x + cy_final = cy_resize - crop_offset_y + + Ks_transformed = torch.zeros_like(k) + Ks_transformed[:, 0:1] = fx_resize + Ks_transformed[:, 1:2] = fy_resize + Ks_transformed[:, 2:3] = cx_final + Ks_transformed[:, 3:4] = cy_final + + return Ks_transformed diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py b/vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py new file mode 100644 index 00000000000..9f30b20876b --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/flow_scheduler.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import torch + +from .fm_solvers_unipc import FlowUniPCMultistepScheduler + + +class LingbotFlowScheduler: + def __init__( + self, + inner: FlowUniPCMultistepScheduler, + timesteps5: torch.Tensor, + ) -> None: + self._inner = inner + # Length-5 schedule: [t0, t1, t2, t3, 0]. + self.timesteps = timesteps5 + # Used by `_convert_flow_pred_to_x0` to look up sigma_t. + self.sigmas = inner.sigmas + self._full_timesteps = inner.timesteps + + def step( + self, + noise_pred: torch.Tensor, + t: torch.Tensor, + latents: torch.Tensor, + return_dict: bool = False, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor]: + # `t` is a per-row scalar (`_scheduler_step_local` loops per row). + if float(t.item()) == 0.0: + return (latents,) + + x0 = self._convert_flow_pred_to_x0(noise_pred, latents, t) + + ts_eq = (self.timesteps == t).nonzero(as_tuple=False) + chunk_step = int(ts_eq[0].item()) if ts_eq.numel() > 0 else 0 + + if chunk_step + 1 < self.timesteps.shape[0] - 1: + next_t = self.timesteps[chunk_step + 1] + noise = torch.randn( + x0.shape, generator=generator, device=x0.device, dtype=x0.dtype + ) + return (self._inner.add_noise(x0, noise, next_t),) + return (x0,) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + return self._inner.add_noise(original_samples, noise, timesteps) + + def _convert_flow_pred_to_x0( + self, + flow_pred: torch.Tensor, + xt: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + original_dtype = flow_pred.dtype + flow_pred, xt, sigmas, timesteps = map( + lambda x: x.double().to(flow_pred.device), + [flow_pred, xt, self.sigmas, self._full_timesteps], + ) + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + return (xt - sigma_t * flow_pred).to(original_dtype) \ No newline at end of file diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py b/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py new file mode 100644 index 00000000000..d9a18899da7 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/fm_solvers_unipc.py @@ -0,0 +1,576 @@ +# Adapted from Lingbot-World/wan/utils/fm_solvers_unipc.py +# Originally derived from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Converted to flow matching. + + +import math + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats # noqa: F401 + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # set table values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None = None, + shift: float | None = None, + ): + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [None] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() + + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp(s, min=1, max=self.config.sample_max_value) + s = s.unsqueeze(1) + sample = torch.clamp(sample, -s, s) / s + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled " + "via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> SchedulerOutput | tuple: + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + if isinstance(timesteps, list): + timesteps = [timestep.to(original_samples.device) for timestep in timesteps] + else: + timesteps = timesteps.to(original_samples.device) + + if self.begin_index is None: + if not isinstance(timesteps, list): + timesteps = [timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + step_indices = [self.step_index] * timesteps.shape[0] + else: + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py new file mode 100644 index 00000000000..2637276dd3f --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/pipeline_lingbot_world_fast.py @@ -0,0 +1,972 @@ +import copy +import logging +import math +import os +import random +import sys +from contextlib import contextmanager +from typing import Any, ClassVar + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as TF +from einops import rearrange +from torch import nn +from tqdm import tqdm +from vllm.sequence import IntermediateTensors + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + get_pipeline_parallel_world_size, + get_pp_group, + is_pipeline_first_stage, + is_pipeline_last_stage, +) +from vllm_omni.diffusion.distributed.pipeline_parallel import ( + AsyncLatents, + PipelineParallelMixin, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.models.interface import SupportCameraPosInput, SupportImageInput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.worker.utils import DiffusionRequestState + +from .cam_utils import ( + compute_relative_poses, + get_Ks_transformed, + get_plucker_embeddings, + interpolate_camera_poses, +) +from .flow_scheduler import LingbotFlowScheduler +from .fm_solvers_unipc import FlowUniPCMultistepScheduler +from .state_lingbot_world_fast import LingbotWorldFastState +from .stream_vae import StreamVAE +from .t5 import T5EncoderModel +from .vae2_1 import Wan2_1_VAE +from .wan_fast import WanModelFast + +logger = logging.getLogger(__name__) + + +CONFIG = { + "text_len": 512, + "num_train_timesteps": 1000, + "vae_stride": (4, 8, 8), + "patch_size": (1, 2, 2), + "timesteps_index": [0, 179, 358, 679], + "sample_shift": 10.0, + "max_area": 480 * 832, + "max_sequence_length": 512, + "chunk_size": 3, + "t5_checkpoint": "models_t5_umt5-xxl-enc-bf16.pth", + "t5_tokenizer": "google/umt5-xxl", + "vae_checkpoint": "Wan2.1_VAE.pth", + "fast_noise_checkpoint": "Lingbot-World-Fast", + "negative_prompt_sample": ( + "画面突变,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景," + "三条腿,背景人很多,倒着走,镜头晃动,画面闪烁,模糊,噪点,水印,签名,文字,变形," + "扭曲,液化,不合逻辑的结构,卡顿,PPT幻灯片感,过暗,欠曝,低对比度,霓虹灯光感," + "过度锐化,3D渲染感,人物,行人,游客,身体,皮肤,肢体,面部特征,汽车,电线" + ), +} + + +def get_lingbot_world_fast_post_process_func( + od_config: OmniDiffusionConfig, +): + def post_process_func( + video: torch.Tensor, + ): + outputs = video.permute(1, 2, 3, 0) + return outputs + + return post_process_func + + +class LingbotWorldFastPipeline( + nn.Module, + SupportImageInput, + SupportCameraPosInput, + PipelineParallelMixin, + CFGParallelMixin, +): + supports_step_execution: ClassVar[bool] = True + supports_micro_step_execution: ClassVar[bool] = True + + STREAM_BATCH_CHUNK_FRAMES: ClassVar[int] = 12 + STREAM_BATCH_NUM_INFERENCE_STEPS: ClassVar[int] = 5 + + def __init__(self, *, od_config: OmniDiffusionConfig): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config + + self.device = get_local_device() + + self.target_dtype = od_config.dtype + + self.control_type = "cam" + self.num_train_timesteps = CONFIG["num_train_timesteps"] + + self.sp_size = od_config.parallel_config.sequence_parallel_size + + self.state = LingbotWorldFastState() + + checkpoint_path = os.path.dirname(self.od_config.model) + assert checkpoint_path is not None, "lingbot_dir is None" + + self.text_encoder = T5EncoderModel( + text_len=CONFIG["text_len"], + dtype=self.target_dtype, + device=self.device, + checkpoint_path=os.path.join(checkpoint_path, CONFIG["t5_checkpoint"]), + tokenizer_path=os.path.join(checkpoint_path, CONFIG["t5_tokenizer"]), + ) + + self.vae_stride = CONFIG["vae_stride"] + self.patch_size = CONFIG["patch_size"] + base_vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_path, CONFIG["vae_checkpoint"]), device=self.device + ) + self.vae = StreamVAE(base_vae) if od_config.stream_batch else base_vae + + logger.info(f"Creating WanModelFast from {checkpoint_path}") + self.model = WanModelFast.from_pretrained( + checkpoint_path, + subfolder=CONFIG["fast_noise_checkpoint"], + torch_dtype=torch.bfloat16, + control_type=self.control_type, + ).to(self.device) + # Partition transformer across PP ranks (no-op at PP=1). + self.model.apply_pp_split() + + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + + self.sample_neg_prompt = CONFIG["negative_prompt_sample"] + + def _configure_model(self, model): + """ + Configures a model object. This includes setting evaluation modes, + applying distributed parallel strategy, and handling device placement. + + Args: + model (torch.nn.Module): + The model instance to configure. + + Returns: + torch.nn.Module: + The configured model. + """ + model.eval().requires_grad_(False) + + def _convert_flow_pred_to_x0( + self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor, scheduler + ) -> torch.Tensor: + """ + Convert flow matching's prediction to x0 prediction. + flow_pred: the prediction with shape [B, C, F, H, W] + xt: the input noisy data with shape [B, C, F, H, W] + timestep: the timestep with shape [B] + + pred = noise - x0 + x_t = (1-sigma_t) * x0 + sigma_t * noise + we have x0 = x_t - sigma_t * pred + """ + # use higher precision for calculations + original_dtype = flow_pred.dtype + flow_pred, xt, sigmas, timesteps = map( + lambda x: x.double().to(flow_pred.device), [flow_pred, xt, scheduler.sigmas, scheduler.timesteps] + ) + timestep_id = torch.argmin((timesteps - timestep).abs()) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + + return x0_pred.to(original_dtype) + + def forward( + self, + req: OmniDiffusionRequest, + ) -> DiffusionOutput: + if len(req.prompts) > 1: + raise ValueError( + """This model only supports a single prompt, not a batched request.""", + """Please pass in a single prompt object or string, or a single-item list.""", + ) + prompt = req.prompts[0].get("prompt") + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) + + session_id = str(req.sampling_params.extra_args.get("session_id") or None) + + force_reset = req.sampling_params.extra_args.get("force_reset") or False + + extension = True + + if force_reset or self.state.session_id is None or self.state.session_id != session_id: + self.state.reset() + self.state.session_id = session_id + extension = False + else: + extension = True + + camera = multi_modal_data.get("camera", None) + if camera is None: + self.od_config.model + raise ValueError("A path to camera positions must be passed to this model through action_path.") + + if extension: + assert multi_modal_data.get("image") is None, ( + "image must not be provided on extension calls; it is only used on the first call of a session" + ) + assert self.model.config.local_attn_size == -1, ( + "video extension requires the model to be configured with local_attn_size == -1" + ) + + batch_size = 1 + num_frames = req.sampling_params.num_frames + # In order to generate something num_frames must be at least 5 since it expects 4*n + 1 as input + # 25 is the smallest length supported by the model. Smaller values generate tensors with dimension zero/negative + num_frames = max(25, num_frames) + + c2ws = camera.get("poses") + chunk_size = CONFIG["chunk_size"] + max_area = CONFIG["max_area"] + + # Fresh: 4N+1 pixel frames → N+1 latents, the first slot is the anchor. + # Extension: 4N pixel frames → N regular latents, no anchor. + if extension: + len_c2ws = (len(c2ws) // 4) * 4 + num_frames = (num_frames // 4) * 4 + num_frames = min(num_frames, len_c2ws) + new_lat_f = num_frames // 4 + else: + len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 + num_frames = ((num_frames - 1) // 4) * 4 + 1 + num_frames = min(num_frames, len_c2ws) + new_lat_f = (num_frames - 1) // 4 + 1 + c2ws = c2ws[:num_frames] + + # 1. Derive spatial shape: from the input image on fresh start, from cache on extension. + if not extension: + img = multi_modal_data.get("image") + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1] + ) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2] + ) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + else: + img = None + h, w, lat_h, lat_w = self.state.h, self.state.w, self.state.lat_h, self.state.lat_w + + new_lat_f = int(new_lat_f - (new_lat_f % chunk_size)) + new_lat_f = max(new_lat_f, 1) + max_seq_len = chunk_size * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + seed_g = req.sampling_params.generator + if seed_g is None: + seed = req.sampling_params.seed + if seed is None: + seed = random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn(16, new_lat_f, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + + # Fresh: msk[0] = 1 (anchor) and the rest = 0, replicated into 4 channels grouped + # by latent frame to give shape [4, new_lat_f, lat_h, lat_w]. + # Extension: no anchor, all zeros, already in the [4, new_lat_f, ...] layout. + if not extension: + F = (new_lat_f - 1) * 4 + 1 + msk = torch.zeros(1, F, lat_h, lat_w, device=self.device) + msk[:, 0] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + else: + msk = torch.zeros(4, new_lat_f, lat_h, lat_w, device=self.device) + + # 2. Prepare timesteps + self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) + timesteps = self.scheduler.timesteps[CONFIG["timesteps_index"]] + + context = self.text_encoder([prompt], self.device) + + dit_cond_dict = None + Ks = torch.from_numpy(camera.get("intrinsics")) + + # Transform the provided intrinsics from the original 480p according to the new image size (h, w). + Ks = get_Ks_transformed( + Ks, height_org=480, width_org=832, height_resize=h, width_resize=w, height_final=h, width_final=w + ) + Ks = Ks[0] + + # One target pose per output latent — must match the f= in the rearrange below. + len_c2ws = len(c2ws) + len_c2ws_ = new_lat_f + c2ws_infer = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws - 1, len_c2ws), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=np.linspace(0, len_c2ws - 1, len_c2ws_), + ) + c2ws_infer = compute_relative_poses(c2ws_infer, framewise=True) + Ks = Ks.repeat(len(c2ws_infer), 1) + + c2ws_infer = c2ws_infer.to(self.device).to(torch.float32) + Ks = Ks.to(self.device).to(torch.float32) + only_rays_d = False + c2ws_plucker_emb = get_plucker_embeddings(c2ws_infer, Ks, h, w, only_rays_d=only_rays_d) + c2ws_plucker_emb = rearrange( + c2ws_plucker_emb, + "f (h c1) (w c2) c -> (f h w) (c c1 c2)", + c1=int(h // lat_h), + c2=int(w // lat_w), + ) + c2ws_plucker_emb = c2ws_plucker_emb[None, ...] # [b, f*h*w, c] + c2ws_plucker_emb = rearrange(c2ws_plucker_emb, "b (f h w) c -> b c f h w", f=new_lat_f, h=lat_h, w=lat_w).to( + self.target_dtype + ) + + # Fresh: pixels = [anchor_image, zeros...] of shape [3, 4N+1, h, w]. + # VAE produces N+1 latents; latent[0] is the anchor encoding. + # Extension: pixels = zeros [3, 4N+1, h, w]. VAE produces N+1 latents, + # of which latent[0] is the special "1-frame init" encoding + # (biased differently than the regular 4-frame-group latents). + # Slice it off so the N conditioning slots are all regular — + # this drops a CONDITIONING slot, not an output latent. + if not extension: + F = (new_lat_f - 1) * 4 + 1 + pixels = torch.concat( + [ + torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), + torch.zeros(3, F - 1, h, w), + ], + dim=1, + ).to(self.device) + y = self.vae.encode([pixels])[0] + else: + pixels = torch.zeros(3, 4 * new_lat_f + 1, h, w, device=self.device) + y = self.vae.encode([pixels])[0][:, 1:] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync_model = getattr(self.model, "no_sync", noop_no_sync) + + # Initialize (fresh) or grow (extension) the KV cache. Cross-attn cache is + # left untouched on extension so text-context k/v computed on the first call + # are reused via crossattn_cache[i]["is_init"] == True. + model_args = self.model.config + transformer_dtype = self.target_dtype + frame_seqlen = int(noise.shape[-2] * noise.shape[-1] // 4) + extra_kv_size = frame_seqlen * new_lat_f + head_dim = model_args.dim // model_args.num_heads + local_num_heads = model_args.num_heads // self.sp_size + + if not extension: + self.state.create_kv_caches( + batch_size, + transformer_dtype, + self.device, + extra_kv_size, + model_args.num_layers, + local_num_heads, + head_dim, + ) + else: + self.state.extend_kv_caches(extra_kv_size) + + # Total cache size after this call, used both as the per-query attention + # window and as the absolute-token offset base for the chunk loop. + prev_lat_f = self.state.current_lat_f + total_kv_size = frame_seqlen * (prev_lat_f + new_lat_f) + start_token_offset = prev_lat_f * frame_seqlen + + # evaluation mode + with ( + torch.amp.autocast("cuda", dtype=self.target_dtype), + torch.no_grad(), + no_sync_model(), + ): + # sample videos + latent = noise + latents_chunk = latent.split(chunk_size, dim=1) # [c, f, h, w] + condition_chunk = y.split(chunk_size, dim=1) + c2ws_plucker_emb_chunk = c2ws_plucker_emb.split(chunk_size, dim=2) + num_inference_chunk = len(latents_chunk) + pred_latent_chunks = [] + for chunk_id in tqdm(range(num_inference_chunk)): + current_latent = latents_chunk[chunk_id] + current_condition = condition_chunk[chunk_id] + current_c2ws_plucker_emb = c2ws_plucker_emb_chunk[chunk_id] + + dit_cond_dict = { + "c2ws_plucker_emb": current_c2ws_plucker_emb.chunk(1, dim=0), + } + + kwargs = { + "context": [context[0]], + "seq_len": max_seq_len, + "y": [current_condition], + "dit_cond_dict": dit_cond_dict, + "kv_cache": self.state.get_kv_caches(), + "local_end_index": self.state.local_end_index, + "global_end_index": self.state.global_end_index, + "crossattn_cache": self.state.get_crossattn_caches(), + "current_start": start_token_offset + chunk_id * chunk_size * frame_seqlen, + "max_attention_size": total_kv_size, + } + + for timestep_idx in range(len(timesteps)): + latent_model_input = [current_latent.to(self.device)] + current_timestep = [timesteps[timestep_idx]] + + timestep = torch.stack(current_timestep).to(self.device) + + noise_pred = self.model(x=latent_model_input, t=timestep, **kwargs)[0] + + x0 = self._convert_flow_pred_to_x0( + flow_pred=noise_pred, + xt=current_latent, + timestep=current_timestep[0], + scheduler=self.scheduler, + ) + + if timestep_idx < len(timesteps) - 1: + next_timestep = timesteps[timestep_idx + 1] + current_latent = self.scheduler.add_noise( + x0, torch.randn(x0.shape, generator=seed_g, device=x0.device, dtype=x0.dtype), next_timestep + ) + else: + # note return x0 + break + + pred_latent_chunks.append(x0) + + # Update kv cache + context_timestep = [timesteps[-1] * 0.0] + timestep = torch.stack(context_timestep).to(self.device) + self.model(x=[x0], t=timestep, **kwargs) + + pred_latent_chunks = torch.cat(pred_latent_chunks, dim=1) + + if self.device.index == 0: + # Wan VAE decode() calls clear_cache() internally, so the very + # first latent always runs the i==0 path (no temporal upsample, + # single-frame output) and leaves feat_map polluted with that + # bias. The decoder's stacked temporal-causal layers also need + # ~2 latents of streaming context before deeper feat_map slots + # match a true mid-stream decode. On extension, prepend the + # prior chunk's last 2 latents so warmup_0 absorbs the i==0 + # bias and warmup_1 fully primes the cache. Then discard the + # 4*K - 3 leading pixels (re-decodes of already-shown frames). + if extension and self.state.last_decoded_latent is not None: + warmup = self.state.last_decoded_latent.to(pred_latent_chunks.device, pred_latent_chunks.dtype) + k = warmup.shape[1] + drop = 4 * k - 3 + to_decode = torch.cat([warmup, pred_latent_chunks], dim=1) + videos = self.vae.decode([to_decode]) + videos = [v[:, drop:] for v in videos] + else: + videos = self.vae.decode([pred_latent_chunks]) + + self.state.last_decoded_latent = pred_latent_chunks[:, -2:].detach().clone() + + if dist.is_initialized(): + dist.barrier() + + if not extension: + self.state.h = h + self.state.w = w + self.state.lat_h = lat_h + self.state.lat_w = lat_w + self.state.frame_seqlen = frame_seqlen + self.state.advance(new_lat_f) + + return DiffusionOutput(output=videos[0]) + + # ------------------------------------------------------------------ + # micro-step execution + # ------------------------------------------------------------------ + + def predict_noise( + self, + intermediate_tensors: IntermediateTensors | None = None, + **kwargs: Any, + ) -> torch.Tensor | IntermediateTensors: + """Single transformer forward; returns IntermediateTensors on non-last PP stages.""" + with torch.amp.autocast("cuda", dtype=self.target_dtype): + result = self.model(**kwargs, intermediate_tensors=intermediate_tensors) + if isinstance(result, IntermediateTensors): + return result + # Last stage returns List[Tensor] (one per row); stack along dim 0. + return torch.stack(result, dim=0) + + def prepare_encode( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> DiffusionRequestState: + """One-time request setup mirroring forward()'s prep up to the chunk loop. + + Stashes per-chunk noise / conditioning / Plucker tensors in state.extra, + initializes (or extends) the model's persistent KV caches sized to this + rank's owned layer slice, and exposes state.timesteps as length 5 + (4 denoise + 1 t=0 KV-update). + """ + if not state.prompts or len(state.prompts) > 1: + raise ValueError("LingbotWorldFastPipeline only supports a single prompt.") + + sampling = state.sampling + if sampling.chunk_frames != 3*4: + logger.warning( + "LingbotWorldFastPipeline requires chunk_frames=3*4=12, got %s. Overriding.", + sampling.chunk_frames, + ) + sampling.chunk_frames = 12 + if sampling.num_inference_steps != 5: + logger.warning( + "LingbotWorldFastPipeline requires num_inference_steps=4+1=5, got %s. Overriding.", + sampling.num_inference_steps, + ) + sampling.num_inference_steps = 5 + + prompt = state.prompts[0].get("prompt") + multi_modal_data = state.prompts[0].get("multi_modal_data", {}) or {} + + extra_args = state.sampling.extra_args or {} + session_id = str(extra_args.get("session_id") or None) + force_reset = bool(extra_args.get("force_reset") or False) + + if force_reset or self.state.session_id is None or self.state.session_id != session_id: + self.state.reset() + self.state.session_id = session_id + extension = False + else: + extension = True + + camera = multi_modal_data.get("camera", None) + if camera is None: + raise ValueError("LingbotWorldFastPipeline requires camera poses in multi_modal_data['camera'].") + + if extension: + assert multi_modal_data.get("image") is None, ( + "image must not be provided on extension calls; it is only used on the first call of a session" + ) + assert self.model.config.local_attn_size == -1, ( + "video extension requires the model to be configured with local_attn_size == -1" + ) + + batch_size = 1 + c2ws = camera.get("poses") + chunk_size = CONFIG["chunk_size"] + max_area = CONFIG["max_area"] + + new_lat_f = max(sampling.num_chunks * chunk_size, 1) + if extension: + num_frames = new_lat_f * 4 + else: + num_frames = (new_lat_f - 1) * 4 + 1 + if len(c2ws) < num_frames: + raise ValueError( + f"camera trajectory has {len(c2ws)} poses; need >= {num_frames} " + f"for {sampling.num_chunks} chunks (chunk_size={chunk_size})." + ) + c2ws = c2ws[:num_frames] + + if not extension: + img = multi_modal_data.get("image") + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1] + ) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2] + ) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + else: + img = None + h, w, lat_h, lat_w = self.state.h, self.state.w, self.state.lat_h, self.state.lat_w + + max_seq_len = chunk_size * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = state.sampling.seed + if seed is None: + seed = random.randint(0, sys.maxsize) + # Two separate generators to keep noise consistent across PP ranks: + # - seed_g: chunk-initial noise. + # - seed_g_addnoise: scheduler.add_noise consumed on last rank only. + seed_g = torch.Generator(device=self.device).manual_seed(seed) + seed_g_addnoise = torch.Generator(device=self.device).manual_seed(seed + 1) + + # Sampler timesteps (4 denoise + 1 t=0 KV-update) + self.scheduler.set_timesteps(self.num_train_timesteps, shift=CONFIG["sample_shift"]) + denoise_timesteps = self.scheduler.timesteps[CONFIG["timesteps_index"]].to(self.device) + timesteps5 = torch.cat([denoise_timesteps, denoise_timesteps.new_zeros(1)], dim=0) + + # Text + camera Plucker + context_list = self.text_encoder([prompt], self.device) + + Ks_raw = torch.from_numpy(camera.get("intrinsics")) + Ks_t = get_Ks_transformed( + Ks_raw, height_org=480, width_org=832, height_resize=h, width_resize=w, height_final=h, width_final=w + )[0] + len_c2ws_orig = len(c2ws) + tgt_indices_full = np.linspace(0, len_c2ws_orig - 1, new_lat_f) + c2ws_infer_full = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws_orig - 1, len_c2ws_orig), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=tgt_indices_full, + ) + c2ws_infer_full = compute_relative_poses(c2ws_infer_full, framewise=True) + c2ws_infer_full = c2ws_infer_full.to(self.device).to(torch.float32) + Ks_t = Ks_t.to(self.device).to(torch.float32) + + anchor_latent: torch.Tensor | None = None + if is_pipeline_first_stage(): + self.vae.reset() + if not extension: + anchor_pixels = ( + torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic") + .transpose(0, 1) + .to(self.device) + ) + anchor_latent = self.vae.init(anchor_pixels) # [16, 1, lat_h, lat_w] + else: + zero_frame = torch.zeros(3, 1, h, w, device=self.device) + self.vae.init(zero_frame) + + # KV cache sizing — per this rank's owned layer slice. + model_args = self.model.config + transformer_dtype = self.target_dtype + frame_seqlen = int(lat_h * lat_w // 4) + extra_kv_size = frame_seqlen * new_lat_f + head_dim = model_args.dim // model_args.num_heads + local_num_heads = model_args.num_heads // self.sp_size + owned_num_layers = self.model.end_layer - self.model.start_layer + + if not extension: + self.state.create_kv_caches( + batch_size, + transformer_dtype, + self.device, + extra_kv_size, + owned_num_layers, + local_num_heads, + head_dim, + ) + else: + self.state.extend_kv_caches(extra_kv_size) + + prev_lat_f = self.state.current_lat_f + total_kv_size = frame_seqlen * (prev_lat_f + new_lat_f) + start_token_offset = prev_lat_f * frame_seqlen + + # State population. + state.prompt_embeds = None # unused; lingbot keeps text as raw list[Tensor] + state.latents = None # per-chunk latents are stacked by encode_chunk_inputs + state.timesteps = timesteps5 + state.step_index = 0 + state.scheduler = LingbotFlowScheduler(self.scheduler, timesteps5) + state.do_true_cfg = False + + state.extra["context"] = context_list + state.extra["anchor_latent"] = anchor_latent + state.extra["start_token_offset"] = start_token_offset + state.extra["max_attention_size"] = total_kv_size + state.extra["frame_seqlen"] = frame_seqlen + state.extra["max_seq_len"] = max_seq_len + state.extra["chunk_size"] = chunk_size + state.extra["lat_h"] = lat_h + state.extra["lat_w"] = lat_w + state.extra["h"] = h + state.extra["w"] = w + state.extra["new_lat_f"] = new_lat_f + state.extra["extension"] = extension + state.extra["seed_g"] = seed_g + state.extra["seed_g_addnoise"] = seed_g_addnoise + + state.extra["c2ws_infer_full"] = c2ws_infer_full + state.extra["Ks_transformed"] = Ks_t + + return state + + def encode_chunk_inputs( + self, + state: DiffusionRequestState, + new_idxs: list[int], + ) -> torch.Tensor: + """Build per-chunk noise, plus VAE-encoded y and Plucker on first stage.""" + seed_g = state.extra["seed_g"] + chunk_size = state.extra["chunk_size"] + lat_h = state.extra["lat_h"] + lat_w = state.extra["lat_w"] + h = state.extra["h"] + w = state.extra["w"] + chunks = state.extra["chunks"] + B = len(new_idxs) + + # noise + noise = torch.randn( + B, 16, chunk_size, lat_h, lat_w, + dtype=torch.float32, generator=seed_g, device=self.device, + ) + + if not is_pipeline_first_stage(): + return noise + + c2ws_infer_full = state.extra["c2ws_infer_full"] + Ks_t = state.extra["Ks_transformed"] + anchor_latent: torch.Tensor | None = state.extra["anchor_latent"] + extension: bool = state.extra["extension"] + + # per-chunk stream-encode y + per-chunk msk + for idx in new_idxs: + is_anchor_chunk = (not extension) and idx == 0 + if is_anchor_chunk: + tail_frames = 4 * (chunk_size - 1) + if tail_frames > 0: + zeros = torch.zeros(3, tail_frames, h, w, device=self.device) + tail_lat = self.vae.encode(zeros) + assert anchor_latent is not None + vae_lat = torch.cat([anchor_latent, tail_lat], dim=1) + else: + assert anchor_latent is not None + vae_lat = anchor_latent + else: + zeros = torch.zeros(3, 4 * chunk_size, h, w, device=self.device) + vae_lat = self.vae.encode(zeros) + + msk_chunk = torch.zeros(4, chunk_size, lat_h, lat_w, device=self.device) + if is_anchor_chunk: + msk_chunk[:, 0] = 1 + chunks[idx].extra["y"] = torch.cat([msk_chunk, vae_lat], dim=0) + + # plucker + frame_indices = torch.tensor( + [ci * chunk_size + f for ci in new_idxs for f in range(chunk_size)], + device=c2ws_infer_full.device, dtype=torch.long, + ) + batched_c2ws = c2ws_infer_full[frame_indices] # [B*chunk_size, 3, 4] + batched_Ks = Ks_t.repeat(B * chunk_size, 1) # [B*chunk_size, 4] + batched_plucker = get_plucker_embeddings(batched_c2ws, batched_Ks, h, w, only_rays_d=False) + batched_plucker = rearrange( + batched_plucker, + "f (h c1) (w c2) c -> f h w (c c1 c2)", + c1=int(h // lat_h), + c2=int(w // lat_w), + ) + batched_plucker = batched_plucker.view(B, chunk_size, lat_h, lat_w, -1) + batched_plucker = batched_plucker.permute(0, 4, 1, 2, 3).contiguous().to(self.target_dtype) + + for i, idx in enumerate(new_idxs): + chunks[idx].extra["plucker"] = batched_plucker[i : i + 1] + + return noise + + def set_pp_recv_dict_buffers(self, state: DiffusionRequestState) -> None: + if get_pipeline_parallel_world_size() == 1: + return + + pp_group = get_pp_group() + slo_fps = getattr(state.sampling, "slo_fps", None) + slo_max_batch = getattr(state.sampling, "slo_max_batch", 1) + slo_max_batch = max(1, slo_max_batch if slo_fps else 1) + + chunk_size = state.extra["chunk_size"] + lat_h = state.extra["lat_h"] + lat_w = state.extra["lat_w"] + max_seq_len = state.extra["max_seq_len"] + n_steps = int(state.timesteps.shape[0]) + + latents_dtype = torch.float32 + it_dtype = self.target_dtype + + for batch_size in range(1, slo_max_batch * n_steps + 1): + latents_template = { + "latents": torch.empty( + batch_size, 16, chunk_size, lat_h, lat_w, dtype=latents_dtype, device="meta" + ) + } + it_template = { + "hidden_states": torch.empty( + batch_size, max_seq_len, self.model.dim, dtype=it_dtype, device="meta" + ), + "grid_sizes": torch.empty(batch_size, 3, dtype=torch.long, device="meta"), + "seq_lens": torch.empty(batch_size, dtype=torch.long, device="meta"), + "c2ws_plucker_emb": torch.empty( + batch_size, max_seq_len, self.model.dim, dtype=it_dtype, device="meta" + ), + } + pp_group.set_recv_dict_buffer("latents", -1, latents_template, batch_size=batch_size) + pp_group.set_recv_dict_buffer("intermediate", 0, it_template, batch_size=batch_size) + + def denoise_step( + self, + state: DiffusionRequestState, + batch_size: int = 1, + **kwargs: Any, + ) -> torch.Tensor | None: + """Fused transformer forward for the batch of chunks. + + Each row's per-chunk metadata (``current_starts``, ``y``, ``c2ws_plucker_emb``) + is read from state.extra keyed by chunk index. Rows whose timestep is 0 + carry the KV-update payload (the chunk's saved x0) — their output is + ignored by ``step_scheduler``. + """ + chunk_idxs: list[int] = state.extra["current_chunk_idxs"] + assert len(chunk_idxs) == batch_size + + chunk_size = state.extra["chunk_size"] + frame_seqlen = state.extra["frame_seqlen"] + start_token_offset = state.extra["start_token_offset"] + chunks = state.extra["chunks"] + context_list = state.extra["context"] + + x_list, y_list, plucker_list = None, None, None + if is_pipeline_first_stage(): + x_list = [state.latents[i] for i in range(batch_size)] + y_list = [chunks[ci].extra["y"] for ci in chunk_idxs] + plucker_list = [chunks[ci].extra["plucker"] for ci in chunk_idxs] + + current_starts = [ + start_token_offset + ci * chunk_size * frame_seqlen for ci in chunk_idxs + ] + + positive_kwargs = { + "x": x_list, + "t": state.batched_timesteps, + "context": [context_list[0]] * batch_size, + "seq_len": state.extra["max_seq_len"], + "y": y_list, + "dit_cond_dict": {"c2ws_plucker_emb": plucker_list}, + "kv_cache": self.state.get_kv_caches(), + "local_end_index": self.state.local_end_index, + "global_end_index": self.state.global_end_index, + "crossattn_cache": self.state.get_crossattn_caches(), + "current_starts": current_starts, + "max_attention_size": state.extra["max_attention_size"], + } + + preposted_its = state.extra.pop("preposted_its", None) + return self.predict_noise_maybe_with_cfg( + do_true_cfg=False, + true_cfg_scale=1.0, + positive_kwargs=positive_kwargs, + negative_kwargs=None, + buf_idx=state.step_index % 2, + batch_size=batch_size, + preposted_its=preposted_its, + ) + + def step_scheduler( + self, + state: DiffusionRequestState, + noise_pred: torch.Tensor, + *, + per_request_scheduler: Any | list[Any] | None = None, + batch_size: int = 1, + **kwargs: Any, + ) -> None: + if per_request_scheduler is None: + per_request_scheduler = state.scheduler + + state.latents = self.scheduler_step_maybe_with_cfg( + noise_pred, + state.batched_timesteps, + state.latents, + do_true_cfg=False, + per_request_scheduler=per_request_scheduler, + generator=state.extra["seed_g_addnoise"], + batch_size=batch_size, + receive_latents=False, + ) + state.step_index += 1 + + def prefetch_tensors( + self, + state: DiffusionRequestState, + batch_size: int = 1, + **kwargs: Any, + ) -> None: + if get_pipeline_parallel_world_size() == 1: + return + buf_idx = state.step_index % 2 + preposted = self.prefetch_tensors_maybe_with_cfg( + do_true_cfg=False, buf_idx=buf_idx, batch_size=batch_size, + ) + if isinstance(preposted, AsyncLatents): + state.latents = preposted + elif preposted is not None: + state.extra["preposted_its"] = preposted + + def post_decode( + self, + state: DiffusionRequestState, + **kwargs: Any, + ) -> DiffusionOutput: + """VAE-decode the finished chunks with prior-tail warmup. + + Mirrors forward()'s decode block: on extension calls prepend the prior + chunk's last 2 latents to prime the temporal-causal feat_map, then drop + ``4*k - 3`` leading pixels. After decoding, refresh ``last_decoded_latent`` + with the tail of the new latents so the next call's decode is warm. + """ + self._sync_pp_send() + pred_latent_chunks = state.latents.transpose(0, 1).reshape( + state.latents.shape[1], + state.latents.shape[0] * state.latents.shape[2], + state.latents.shape[3], + state.latents.shape[4], + ) + # pred_latent_chunks: [16, B*chunk_size, lat_h, lat_w] + + extension = state.extra["extension"] + if self.state.last_decoded_latent is not None: + warmup = self.state.last_decoded_latent.to( + pred_latent_chunks.device, pred_latent_chunks.dtype + ) + k = warmup.shape[1] + drop = 4 * k - 3 + to_decode = torch.cat([warmup, pred_latent_chunks], dim=1) + videos = self.vae.decode([to_decode]) + videos = [v[:, drop:] for v in videos] + else: + videos = self.vae.decode([pred_latent_chunks]) + + self.state.last_decoded_latent = pred_latent_chunks[:, -2:].detach().clone() + + sampling = state.sampling + chunks_so_far = state.extra.get("chunks_decoded", 0) + chunks_this_call = state.latents.shape[0] + is_final = chunks_so_far + chunks_this_call >= sampling.num_chunks + if is_final: + if not extension: + self.state.h = state.extra["h"] + self.state.w = state.extra["w"] + self.state.lat_h = state.extra["lat_h"] + self.state.lat_w = state.extra["lat_w"] + self.state.frame_seqlen = state.extra["frame_seqlen"] + self.state.advance(state.extra["new_lat_f"]) + + return DiffusionOutput(output=videos[0]) + + def load_weights(self, weights): + pass diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py new file mode 100644 index 00000000000..35e00fc03fc --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/state_lingbot_world_fast.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Lingbot World Fast pipeline persistent state.""" + +from __future__ import annotations + +import logging +from enum import IntEnum + +import torch + +logger = logging.getLogger(__name__) + + +class CacheIndex(IntEnum): + K = 0 + V = 1 + + +class LingbotWorldFastState: + """Pipeline persistent state across forward() calls. + + Lifecycle: + - Created once in LingbotWorldFastPipeline.__init__() + - Mutated every forward() call (frame append, KV cache grow) + - reset() on new session / local_attn_size exceeded + """ + + def __init__(self) -> None: + self.is_initialized = False + self.reset() + + # ------------------------------------------------------------------ + # Reset / should_reset + # ------------------------------------------------------------------ + + def reset(self) -> None: + """Clear all state.""" + + if self.is_initialized: + for cache in self.kv_cache: + del cache + for cache in self.crossattn_cache: + if isinstance(cache["k"], torch.Tensor): + del cache["k"] + del cache["v"] + + self.kv_cache: list[torch.Tensor] | None = None + self.crossattn_cache: list[dict[str, bool | torch.Tensor | None]] | None = None + self.current_start_frame: int = 0 + self.local_end_index: list[torch.Tensor] | None = None + self.global_end_index: list[torch.Tensor] | None = None + + self.is_initialized: bool = False + self.current_lat_f: int = 0 + self.session_id: str | None = None + + self.batch_size: int | None = None + self.num_layers: int | None = None + self.num_heads: int | None = None + self.head_dim: int | None = None + + # Shape constants captured on the first call of a session and reused + # on extension calls, where multi_modal_data["image"] is absent. + self.h: int | None = None + self.w: int | None = None + self.lat_h: int | None = None + self.lat_w: int | None = None + self.frame_seqlen: int | None = None + + # Last few latents emitted by the diffusion loop on the previous call. + # Prepended to pred_latent_chunks on extension so the Wan VAE decoder's + # stacked temporal feat_maps are fully warmed before the first NEW + # latent is decoded. The decoder's temporal receptive field spans + # ~2 latents, so we cache the last 2. + self.last_decoded_latent: torch.Tensor | None = None + + # ------------------------------------------------------------------ + # KV cache management + # ------------------------------------------------------------------ + + def create_kv_caches( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + kv_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + ) -> None: + self.batch_size = batch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + + """Initialize empty KV caches and cross-attention caches.""" + self.kv_cache = [ + torch.zeros(2, batch_size, kv_size, num_heads, head_dim, dtype=dtype, device=device) + for _ in range(num_layers) + ] + + self.local_end_index = [torch.tensor([0], dtype=torch.long, device=device) for _ in range(num_layers)] + self.global_end_index = [torch.tensor([0], dtype=torch.long, device=device) for _ in range(num_layers)] + + self.crossattn_cache = [{"is_init": False, "k": None, "v": None} for _ in range(num_layers)] + + self.is_initialized = True + + def extend_kv_caches(self, extra_kv_size: int): + assert self.is_initialized, "Cannot extend uninitialized kv cache" + + dtype = self.kv_cache[0].dtype + device = self.kv_cache[0].device + + self.kv_cache = [ + torch.cat( + [ + self.kv_cache[i], + torch.zeros( + 2, self.batch_size, extra_kv_size, self.num_heads, self.head_dim, dtype=dtype, device=device + ), + ], + dim=2, + ) + for i in range(self.num_layers) + ] + + def update_kv_cache( + self, + layer_index: int, + updated_kv: torch.Tensor, + is_negative: bool = False, + ) -> None: + """Update a single layer's KV cache after prefill.""" + cache = self.kv_cache_neg if is_negative else self.kv_cache + assert cache is not None, "KV caches not initialized, call create_kv_caches first" + cache[layer_index] = updated_kv.clone() + + def get_kv_caches(self) -> list[torch.Tensor]: + """Get KV caches for the specified branch.""" + assert self.kv_cache is not None, "KV caches not initialized" + return self.kv_cache + + def get_crossattn_caches(self, is_negative: bool = False) -> list[dict[str, bool | torch.Tensor | None]]: + """Get cross-attention caches for the specified branch.""" + assert self.crossattn_cache is not None, "Cross-attn caches not initialized" + return self.crossattn_cache + + def advance(self, delta: int): + self.current_lat_f += delta diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py b/vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py new file mode 100644 index 00000000000..8c6c7557a06 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/stream_vae.py @@ -0,0 +1,82 @@ +"""Per-chunk streaming VAE encode wrapper around ``Wan2_1_VAE``.""" + +from __future__ import annotations + +import torch +import torch.cuda.amp as amp + +from .vae2_1 import Wan2_1_VAE + + +class StreamVAE: + def __init__(self, vae: Wan2_1_VAE) -> None: + self._vae = vae + self._model = vae.model + self._scale = vae.scale + self._dtype = vae.dtype + + def reset(self) -> None: + """Clear encoder feat_map cache. Call at the start of each new request.""" + self._model.clear_cache() + + def decode(self, zs): + return self._vae.decode(zs) + + @property + def dtype(self) -> torch.dtype: + return self._vae.dtype + + @torch.no_grad() + def init(self, frame: torch.Tensor) -> torch.Tensor: + """Encode the single init pixel frame and return its latent. + + Caller keeps the latent for fresh starts (anchor encoding) or + discards it for extension starts (just sets up the init bias). + + Args: + frame: ``[C, 1, H, W]`` or ``[B, C, 1, H, W]``. + Returns: + ``[z_dim, 1, H, W]`` latent. + """ + with amp.autocast(dtype=self._dtype): + pixels = frame.unsqueeze(0) if frame.dim() == 4 else frame + out = self._encode_group(pixels) + mu = self._apply_conv1_and_normalize(out) + return mu.float().squeeze(0) + + @torch.no_grad() + def encode(self, pixels: torch.Tensor) -> torch.Tensor: + """Encode ``4*N`` pixel frames using the preserved state. + + Args: + pixels: ``[C, 4N, H, W]`` or ``[B, C, 4N, H, W]``. + Returns: + ``[z_dim, N, H, W]`` latents. + """ + pixels = pixels.unsqueeze(0) if pixels.dim() == 4 else pixels + T = pixels.shape[2] + assert T % 4 == 0, f"StreamVAE.encode expects a multiple of 4 pixel frames, got {T}" + N = T // 4 + with amp.autocast(dtype=self._dtype): + outs = [self._encode_group(pixels[:, :, i * 4 : (i + 1) * 4]) for i in range(N)] + out = torch.cat(outs, dim=2) + mu = self._apply_conv1_and_normalize(out) + return mu.float().squeeze(0) + + # ── internals ────────────────────────────────────────────────────────── + + def _encode_group(self, pixels: torch.Tensor) -> torch.Tensor: + """One ``(1|4)``-frame encoder pass using the live ``_enc_feat_map``.""" + self._model._enc_conv_idx = [0] + return self._model.encoder( + pixels, feat_cache=self._model._enc_feat_map, feat_idx=self._model._enc_conv_idx + ) + + def _apply_conv1_and_normalize(self, out: torch.Tensor) -> torch.Tensor: + mu, _ = self._model.conv1(out).chunk(2, dim=1) + z = self._model.z_dim + if isinstance(self._scale[0], torch.Tensor): + mu = (mu - self._scale[0].view(1, z, 1, 1, 1)) * self._scale[1].view(1, z, 1, 1, 1) + else: + mu = (mu - self._scale[0]) * self._scale[1] + return mu \ No newline at end of file diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/t5.py b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py new file mode 100644 index 00000000000..0863d121b4b --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/t5.py @@ -0,0 +1,451 @@ +# Adapted from Lingbot-World/wan/modules/t5.py +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + def __init__(self, dim, dim_ffn, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) + + def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super().__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = ( + max_exact + + ( + torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) + ).long() + ) + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super().__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + def __init__( + self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1, + ): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout + ) + self.decoder = T5Decoder( + self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout + ) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5( + name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device="cpu", + **kwargs, +): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.accelerator.current_device_index(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = ( + umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False) + ) + logging.info(f"loading {checkpoint_path}") + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def __call__(self, texts, device): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py b/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py new file mode 100644 index 00000000000..e939b191c12 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/tokenizers.py @@ -0,0 +1,78 @@ +# Adapted from Lingbot-World/wan/modules/tokenizers.py +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ["HuggingfaceTokenizer"] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + + # arguments + _kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py b/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py new file mode 100644 index 00000000000..6017abe3476 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/vae2_1.py @@ -0,0 +1,610 @@ +# Adapted from Lingbot-World/wan/modules/vae2_1.py +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = ["Wan2_1_VAE"] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolution. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +class Encoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1) + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temporal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temporal_upsample = temporal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = "upsample3d" if temporal_upsample[i] else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temporal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + logging.info(f"loading {pretrained_path}") + model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class Wan2_1_VAE: + def __init__(self, z_dim=16, vae_pth="cache/vae_step_411000.pth", dtype=torch.float, device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs] diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py new file mode 100644 index 00000000000..89b05b65fe9 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_fast.py @@ -0,0 +1,753 @@ +"""Some of the functions are borrowed from SelfForcing (https://github.com/guandeh17/Self-Forcing).""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as torch_F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange +from vllm.model_executor.models.utils import PPMissingLayer +from vllm.sequence import IntermediateTensors + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.distributed.parallel_state import ( + get_pipeline_parallel_rank, + get_pipeline_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) + +from .state_lingbot_world_fast import CacheIndex +from .wan_model import WanLayerNorm, WanRMSNorm, WanSelfAttention, rope_params, sinusoidal_embedding_1d + + +def causal_rope_apply(x, grid_sizes, freqs, start_frames=0): + """Apply causal rotary position embedding per batch row. + + start_frames: int or list[int] of per-row frame offsets. + An int broadcasts to all rows. + """ + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + if isinstance(start_frames, int): + start_frames = [start_frames] * grid_sizes.shape[0] + + # loop over samples + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + sf = start_frames[i] + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][sf : sf + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).type_as(x) + + +class CausalWanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + kv_cache=None, + local_end_index=None, + global_end_index=None, + current_starts=0, + max_attention_size=1_000_000, + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + current_starts(int | list[int]): per-row absolute token offset; an int + broadcasts to all rows. + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + if isinstance(current_starts, int): + current_starts = [current_starts] * b + assert len(current_starts) == b + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + frame_seqlen = math.prod(grid_sizes[0][1:]).item() + start_frames = [cs // frame_seqlen for cs in current_starts] + roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frames=start_frames).type_as(v) + roped_key = causal_rope_apply(k, grid_sizes, freqs, start_frames=start_frames).type_as(v) + num_new_tokens = roped_query.shape[1] + + if self.local_attn_size != -1: + # Cache-rolling path only supports single-row processing. + assert b == 1, "local_attn_size != -1 requires batch_size=1" + current_start = current_starts[0] + current_end = current_start + num_new_tokens + sink_tokens = self.sink_size * frame_seqlen + kv_cache_size = kv_cache[CacheIndex.K].shape[1] + + if (current_end > global_end_index.item()) and ( + num_new_tokens + local_end_index.item() > kv_cache_size + ): + num_evicted_tokens = num_new_tokens + local_end_index.item() - kv_cache_size + num_rolled_tokens = local_end_index.item() - num_evicted_tokens - sink_tokens + kv_cache[CacheIndex.K][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.K][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + kv_cache[CacheIndex.V][:, sink_tokens : sink_tokens + num_rolled_tokens] = kv_cache[CacheIndex.V][ + :, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens + ].clone() + new_local_end_index = ( + local_end_index.item() + current_end - global_end_index.item() - num_evicted_tokens + ) + else: + new_local_end_index = local_end_index.item() + current_end - global_end_index.item() + + local_start_index = new_local_end_index - num_new_tokens + kv_cache[CacheIndex.K][:, local_start_index:new_local_end_index] = roped_key + kv_cache[CacheIndex.V][:, local_start_index:new_local_end_index] = v + + k_cache = kv_cache[CacheIndex.K][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + v_cache = kv_cache[CacheIndex.V][:, max(0, new_local_end_index - max_attention_size) : new_local_end_index] + out = self.attn(roped_query, k_cache, v_cache) + + global_end_index.fill_(current_end) + local_end_index.fill_(new_local_end_index) + else: + # local_attn_size == -1: per-row writes to non-overlapping cache slots, + # per-row attention reads sized by max_attention_size. Loops once per + # batch row inside attention to avoid needing a key-padding mask. + outs = [] + max_end = 0 + for i in range(b): + cs_i = current_starts[i] + ce_i = cs_i + num_new_tokens + kv_cache[CacheIndex.K][:, cs_i:ce_i] = roped_key[i : i + 1] + kv_cache[CacheIndex.V][:, cs_i:ce_i] = v[i : i + 1] + + kv_start_i = max(0, ce_i - max_attention_size) + k_cache_i = kv_cache[CacheIndex.K][:, kv_start_i:ce_i] + v_cache_i = kv_cache[CacheIndex.V][:, kv_start_i:ce_i] + + outs.append(self.attn(roped_query[i : i + 1], k_cache_i, v_cache_i)) + if ce_i > max_end: + max_end = ce_i + + out = torch.cat(outs, dim=0) + global_end_index.fill_(max_end) + local_end_index.fill_(max_end) + + # output + out = out.flatten(2) + out = self.o(out) + return out + + +class WanCrossAttention(WanSelfAttention): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward(self, x, context, context_lens, crossattn_cache=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + + if crossattn_cache is not None: + if not crossattn_cache.get("is_init", False): + crossattn_cache["is_init"] = True + # Cache at B=1 (text context is shared across chunks in a batch); + # expand on retrieval to match q's batch size for variable-B calls. + k = self.norm_k(self.k(context[:1])).view(1, -1, n, d) + v = self.v(context[:1]).view(1, -1, n, d) + crossattn_cache[CacheIndex.K] = k + crossattn_cache[CacheIndex.V] = v + else: + k = crossattn_cache[CacheIndex.K] + v = crossattn_cache[CacheIndex.V] + if k.shape[0] != b: + k = k.expand(b, *k.shape[1:]) + v = v.expand(b, *v.shape[1:]) + else: + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class CausalWanAttentionBlock(nn.Module): + def __init__( + self, dim, ffn_dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, cross_attn_norm=False, eps=1e-6 + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.local_attn_size = local_attn_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = CausalWanSelfAttention( + dim=dim, num_heads=num_heads, local_attn_size=local_attn_size, sink_size=sink_size, qk_norm=qk_norm, eps=eps + ) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + self.cam_injector_layer1 = nn.Linear(dim, dim) + self.cam_injector_layer2 = nn.Linear(dim, dim) + self.cam_scale_layer = nn.Linear(dim, dim) + self.cam_shift_layer = nn.Linear(dim, dim) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + dit_cond_dict=None, + kv_cache=None, + local_end_index=None, + global_end_index=None, + crossattn_cache=None, + current_starts=0, + max_attention_size=1_000_000, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, F, 6, C] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + current_starts(int | list[int]): per-row absolute token offset; int broadcasts. + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + assert e[0].dtype == torch.float32 + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), + seq_lens, + grid_sizes, + freqs, + kv_cache, + local_end_index, + global_end_index, + current_starts, + max_attention_size, + ) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].squeeze(2) + + # cam injection (only if dit_cond_dict is provided and contains c2ws_plucker_emb) + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_hidden_states = self.cam_injector_layer2(torch_F.silu(self.cam_injector_layer1(c2ws_plucker_emb))) + c2ws_hidden_states = c2ws_hidden_states + c2ws_plucker_emb + cam_scale = self.cam_scale_layer(c2ws_hidden_states) + cam_shift = self.cam_shift_layer(c2ws_hidden_states) + x = (1.0 + cam_scale) * x + cam_shift + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None): + x = x + self.cross_attn(self.norm3(x), context, context_lens, crossattn_cache=crossattn_cache) + y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].squeeze(2) + return x + + x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache) + return x + + +class CausalHead(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)) + return x + + +class WanModelFast(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim"] + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + model_type="t2v", + control_type="cam", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + local_attn_size=-1, + sink_size=0, + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + control_type (`str`, *optional*, defaults to 'cam'): + Type of conditioning control signal - 'cam' (6-dim camera Plucker + embeddings) or 'act' (7-dim action embeddings including WASD movement) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + local_attn_size (`int`, *optional*, defaults to -1): + Window size for temporal local attention (-1 indicates global attention) + sink_size (`int`, *optional*, defaults to 0): + Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.local_attn_size = local_attn_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + if control_type == "cam": + control_dim = 6 + elif control_type == "act": + control_dim = 7 + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + + self.patch_embedding_wancamctrl = nn.Linear( + control_dim * 64 * patch_size[0] * patch_size[1] * patch_size[2], dim + ) + self.c2ws_hidden_states_layer1 = nn.Linear(dim, dim) + self.c2ws_hidden_states_layer2 = nn.Linear(dim, dim) + + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList( + [ + CausalWanAttentionBlock( + dim, ffn_dim, num_heads, local_attn_size, sink_size, qk_norm, cross_attn_norm, eps + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = CausalHead(dim, out_dim, patch_size, eps) + + # PP layout — defaults to single-stage; apply_pp_split() refines after loading. + self.start_layer = 0 + self.end_layer = num_layers + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) + + # initialize weights + self.init_weights() + + def apply_pp_split(self) -> None: + """Partition the model across PP ranks. Called after weight loading. + + After this returns, blocks outside this rank's [start_layer, end_layer) + slice are replaced with PPMissingLayer(); embeddings/head are kept only + on the first/last stage. KV-cache sizing (in the pipeline state) reads + end_layer - start_layer to allocate just for the owned slice. + """ + pp_world = get_pipeline_parallel_world_size() + if pp_world <= 1: + self.start_layer = 0 + self.end_layer = self.num_layers + return + + rank = get_pipeline_parallel_rank() + per_rank = self.num_layers // pp_world + rem = self.num_layers % pp_world + # Even split: extra layers go to the first `rem` ranks. + self.start_layer = rank * per_rank + min(rank, rem) + self.end_layer = self.start_layer + per_rank + (1 if rank < rem else 0) + + for i in range(self.num_layers): + if not (self.start_layer <= i < self.end_layer): + self.blocks[i] = PPMissingLayer() + + if not is_pipeline_first_stage(): + self.patch_embedding = PPMissingLayer() + self.patch_embedding_wancamctrl = PPMissingLayer() + self.c2ws_hidden_states_layer1 = PPMissingLayer() + self.c2ws_hidden_states_layer2 = PPMissingLayer() + + if not is_pipeline_last_stage(): + self.head = PPMissingLayer() + + def forward( + self, + x, + t, + context, + seq_len, + y=None, + dit_cond_dict=None, + kv_cache=None, + local_end_index=None, + global_end_index=None, + crossattn_cache=None, + current_starts=0, + max_attention_size=1_000_000, + intermediate_tensors: IntermediateTensors | None = None, + ): + r""" + Run the diffusion model with kv caching. + + On the first PP stage, ``x``/``y``/``dit_cond_dict`` are consumed to build + the token sequence; non-first stages take ``hidden_states`` (and the + camera-conditioned ``c2ws_plucker_emb`` if used) from ``intermediate_tensors``. + Non-last stages return an ``IntermediateTensors`` carrying ``hidden_states`` + (plus ``c2ws_plucker_emb`` so downstream stages can do cam injection). + + Args: + current_starts (int | list[int]): per-row absolute token offset. + int broadcasts to all rows. + intermediate_tensors: per-stage hidden state from the previous PP rank. + + Returns: + list[Tensor] on last PP stage; IntermediateTensors elsewhere. + """ + + if self.model_type == "i2v" and is_pipeline_first_stage(): + assert y is not None + + # params + first_stage = is_pipeline_first_stage() + last_stage = is_pipeline_last_stage() + # `freqs` lives as a plain attribute (not a buffer) — move it to the + # device of the first parameter we can find on this stage. + first_param = next(self.parameters()) + device = first_param.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if first_stage: + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long, device=device) for u in x] + ) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device) + assert seq_lens.max() <= seq_len + x = torch.cat(x) + else: + assert intermediate_tensors is not None, "non-first PP stage requires intermediate_tensors" + x = intermediate_tensors["hidden_states"] + grid_sizes = intermediate_tensors["grid_sizes"] + seq_lens = intermediate_tensors["seq_lens"] + + B = x.shape[0] + s = x.shape[1] + + # Per-row time embeddings: same timestep replicated across this row's tokens. + with torch.amp.autocast("cuda", dtype=torch.float32): + if t.dim() == 1: + t_full = t.unsqueeze(1).expand(B, s).contiguous() + else: + t_full = t + bt, btn = t_full.shape + t_flat = t_full.flatten() + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t_flat).unflatten(0, (bt, btn)).float() + ) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context — text embedding runs on every stage (each block has cross-attn). + context_lens = None + context = self.text_embedding( + torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) + ) + + # cam Plucker — processed on first stage, then forwarded via intermediate_tensors + # so downstream stages re-use the same embedding for in-block cam injection. + if first_stage: + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + c2ws_plucker_emb = dit_cond_dict["c2ws_plucker_emb"] + c2ws_plucker_emb = [ + rearrange( + i, + "1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)", + c1=self.patch_size[0], + c2=self.patch_size[1], + c3=self.patch_size[2], + ) + for i in c2ws_plucker_emb + ] + c2ws_plucker_emb = torch.cat(c2ws_plucker_emb, dim=0) + + c2ws_plucker_emb = self.patch_embedding_wancamctrl(c2ws_plucker_emb) + c2ws_hidden_states = self.c2ws_hidden_states_layer2( + torch_F.silu(self.c2ws_hidden_states_layer1(c2ws_plucker_emb)) + ) + dit_cond_dict = dict(dit_cond_dict) + dit_cond_dict["c2ws_plucker_emb"] = c2ws_plucker_emb + c2ws_hidden_states + else: + if "c2ws_plucker_emb" in intermediate_tensors.tensors: + dit_cond_dict = {"c2ws_plucker_emb": intermediate_tensors["c2ws_plucker_emb"]} + else: + dit_cond_dict = None + + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dit_cond_dict=dit_cond_dict, + max_attention_size=max_attention_size, + ) + + # Iterate this rank's blocks. kv_cache / crossattn_cache / *_end_index are + # sized to (end_layer - start_layer) — index locally. + for local_idx, block in enumerate(self.blocks[self.start_layer : self.end_layer]): + kwargs.update( + { + "kv_cache": kv_cache[local_idx], + "crossattn_cache": crossattn_cache[local_idx], + "local_end_index": local_end_index[local_idx], + "global_end_index": global_end_index[local_idx], + "current_starts": current_starts, + } + ) + x = block(x, **kwargs) + + if not last_stage: + model_dtype = next(self.parameters()).dtype + it = { + "hidden_states": x.to(model_dtype), + "grid_sizes": grid_sizes, + "seq_lens": seq_lens, + } + if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict: + it["c2ws_plucker_emb"] = dit_cond_dict["c2ws_plucker_emb"].to(model_dtype) + return IntermediateTensors(it) + + # head + unpatchify only on the last PP stage + x = self.head(x, e) + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py b/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py new file mode 100644 index 00000000000..4a21edd4585 --- /dev/null +++ b/vllm_omni/diffusion/models/lingbot_world_fast/wan_model.py @@ -0,0 +1,95 @@ +# Adapted from Lingbot-World/wan/modules/model.py +# +# Only the building blocks used by wan_fast.py are kept here: norm layers, +# the self-attention __init__ shape (used as a base class for the local +# WanCrossAttention that overrides forward), and the rope / time-embedding +# helpers. The original `flash_attention`-based forward paths are not used by +# the Fast pipeline, so this file does not depend on wan.modules.attention. + +import torch +import torch.nn as nn + +__all__ = [ + "WanLayerNorm", + "WanRMSNorm", + "WanSelfAttention", + "rope_params", + "sinusoidal_embedding_1d", +] + + +def sinusoidal_embedding_1d(dim, position): + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + """Base class providing the Q/K/V/O linear layers and (optional) QK RMSNorm. + + `wan_fast.py` only consumes this class via inheritance and `super().__init__()` + — the attention forward path is always overridden — so this trimmed copy + intentionally omits the flash-attention-based forward. + """ + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py index 3efe564bc61..fbf32b5c9f1 100644 --- a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py +++ b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py @@ -738,4 +738,4 @@ def add_noise( return noisy_samples def __len__(self) -> int: - return self.config.num_train_timesteps + return self.config.num_train_timesteps \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py index 4059c4fb568..2b8c37b47d0 100644 --- a/vllm_omni/diffusion/models/wan2_2/__init__.py +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -53,4 +53,4 @@ "WanVACETransformer3DModel", ] -patch_wan_rms_norm() +patch_wan_rms_norm() \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 2d8c752a4eb..0ee9e20fc8b 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -996,4 +996,4 @@ class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__(od_config=od_config, prefix=prefix) - self.__init_dmd2__() + self.__init_dmd2__() \ No newline at end of file diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py index 25444044c2d..4a285b8b12f 100644 --- a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py +++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py @@ -144,4 +144,4 @@ def step( return WanEulerSchedulerOutput(prev_sample=prev_sample) def __len__(self) -> int: - return self.num_train_timesteps + return self.num_train_timesteps \ No newline at end of file diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index d8302c11501..e268b0c40d1 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +from typing import Any import torch.nn as nn from vllm.logger import init_logger @@ -261,6 +262,11 @@ "pipeline_diffusers_adapter", "DiffusersAdapterPipeline", ), + "LingbotWorldFastPipeline": ( + "lingbot_world_fast", + "pipeline_lingbot_world_fast", + "LingbotWorldFastPipeline", + ), "HiDreamImagePipeline": ( "hidream_image", "pipeline_hidream_image", @@ -487,6 +493,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "OmniVoicePipeline": "get_omnivoice_post_process_func", "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func", "SenseNovaU1Pipeline": "get_sensenova_u1_post_process_func", + "LingbotWorldFastPipeline": "get_lingbot_world_fast_post_process_func", "HiDreamImagePipeline": "get_hidream_image_post_process_func", } @@ -590,3 +597,26 @@ def get_diffusion_pre_process_func(od_config: OmniDiffusionConfig): return None # Return None if no pre-processing function is registered (for backward compatibility) func_name = _DIFFUSION_PRE_PROCESS_FUNCS[od_config.model_class_name] return _load_process_func(od_config, func_name) + + +_STREAM_BATCH_OVERRIDE_ATTRS: dict[str, str] = { + "chunk_frames": "STREAM_BATCH_CHUNK_FRAMES", + "num_inference_steps": "STREAM_BATCH_NUM_INFERENCE_STEPS", +} + +def apply_required_sampling_overrides(sampling: Any, model_class_name: str) -> None: + """Overwrite sampling-param fields that the model has hard requirements on.""" + pipeline_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if pipeline_cls is None: + return + for field, attr in _STREAM_BATCH_OVERRIDE_ATTRS.items(): + required = getattr(pipeline_cls, attr, None) + if required is None: + continue + current = getattr(sampling, field, None) + if current != required: + logger.warning( + "%s requires sampling.%s=%s, got %r. Overriding.", + model_class_name, field, required, current, + ) + setattr(sampling, field, required) diff --git a/vllm_omni/diffusion/sched/__init__.py b/vllm_omni/diffusion/sched/__init__.py index e0263733847..bbea1b417d3 100644 --- a/vllm_omni/diffusion/sched/__init__.py +++ b/vllm_omni/diffusion/sched/__init__.py @@ -11,6 +11,7 @@ ) from vllm_omni.diffusion.sched.request_scheduler import RequestScheduler from vllm_omni.diffusion.sched.step_scheduler import StepScheduler +from vllm_omni.diffusion.sched.stream_batch_scheduler import StreamBatchScheduler Scheduler = RequestScheduler @@ -23,5 +24,6 @@ "SchedulerInterface", "RequestScheduler", "StepScheduler", + "StreamBatchScheduler", "Scheduler", ] diff --git a/vllm_omni/diffusion/sched/base_scheduler.py b/vllm_omni/diffusion/sched/base_scheduler.py index 0bd16597f3b..22d0a25083c 100644 --- a/vllm_omni/diffusion/sched/base_scheduler.py +++ b/vllm_omni/diffusion/sched/base_scheduler.py @@ -51,6 +51,7 @@ def __init__(self) -> None: self._waiting: deque[str] = deque() self._running: list[str] = [] self._running_sampling_params_key: SamplingParamsKey | None = None + self._blocked: set[str] = set() self._finished_req_ids: set[str] = set() self.max_num_running_reqs: int = 1 @@ -62,6 +63,7 @@ def initialize(self, od_config: OmniDiffusionConfig) -> None: self._waiting.clear() self._running.clear() self._running_sampling_params_key = None + self._blocked.clear() self._finished_req_ids.clear() max_num_seqs = getattr(od_config, "max_num_seqs", 1) try: @@ -129,7 +131,27 @@ def schedule(self) -> DiffusionSchedulerOutput: return scheduler_output def has_requests(self) -> bool: - return bool(self._waiting or self._running) + return bool(self._waiting or self._running or self._blocked) + + def block_request(self, sched_req_id: str) -> bool: + """Move a RUNNING request to BLOCKED. In-flight work continues.""" + + if sched_req_id not in self._running: + return False + self._running.remove(sched_req_id) + self._blocked.add(sched_req_id) + self._request_states[sched_req_id].status = DiffusionRequestStatus.BLOCKED + return True + + def unblock_request(self, sched_req_id: str) -> bool: + """Move a BLOCKED request to WAITING.""" + + if sched_req_id not in self._blocked: + return False + self._blocked.discard(sched_req_id) + self._waiting.append(sched_req_id) + self._request_states[sched_req_id].status = DiffusionRequestStatus.WAITING + return True def get_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: return self._request_states.get(sched_req_id) @@ -168,6 +190,7 @@ def close(self) -> None: self._waiting.clear() self._running.clear() self._running_sampling_params_key = None + self._blocked.clear() self._finished_req_ids.clear() self._reset_scheduler_state() @@ -182,6 +205,7 @@ def _finish_requests( finished_req_ids: set[str] = set() running_to_remove: set[str] = set() waiting_to_remove: set[str] = set() + blocked_to_remove: set[str] = set() for sched_req_id, status in statuses.items(): assert DiffusionRequestStatus.is_finished(status) @@ -194,6 +218,8 @@ def _finish_requests( running_to_remove.add(sched_req_id) if sched_req_id in self._waiting: waiting_to_remove.add(sched_req_id) + if sched_req_id in self._blocked: + blocked_to_remove.add(sched_req_id) if running_to_remove: self._running = [sched_req_id for sched_req_id in self._running if sched_req_id not in running_to_remove] @@ -203,6 +229,8 @@ def _finish_requests( self._waiting = deque( sched_req_id for sched_req_id in self._waiting if sched_req_id not in waiting_to_remove ) + if blocked_to_remove: + self._blocked -= blocked_to_remove for sched_req_id in finished_req_ids: state = self._request_states[sched_req_id] diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index 1bfb6945889..b486e9803ed 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -27,6 +27,7 @@ class DiffusionRequestStatus(enum.IntEnum): WAITING = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() + BLOCKED = enum.auto() # if any status is after or equal to FINISHED_COMPLETED, it is considered finished FINISHED_COMPLETED = enum.auto() @@ -111,6 +112,30 @@ def make_empty(cls) -> CachedRequestData: return cls(sched_req_ids=[]) +@dataclass +class Layout: + """How the previous latent should be sliced. + + - head [0:len(finished_idxs)] are chunks completing denoising (to decode) + - next [len(finished_idxs) : len(finished_idxs)+len(circulating_idxs)] are + re-admitted chunks + - rank 0 appends len(new_idxs) fresh randn rows at the tail before forwarding. + """ + + circulating_idxs: list[int] + finished_idxs: list[int] + new_idxs: list[int] + + +@dataclass +class RankTask: + """One unit of work for a rank in a stream-batch micro-step.""" + + sched_req_id: str + chunk_indices: list[int] + layout: Layout + + @dataclass class DiffusionSchedulerOutput: """Output of a single scheduling cycle.""" @@ -122,6 +147,9 @@ class DiffusionSchedulerOutput: num_running_reqs: int num_waiting_reqs: int + # stream-batch scheduling fields + assignment: list[RankTask] | None = None + @cached_property def scheduled_req_ids(self) -> list[str]: """ diff --git a/vllm_omni/diffusion/sched/stream_batch_scheduler.py b/vllm_omni/diffusion/sched/stream_batch_scheduler.py new file mode 100644 index 00000000000..9970372c976 --- /dev/null +++ b/vllm_omni/diffusion/sched/stream_batch_scheduler.py @@ -0,0 +1,318 @@ +"""Temporal-pipeline-parallel scheduler for streaming chunked diffusion. + +Each ``schedule()`` call corresponds to one micro-step. The pipeline is modeled +as ``pp_size`` per-rank chunk queues. At each schedule(), chunks at rank N-1 +drain (finished -> Layout finished slice, otherwise -> circulating back to +rank 0), queues shift one rank, and rank 0 receives the circulating chunks +plus B fresh admits up to the request's output chunk target. +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.base_scheduler import _BaseScheduler +from vllm_omni.diffusion.sched.interface import ( + DiffusionRequestStatus, + DiffusionSchedulerOutput, + Layout, + RankTask, +) + +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import RunnerOutput + +logger = init_logger(__name__) + + +@dataclass +class _InFlightChunk: + chunk_idx: int + steps_done: int = 0 + + +@dataclass +class _Progress: + sched_req_id: str + pp_size: int + chunk_frames: int + num_chunks: int + num_steps: int + + next_chunk_idx: int = 0 + batch_size: int = 0 + + # chunks that will be processed by rank r at the current micro-step + chunks_at: list[deque[_InFlightChunk]] = field(default_factory=list) + # rank r's layout — constructed at rank 0 and shifted forward each step + layouts_at: list[Layout] = field(default_factory=list) + + +@dataclass +class _SLOReqState: + slo_fps: float + max_batch: int + chunk_frames: int + batch_size: int = 1 + warmed_up: bool = False + + +class _SLOController: + """Per-step B_target adjustment for per-request chunk admission.""" + + SLACK_HEADROOM = 0.2 + + def __init__(self) -> None: + self._reqs: dict[str, _SLOReqState] = {} + + def register( + self, + sched_req_id: str, + slo_fps: float | None, + max_batch: int, + chunk_frames: int, + ) -> None: + if slo_fps is None or slo_fps <= 0: + return + self._reqs[sched_req_id] = _SLOReqState( + slo_fps=float(slo_fps), + max_batch=max(1, max_batch), + chunk_frames=max(1, chunk_frames), + ) + + def get_target(self, sched_req_id: str) -> int: + st = self._reqs.get(sched_req_id) + return st.batch_size if st is not None else 1 + + def mark_warmed_up(self, sched_req_id: str) -> None: + st = self._reqs.get(sched_req_id) + if st is not None: + st.warmed_up = True + + def observe(self, sched_req_id: str, latency_ns: int | None, b_current: int | None) -> None: + st = self._reqs.get(sched_req_id) + if st is None or not st.warmed_up or latency_ns is None or latency_ns <= 0 or b_current is None or b_current <= 0: + return + + budget = (b_current * st.chunk_frames / st.slo_fps) * 1e9 + if latency_ns > budget: + new_b = max(1, st.batch_size - 1) + elif latency_ns < budget * (1.0 - self.SLACK_HEADROOM) and st.batch_size < st.max_batch: + new_b = st.batch_size + 1 + else: + return + + if new_b != st.batch_size: + logger.info( + "SLO[%s]: B_target %d -> %d (latency=%.2fms budget=%.2fms)", + sched_req_id, st.batch_size, new_b, latency_ns / 1e6, budget / 1e6, + ) + st.batch_size = new_b + + def unregister(self, sched_req_id: str) -> None: + self._reqs.pop(sched_req_id, None) + + +class StreamBatchScheduler(_BaseScheduler): + """Temporal-PP scheduler driving chunked-streaming diffusion requests. + + Per micro-step: + 1. Promote waiting requests (handled by the base class). + 2. Drain rank N-1: finished chunks -> finished slice in + Layout, otherwise -> circulating back to rank 0. + 3. Shift per-rank queues by one (rank r <- rank r-1). + 4. Rank 0 = circulating + B fresh admits, where + `B = min(B_target, output_chunks_remaining)`. + 5. Emit per-rank assignment with Layout attached to every RankTask. + """ + + def __init__(self) -> None: + super().__init__() + self.pp_size: int = 1 + self._progress: dict[str, _Progress] = {} + self._slo: _SLOController = _SLOController() + + # ── Lifecycle ────────────────────────────────────────────────────────── + + def initialize(self, od_config: OmniDiffusionConfig) -> None: + super().initialize(od_config) + self.pp_size = od_config.parallel_config.pipeline_parallel_size + # TODO: support multiple requests + self.max_num_running_reqs = 1 + + def _reset_scheduler_state(self) -> None: + self._progress.clear() + self._slo = _SLOController() + + def _pop_extra_request_state(self, sched_req_id: str) -> None: + self._progress.pop(sched_req_id, None) + self._slo.unregister(sched_req_id) + + # ── Request admission ────────────────────────────────────────────────── + + def add_request(self, request: OmniDiffusionRequest) -> str: + sampling = request.sampling_params + if sampling.chunk_frames is None or sampling.chunk_frames <= 0: + raise ValueError( + f"chunk_frames must be a positive int when stream_batch=True, got {sampling.chunk_frames}" + ) + if sampling.num_chunks is None or sampling.num_chunks <= 0: + raise ValueError(f"num_chunks must be a positive int, got {sampling.num_chunks}") + if sampling.num_inference_steps is None or sampling.num_inference_steps <= 0: + raise ValueError( + f"num_inference_steps must be a positive int, got {sampling.num_inference_steps}" + ) + return super().add_request(request) + + # ── Scheduling ───────────────────────────────────────────────────────── + + def schedule(self) -> DiffusionSchedulerOutput: + base_output = super().schedule() + + for new_req in base_output.scheduled_new_reqs: + self._init_progress(new_req.sched_req_id, new_req.req) + + for progress in self._progress.values(): + self._advance_chunk_pipeline(progress) + + if self._progress: + base_output.assignment = self._build_assignment() + + logger.info( + "StreamBatchScheduler schedule: %d running req(s), assignment=%s", + len(self._running), base_output.assignment, + ) + + return base_output + + def _init_progress(self, sched_req_id: str, req: OmniDiffusionRequest) -> None: + sampling = req.sampling_params + chunk_frames = sampling.chunk_frames + num_chunks = sampling.num_chunks + num_steps = sampling.num_inference_steps + + self._progress[sched_req_id] = _Progress( + sched_req_id=sched_req_id, + chunk_frames=chunk_frames, + num_chunks=num_chunks, + num_steps=num_steps, + pp_size=self.pp_size, + chunks_at=[deque() for _ in range(self.pp_size)], + layouts_at=[ + Layout(circulating_idxs=[], finished_idxs=[], new_idxs=[]) + for _ in range(self.pp_size) + ], + ) + + self._slo.register( + sched_req_id=sched_req_id, + slo_fps=sampling.slo_fps, + max_batch=sampling.slo_max_batch, + chunk_frames=chunk_frames, + ) + + logger.debug( + "StreamBatchScheduler initialized progress for %s " + "(chunk_frames=%d, num_chunks=%d, num_steps=%d, slo_fps=%s, pp_size=%d)", + sched_req_id, chunk_frames, num_chunks, num_steps, sampling.slo_fps, self.pp_size, + ) + + def _advance_chunk_pipeline(self, progress: _Progress) -> None: + """Advance the per-rank queues and layouts by one micro-step.""" + + pp = progress.pp_size + + # 1. Drain last rank from previous step + finished_idxs: list[int] = [] + circulating: list[_InFlightChunk] = [] + last = progress.chunks_at[pp - 1] + while last: + chunk = last.popleft() + chunk.steps_done += 1 + if chunk.steps_done >= progress.num_steps: + finished_idxs.append(chunk.chunk_idx) + else: + circulating.append(chunk) + + # 2. Shift chunks and layouts: rank r receives what rank r-1 had + for r in range(pp - 1, 0, -1): + progress.chunks_at[r] = progress.chunks_at[r - 1] + progress.layouts_at[r] = progress.layouts_at[r - 1] + progress.chunks_at[0] = deque() + + # 3. Rank 0 = circulating + B fresh admits + for chunk in circulating: + progress.chunks_at[0].append(chunk) + + output_chunks_remaining = progress.num_chunks - progress.next_chunk_idx + b_target = self._slo.get_target(progress.sched_req_id) + batch_size = min(b_target, output_chunks_remaining) + + new_idxs: list[int] = [] + for _ in range(batch_size): + chunk_idx = progress.next_chunk_idx + progress.next_chunk_idx += 1 + progress.chunks_at[0].append(_InFlightChunk(chunk_idx=chunk_idx)) + new_idxs.append(chunk_idx) + progress.batch_size = batch_size + + # 4. Set rank 0's layout for this step. + progress.layouts_at[0] = Layout( + circulating_idxs=[c.chunk_idx for c in circulating], + finished_idxs=finished_idxs, + new_idxs=new_idxs, + ) + + if finished_idxs: + self._slo.mark_warmed_up(progress.sched_req_id) + + def _build_assignment(self) -> list[RankTask]: + assert len(self._progress) <= 1 #TODO: support multiple requests + assignment: list[RankTask] = [] + for progress in self._progress.values(): + for r in range(self.pp_size): + queue = progress.chunks_at[r] + assignment.append(RankTask( + sched_req_id=progress.sched_req_id, + chunk_indices=[c.chunk_idx for c in queue], + layout=progress.layouts_at[r], + )) + return assignment + + # ── Output processing ────────────────────────────────────────────────── + + def update_from_output( + self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput + ) -> set[str]: + sched_req_ids = sched_output.scheduled_req_ids + if not sched_req_ids: + return set() + + assert len(sched_req_ids) == 1, "Multiple scheduled requests not supported" + + sched_req_id = output.req_id + + assert sched_req_id == sched_req_ids[0] + + progress = self._progress.get(sched_req_id) + if progress is not None and output.micro_step_wall_ns is not None: + self._slo.observe(sched_req_id, output.micro_step_wall_ns, progress.batch_size) + + terminal: dict[str, DiffusionRequestStatus] = {} + terminal_errors: dict[str, str | None] = {} + + if progress is not None: + err = output.result.error if output.result is not None else None + if err is not None: + terminal[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[sched_req_id] = err + elif output.finished: + terminal[sched_req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + + return self._finalize_update_from_output(sched_output, terminal, terminal_errors) \ No newline at end of file diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index ab6dbb79afd..0219cd73a57 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -28,13 +28,14 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.interface import supports_step_execution +from vllm_omni.diffusion.models.interface import supports_step_execution, supports_micro_step_execution from vllm_omni.diffusion.offloader import get_offload_backend from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput +from vllm_omni.diffusion.distributed.parallel_state import get_pp_group +from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput, Layout from vllm_omni.diffusion.worker.input_batch import InputBatch, scatter_latents -from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, DiffusionRequestState, RunnerOutput +from vllm_omni.diffusion.worker.utils import BatchRunnerOutput, DiffusionRequestState, RunnerOutput, ChunkState from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.platforms import current_omni_platform from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin @@ -154,6 +155,13 @@ def get_memory_context(): "prepare_encode(), denoise_step(), step_scheduler(), and post_decode(); " f"{self.od_config.model_class_name} does not support that contract." ) + if getattr(self.od_config, "stream_batch", False) and not self.supports_micro_step_mode(): + raise ValueError( + "stream_batch=True requires a pipeline implementing the micro-step " + "execution protocol (prepare_encode, set_pp_recv_dict_buffers, " + "denoise_step, prefetch_tensors, step_scheduler, post_decode, encode_chunk_inputs); " + f"{self.od_config.model_class_name} does not support that contract." + ) # Apply CPU offloading self.offload_backend = get_offload_backend(self.od_config, device=self.device) @@ -486,3 +494,191 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BatchR self._update_states_after(states, input_batch, pipeline_interrupted) return BatchRunnerOutput.from_list(runner_output_list) + # ------------------------------------------------------------------ + # Temporal-PP micro-step execution + # ------------------------------------------------------------------ + + def supports_micro_step_mode(self) -> bool: + """Return whether current pipeline supports micro-step execution.""" + + return self.pipeline is not None and supports_micro_step_execution(self.pipeline) + + def execute_micro_step(self, sched_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one temporal-PP micro-step.""" + + assert self.pipeline is not None, "Model not loaded. Call load_model() first." + if not self.supports_micro_step_mode(): + raise ValueError("Current pipeline does not support micro-step execution.") + if self.od_config.cache_backend not in (None, "none"): + raise ValueError("Micro-step mode does not support cache_backend yet.") + + assignment = sched_output.assignment + if assignment is None: + raise ValueError("execute_micro_step requires assignment in sched_output.") + + use_hsdp = self.od_config.parallel_config.use_hsdp + grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() + + with grad_context: + states, new_request_ids = self._update_states(sched_output) + if len(states) != 1: + raise ValueError( + f"Micro-step mode supports exactly one running request, got {len(states)}." + ) + state = states[0] + is_new_request = state.req_id in new_request_ids + + if is_new_request: + if state.sampling.generator is None and state.sampling.seed is not None: + gen_device = state.sampling.generator_device or ( + "cpu" if self.device.type == "cpu" else self.device + ) + state.sampling.generator = torch.Generator(device=gen_device).manual_seed(state.sampling.seed) + + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + pp_group = get_pp_group() + task = assignment[pp_group.rank_in_group] + chunk_idxs = list(task.chunk_indices) + layout = task.layout + + if is_new_request: + state.extra.pop("chunks", None) + pp_group.reset_buffer() + self.pipeline.prepare_encode(state) + self.pipeline.set_pp_recv_dict_buffers(state) + + t_start_ns = time.perf_counter_ns() if pp_group.is_first_rank else None + result: DiffusionOutput | None = None + finished = False + + if pp_group.is_first_rank: + result = self._update_decoded_chunks(state, layout) + finished = result is not None + + if pp_group.is_first_rank or pp_group.is_last_rank: + self._prepare_chunk_latents(state, layout, is_first_rank=pp_group.is_first_rank) + + if chunk_idxs: + state.extra["current_chunk_idxs"] = chunk_idxs + + chunks: list[ChunkState] = [ + self._get_or_create_chunk(state, idx)[0] for idx in chunk_idxs + ] + + # Per-row timesteps + state.batched_timesteps = torch.stack( + [state.timesteps[c.step_index] for c in chunks] + ) + + batch_size = len(chunks) + noise_pred = self.pipeline.denoise_step(state, batch_size=batch_size) + + if noise_pred is None and getattr(self.pipeline, "interrupt", False): + self._update_state_after(state, layout, finished=True) + return RunnerOutput( + req_id=state.req_id, + finished=True, + result=DiffusionOutput(error="micro-step denoise interrupted"), + ) + + schedulers = [c.scheduler for c in chunks] + self.pipeline.step_scheduler( + state, noise_pred, per_request_scheduler=schedulers, batch_size=batch_size, + ) + + for c in chunks: + c.step_index += 1 + + if pp_group.is_last_rank: + for i, c in enumerate(chunks): + c.latents = state.latents[i : i + 1] + + prev_task = assignment[pp_group.group_prev_rank] + if prev_task.chunk_indices: + self.pipeline.prefetch_tensors(state, batch_size=len(prev_task.chunk_indices)) + + self._update_state_after(state, layout, finished=finished) + return RunnerOutput( + req_id=state.req_id, + finished=finished, + result=result, + micro_step_wall_ns=( + time.perf_counter_ns() - t_start_ns if t_start_ns is not None else None + ), + ) + + @staticmethod + def _get_or_create_chunk(state: DiffusionRequestState, chunk_idx: int) -> tuple[ChunkState, bool]: + chunks: dict[int, ChunkState] = state.extra.setdefault("chunks", {}) + chunk = chunks.get(chunk_idx) + if chunk is not None: + return chunk, chunk.step_index == 0 + chunk = ChunkState(idx=chunk_idx) + chunk.scheduler = copy.deepcopy(state.scheduler) + chunks[chunk_idx] = chunk + return chunk, True + + def _prepare_chunk_latents(self, state: DiffusionRequestState, layout: Layout, is_first_rank: bool): + pieces: list[torch.Tensor] = [] + + n_finished = len(layout.finished_idxs) + + for i, idx in enumerate(layout.circulating_idxs): + chunk, _ = self._get_or_create_chunk(state, idx) + if is_first_rank: + chunk.latents = state.latents[n_finished + i : n_finished + i + 1] + pieces.append(chunk.latents) + + if layout.new_idxs: + for idx in layout.new_idxs: + self._get_or_create_chunk(state, idx) + encoded = self.pipeline.encode_chunk_inputs(state, layout.new_idxs) + for i, idx in enumerate(layout.new_idxs): + state.extra["chunks"][idx].latents = encoded[i : i + 1] + pieces.append(encoded) + + state.latents = torch.cat(pieces, dim=0) if pieces else None + + def _update_decoded_chunks(self, state: DiffusionRequestState, layout: Layout) -> DiffusionOutput | None: + n_finished = len(layout.finished_idxs) + if n_finished > 0: + saved = state.latents + state.latents = state.latents[:n_finished] + decoded = self.pipeline.post_decode(state) + state.latents = saved + + state.extra.setdefault("decoded_chunks", []).append(decoded) + state.extra["chunks_decoded"] = ( + state.extra.get("chunks_decoded", 0) + n_finished + ) + + if state.extra.get("chunks_decoded", 0) >= state.sampling.num_chunks: + return self._merge_chunk_outputs(state.extra["decoded_chunks"]) + + return None + + def _update_state_after(self, state: DiffusionRequestState, layout: Layout, finished: bool = False): + for idx in layout.finished_idxs: + state.extra.get("chunks", {}).pop(idx, None) + + if finished: + self.state_cache.pop(state.req_id, None) + + @staticmethod + def _merge_chunk_outputs(chunks: list[DiffusionOutput]) -> DiffusionOutput: + """Merge decoded chunk outputs along the temporal axis. + + Supports both: + - 5D ``[B, C, T, H, W]`` (Wan-style): time axis = dim 2. + - 4D ``[C, T, H, W]`` (lingbot-style): time axis = dim 1. + + NOTE: This is a temporary solution until streaming output is supported. + """ + + try: + outputs = [c.output for c in chunks] + time_dim = outputs[0].dim() - 3 + merged = torch.cat(outputs, dim=time_dim) + except Exception as e: + return DiffusionOutput(error=f"Failed to merge {len(chunks)} chunk outputs: {e}") + return DiffusionOutput(output=merged) \ No newline at end of file diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index cd5fe246d94..a3855887f4f 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -380,6 +380,20 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRu profiler.step() return output + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: + """Execute one temporal-PP micro-step by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None: + self.lora_manager.set_active_adapter(None) + if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): + raise ValueError("Stream-batch mode does not support LoRA yet.") + profiler = self._get_profiler() + ctx = profiler.annotate_context_manager("diffusion_micro_step") if profiler else nullcontext() + with ctx: + output = self.model_runner.execute_micro_step(scheduler_output) + if profiler: + profiler.step() + return output def _activate_step_lora(self, scheduler_output: DiffusionSchedulerOutput) -> None: """Activate the LoRA adapter for the scheduled step batch. @@ -986,6 +1000,10 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRu """Execute one diffusion step.""" return self.worker.execute_stepwise(scheduler_output) + def execute_micro_step(self, scheduler_output: DiffusionSchedulerOutput) -> BaseRunnerOutput: + """Execute one temporal-PP micro-step.""" + return self.worker.execute_micro_step(scheduler_output) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """ Load model weights. diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py index 1e98a7784ab..a214a6aa6dd 100644 --- a/vllm_omni/diffusion/worker/utils.py +++ b/vllm_omni/diffusion/worker/utils.py @@ -5,6 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Iterator from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -55,6 +56,8 @@ class DiffusionRequestState: timesteps: torch.Tensor | list[torch.Tensor] | None = None step_index: int = 0 + batched_timesteps: torch.Tensor | None = None + # ── Per-request scheduler instance (set once by prepare_encode) ── scheduler: Any | None = None @@ -110,6 +113,21 @@ def new_request(self) -> bool: return self.step_index == 0 or self.timesteps is None +@dataclass +class ChunkState: + """Per-chunk state for one in-flight chunk of a streaming request. + + Lives inside ``DiffusionRequestState.extra["chunks"]`` (keyed by + ``chunk_idx``). + """ + + idx: int + latents: torch.Tensor | None = None + step_index: int = 0 + scheduler: Any | None = None + extra: dict[str, Any] = field(default_factory=dict) + + class BaseRunnerOutput(ABC): @abstractmethod def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: @@ -118,11 +136,15 @@ def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: @dataclass class RunnerOutput(BaseRunnerOutput): - """Output of a single denoising step for a request. + """Output of a single execution step for a request. - NOTE: `latents` may be None when returned through IPC to avoid - serialization overhead. The actual latents are kept in Worker's - _request_state_cache. + Each scheduler reads the fields it needs: + + - ``StepScheduler`` reads ``step_index`` / ``finished``. + - ``StreamBatchScheduler`` reads ``finished`` / ``result`` / + ``micro_step_wall_ns``. + + Fields not relevant to an execution path are left as ``None`` / ``False``. """ req_id: str @@ -130,6 +152,9 @@ class RunnerOutput(BaseRunnerOutput): finished: bool = False result: DiffusionOutput | None = None + # ── Temporal-PP micro-step fields ── + micro_step_wall_ns: int | None = None + def get_req_output(self, sched_req_id: str) -> RunnerOutput | None: return self if self.req_id == sched_req_id else None @@ -159,4 +184,4 @@ def __len__(self) -> int: @classmethod def from_list(cls, runner_output_list: list[RunnerOutput]) -> BatchRunnerOutput: - return cls(runner_outputs=runner_output_list) + return cls(runner_outputs=runner_output_list) \ No newline at end of file diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 23ef9b85567..90384333772 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1922,6 +1922,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "model_class_name": kwargs.get("model_class_name", None), "additional_config": kwargs.get("additional_config", None), "step_execution": kwargs.get("step_execution", False), + "stream_batch": kwargs.get("stream_batch", False), "vae_use_slicing": kwargs.get("vae_use_slicing", False), "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index cc6e9a4dabb..a91bf5345eb 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -651,6 +651,17 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Enable AR stage profiler to include AR stage timing in stage_durations.", ) + omni_config_group.add_argument( + "--ws-max-size", + type=int, + default=1_048_576, # 1MB + help="Change max size of a websocket payload that is accepted by the server", + ) + omni_config_group.add_argument( + "--ws", + default="auto", + help="Set the websocket Protocol type", + ) # Supplementary auxiliary text encoder parameters # (e.g., the meta llama/meta llama-3.1-8b-instrument used by hidream) omni_config_group.add_argument( diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 69a6dd603ed..3520e5f301c 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -115,6 +115,7 @@ VideoListResponse, VideoResponse, ) +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat @@ -377,6 +378,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, if log_config is not None: uvicorn_kwargs["log_config"] = log_config + if args.ws_max_size is not None: + uvicorn_kwargs["ws_max_size"] = args.ws_max_size + if args.ws is not None: + uvicorn_kwargs["ws"] = args.ws + async with build_async_omni( args, client_config=client_config, @@ -631,6 +637,10 @@ async def omni_init_app_state( state.openai_streaming_speech = None state.openai_streaming_video = None + state.openai_serving_world_camera = ServingRealtimeWorldCamera.create_policy_server( + engine_client=engine_client, model_name=model_name + ) + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 logger.info("Pure diffusion API server initialized for model: %s", model_name) @@ -948,6 +958,10 @@ async def omni_init_app_state( stage_configs=state.stage_configs, ) + state.openai_serving_world_camera = ServingRealtimeWorldCamera.create_policy_server( + engine_client=engine_client, model_name=model_name + ) + state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -1406,6 +1420,21 @@ async def realtime_websocket(websocket: WebSocket): await connection.handle_connection() +@router.websocket("/v1/realtime/world/camera") +async def realtime_world_camera_openpi(websocket: WebSocket): + from vllm_omni.entrypoints.openai.realtime.world.camera_connection import WorldCameraRealtimeConnection + + serving = getattr(websocket.app.state, "openai_serving_world_camera", None) + + if serving is None: + await websocket.accept() + await websocket.send_json({"type": "error", "error": "World Model policy not available", "code": "unsupported"}) + await websocket.close() + return + connection = WorldCameraRealtimeConnection(websocket, serving) + await connection.handle_connection() + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py new file mode 100644 index 00000000000..01c99d67a03 --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_connection.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""WebSocket connection for robot policy inference (OpenPI protocol). + +Protocol (compatible with DreamZero test_client_AR.py): + Connect -> server sends msgpack(PolicyServerConfig fields) + Infer -> client sends msgpack(req), server sends msgpack(ndarray) + Reset -> client sends msgpack({endpoint:reset}), server sends "reset successful" +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import torch +from fastapi import WebSocket +from starlette.websockets import WebSocketDisconnect +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.realtime.world.camera_serving import ServingRealtimeWorldCamera +from vllm_omni.entrypoints.openai.video_api_utils import _normalize_frames + +logger = init_logger(__name__) +_DEFAULT_IDLE_TIMEOUT = 30.0 +CHUNK_FRAMES = 4 + + +def _get_msgpack_numpy() -> Any: + try: + from openpi_client import msgpack_numpy + except ImportError as exc: + raise ImportError( + "The `/v1/realtime/world/camera` endpoint requires the optional " + "`openpi-client` dependency. Install it with `pip install openpi-client`." + ) from exc + + return msgpack_numpy + + +def _pack(obj: Any) -> bytes: + return _get_msgpack_numpy().packb(obj) + + +def _unpack(data: bytes) -> Any: + return _get_msgpack_numpy().unpackb(data) + + +class WorldCameraRealtimeConnection: + """WebSocket connection for world model inference.""" + + def __init__( + self, + websocket: WebSocket, + serving: ServingRealtimeWorldCamera, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + ) -> None: + self.websocket = websocket + self.serving = serving + self._idle_timeout = idle_timeout + + async def _send_error(self, message: str) -> None: + await self.websocket.send_bytes(_pack({"type": "error", "message": message})) + + def _unpack_request(self, data: bytes) -> dict[str, Any]: + req = _unpack(data) + if not isinstance(req, dict): + raise ValueError("Invalid request payload") + return req + + async def handle_connection(self) -> None: + """Main loop.""" + await self.websocket.accept() + + try: + # Send model-specific PolicyServerConfig resolved by serving from + # diffusion od_config.model_config. + metadata = self.serving.policy_server_config.to_dict() + await self.websocket.send_bytes(_pack(metadata)) + + while True: + try: + msg = await asyncio.wait_for( + self.websocket.receive(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + logger.info("World Model OpenPI connection idle timeout after %.1f seconds", self._idle_timeout) + try: + await self.websocket.close() + except Exception: + logger.debug("Failed to close idle World Model websocket", exc_info=True) + return + + if msg.get("type") == "websocket.disconnect": + break + + if "bytes" not in msg or not msg["bytes"]: + continue + + try: + req = self._unpack_request(msg["bytes"]) + except Exception: + logger.exception("Invalid world model OpenPI request payload") + try: + await self._send_error("Invalid request payload") + except Exception: + break + continue + + try: + endpoint = req.pop("endpoint", "infer") + + if endpoint == "reset": + self.serving.reset(req) + await self.websocket.send_text("reset successful") + else: + result = await self.serving.infer(req) + + if ( + len(result.images) == 1 + and isinstance(result.images[0], tuple) + and len(result.images[0]) == 1 + ): + frames = result.images[0] + elif len(result.images) == 1 and isinstance(result.images[0], dict): + frames = result.images[0].get("frames") or result.images[0].get("video") + else: + frames = result.images + + if len(frames) == 1: + frames = frames[0] + + if isinstance(frames, torch.Tensor): + frames = frames.numpy(force=True) + + frames = _normalize_frames(frames) + + total = (len(frames) + CHUNK_FRAMES - 1) // CHUNK_FRAMES + for i in range(total): + chunk = frames[i * CHUNK_FRAMES : (i + 1) * CHUNK_FRAMES] + await self.websocket.send_bytes( + _pack( + { + "type": "frame", + "index": i, + "total": total, + "video": chunk, + } + ) + ) + + except Exception: + logger.exception("Error handling request") + try: + await self._send_error("Internal inference error") + except Exception: + break + + except WebSocketDisconnect: + pass + except Exception: + logger.exception("Connection error") diff --git a/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py new file mode 100644 index 00000000000..4ac8a6c4a1f --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime/world/camera_serving.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Protocol structs for the /v1/realtime/world/* family of endpoints. + +These are msgpack-serialised over the WebSocket wire via ``msgspec.msgpack``. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import numpy as np +from omegaconf import OmegaConf +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def _to_builtin_container(value: Any) -> Any: + if OmegaConf.is_config(value): + return OmegaConf.to_container(value, resolve=True) + if isinstance(value, Mapping): + return {key: _to_builtin_container(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_to_builtin_container(item) for item in value] + return value + + +@dataclass(frozen=True) +class CameraServerConfig: + """Static server-side camera/pipeline parameters sent to a client on connect.""" + + values: dict[str, Any] + + @classmethod + def from_model_config(cls, model_config: Any) -> CameraServerConfig: + return cls(_to_builtin_container(model_config)) + + def to_dict(self) -> dict[str, Any]: + return _to_builtin_container(self.values) + + +class ServingRealtimeWorldCamera: + """World Model Camera serving layer for OpenPI protocol. + + Model-specific transform/state lives in the diffusion pipeline. + """ + + def __init__( + self, + engine_client: Any, + model_name: str | None = None, + ) -> None: + self.engine_client = engine_client + self.model_name = model_name + self._current_session_id: str | None = None + self._call_count = 0 + self.policy_server_config = self._get_policy_server_config(engine_client) + self._force_reset = False + + @classmethod + def create_policy_server( + cls, + engine_client: Any, + model_name: str | None = None, + ) -> ServingRealtimeWorldCamera | None: + try: + return cls(engine_client=engine_client, model_name=model_name) + except ValueError as exc: + if "policy_server_config" not in str(exc): + raise + logger.info("World Model OpenPI serving disabled for model %s", model_name) + return None + + @staticmethod + def _get_policy_server_config(engine_client: Any) -> CameraServerConfig: + model_config = None + get_od_config = getattr(engine_client, "get_diffusion_od_config", None) + if callable(get_od_config): + od_config = get_od_config() + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + for stage_config in getattr(engine_client, "stage_configs", []) or []: + if getattr(stage_config, "stage_type", None) != "diffusion": + continue + engine_args = getattr(stage_config, "engine_args", None) + model_config = getattr(engine_args, "model_config", None) + if model_config is not None: + break + + if model_config is None: + od_config = getattr(engine_client, "od_config", None) + model_config = getattr(od_config, "model_config", None) + + if model_config is None: + model_config = getattr(engine_client, "model_config", None) + + return CameraServerConfig.from_model_config(model_config) + + def reset(self, req: dict) -> None: + """Reset serving state. + + Engine-side Lingbot state is reset on the next inference request via + `extra_args["reset"]`, not by an immediate websocket-side RPC. + """ + self._current_session_id = None + self._force_reset = True + + async def infer(self, req: dict) -> np.ndarray: + """raw req → engine → video.""" + # Session tracking + + session_id = req.get("session_id") + if session_id is not None and session_id != self._current_session_id: + if self._current_session_id is not None: + logger.info("Session changed %s → %s", self._current_session_id, session_id) + self.reset({}) + self._current_session_id = session_id + + self._call_count += 1 + + # Build request, run inference through AsyncOmni + request = self._build_request(req) + + # After an inference call we reset the _force_reset argument + self._force_reset = False + + result = None + # OpenPI policy serving is one request -> one action reply. AsyncOmni + # exposes an async iterator, so consume it to completion and use the + # final output, matching other non-streaming OpenAI serving paths. + async for output in self.engine_client.generate( + prompt=request.prompts[0], + request_id=request.request_ids[0], + sampling_params_list=[request.sampling_params], + ): + result = output + if result is None: + raise RuntimeError("World Model Camera OpenPI request produced no output.") + + return result + + def _build_request(self, req: dict) -> Any: + """Build engine request from raw robot req. + + Returns an `OmniDiffusionRequest` payload consumed by + `AsyncOmni.generate()` and routed to the diffusion stage. + """ + from vllm_omni.diffusion.request import OmniDiffusionRequest + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + extra_args = {"session_id": self._current_session_id or "default", "force_reset": self._force_reset} + + camera = req.get("camera", None) + + multi_modal_data = { + "image": req.get("image", None), + "camera": camera, + } + + prompt = req.get("prompt", "") + + extra_body = req.get("extra_body", {}) + + height = extra_body.get("height", None) + width = extra_body.get("width", None) + num_frames = extra_body.get("num_frames", None) + fps = extra_body.get("fps", None) + seed = extra_body.get("seed", None) + + sampling_params = OmniDiffusionSamplingParams( + height=height, width=width, num_frames=num_frames, frame_rate=fps, extra_args=extra_args, seed=seed + ) + return OmniDiffusionRequest( + prompts=[ + { + "prompt": prompt, + "multi_modal_data": multi_modal_data, + } + ], + sampling_params=sampling_params, + request_ids=[f"camera-{self._current_session_id or 'default'}"], + ) diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 1b80f4b1b77..dd10d36a6b8 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -220,6 +220,12 @@ class OmniDiffusionSamplingParams: width_latents: list[int] | int | None = None num_frames: int = 1 # Default for image models num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus + chunk_frames: int = 5 # Used when stream_batch=True + num_chunks: int = 1 + + # SLO-adaptive stream batching. ``slo_fps=None`` keeps B_target fixed at 1. + slo_fps: float | None = None + slo_max_batch: int = 8 # Original dimensions (before VAE scaling) height: int | None = None