diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index b81ff5b992c..459da61cd59 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -33,6 +33,8 @@ th { | `WanPipeline` | Wan2.1-T2V, Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Wan22VACEPipeline` | Wan2.1-VACE | `Wan-AI/Wan2.1-VACE-1.3B-diffusers`, `Wan-AI/Wan2.1-VACE-14B-diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | +| `AniSoraI2VCogVideoXPipeline` | AniSora-I2V (5B) | `Disty0/Index-anisora-5B-diffusers` | ✅︎ | ✅︎ | | | +| `AniSoraV2I2VPipeline` | AniSora-I2V (14B) | `aardsoul-music/Wan2.1-Anisora-14B` | ✅︎ | ✅︎ | | | | `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | | | `LTX2ImageToVideoPipeline` | LTX-2-I2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | | | `LTX2TwoStagesPipeline` | LTX-2-T2V | `rootonchair/LTX-2-19b-distilled` | ✅︎ | ✅︎ | | | diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 3e9ddb94261..cf1d451d243 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -145,11 +145,13 @@ The following tables show which models support each feature: | **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | | **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **AniSora V1 (5B)** | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | +| **AniSora V2 (14B)** | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | **Frame Interpolation Support** - **Supported**: Wan2.2 text-to-video, image-to-video, and TI2V pipelines -- **Not supported**: Wan2.1-VACE, LTX-2, LTX-2.3, Helios, HunyuanVideo-1.5, DreamID-Omni +- **Not supported**: Wan2.1-VACE, LTX-2, LTX-2.3, Helios, HunyuanVideo-1.5, DreamID-Omni, AniSora ### AudioGen diff --git a/tests/e2e/offline_inference/test_anisora_i2v.py b/tests/e2e/offline_inference/test_anisora_i2v.py new file mode 100644 index 00000000000..36542e12e81 --- /dev/null +++ b/tests/e2e/offline_inference/test_anisora_i2v.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E offline inference tests for Index-AniSora I2V models. + +Covers: +- V1 (5B, CogVideoX-based): AniSoraI2VCogVideoXPipeline +- V2 (14B, Wan2.1-based): AniSoraV2I2VPipeline +- TP=2 for both models (requires 2 GPUs) +""" + +import gc +import os +import sys +from pathlib import Path + +import PIL.Image +import pytest +import torch + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from tests.utils import hardware_test +from vllm_omni import Omni +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODEL_V1 = "Disty0/Index-anisora-5B-diffusers" +MODEL_V2 = "aardsoul-music/Wan2.1-Anisora-14B" + +# V1: vae_scale_factor_temporal=4 → num_frames % 4 == 1, e.g. 5, 9, 13 ... +# V2: same constraint (Wan2.1 VAE) +NUM_FRAMES = 5 +HEIGHT = 480 +WIDTH = 720 +SEED = 42 + + +def _dummy_image() -> PIL.Image.Image: + """Create a small solid-color image for testing.""" + return PIL.Image.new("RGB", (WIDTH, HEIGHT), color=(100, 149, 237)) + + +def _assert_video_output(output, num_frames: int, height: int, width: int) -> None: + assert output is not None + if isinstance(output, OmniRequestOutput): + assert output.final_output_type == "image" + assert output.request_output is not None + frames = output.request_output.images[0] + else: + frames = output + assert frames is not None + assert hasattr(frames, "shape"), f"Expected tensor, got {type(frames)}" + # shape: (batch, num_frames, height, width, channels) + # Pipeline may round num_frames up for VAE temporal alignment + assert frames.shape[1] >= num_frames, f"Expected >= {num_frames} frames, got {frames.shape[1]}" + assert frames.shape[2] == height + assert frames.shape[3] == width + + +def _cleanup(model): + """Shut down model workers and free GPU memory between tests.""" + model.shutdown() + del model + gc.collect() + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# V1 (5B CogVideoX) — single GPU +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}) +def test_anisora_v1_offline_single_gpu(): + """V1 (5B) offline inference on a single GPU.""" + model = Omni(model=MODEL_V1) + image = _dummy_image() + outputs = model.generate( + {"prompt": "a cat sitting calmly", "multi_modal_data": {"image": image}}, + OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=2, + guidance_scale=6.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(SEED), + ), + ) + result = outputs[0] if isinstance(outputs, list) else outputs + _assert_video_output(result, NUM_FRAMES, HEIGHT, WIDTH) + _cleanup(model) + + +# --------------------------------------------------------------------------- +# V1 (5B CogVideoX) — SP=2 (Ulysses) +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +def test_anisora_v1_offline_sp2(): + """V1 (5B) offline inference with sequence_parallel_size=2 (Ulysses).""" + model = Omni( + model=MODEL_V1, + parallel_config=DiffusionParallelConfig(sequence_parallel_size=2, ulysses_degree=2), + ) + image = _dummy_image() + outputs = model.generate( + {"prompt": "a cat sitting calmly", "multi_modal_data": {"image": image}}, + OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=2, + guidance_scale=6.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(SEED), + ), + ) + result = outputs[0] if isinstance(outputs, list) else outputs + _assert_video_output(result, NUM_FRAMES, HEIGHT, WIDTH) + _cleanup(model) + + +# --------------------------------------------------------------------------- +# V1 (5B CogVideoX) — FP8 quantization, single GPU +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}) +def test_anisora_v1_offline_fp8(): + """V1 (5B) offline inference with FP8 quantization (W8A8).""" + model = Omni(model=MODEL_V1, quantization="fp8") + image = _dummy_image() + outputs = model.generate( + {"prompt": "a cat sitting calmly", "multi_modal_data": {"image": image}}, + OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=2, + guidance_scale=6.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(SEED), + ), + ) + result = outputs[0] if isinstance(outputs, list) else outputs + _assert_video_output(result, NUM_FRAMES, HEIGHT, WIDTH) + _cleanup(model) diff --git a/tests/e2e/online_serving/test_anisora_online.py b/tests/e2e/online_serving/test_anisora_online.py new file mode 100644 index 00000000000..8354c53503a --- /dev/null +++ b/tests/e2e/online_serving/test_anisora_online.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online serving tests for Index-AniSora I2V models. + +Exercises the /v1/videos async job lifecycle (create → poll → download → delete) +for both V1 (5B CogVideoX) and V2 (14B Wan2.1) via the HTTP API. +""" + +import base64 +import io +import os +import time +import uuid +from pathlib import Path +from typing import Any + +import PIL.Image +import pytest +import requests + +from tests.conftest import OmniServer +from tests.utils import hardware_test + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +MODEL_V1 = "Disty0/Index-anisora-5B-diffusers" +MODEL_V2 = "aardsoul-music/Wan2.1-Anisora-14B" + +VIDEO_POLL_INTERVAL_S = 2.0 +VIDEO_TIMEOUT_S = 900.0 + +NUM_FRAMES = 5 +HEIGHT = 480 +WIDTH = 720 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _dummy_image_b64() -> str: + """Return a small solid-color image as a base64-encoded JPEG string.""" + img = PIL.Image.new("RGB", (WIDTH, HEIGHT), color=(100, 149, 237)) + buf = io.BytesIO() + img.save(buf, format="JPEG") + return base64.b64encode(buf.getvalue()).decode() + + +def _video_api_url(server: OmniServer, suffix: str = "") -> str: + return f"http://{server.host}:{server.port}/v1/videos{suffix}" + + +def _multipart_fields(payload: dict[str, Any]) -> list[tuple[str, tuple[None, str]]]: + return [(k, (None, str(v))) for k, v in payload.items() if v is not None] + + +def _create_video_job(server: OmniServer, image_b64: str, **overrides: Any) -> requests.Response: + payload: dict[str, Any] = { + "prompt": f"a cat sitting calmly {uuid.uuid4().hex[:6]}", + "width": WIDTH, + "height": HEIGHT, + "num_frames": NUM_FRAMES, + "fps": 8, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "seed": 42, + } + payload.update(overrides) + fields = _multipart_fields(payload) + # Pass image as a file-like field + img_bytes = base64.b64decode(image_b64) + files = fields + [("input_reference", ("input.jpg", img_bytes, "image/jpeg"))] + return requests.post(_video_api_url(server), files=files, timeout=VIDEO_TIMEOUT_S) + + +def _wait_for_completion(server: OmniServer, video_id: str) -> dict[str, Any]: + deadline = time.time() + VIDEO_TIMEOUT_S + while time.time() < deadline: + resp = requests.get(_video_api_url(server, f"/{video_id}"), timeout=VIDEO_TIMEOUT_S) + assert resp.status_code == 200, resp.text + data = resp.json() + if data["status"] == "completed": + return data + if data["status"] == "failed": + raise AssertionError(f"Job {video_id} failed: {data}") + time.sleep(VIDEO_POLL_INTERVAL_S) + raise AssertionError(f"Timed out waiting for job {video_id}") + + +def _assert_mp4(content: bytes) -> None: + assert len(content) > 32 + assert content[4:8] == b"ftyp", "Response is not a valid MP4" + + +def _best_effort_delete(server: OmniServer, video_id: str) -> None: + try: + requests.delete(_video_api_url(server, f"/{video_id}"), timeout=30) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# V1 (5B) server fixture +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="function") +def anisora_v1_server(): + with OmniServer(MODEL_V1, ["--num-gpus", "1", "--disable-log-stats"]) as server: + yield server + + +@pytest.fixture(scope="function") +def anisora_v1_tp2_server(): + with OmniServer(MODEL_V1, ["--num-gpus", "2", "--tensor-parallel-size", "2", "--disable-log-stats"]) as server: + yield server + + +# --------------------------------------------------------------------------- +# V2 (14B) server fixture — needs >80GB GPU; skip if not available +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="function") +def anisora_v2_server(): + # V2 (14B) requires 2× GPUs; single A100 80 GB runs OOM + with OmniServer(MODEL_V2, ["--num-gpus", "2", "--disable-log-stats"]) as server: + yield server + + +# --------------------------------------------------------------------------- +# V1 tests +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}) +def test_anisora_v1_online_create_poll_download_delete(anisora_v1_server: OmniServer, tmp_path: Path): + """V1: full job lifecycle — create → poll → download → delete.""" + image_b64 = _dummy_image_b64() + video_id: str | None = None + try: + resp = _create_video_job(anisora_v1_server, image_b64) + assert resp.status_code == 200, resp.text + created = resp.json() + video_id = created["id"] + assert created["status"] == "queued" + assert created["model"] == MODEL_V1 + + completed = _wait_for_completion(anisora_v1_server, video_id) + assert completed["file_name"] is not None + assert completed["progress"] == 100 + + dl = requests.get( + _video_api_url(anisora_v1_server, f"/{video_id}/content"), + timeout=VIDEO_TIMEOUT_S, + ) + assert dl.status_code == 200, dl.text + assert dl.headers["content-type"].startswith("video/mp4") + _assert_mp4(dl.content) + + out = tmp_path / completed["file_name"] + out.write_bytes(dl.content) + assert out.stat().st_size == len(dl.content) + finally: + if video_id: + _best_effort_delete(anisora_v1_server, video_id) + + +# --------------------------------------------------------------------------- +# V1 TP=2 tests +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +def test_anisora_v1_online_tp2_create_poll_download_delete(anisora_v1_tp2_server: OmniServer, tmp_path: Path): + """V1 TP=2: full job lifecycle with tensor parallelism — create → poll → download → delete.""" + image_b64 = _dummy_image_b64() + video_id: str | None = None + try: + resp = _create_video_job(anisora_v1_tp2_server, image_b64) + assert resp.status_code == 200, resp.text + created = resp.json() + video_id = created["id"] + assert created["status"] == "queued" + assert created["model"] == MODEL_V1 + + completed = _wait_for_completion(anisora_v1_tp2_server, video_id) + assert completed["file_name"] is not None + assert completed["progress"] == 100 + + dl = requests.get( + _video_api_url(anisora_v1_tp2_server, f"/{video_id}/content"), + timeout=VIDEO_TIMEOUT_S, + ) + assert dl.status_code == 200, dl.text + assert dl.headers["content-type"].startswith("video/mp4") + _assert_mp4(dl.content) + + out = tmp_path / completed["file_name"] + out.write_bytes(dl.content) + assert out.stat().st_size == len(dl.content) + finally: + if video_id: + _best_effort_delete(anisora_v1_tp2_server, video_id) + + +# --------------------------------------------------------------------------- +# V2 tests +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusion +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}) +def test_anisora_v2_online_create_poll_download_delete(anisora_v2_server: OmniServer, tmp_path: Path): + """V2: full job lifecycle — create → poll → download → delete.""" + image_b64 = _dummy_image_b64() + video_id: str | None = None + try: + resp = _create_video_job(anisora_v2_server, image_b64, guidance_scale=5.0) + assert resp.status_code == 200, resp.text + created = resp.json() + video_id = created["id"] + assert created["status"] == "queued" + + completed = _wait_for_completion(anisora_v2_server, video_id) + assert completed["file_name"] is not None + assert completed["progress"] == 100 + + dl = requests.get( + _video_api_url(anisora_v2_server, f"/{video_id}/content"), + timeout=VIDEO_TIMEOUT_S, + ) + assert dl.status_code == 200, dl.text + _assert_mp4(dl.content) + finally: + if video_id: + _best_effort_delete(anisora_v2_server, video_id) diff --git a/vllm_omni/diffusion/models/anisora/__init__.py b/vllm_omni/diffusion/models/anisora/__init__.py new file mode 100644 index 00000000000..d8bc799e4f6 --- /dev/null +++ b/vllm_omni/diffusion/models/anisora/__init__.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Index-AniSora model support for vLLM-Omni. + +Supports two architectures: +- V1.0 (5B): CogVideoX-based - Uses AniSoraI2VCogVideoXPipeline +- V2/V3 (14B): Wan2.1-based - Uses AniSoraV2I2VPipeline + +V1.0 Models: +- Disty0/Index-anisora-5B-diffusers + +V2/V3 Models (hybrid loading with Wan2.1): +- aardsoul-music/Wan2.1-Anisora-14B +- ikusa/anisorav2 +- IndexTeam/Index-anisora (V3.1, V3.2, anymask) +""" + +from .pipeline_anisora_i2v_cogvideox import ( + AniSoraI2VCogVideoXPipeline, +) +from .pipeline_anisora_v2_i2v import ( + AniSoraV2I2VPipeline, +) + +__all__ = [ + "AniSoraI2VCogVideoXPipeline", + "AniSoraV2I2VPipeline", +] diff --git a/vllm_omni/diffusion/models/anisora/cogvideox_transformer.py b/vllm_omni/diffusion/models/anisora/cogvideox_transformer.py new file mode 100644 index 00000000000..ea5413d2d98 --- /dev/null +++ b/vllm_omni/diffusion/models/anisora/cogvideox_transformer.py @@ -0,0 +1,659 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +CogVideoX Transformer using vLLM-Omni parallelized layers. + +Replaces the stock diffusers CogVideoXTransformer3DModel with an optimized +version that uses QKVParallelLinear, RowParallelLinear, ColumnParallelLinear, +and vLLM-Omni's Attention backend for tensor/sequence parallelism support. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F +from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from diffusers.models.normalization import AdaLayerNorm +from torch import nn +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput +from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# Config container — exposes the same attributes the pipeline reads via +# ``self.transformer.config.`` +# --------------------------------------------------------------------------- +class _CogVideoXTransformerConfig: + """Lightweight config object that mirrors the diffusers ConfigMixin API.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +# --------------------------------------------------------------------------- +# Column-parallel GELU (same helper as in flux_transformer.py) +# --------------------------------------------------------------------------- +class ColumnParallelApproxGELU(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + approximate: str, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + quant_config=quant_config, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +# --------------------------------------------------------------------------- +# Feed-forward (ColumnParallel GELU → RowParallel) +# --------------------------------------------------------------------------- +class CogVideoXFeedForward(nn.Module): + def __init__( + self, + dim: int, + inner_dim: int | None = None, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + inner_dim = inner_dim or dim * 4 + # nn.ModuleList indices kept identical to diffusers for weight name + # compatibility: net.0 = GELU proj, net.1 = (identity/dropout), net.2 = out proj. + self.net = nn.ModuleList( + [ + ColumnParallelApproxGELU( + dim, + inner_dim, + approximate="tanh", + bias=bias, + quant_config=quant_config, + ), + nn.Identity(), + RowParallelLinear( + inner_dim, + dim, + input_is_parallel=True, + return_bias=False, + quant_config=quant_config, + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for module in self.net: + x = module(x) + return x + + +# --------------------------------------------------------------------------- +# Adaptive LayerNorm Zero — keeps diffusers naming for weight loading +# --------------------------------------------------------------------------- +class CogVideoXLayerNormZero(nn.Module): + """Adaptive LN that produces 6 modulation vectors (shift/scale/gate × 2 streams).""" + + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + ): + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +# --------------------------------------------------------------------------- +# CogVideoX Attention (fused QKV with RoPE on image tokens only) +# --------------------------------------------------------------------------- +class CogVideoXAttention(nn.Module): + def __init__( + self, + query_dim: int, + num_attention_heads: int, + attention_head_dim: int, + bias: bool = True, + out_bias: bool = True, + qk_norm: bool = True, + eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + parallel_config: DiffusionParallelConfig | None = None, + ): + super().__init__() + self.heads = num_attention_heads + self.head_dim = attention_head_dim + self.parallel_config = parallel_config + inner_dim = num_attention_heads * attention_head_dim + + # Fused Q/K/V projection + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=attention_head_dim, + total_num_heads=num_attention_heads, + bias=bias, + quant_config=quant_config, + ) + + # Q/K normalization (LayerNorm, matching diffusers qk_norm="layer_norm") + if qk_norm: + self.norm_q = nn.LayerNorm(attention_head_dim, eps=eps, elementwise_affine=True) + self.norm_k = nn.LayerNorm(attention_head_dim, eps=eps, elementwise_affine=True) + else: + self.norm_q = None + self.norm_k = None + + # Rotary embeddings (interleaved style, same as diffusers CogVideoX) + self.rope = RotaryEmbedding(is_neox_style=False) + + # vLLM-Omni attention backend + self.attn = Attention( + num_heads=self.to_qkv.num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + num_kv_heads=self.to_qkv.num_kv_heads, + ) + + # Output projection + self.to_out = RowParallelLinear( + inner_dim, + query_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + quant_config=quant_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + text_seq_length = encoder_hidden_states.size(1) + + # Concatenate text + image (CogVideoX does joint self-attention) + combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # Fused QKV projection + qkv, _ = self.to_qkv(combined) + q_size = self.to_qkv.num_heads * self.head_dim + kv_size = self.to_qkv.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + + # Reshape to [B, S, num_heads, head_dim] + query = query.unflatten(-1, (self.to_qkv.num_heads, -1)) + key = key.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + value = value.unflatten(-1, (self.to_qkv.num_kv_heads, -1)) + + # Q/K normalization + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Apply RoPE only to image tokens (not text) + if image_rotary_emb is not None: + cos, sin = image_rotary_emb + cos = cos.to(query.dtype) + sin = sin.to(query.dtype) + # vllm's RotaryEmbedding expects half-dim cos/sin (it doubles internally) + # diffusers' get_3d_rotary_pos_embed returns full-dim with repeat_interleave(2): + # (c1, c1, c2, c2, ...) — take every other element to get (c1, c2, ...) + cos = cos[..., ::2] + sin = sin[..., ::2] + # Split text and image portions + q_text, q_image = query[:, :text_seq_length], query[:, text_seq_length:] + k_text, k_image = key[:, :text_seq_length], key[:, text_seq_length:] + # Apply RoPE to image tokens only + q_image = self.rope(q_image, cos, sin) + k_image = self.rope(k_image, cos, sin) + # Recombine + query = torch.cat([q_text, q_image], dim=1) + key = torch.cat([k_text, k_image], dim=1) + + # Attention — SP-aware: text tokens are replicated (not sharded) across SP ranks, + # so pass them as joint (front) context while image tokens are the sharded sequence. + sp_size = self.parallel_config.sequence_parallel_size if self.parallel_config else None + use_sp_joint = ( + sp_size is not None + and sp_size > 1 + and is_forward_context_available() + and not get_forward_context().split_text_embed_in_sp + ) + if use_sp_joint: + # SP-aware attention: image shard attends to [text, image_full] via Ulysses all-to-all. + # Text tokens are replicated (not sharded) and passed as joint context. + # Ulysses gathers the full image sequence before attention, then scatters back, + # so each rank computes full-sequence attention for its head/sequence shard. + # Output: [B, text + image_shard, H, D] — text portion gathered via AllGather on heads + # (identical on all ranks), image portion via reverse AllToAll (shard for this rank). + hidden_states = self.attn( + query[:, text_seq_length:], + key[:, text_seq_length:], + value[:, text_seq_length:], + AttentionMetadata( + joint_query=query[:, :text_seq_length], + joint_key=key[:, :text_seq_length], + joint_value=value[:, :text_seq_length], + joint_strategy="front", + ), + ) + else: + hidden_states = self.attn(query, key, value) + + # Flatten heads and project output + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + hidden_states = self.to_out(hidden_states) + + # Split back into text and image streams + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# CogVideoX Transformer Block +# --------------------------------------------------------------------------- +class CogVideoXBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + ff_inner_dim: int | None = None, + bias: bool = True, + attention_bias: bool = True, + attention_out_bias: bool = True, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + quant_config: QuantizationConfig | None = None, + parallel_config: DiffusionParallelConfig | None = None, + ): + super().__init__() + + # 1. Attention path + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + self.attn1 = CogVideoXAttention( + query_dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + quant_config=quant_config, + parallel_config=parallel_config, + ) + + # 2. FFN path + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + self.ff = CogVideoXFeedForward( + dim=dim, + inner_dim=ff_inner_dim, + bias=bias, + quant_config=quant_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + text_seq_length = encoder_hidden_states.size(1) + + # Attention with AdaLN modulation + norm_hidden, norm_enc_hidden, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + attn_hidden, attn_enc_hidden = self.attn1(norm_hidden, norm_enc_hidden, image_rotary_emb=image_rotary_emb) + hidden_states = hidden_states + gate_msa * attn_hidden + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_enc_hidden + + # FFN with AdaLN modulation + norm_hidden, norm_enc_hidden, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + # Concatenate for joint FFN (text first, matching diffusers) + combined = torch.cat([norm_enc_hidden, norm_hidden], dim=1) + ff_output = self.ff(combined) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# CogVideoX Transformer 3D Model (main model) +# --------------------------------------------------------------------------- +class CogVideoXTransformer3DModel(nn.Module): + """ + Optimized CogVideoX Transformer using vLLM-Omni parallelized layers. + + Uses QKVParallelLinear for fused Q/K/V projections, ColumnParallelLinear + and RowParallelLinear for FFN, and vLLM-Omni's Attention backend. + + Sequence Parallelism: + Supports non-intrusive SP via _sp_plan (Ulysses or hybrid Ulysses+Ring). + The plan splits image tokens across SP ranks while keeping text tokens + (encoder_hidden_states) replicated. The attention uses joint_strategy="front" + to combine replicated text context with sharded image tokens correctly. + + - image_rotary_emb: split at model entry (root ""), dim=0 + - hidden_states (image tokens): split at transformer_blocks.0, dim=1 + - proj_out output: gathered back to full sequence for unpatchify + """ + + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + } + + # Sequence Parallelism for CogVideoX (following diffusers' _cp_plan pattern). + # + # CogVideoX uses joint text+image self-attention in every block. For SP, only + # image tokens are sharded; text tokens stay replicated on all ranks and are + # passed as AttentionMetadata joint context (joint_strategy="front"). + # + # - "": split image_rotary_emb (tuple of [img_seq, head_dim] tensors) before patch_embed + # - "transformer_blocks.0": split hidden_states (image tokens [B, img_seq, D]) at first block + # - "proj_out": gather image tokens after the final linear projection (before unpatchify) + _sp_plan = { + # Split both RoPE tensors (cos, sin) along the sequence dim at model entry. + # image_rotary_emb is (cos, sin), each shaped [img_seq, head_dim]. + "": { + "image_rotary_emb": [ + SequenceParallelInput(split_dim=0, expected_dims=2, auto_pad=True), + SequenceParallelInput(split_dim=0, expected_dims=2, auto_pad=True), + ], + }, + # Split image hidden_states along sequence dim at the first block's input. + # After patch_embed + manual split, hidden_states is [B, img_seq, D]. + "transformer_blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True), + }, + # Gather image tokens after proj_out; unpatchify then operates on the full sequence. + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def __init__( + self, + od_config: OmniDiffusionConfig, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: int | None = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + ofs_embed_dim: int | None = None, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + patch_size_t: int | None = None, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + patch_bias: bool = True, + quant_config: QuantizationConfig | None = None, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # Store parallel config for SP-aware attention + self.parallel_config = od_config.parallel_config + + # Store config for pipeline access (self.transformer.config.*) + self.config = _CogVideoXTransformerConfig( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + time_embed_dim=time_embed_dim, + text_embed_dim=text_embed_dim, + num_layers=num_layers, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + patch_size=patch_size, + patch_size_t=patch_size_t, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + ) + + # 1. Patch embedding (kept from diffusers — runs once, not a bottleneck) + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + patch_size_t=patch_size_t, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=patch_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Timestep embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + self.ofs_proj = None + self.ofs_embedding = None + if ofs_embed_dim: + self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift) + self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) + + # 3. Transformer blocks (optimized with vLLM-Omni layers) + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + quant_config=quant_config, + parallel_config=self.parallel_config, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks (kept from diffusers — runs once) + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + + if patch_size_t is None: + output_dim = patch_size * patch_size * out_channels + else: + output_dim = patch_size * patch_size * patch_size_t * out_channels + self.proj_out = nn.Linear(inner_dim, output_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor | int | float, + timestep_cond: torch.Tensor | None = None, + ofs: torch.LongTensor | int | float | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + return_dict: bool = True, + ) -> tuple[torch.Tensor] | torch.Tensor: + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Timestep embedding + t_emb = self.time_proj(timestep) + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if self.ofs_embedding is not None and ofs is not None: + ofs_emb = self.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = self.ofs_embedding(ofs_emb) + emb = emb + ofs_emb + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # 3. Transformer blocks + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_final(hidden_states) + + # 4. Output + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + p = self.config.patch_size + p_t = self.config.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + if not return_dict: + return (output,) + return output + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Map diffusers separate Q/K/V weights to our fused QKV + stacked_params_mapping = [ + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + original_name = name + lookup_name = name + + # Handle stacked Q/K/V parameters + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in original_name: + continue + lookup_name = original_name.replace(weight_name, param_name) + if lookup_name not in params_dict: + break + param = params_dict[lookup_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Diffusers uses attn1.to_out.0 (ModuleList); we use attn1.to_out (direct) + if ".to_out.0." in lookup_name: + lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + + if lookup_name not in params_dict: + logger.debug("Skipping weight %s (not in model)", original_name) + continue + + param = params_dict[lookup_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(original_name) + loaded_params.add(lookup_name) + + return loaded_params diff --git a/vllm_omni/diffusion/models/anisora/pipeline_anisora_i2v_cogvideox.py b/vllm_omni/diffusion/models/anisora/pipeline_anisora_i2v_cogvideox.py new file mode 100644 index 00000000000..028d9df301c --- /dev/null +++ b/vllm_omni/diffusion/models/anisora/pipeline_anisora_i2v_cogvideox.py @@ -0,0 +1,671 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AniSora I2V Pipeline using CogVideoX architecture. + +AniSora V1 (5B) is based on CogVideoX. The transformer uses vLLM-Omni's +parallelized layers (QKVParallelLinear, etc.) for tensor parallelism support. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable + +import numpy as np +import PIL.Image +import torch +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, +) +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.models.transformers.cogvideox_transformer_3d import ( + CogVideoXTransformer3DModel as DiffusersCogVideoXTransformer, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, T5EncoderModel +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.anisora.cogvideox_transformer import ( + CogVideoXTransformer3DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) + + +def get_anisora_i2v_post_process_func( + od_config: OmniDiffusionConfig, +): + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_anisora_i2v_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for I2V: load and resize input image.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + "No image is provided. This model requires an image to run. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"""Unsupported image format {raw_image.__class__}.""", + """Please correctly set `"multi_modal_data": {"image": , …}`""", + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 480P + max_area = 480 * 832 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # height/width set above + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence checked above + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, + height=request.sampling_params.height, + width=request.sampling_params.width, + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +class AniSoraI2VCogVideoXPipeline(nn.Module): + """ + AniSora Image-to-Video Pipeline using CogVideoX architecture. + + The transformer uses vLLM-Omni's optimized layers (QKVParallelLinear, etc.) + for tensor parallelism support. Compatible with the + Disty0/Index-anisora-5B-diffusers model. + """ + + # vLLM uses this flag to decide whether to feed dummy images in warmup + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + ): + super().__init__() + self.device = get_local_device() + self.dtype = od_config.dtype + model_path = od_config.model + + local_files_only = os.path.exists(model_path) if isinstance(model_path, str) else False + + logger.info("Loading tokenizer from %s...", model_path) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder="tokenizer", + local_files_only=local_files_only, + ) + + logger.info("Loading text encoder from %s...", model_path) + self.text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + torch_dtype=self.dtype, + local_files_only=local_files_only, + ) + + logger.info("Loading VAE from %s...", model_path) + self.vae = AutoencoderKLCogVideoX.from_pretrained( + model_path, + subfolder="vae", + torch_dtype=torch.float32, # VAE in float32 for precision + local_files_only=local_files_only, + ) + + # Load transformer config from pretrained, then build optimized model + logger.info("Loading transformer config from %s...", model_path) + tf_config = DiffusersCogVideoXTransformer.load_config( + model_path, subfolder="transformer", local_files_only=local_files_only + ) + self.transformer = CogVideoXTransformer3DModel( + od_config=od_config, + num_attention_heads=tf_config.get("num_attention_heads", 30), + attention_head_dim=tf_config.get("attention_head_dim", 64), + in_channels=tf_config.get("in_channels", 16), + out_channels=tf_config.get("out_channels", 16), + flip_sin_to_cos=tf_config.get("flip_sin_to_cos", True), + freq_shift=tf_config.get("freq_shift", 0), + time_embed_dim=tf_config.get("time_embed_dim", 512), + ofs_embed_dim=tf_config.get("ofs_embed_dim", None), + text_embed_dim=tf_config.get("text_embed_dim", 4096), + num_layers=tf_config.get("num_layers", 30), + dropout=tf_config.get("dropout", 0.0), + attention_bias=tf_config.get("attention_bias", True), + sample_width=tf_config.get("sample_width", 90), + sample_height=tf_config.get("sample_height", 60), + sample_frames=tf_config.get("sample_frames", 49), + patch_size=tf_config.get("patch_size", 2), + patch_size_t=tf_config.get("patch_size_t", None), + temporal_compression_ratio=tf_config.get("temporal_compression_ratio", 4), + max_text_seq_length=tf_config.get("max_text_seq_length", 226), + activation_fn=tf_config.get("activation_fn", "gelu-approximate"), + timestep_activation_fn=tf_config.get("timestep_activation_fn", "silu"), + norm_elementwise_affine=tf_config.get("norm_elementwise_affine", True), + norm_eps=tf_config.get("norm_eps", 1e-5), + spatial_interpolation_scale=tf_config.get("spatial_interpolation_scale", 1.875), + temporal_interpolation_scale=tf_config.get("temporal_interpolation_scale", 1.0), + use_rotary_positional_embeddings=tf_config.get("use_rotary_positional_embeddings", False), + use_learned_positional_embeddings=tf_config.get("use_learned_positional_embeddings", False), + patch_bias=tf_config.get("patch_bias", True), + ) + + # Tell the framework to load transformer weights from checkpoint + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model_path, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + logger.info("Loading scheduler from %s...", model_path) + self.scheduler = CogVideoXDDIMScheduler.from_pretrained( + model_path, + subfolder="scheduler", + local_files_only=local_files_only, + ) + + # Scale factors from VAE config + self.vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio + + self._current_timestep = None + logger.info("Pipeline loaded successfully!") + + def to(self, device): + """Move pipeline to device.""" + self.device = device + self.text_encoder = self.text_encoder.to(device) + self.vae = self.vae.to(device) + self.transformer = self.transformer.to(device) + return self + + @torch.no_grad() + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + max_sequence_length: int = 226, + ): + """Encode text prompts using T5.""" + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(self.device) + + prompt_embeds = self.text_encoder(text_input_ids)[0] + prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) + + # Negative prompt + if negative_prompt is not None: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + uncond_input = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.dtype, device=self.device) + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def encode_image(self, image: PIL.Image.Image | torch.Tensor, height: int, width: int): + """Encode input image to latent space.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Preprocess if PIL Image + if isinstance(image, PIL.Image.Image): + image_tensor = video_processor.preprocess(image, height=height, width=width) + else: + image_tensor = image + + # Move to device and ensure correct dtype (VAE expects float32 for encoding) + image_tensor = image_tensor.to(device=self.device, dtype=torch.float32) + + # Add frame dimension: [B, C, H, W] -> [B, C, F, H, W] with F=1 + # CogVideoX VAE expects [B, C, F, H, W] + image_tensor = image_tensor.unsqueeze(2) # [B, C, 1, H, W] + + # Encode + latent = self.vae.encode(image_tensor).latent_dist.sample() + + # Scale by latent std if available + if hasattr(self.vae.config, "scaling_factor"): + latent = latent * self.vae.config.scaling_factor + + return latent # [B, C, 1, H//8, W//8] + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + ): + """Prepare 3D rotary positional embeddings.""" + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 style + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=self.device, + ) + else: + # CogVideoX 1.5 style + base_num_frames = (num_frames + p_t - 1) // p_t + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + device=self.device, + ) + + return freqs_cos, freqs_sin + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + image: PIL.Image.Image | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int | None = None, + num_inference_steps: int | None = None, + guidance_scale: float | None = None, + generator: torch.Generator | None = None, + output_type: str | None = None, + **kwargs, + ) -> DiffusionOutput: + """ + Forward pass for vLLM framework integration. + Extracts parameters from OmniDiffusionRequest and generates video. + """ + # Ensure model weights are on the correct device before any compute + self.to(self.device) + if len(req.prompts) > 1: + raise ValueError( + "This model only supports a single prompt, not a batched request. " + "Please pass in a single prompt object or string, or a single-item list." + ) + + # Extract text prompts from request if not explicitly provided + if prompt is None: + first_prompt = None + if getattr(req, "prompts", None): + first_prompt = req.prompts[0] + if isinstance(first_prompt, str): + prompt = first_prompt + elif isinstance(first_prompt, dict): + prompt = first_prompt.get("text") or first_prompt.get("prompt") or first_prompt.get("caption") + + if negative_prompt is None: + neg = None + first_prompt = None + if getattr(req, "prompts", None): + first_prompt = req.prompts[0] + if isinstance(first_prompt, dict): + neg = first_prompt.get("negative_text") or first_prompt.get("negative_prompt") + negative_prompt = neg + + # Use preprocessed_image if available, otherwise use image from multi_modal_data + if image is None: + first_prompt = None + if getattr(req, "prompts", None): + first_prompt = req.prompts[0] + if isinstance(first_prompt, dict): + additional_info = first_prompt.get("additional_information", {}) + if isinstance(additional_info, dict) and "preprocessed_image" in additional_info: + image = additional_info["preprocessed_image"] + elif "multi_modal_data" in first_prompt and isinstance(first_prompt["multi_modal_data"], dict): + image = first_prompt["multi_modal_data"].get("image") + + # Derive sampling parameters from explicit args, then from req.sampling_params, then defaults + sampling_params = getattr(req, "sampling_params", None) + + def _get_sp_attr(name, *aliases, default=None): + if sampling_params is None: + return default + for key in (name, *aliases): + value = getattr(sampling_params, key, None) + if value is not None: + return value + return default + + if height is None: + height = _get_sp_attr("height", default=480) + if width is None: + width = _get_sp_attr("width", default=832) + if num_frames is None: + num_frames = _get_sp_attr("num_frames", "frames", default=17) + if num_inference_steps is None: + num_inference_steps = _get_sp_attr("num_inference_steps", "steps", default=50) + + if getattr(sampling_params, "guidance_scale_provided", False): + guidance_scale = sampling_params.guidance_scale + elif guidance_scale is None: + guidance_scale = _get_sp_attr("guidance_scale", "cfg_scale", default=6.0) + output_type = output_type or "tensor" + + if prompt is None: + raise ValueError("Prompt is required") + if image is None: + raise ValueError("Image is required for I2V generation") + if isinstance(image, str): + image = PIL.Image.open(image).convert("RGB") + + return self._generate( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type=output_type, + ) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + image: PIL.Image.Image, + negative_prompt: str | list[str] | None = None, + height: int = 480, + width: int = 832, + num_frames: int = 17, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + generator: torch.Generator | None = None, + output_type: str = "tensor", + ): + """ + Direct call interface for standalone usage. + + Args: + prompt: Text prompt(s) + image: Input image + negative_prompt: Negative prompt(s) + height: Output height + width: Output width + num_frames: Number of output frames + num_inference_steps: Denoising steps + guidance_scale: Classifier-free guidance scale + generator: Random generator for reproducibility + output_type: "tensor", "np", or "pil" + """ + return self._generate( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type=output_type, + ) + + def _generate( + self, + prompt: str | list[str], + image: PIL.Image.Image, + negative_prompt: str | list[str] | None = None, + height: int = 480, + width: int = 832, + num_frames: int = 17, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + generator: torch.Generator | None = None, + output_type: str = "tensor", + ) -> DiffusionOutput: + # Default to empty negative prompt so CFG is not silently skipped + if negative_prompt is None and guidance_scale > 1.0: + negative_prompt = "" + + # Encode prompt + logger.info("Encoding prompts...") + prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt) + + do_classifier_free_guidance = guidance_scale > 1.0 and negative_prompt_embeds is not None + + # Encode image + logger.info("Encoding image...") + image_latents = self.encode_image(image, height, width) + + # Prepare latent dimensions + batch_size = prompt_embeds.shape[0] + num_channels_latents = self.transformer.config.in_channels // 2 # 16 for noise, 16 for image + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # CogVideoX uses [B, F, C, H, W] format + latent_shape = ( + batch_size, + latent_num_frames, + num_channels_latents, + latent_height, + latent_width, + ) + + # Initial noise + logger.info("Preparing latents...") + latents = randn_tensor(latent_shape, generator=generator, device=self.device, dtype=self.dtype) + + # Prepare image latents for conditioning + # image_latents: [B, C, 1, H, W] -> [B, 1, C, H, W] + image_latents = image_latents.permute(0, 2, 1, 3, 4).to(dtype=self.dtype) + + # Pad image latents to match num_frames: first frame is image, rest are zeros + padding_shape = ( + batch_size, + latent_num_frames - 1, + num_channels_latents, + latent_height, + latent_width, + ) + latent_padding = torch.zeros(padding_shape, device=self.device, dtype=self.dtype) + image_latents_padded = torch.cat([image_latents, latent_padding], dim=1) # [B, F, C, H, W] + + # Prepare rotary embeddings + logger.info("Preparing rotary embeddings...") + image_rotary_emb = self._prepare_rotary_positional_embeddings(height, width, latent_num_frames) + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + + # Scale initial noise + latents = latents * self.scheduler.init_noise_sigma + + logger.info("Starting denoising loop (%d steps)...", num_inference_steps) + for i, t in enumerate(timesteps): + self._current_timestep = t + + # Expand latents for CFG + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Expand image latents for CFG + latent_image_input = ( + torch.cat([image_latents_padded] * 2) if do_classifier_free_guidance else image_latents_padded + ) + + # Concatenate noise and image latents along channel dimension + # [B, F, C, H, W] + [B, F, C, H, W] -> [B, F, 2C, H, W] + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # Prepare prompt embeds for CFG + if do_classifier_free_guidance: + prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds]) + else: + prompt_embeds_input = prompt_embeds + + # Expand timestep to match batch dimension (for CFG) + batch_size = latent_model_input.shape[0] + timestep = t.expand(batch_size).to(latent_model_input.dtype) + + # Predict noise + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds_input, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # CFG + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if (i + 1) % 10 == 0: + logger.debug("Step %d/%d", i + 1, num_inference_steps) + + self._current_timestep = None + logger.info("Decoding latents...") + + # Decode latents + # CogVideoX VAE expects [B, C, F, H, W] + latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] -> [B, C, F, H, W] + + # Unscale + if hasattr(self.vae.config, "scaling_factor"): + latents = latents / self.vae.config.scaling_factor + + latents = latents.to(dtype=torch.float32) + video = self.vae.decode(latents).sample + + # video: [B, C, F, H, W] in range [-1, 1] + logger.info("Output shape: %s", video.shape) + logger.info("Output range: [%.3f, %.3f]", video.min().item(), video.max().item()) + + return DiffusionOutput(output=video) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights) + # Record components already loaded via from_pretrained + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + return loaded_weights diff --git a/vllm_omni/diffusion/models/anisora/pipeline_anisora_v2_i2v.py b/vllm_omni/diffusion/models/anisora/pipeline_anisora_v2_i2v.py new file mode 100644 index 00000000000..a0eb55b502a --- /dev/null +++ b/vllm_omni/diffusion/models/anisora/pipeline_anisora_v2_i2v.py @@ -0,0 +1,836 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AniSora V2/V3 Image-to-Video Pipeline using Wan2.1 architecture. + +The transformer uses vLLM-Omni's optimized WanTransformer3DModel with +QKVParallelLinear for tensor parallelism support. + +This pipeline uses a hybrid loading approach: +- VAE, Text Encoder, Scheduler from official Wan2.1-I2V-14B-Diffusers +- Transformer weights from AniSora (community conversion or official) + +Supports: +- IndexTeam/Index-anisora (V2, V3.1, V3.2, anymask) +- aardsoul-music/Wan2.1-Anisora-14B +- ikusa/anisorav2 + +All these use WanModel architecture with in_dim=36 (16 noise + 16 image + 4 mask). +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import ( + AutoTokenizer, + CLIPImageProcessor, + CLIPVisionModel, + UMT5EncoderModel, +) + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) + +# Default Wan2.1 base repo for loading shared components (VAE, T5, CLIP). +# To use a local copy or a different repo, pass it via od_config.model_paths["wan_base"]. +DEFAULT_WAN_BASE = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + + +def get_anisora_v2_i2v_post_process_func( + od_config: OmniDiffusionConfig, +): + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_anisora_v2_i2v_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for I2V: load and resize input image.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + "No image is provided. This model requires an image to run. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"Unsupported image format {raw_image.__class__}. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 480P + max_area = 480 * 832 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore # height/width set above + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore # key existence checked above + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, + height=request.sampling_params.height, + width=request.sampling_params.width, + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +class AniSoraV2I2VPipeline(nn.Module): + """ + AniSora V2/V3 Image-to-Video Pipeline using Wan2.1 architecture. + + This pipeline uses a hybrid loading approach for diffusers compatibility: + - VAE, T5, CLIP, Scheduler from official Wan2.1 diffusers repo + - Transformer weights from AniSora community conversions + """ + + # vLLM uses this flag to decide whether to feed dummy images in warmup + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + """ + Args: + od_config: OmniDiffusionConfig with model id/path, dtype, and runtime options. + prefix: Reserved prefix string for compatibility (currently unused). + """ + super().__init__() + self.od_config = od_config + self.device = get_local_device() + self.dtype = od_config.dtype + + model_path = od_config.model + wan_base_path = od_config.model_paths.get("wan_base", DEFAULT_WAN_BASE) + + flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + + # Determine if local files + local_wan = os.path.exists(wan_base_path) + + logger.info("=== AniSora V2 I2V Pipeline (Hybrid Loading) ===") + logger.info("AniSora transformer: %s", model_path) + logger.info("Wan2.1 base (VAE/T5): %s", wan_base_path) + + # Load tokenizer from Wan base + logger.info("Loading tokenizer from Wan2.1...") + self.tokenizer = AutoTokenizer.from_pretrained( + wan_base_path, + subfolder="tokenizer", + local_files_only=local_wan, + ) + + # Load T5 text encoder from Wan base + logger.info("Loading T5 text encoder from Wan2.1...") + self.text_encoder = UMT5EncoderModel.from_pretrained( + wan_base_path, + subfolder="text_encoder", + torch_dtype=self.dtype, + local_files_only=local_wan, + ).to(self.device) + + # Load CLIP image encoder from Wan base (for I2V conditioning) + logger.info("Loading CLIP image encoder from Wan2.1...") + try: + self.image_processor = CLIPImageProcessor.from_pretrained( + wan_base_path, + subfolder="image_processor", + local_files_only=local_wan, + ) + self.image_encoder = CLIPVisionModel.from_pretrained( + wan_base_path, + subfolder="image_encoder", + torch_dtype=self.dtype, + local_files_only=local_wan, + ).to(self.device) + self.has_image_encoder = True + except Exception as e: + raise RuntimeError( + f"Failed to load CLIP image encoder from {wan_base_path}. " + "This is required for I2V (image-to-video) conditioning." + ) from e + + # Load VAE from Wan base + logger.info("Loading VAE from Wan2.1...") + self.vae = AutoencoderKLWan.from_pretrained( + wan_base_path, + subfolder="vae", + torch_dtype=torch.float32, # VAE in float32 for precision + local_files_only=local_wan, + ).to(self.device) + + # Load transformer using vLLM-Omni's optimized WanTransformer3DModel + logger.info("Loading transformer from AniSora: %s...", model_path) + + # Wan2.1 I2V config values (in_channels=36 for I2V: 16 noise + 16 image + 4 mask) + self.transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=40, + attention_head_dim=128, + in_channels=36, + out_channels=16, + text_dim=4096, + freq_dim=256, + ffn_dim=13824, + num_layers=40, + cross_attn_norm=True, + eps=1e-6, + image_dim=1280, + added_kv_proj_dim=5120, + rope_max_seq_len=1024, + ) + + # Tell the framework to load transformer weights from AniSora checkpoint + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model_path, + subfolder=None, + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + # Initialize scheduler + logger.info("Initializing scheduler...") + self.scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", + ) + + # VAE scale factors + self.vae_scale_factor_temporal = getattr(self.vae.config, "temporal_compression_ratio", 4) + self.vae_scale_factor_spatial = getattr(self.vae.config, "spatial_compression_ratio", 8) + + self._current_timestep = None + self._device_moved = False + logger.info("Pipeline loaded successfully!") + + def to(self, device): + """Move pipeline to device.""" + self.device = device + self.text_encoder = self.text_encoder.to(device) + self.vae = self.vae.to(device) + self.transformer = self.transformer.to(device) + self.image_encoder = self.image_encoder.to(device) + return self + + @staticmethod + def _prompt_clean(text: str) -> str: + """Clean prompt text.""" + return " ".join(text.strip().split()) + + @torch.no_grad() + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + max_sequence_length: int = 512, + ): + """Encode text prompts using T5.""" + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + prompt_clean = [self._prompt_clean(p) for p in prompt] + + text_inputs = self.tokenizer( + prompt_clean, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids = text_inputs.input_ids.to(self.device) + mask = text_inputs.attention_mask.to(self.device) + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(ids, mask).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) + + # Trim and pad to consistent length + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], + dim=0, + ) + + # Negative prompt + negative_prompt_embeds = None + if negative_prompt is not None: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + neg_text_inputs = self.tokenizer( + [self._prompt_clean(p) for p in negative_prompt], + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + ids_neg = neg_text_inputs.input_ids.to(self.device) + mask_neg = neg_text_inputs.attention_mask.to(self.device) + seq_lens_neg = mask_neg.gt(0).sum(dim=1).long() + + negative_prompt_embeds = self.text_encoder(ids_neg, mask_neg).last_hidden_state + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.dtype, device=self.device) + negative_prompt_embeds = [u[:v] for u, v in zip(negative_prompt_embeds, seq_lens_neg)] + negative_prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) + for u in negative_prompt_embeds + ], + dim=0, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def encode_image_clip(self, image: PIL.Image.Image) -> torch.Tensor: + """Encode image using CLIP for conditioning.""" + pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=self.device, dtype=self.dtype) + image_embeds = self.image_encoder(pixel_values, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None = None, + last_image: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare latents for I2V generation. + + Returns: + latents: Initial noise latents [B, C, F, H, W] + condition: Encoded image condition with mask [B, C+4, F, H, W] + first_frame_mask: Mask for conditioning + """ + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + latent_height, + latent_width, + ) + + # Generate noise + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # Prepare image condition + image = image.unsqueeze(2) # [B, C, 1, H, W] + + if last_image is None: + # Pad with zeros for remaining frames + video_condition = torch.cat( + [ + image, + image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width), + ], + dim=2, + ) + else: + # First and last frame conditioning + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [ + image, + image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), + last_image, + ], + dim=2, + ) + + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + # Encode through VAE + latent_condition = self.vae.encode(video_condition).latent_dist.mode() + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + # Normalize latents + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, -1, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + latent_condition.device, latent_condition.dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + + latent_condition = latent_condition.to(dtype) + + # Create mask: 1 for frames with condition, 0 for frames to generate + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width, device=device) + if last_image is None: + mask_lat_size[:, :, 1:] = 0 # Only first frame is conditioned + else: + mask_lat_size[:, :, 1:-1] = 0 # First and last frames are conditioned + + # Compress mask temporally + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = first_frame_mask.repeat(1, 1, self.vae_scale_factor_temporal, 1, 1) + mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + # Concatenate mask with condition [B, C+4, F, H, W] + condition = torch.cat([mask_lat_size, latent_condition], dim=1) + + # Return placeholder for first_frame_mask (not used in this mode) + first_frame_mask = torch.ones( + 1, + 1, + num_latent_frames, + latent_height, + latent_width, + dtype=dtype, + device=device, + ) + + return latents, condition, first_frame_mask + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + image: PIL.Image.Image | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int | None = None, + num_inference_steps: int | None = None, + guidance_scale: float | None = None, + generator: torch.Generator | None = None, + last_image: PIL.Image.Image | None = None, + output_type: str | None = None, + **kwargs, + ) -> DiffusionOutput: + """ + Forward pass for vLLM framework integration. + Extracts parameters from OmniDiffusionRequest and generates video. + """ + # Move to device once on first forward (after framework loads weights) + if not self._device_moved: + self.to(self.device) + self._device_moved = True + if len(req.prompts) > 1: + raise ValueError( + "This model only supports a single prompt, not a batched request. " + "Please pass in a single prompt object or string, or a single-item list." + ) + + # Extract text prompts from request if not explicitly provided + if prompt is None: + first_prompt = None + if getattr(req, "prompts", None): + first_prompt = req.prompts[0] + if isinstance(first_prompt, str): + prompt = first_prompt + elif isinstance(first_prompt, dict): + prompt = first_prompt.get("text") or first_prompt.get("prompt") or first_prompt.get("caption") + + if negative_prompt is None: + neg = None + first_prompt = None + if getattr(req, "prompts", None): + first_prompt = req.prompts[0] + if isinstance(first_prompt, dict): + neg = first_prompt.get("negative_text") or first_prompt.get("negative_prompt") + negative_prompt = neg + + # Prefer PIL image from multi_modal_data for CLIP conditioning. + if image is None: + first_prompt = None + if getattr(req, "prompts", None): + first_prompt = req.prompts[0] + if isinstance(first_prompt, dict): + # First try the PIL image from multi_modal_data (needed for CLIP) + if "multi_modal_data" in first_prompt and isinstance(first_prompt["multi_modal_data"], dict): + image = first_prompt["multi_modal_data"].get("image") + + # Derive sampling parameters from explicit args, then from req.sampling_params, then defaults + sampling_params = getattr(req, "sampling_params", None) + + def _get_sp_attr(name, *aliases, default=None): + if sampling_params is None: + return default + for key in (name, *aliases): + value = getattr(sampling_params, key, None) + if value is not None: + return value + return default + + if height is None: + height = _get_sp_attr("height", default=480) + if width is None: + width = _get_sp_attr("width", default=832) + if num_frames is None: + num_frames = _get_sp_attr("num_frames", "frames", default=81) + if num_inference_steps is None: + num_inference_steps = _get_sp_attr("num_inference_steps", "steps", default=40) + + if getattr(sampling_params, "guidance_scale_provided", False): + guidance_scale = sampling_params.guidance_scale + elif guidance_scale is None: + guidance_scale = _get_sp_attr("guidance_scale", "cfg_scale", default=5.0) + output_type = output_type or "tensor" + + if prompt is None: + raise ValueError("Prompt is required") + if image is None: + raise ValueError("Image is required for I2V generation") + if isinstance(image, str): + image = PIL.Image.open(image).convert("RGB") + + return self._generate( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + last_image=last_image, + output_type=output_type, + ) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + image: PIL.Image.Image, + negative_prompt: str | list[str] | None = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 40, + guidance_scale: float = 5.0, + generator: torch.Generator | None = None, + last_image: PIL.Image.Image | None = None, + output_type: str = "tensor", + ) -> DiffusionOutput: + """ + Direct call interface for standalone usage. + + Args: + prompt: Text prompt(s) + image: Input image (first frame) + negative_prompt: Negative prompt(s) + height: Output height + width: Output width + num_frames: Number of output frames (should be 4n+1) + num_inference_steps: Denoising steps (40 recommended for I2V) + guidance_scale: Classifier-free guidance scale + generator: Random generator for reproducibility + last_image: Optional last frame for interpolation + output_type: "tensor", "np", or "pil" + """ + return self._generate( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + last_image=last_image, + output_type=output_type, + ) + + def _generate( + self, + prompt: str | list[str], + image: PIL.Image.Image, + negative_prompt: str | list[str] | None = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 40, + guidance_scale: float = 5.0, + generator: torch.Generator | None = None, + last_image: PIL.Image.Image | None = None, + output_type: str = "tensor", + ) -> DiffusionOutput: + # Ensure num_frames is compatible with VAE temporal scaling + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + # Default to empty negative prompt so CFG is not silently skipped + if negative_prompt is None and guidance_scale > 1.0: + negative_prompt = "" + + # Encode prompt + logger.info("Encoding prompts...") + prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt) + batch_size = prompt_embeds.shape[0] + + do_classifier_free_guidance = guidance_scale > 1.0 and negative_prompt_embeds is not None + + # Encode image with CLIP for additional conditioning + logger.info("Encoding image with CLIP...") + image_embeds = self.encode_image_clip(image) + image_embeds = image_embeds.repeat(batch_size, 1, 1).to(self.dtype) + + # Preprocess image for VAE + logger.info("Preprocessing image for VAE...") + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + image_tensor = video_processor.preprocess(image, height=height, width=width) + image_tensor = image_tensor.to(device=self.device, dtype=torch.float32) + + # Handle last_image if provided + last_image_tensor = None + if last_image is not None: + last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) + last_image_tensor = last_image_tensor.to(device=self.device, dtype=torch.float32) + + # Prepare latents + logger.info("Preparing latents...") + num_channels_latents = self.transformer.config.out_channels # 16 + + latents, condition, _ = self.prepare_latents( + image=image_tensor, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=torch.float32, + device=self.device, + generator=generator, + last_image=last_image_tensor, + ) + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.scheduler.timesteps + + logger.info("Starting denoising loop (%d steps)...", num_inference_steps) + for i, t in enumerate(timesteps): + self._current_timestep = t + + # Concatenate noise latents with condition [B, C, F, H, W] + [B, C+4, F, H, W] + latent_model_input = torch.cat([latents, condition], dim=1).to(self.dtype) + + if do_classifier_free_guidance: + # Batch conditional and unconditional in a single forward pass + latent_model_input = torch.cat([latent_model_input] * 2) + timestep = t.expand(latent_model_input.shape[0]) + prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds]) + image_embeds_input = torch.cat([image_embeds, image_embeds]) if image_embeds is not None else None + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds_input, + encoder_hidden_states_image=image_embeds_input, + return_dict=False, + )[0] + + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + timestep = t.expand(latent_model_input.shape[0]) + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0] + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if (i + 1) % 10 == 0: + logger.debug("Step %d/%d", i + 1, num_inference_steps) + + self._current_timestep = None + logger.info("Decoding latents...") + + # Decode latents + latents = latents.to(self.vae.dtype) + + # Denormalize + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + + video = self.vae.decode(latents, return_dict=False)[0] + + logger.info("Output shape: %s", video.shape) + logger.info("Output range: [%.3f, %.3f]", video.min().item(), video.max().item()) + + return DiffusionOutput(output=video) + + @staticmethod + def _convert_anisora_key(key: str) -> str: + """Convert a single AniSora weight key to diffusers naming convention.""" + new_key = key + + # Block-level: attention layer names + new_key = new_key.replace(".self_attn.", ".attn1.") + new_key = new_key.replace(".cross_attn.", ".attn2.") + + # Projection names within attention + if ".attn1." in new_key or ".attn2." in new_key: + new_key = new_key.replace(".k.", ".to_k.") + new_key = new_key.replace(".q.", ".to_q.") + new_key = new_key.replace(".v.", ".to_v.") + new_key = new_key.replace(".o.", ".to_out.0.") + + # Image key/value projections for cross-attention + new_key = new_key.replace(".k_img.", ".add_k_proj.") + new_key = new_key.replace(".v_img.", ".add_v_proj.") + new_key = new_key.replace(".norm_k_img.", ".norm_added_k.") + + # FFN: ffn.0 -> ffn.net.0.proj, ffn.2 -> ffn.net.2 + new_key = new_key.replace(".ffn.0.", ".ffn.net.0.proj.") + new_key = new_key.replace(".ffn.2.", ".ffn.net.2.") + + # Normalization: norm3 -> norm2 + new_key = new_key.replace(".norm3.", ".norm2.") + + # Modulation -> scale_shift_table + new_key = new_key.replace(".modulation", ".scale_shift_table") + + # Non-block (global) conversions + new_key = new_key.replace("text_embedding.0.", "condition_embedder.text_embedder.linear_1.") + new_key = new_key.replace("text_embedding.2.", "condition_embedder.text_embedder.linear_2.") + new_key = new_key.replace("time_embedding.0.", "condition_embedder.time_embedder.linear_1.") + new_key = new_key.replace("time_embedding.2.", "condition_embedder.time_embedder.linear_2.") + new_key = new_key.replace("time_projection.1.", "condition_embedder.time_proj.") + new_key = new_key.replace("head.head.", "proj_out.") + if new_key == "head.scale_shift_table": + new_key = "scale_shift_table" + + # Image embedder + new_key = new_key.replace("img_emb.proj.0.", "condition_embedder.image_embedder.norm1.") + new_key = new_key.replace("img_emb.proj.1.", "condition_embedder.image_embedder.ff.net.0.proj.") + new_key = new_key.replace("img_emb.proj.3.", "condition_embedder.image_embedder.ff.net.2.") + new_key = new_key.replace("img_emb.proj.4.", "condition_embedder.image_embedder.norm2.") + + return new_key + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with AniSora→diffusers key conversion. + + The framework feeds weights from weights_sources (prefixed with + 'transformer.'). We strip the prefix, convert AniSora key names + to diffusers format, then delegate to the transformer's load_weights() + which handles QKV fusion and TP sharding. + """ + prefix = "transformer." + + def _convert(weights_iter): + for name, tensor in weights_iter: + if name.startswith(prefix): + name = name[len(prefix) :] + yield self._convert_anisora_key(name), tensor + + loaded = {f"transformer.{n}" for n in self.transformer.load_weights(_convert(weights))} + logger.info("Loaded %d transformer weight entries", len(loaded)) + + # Record components already loaded via from_pretrained + loaded |= {f"text_encoder.{n}" for n, _ in self.text_encoder.named_parameters()} + loaded |= {f"vae.{n}" for n, _ in self.vae.named_parameters()} + loaded |= {f"image_encoder.{n}" for n, _ in self.image_encoder.named_parameters()} + return loaded + + +# Note: This module is intended to be used via the OmniDiffusion entrypoints. +# A standalone __main__ test harness is intentionally omitted to avoid +# suggesting incorrect instantiation patterns for AniSoraV2I2VPipeline. diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 37f5199447c..4d38a63b354 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -65,6 +65,18 @@ "pipeline_wan2_2_vace", "Wan22VACEPipeline", ), + # Index-AniSora V2/V3 (14B) - Wan2.1 architecture with hybrid loading + "AniSoraV2I2VPipeline": ( + "anisora", + "pipeline_anisora_v2_i2v", + "AniSoraV2I2VPipeline", + ), + # Index-AniSora V1 (5B) - CogVideoX-based I2V + "AniSoraI2VCogVideoXPipeline": ( + "anisora", + "pipeline_anisora_i2v_cogvideox", + "AniSoraI2VCogVideoXPipeline", + ), "LTX2Pipeline": ( "ltx2", "pipeline_ltx2", @@ -442,6 +454,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "MagiHumanPipeline": "get_magi_human_post_process_func", "OmniVoicePipeline": "get_omnivoice_post_process_func", "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func", + "AniSoraI2VCogVideoXPipeline": "get_anisora_i2v_post_process_func", + "AniSoraV2I2VPipeline": "get_anisora_v2_i2v_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { @@ -463,6 +477,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "HeliosPyramidPipeline": "get_helios_pre_process_func", "HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_pre_process_func", "MagiHumanPipeline": "get_magi_human_pre_process_func", + "AniSoraI2VCogVideoXPipeline": "get_anisora_i2v_pre_process_func", + "AniSoraV2I2VPipeline": "get_anisora_v2_i2v_pre_process_func", }