diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 03a9a57896c..ebebf554b69 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -33,6 +33,7 @@ th { | `ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanPipeline` | Wan2.1-T2V, Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | +| `Cosmos3OmniDiffusersPipeline` | Cosmos3 T2I, T2V, I2V, T2V with sound, action policy | `nvidia/Cosmos3-Nano` | ✅︎ | | | | | `WanSpeechToVideoPipeline` | Wan2.2-S2V | `Wan-AI/Wan2.2-S2V-14B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Wan22VACEPipeline` | Wan2.1-VACE | `Wan-AI/Wan2.1-VACE-1.3B-diffusers`, `Wan-AI/Wan2.1-VACE-14B-diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | | diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 9400c828e08..f87fd09ac1f 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -133,10 +133,12 @@ The following tables show which models support each feature: | **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ (decode) | ❌ | ❌ | | **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ❌ | ✅ | ❌ | ✅ (decode) | ✅ | ❌ | | **ERNIE-Image** | ❌ | ✅ | ✅ | ❓ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | +| **Cosmos3** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | > Notes: > 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT. > 2. `Tongyi-MAI/Z-Image-Turbo` and `SII-GAIR/daVinci-MagiHuman-Base-1080p` are distilled models with minimal NFEs; CFG-Parallel is not necessary. +> 3. Cosmos3 T2I uses `Cosmos3OmniDiffusersPipeline` with `modalities=["image"]`. Model-level CPU offload is not supported; use layerwise offload. ### VideoGen @@ -149,6 +151,8 @@ The following tables show which models support each feature: | **Helios** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | | **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | +| **Cosmos3** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (encode/decode) | ✅ | ❌ | + **Frame Interpolation Support** diff --git a/tests/diffusion/cache/test_cache_dit.py b/tests/diffusion/cache/test_cache_dit.py index 0fd07393d1b..146084a508c 100644 --- a/tests/diffusion/cache/test_cache_dit.py +++ b/tests/diffusion/cache/test_cache_dit.py @@ -25,6 +25,7 @@ cd_backend.enable_cache_for_helios, cd_backend.enable_cache_for_wan22, cd_backend.enable_cache_for_longcat_image, + cd_backend.enable_cache_for_cosmos3, ] SAMPLE_CACHE_CONFIG = DiffusionCacheConfig() @@ -47,6 +48,24 @@ def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler): assert adapter_kwargs["has_separate_cfg"] is True +@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter") +@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit") +def test_cosmos3_cache_dit_wraps_gen_layers(mock_cache_dit, mock_block_adapter): + """Cosmos3 should cache only the repeated GEN pathway blocks.""" + mock_pipeline = Mock() + gen_layers = object() + mock_pipeline.transformer.gen_layers = gen_layers + + cd_backend.enable_cache_for_cosmos3(mock_pipeline, SAMPLE_CACHE_CONFIG) + + mock_cache_dit.enable_cache.assert_called_once() + adapter_kwargs = mock_block_adapter.call_args.kwargs + assert adapter_kwargs["transformer"] is mock_pipeline.transformer + assert adapter_kwargs["blocks"] == [gen_layers] + assert adapter_kwargs["has_separate_cfg"] is True + assert adapter_kwargs["check_forward_pattern"] is False + + # This test is skipped on ROCm since rocm_unquantized_gemm doesn't support CPU backend @pytest.mark.skipif( current_omni_platform.is_rocm(), diff --git a/tests/diffusion/models/cosmos3/__init__.py b/tests/diffusion/models/cosmos3/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/diffusion/models/cosmos3/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/models/cosmos3/conftest.py b/tests/diffusion/models/cosmos3/conftest.py new file mode 100644 index 00000000000..80a7105d2ca --- /dev/null +++ b/tests/diffusion/models/cosmos3/conftest.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +from torch import nn + + +class StubScheduler: + def __init__(self, timesteps: list[int] | None = None, *, flow_shift: float = 1.0) -> None: + self.timesteps = torch.tensor(timesteps or [9, 3], dtype=torch.int64) + self.config = SimpleNamespace(num_train_timesteps=1000, flow_shift=flow_shift) + self.set_timesteps_calls: list[tuple[int, torch.device]] = [] + self.step_calls: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] + + def set_timesteps(self, num_steps: int, device: torch.device) -> None: + self.set_timesteps_calls.append((num_steps, device)) + self.timesteps = torch.arange(num_steps, 0, -1, dtype=torch.int64, device=device) + + def step(self, noise_pred: torch.Tensor, timestep: torch.Tensor, latents: torch.Tensor, **kwargs): + del kwargs + self.step_calls.append((noise_pred.clone(), timestep.clone(), latents.clone())) + return (latents + noise_pred,) + + +class _ModeLatentDist: + def __init__(self, latents: torch.Tensor) -> None: + self._latents = latents + + def mode(self) -> torch.Tensor: + return self._latents + + +class StubCosmos3VAE: + dtype = torch.float32 + + def __init__(self, z_dim: int = 2, *, temporal: int = 4, spatial: int = 8) -> None: + self.config = SimpleNamespace( + z_dim=z_dim, + scale_factor_temporal=temporal, + scale_factor_spatial=spatial, + latents_mean=[0.0] * z_dim, + latents_std=[1.0] * z_dim, + ) + + def encode(self, video: torch.Tensor): + latent_frames = (video.shape[2] - 1) // self.config.scale_factor_temporal + 1 + latent_height = video.shape[-2] // self.config.scale_factor_spatial + latent_width = video.shape[-1] // self.config.scale_factor_spatial + latents = torch.ones( + video.shape[0], + self.config.z_dim, + latent_frames, + latent_height, + latent_width, + dtype=video.dtype, + device=video.device, + ) + return SimpleNamespace(latent_dist=_ModeLatentDist(latents)) + + def decode(self, latents: torch.Tensor, return_dict: bool = False): + del return_dict + return (latents,) + + +class StubCosmos3Transformer(nn.Module): + def __init__( + self, + *, + latent_channel_size: int = 2, + sound_gen: bool = False, + sound_dim: int = 3, + action_gen: bool = False, + action_dim: int = 4, + ) -> None: + super().__init__() + self.latent_channel_size = latent_channel_size + self.sound_gen = sound_gen + self.sound_dim = sound_dim + self.action_gen = action_gen + self.action_dim = action_dim + self.cached_kv: Any | None = None + self.cached_freqs_gen: Any | None = None + self.calls: list[dict[str, Any]] = [] + self.reset_calls = 0 + + def reset_cache(self) -> None: + self.reset_calls += 1 + self.cached_kv = None + self.cached_freqs_gen = None + + def forward( + self, + *, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + text_ids: torch.Tensor, + text_mask: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + token = int(text_ids.reshape(-1)[0].item()) if text_ids.numel() else 0 + sound_latents = kwargs.get("sound_latents") + self.calls.append( + { + "token": token, + "timestep": timestep.clone(), + "text_mask": text_mask.clone(), + "cache_before": self.cached_kv, + "kwargs": dict(kwargs), + } + ) + if self.cached_kv is None: + marker = torch.tensor([token], dtype=torch.float32) + self.cached_kv = [(marker, marker + 100)] + self.cached_freqs_gen = (marker + 200, marker + 300) + action_latents = kwargs.get("action_latents") + outputs: list[torch.Tensor] = [torch.full_like(hidden_states, float(token))] + if action_latents is not None: + outputs.append(torch.full_like(action_latents, float(token + 20))) + if sound_latents is not None: + outputs.append(torch.full_like(sound_latents, float(token + 10))) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + +def passthrough_progress_bar(iterable): + return iterable + + +@pytest.fixture(autouse=True) +def fake_cosmos3_guardrails(monkeypatch: pytest.MonkeyPatch): + module = types.ModuleType("vllm_omni.diffusion.models.cosmos3.guardrails") + module.is_guardrails_enabled = lambda od_config, sampling_params=None: False + module.ensure_initialized = lambda od_config: None + module.check_text_safety = lambda text: None + module.check_video_safety = lambda video: video + monkeypatch.setitem(sys.modules, module.__name__, module) + return module + + +@pytest.fixture +def make_cosmos3_pipeline(): + def _make(): + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + pipeline = object.__new__(Cosmos3OmniDiffusersPipeline) + nn.Module.__init__(pipeline) + pipeline.od_config = SimpleNamespace() + pipeline.device = torch.device("cpu") + pipeline.dtype = torch.float32 + pipeline.transformer = StubCosmos3Transformer(latent_channel_size=2) + pipeline.vae = StubCosmos3VAE(z_dim=2) + pipeline.vae_scale_factor_temporal = 4 + pipeline.vae_scale_factor_spatial = 8 + pipeline.scheduler = StubScheduler([9, 3], flow_shift=1.0) + pipeline._base_scheduler_config = pipeline.scheduler.config + pipeline._engine_init_flow_shift = 1.0 + pipeline._current_flow_shift = 1.0 + pipeline._guidance_scale = None + pipeline._num_timesteps = None + pipeline.progress_bar = passthrough_progress_bar + pipeline._sound_tokenizer = None + return pipeline + + return _make + + +def make_sampling_params(**overrides: Any) -> SimpleNamespace: + values = { + "height": None, + "width": None, + "num_frames": None, + "num_inference_steps": None, + "guidance_scale": None, + "generator": None, + "seed": 123, + "num_outputs_per_prompt": 1, + "frame_rate": None, + "resolved_frame_rate": None, + "max_sequence_length": None, + "extra_args": {}, + } + values.update(overrides) + return SimpleNamespace(**values) diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py b/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py new file mode 100644 index 00000000000..52d47f8a2ed --- /dev/null +++ b/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from PIL import Image +from torch import nn + +from tests.diffusion.models.cosmos3.conftest import make_sampling_params + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def _ids(value: int) -> torch.Tensor: + return torch.tensor([[value]], dtype=torch.long) + + +def _mask() -> torch.Tensor: + return torch.ones(1, 1, dtype=torch.long) + + +def test_pipeline_registered_and_exported() -> None: + from vllm_omni.diffusion.cache.cache_dit_backend import CUSTOM_DIT_ENABLERS + from vllm_omni.diffusion.models import cosmos3 + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import Cosmos3OmniDiffusersPipeline + from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin + from vllm_omni.diffusion.registry import ( + _DIFFUSION_MODELS, + _DIFFUSION_POST_PROCESS_FUNCS, + _DIFFUSION_PRE_PROCESS_FUNCS, + ) + + assert issubclass(Cosmos3OmniDiffusersPipeline, nn.Module) + assert issubclass(Cosmos3OmniDiffusersPipeline, ProgressBarMixin) + assert Cosmos3OmniDiffusersPipeline.support_image_input is True + assert _DIFFUSION_MODELS["Cosmos3OmniDiffusersPipeline"] == ( + "cosmos3", + "pipeline_cosmos3", + "Cosmos3OmniDiffusersPipeline", + ) + assert _DIFFUSION_PRE_PROCESS_FUNCS["Cosmos3OmniDiffusersPipeline"] == "get_cosmos3_pre_process_func" + assert _DIFFUSION_POST_PROCESS_FUNCS["Cosmos3OmniDiffusersPipeline"] == "get_cosmos3_post_process_func" + assert "Cosmos3OmniDiffusersPipeline" in CUSTOM_DIT_ENABLERS + assert "Cosmos3OmniDiffusersPipeline" in cosmos3.__all__ + + +def test_preprocess_i2v_image_and_action_video_inputs() -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_pre_process_func + + preprocess = get_cosmos3_pre_process_func(SimpleNamespace()) + i2v = SimpleNamespace( + prompts=[{"prompt": "A slow camera push.", "multi_modal_data": {"image": Image.new("RGB", (320, 160))}}], + sampling_params=SimpleNamespace(height=None, width=None, extra_args={}), + ) + + result = preprocess(i2v) + assert (result.sampling_params.height, result.sampling_params.width) == (672, 1344) + assert tuple(result.prompts[0]["additional_information"]["preprocessed_image"].shape[-2:]) == (672, 1344) + + frames = [Image.new("RGB", (8, 4), color) for color in ("red", "green", "blue")] + action = SimpleNamespace( + prompts=[{"prompt": "Move.", "multi_modal_data": {"video": frames}}], + sampling_params=SimpleNamespace(height=16, width=32, extra_args={"action_mode": "forward_dynamics"}), + ) + + additional = preprocess(action).prompts[0]["additional_information"] + assert tuple(additional["preprocessed_image"].shape) == (1, 3, 16, 32) + assert tuple(additional["preprocessed_video"].shape) == (1, 3, 3, 16, 32) + + +def test_postprocess_handles_image_video_audio_and_validation() -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_post_process_func + + func = get_cosmos3_post_process_func(SimpleNamespace()) + video = torch.zeros(1, 3, 1, 4, 4) + + assert func(video, output_type="latent") is video + assert func({"image": video})[0].size == (4, 4) + assert "video" in func({"video": video}) + assert ( + func( + {"video": video, "audio": torch.ones(1, 2, 16), "audio_sample_rate": 48000}, + sampling_params=SimpleNamespace(extra_args={"resolved_frame_rate": 12}), + )["audio_sample_rate"] + == 48000 + ) + + with pytest.raises(ValueError, match="text-to-image postprocess expects"): + func({"image": torch.zeros(1, 3, 2, 4, 4)}) + with pytest.raises(ValueError, match="both image and video"): + func({"image": video, "video": video}) + + +def test_prompt_formatting_and_checkpoint_key_remap(make_cosmos3_pipeline) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import Cosmos3OmniDiffusersPipeline + + pipeline = make_cosmos3_pipeline() + captured: list[str] = [] + pipeline._tokenize_prompt = lambda text, *args, **kwargs: (captured.append(text) or _ids(len(captured)), _mask()) + + pipeline._format_and_tokenize_prompts( + "A robot", + "bad", + num_frames=48, + frame_rate=24, + height=720, + width=1280, + max_sequence_length=32, + sp=SimpleNamespace(extra_args={"negative_metadata_mode": "inverse"}), + use_system_prompt=True, + is_t2i=False, + ) + assert "The video is 2.0 seconds long" in captured[0] + assert "The video is not 2.0 seconds long" in captured[1] + + remaps = { + "embed_tokens.weight": "transformer.language_model.embed_tokens.weight", + "model.embed_tokens.weight": "transformer.language_model.embed_tokens.weight", + "norm.weight": "transformer.language_model.norm.weight", + "norm_moe_gen.weight": "transformer.norm_moe_gen.weight", + "proj_in.weight": "transformer.proj_in.weight", + "proj_out.bias": "transformer.proj_out.bias", + "layers.3.self_attn.to_q.weight": "transformer.language_model.layers.3.self_attn.to_q.weight", + "layers.3.self_attn.to_out.weight": "transformer.language_model.layers.3.self_attn.to_out.weight", + "layers.3.self_attn.norm_q.weight": "transformer.language_model.layers.3.self_attn.norm_q.weight", + "layers.3.self_attn.add_q_proj.weight": "transformer.gen_layers.3.cross_attention.to_q.weight", + "layers.3.self_attn.to_add_out.weight": "transformer.gen_layers.3.cross_attention.to_out.weight", + "layers.3.self_attn.norm_added_q.weight": "transformer.gen_layers.3.cross_attention.norm_q.weight", + "transformer.model.layers.3.self_attn.add_k_proj.weight": ( + "transformer.gen_layers.3.cross_attention.to_k.weight" + ), + } + assert {key: Cosmos3OmniDiffusersPipeline._remap_ckpt_key(key) for key in remaps} == remaps + + +def test_prepare_latents_for_video_image_sound_and_action(make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + latents = pipeline._prepare_latents(16, 24, 5, torch.Generator(device="cpu").manual_seed(0)) + assert latents.shape == (1, 2, 2, 2, 3) + + pipeline._encode_conditioning_video = lambda *args, **kwargs: torch.full((1, 2, 2, 2, 3), 5.0) + i2v_latents, velocity_mask, image_latent = pipeline._prepare_latents_i2v( + torch.zeros(1, 3, 16, 24), 16, 24, 5, torch.Generator(device="cpu").manual_seed(0) + ) + torch.testing.assert_close(i2v_latents[:, :, 0], torch.full((1, 2, 2, 3), 5.0)) + assert velocity_mask.tolist() == [[[[[0.0]], [[1.0]]]]] + assert image_latent.shape == (1, 2, 1, 2, 3) + + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + pipeline._sound_tokenizer = SimpleNamespace( + sample_rate=10, + latent_ch=3, + hop_size=4, + decode=lambda x: torch.ones(x.shape[0], 2, 24), + ) + assert pipeline._resolve_sound_target_samples(SimpleNamespace(extra_args={"sound_duration": 2.0}), 9, 3.0) == ( + 20, + 2.0, + 10, + ) + sound_latents, latent_frames = pipeline._prepare_sound_latents(21, torch.Generator(device="cpu").manual_seed(0)) + assert (sound_latents.shape, latent_frames) == (torch.Size([1, 3, 6]), 6) + assert pipeline._decode_sound_latents(torch.zeros(1, 3, 6), target_audio_samples=21).shape == (1, 2, 21) + + pipeline.transformer = pipeline.transformer.__class__(action_gen=True, action_dim=4) + action, action_mask, clean, raw_dim = pipeline._prepare_action_latents( + mode="forward_dynamics", + action_chunk_size=2, + raw_action_dim=None, + generator=torch.Generator(device="cpu").manual_seed(0), + sp=SimpleNamespace(extra_args={"action": [[1.0, 2.0], [3.0, 4.0]]}), + ) + assert raw_dim == 2 + assert action_mask.tolist() == [[[0.0], [0.0]]] + torch.testing.assert_close(action, clean) + + +def test_diffuse_covers_cfg_i2v_and_multimodal_steps(make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + latents = torch.zeros(1, 2, 1, 1, 1) + + result = pipeline.diffuse( + latents=latents, + timesteps=torch.tensor([900, 100]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=3.0, + shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0}, + guidance_interval=(500.0, 1000.0), + ) + assert [call["token"] for call in pipeline.transformer.calls] == [2, 1, 2] + torch.testing.assert_close(result, torch.full_like(latents, 6.0)) + + i2v = pipeline.diffuse( + latents=torch.zeros(1, 2, 2, 1, 1), + timesteps=torch.tensor([7]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=1.0, + shared_kwargs={"video_shape": (2, 1, 1), "fps": 24.0}, + velocity_mask=torch.tensor([[[[[0.0]], [[1.0]]]]]), + image_latent=torch.full((1, 2, 1, 1, 1), 7.0), + ) + torch.testing.assert_close(i2v[:, :, 0:1], torch.full((1, 2, 1, 1, 1), 7.0)) + + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + video_result, action_result = pipeline.diffuse( + latents=latents, + action_latents=torch.zeros(1, 3, 4), + action_velocity_mask=torch.ones(1, 3, 1), + action_condition_latents=torch.zeros(1, 3, 4), + timesteps=torch.tensor([7, 3]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=1.0, + shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0, "action_domain_ids": torch.tensor([0])}, + ) + torch.testing.assert_close(video_result, torch.full_like(latents, 4.0)) + torch.testing.assert_close(action_result, torch.full((), 44.0).expand_as(action_result)) + + +class TestForwardRouting: + def _install_forward_stubs(self, pipeline): + captured: dict[str, object] = {"diffuse_calls": [], "prepare_calls": []} + + def fake_format(prompt, negative_prompt, num_frames, frame_rate, height, width, *args, **kwargs): + captured["format"] = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "frame_rate": frame_rate, + "height": height, + "width": width, + "is_t2i": kwargs["is_t2i"], + } + return _ids(2), _mask(), _ids(1), _mask() + + def fake_prepare(height, width, num_frames, generator): + captured["prepare_calls"].append((height, width, num_frames, generator.initial_seed())) + return torch.zeros(1, 2, 1, 1, 1) + + def fake_diffuse(**kwargs): + captured["diffuse_calls"].append(kwargs) + outputs = [kwargs["latents"] + len(captured["diffuse_calls"])] + if kwargs.get("action_latents") is not None: + outputs.append(kwargs["action_latents"] + 3.0) + if kwargs.get("sound_latents") is not None: + outputs.append(kwargs["sound_latents"] + 2.0) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + pipeline._format_and_tokenize_prompts = fake_format + pipeline._prepare_latents = fake_prepare + pipeline._set_flow_shift = lambda target: captured.setdefault("flow_shifts", []).append(target) + + def fake_set_scheduler_timesteps(steps): + captured.setdefault("scheduler_steps", []).append(steps) + pipeline.scheduler.timesteps = torch.tensor([7]) + + pipeline._set_scheduler_timesteps = fake_set_scheduler_timesteps + pipeline.diffuse = fake_diffuse + pipeline._decode_latents = lambda latents: latents + return captured + + @pytest.mark.parametrize( + ("prompt", "sampling_params", "expected"), + [ + ( + {"prompt": "A painted robot", "modalities": ["image"]}, + make_sampling_params(num_outputs_per_prompt=2), + {"key": "image", "is_t2i": True, "flow": [3.0], "steps": [50, 50], "frames": 1}, + ), + ( + "A warehouse robot", + make_sampling_params(), + {"key": "video", "is_t2i": False, "flow": [1.0], "steps": [35], "frames": 189}, + ), + ], + ) + def test_forward_defaults_and_mode_selection( + self, + make_cosmos3_pipeline, + prompt, + sampling_params, + expected, + ) -> None: + pipeline = make_cosmos3_pipeline() + captured = self._install_forward_stubs(pipeline) + + output = pipeline.forward(SimpleNamespace(prompts=[prompt], sampling_params=sampling_params)) + + assert expected["key"] in output.output + assert captured["format"]["is_t2i"] is expected["is_t2i"] + assert captured["format"]["num_frames"] == expected["frames"] + assert captured["flow_shifts"] == expected["flow"] + assert captured["scheduler_steps"] == expected["steps"] + + def test_forward_i2v_sound_and_action_routes(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + captured = self._install_forward_stubs(pipeline) + image_tensor = torch.zeros(1, 3, 16, 16) + velocity_mask = torch.ones(1, 1, 1, 1, 1) + + pipeline._prepare_latents_i2v = lambda *args, **kwargs: ( + torch.zeros(1, 2, 1, 1, 1), + velocity_mask, + torch.zeros(1, 2, 1, 1, 1), + ) + pipeline.forward( + SimpleNamespace( + prompts=[ + { + "prompt": "move", + "modalities": ["video"], + "additional_information": {"preprocessed_image": image_tensor}, + } + ], + sampling_params=make_sampling_params(height=16, width=16, num_frames=5), + ) + ) + assert captured["diffuse_calls"][-1]["shared_kwargs"]["noisy_frame_mask"] is velocity_mask + + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + sound_latents = torch.zeros(1, 3, 4) + pipeline._resolve_sound_target_samples = lambda *args: (20, 2.0, 10) + pipeline._prepare_sound_latents = lambda *args: (sound_latents, 4) + pipeline._decode_sound_latents = lambda *args: torch.ones(1, 2, 20) + output = pipeline.forward( + SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["video"], "generate_sound": True}], + sampling_params=make_sampling_params(num_frames=9, frame_rate=3.0), + ) + ) + assert captured["diffuse_calls"][-1]["sound_latents"] is sound_latents + assert output.output["audio_sample_rate"] == 10 + + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + output = pipeline.forward( + SimpleNamespace( + prompts=[ + { + "prompt": "Pick the block.", + "modalities": ["video"], + "additional_information": {"preprocessed_image": image_tensor}, + } + ], + sampling_params=make_sampling_params( + height=16, + width=16, + extra_args={ + "action_mode": "policy", + "action_chunk_size": 2, + "raw_action_dim": 2, + "domain_name": "bridge_orig_lerobot", + }, + ), + ) + ) + assert captured["diffuse_calls"][-1]["shared_kwargs"]["action_domain_ids"].tolist() == [7] + assert output.custom_output["action"].shape == (1, 2, 2) + + @pytest.mark.parametrize( + ("prompt", "sampling_params", "message"), + [ + (["one", "two"], make_sampling_params(), "single prompt"), + ([{"prompt": "one", "modalities": ["image", "video"]}], make_sampling_params(), "both image and video"), + ( + [{"prompt": "x", "modalities": ["image"], "generate_sound": True}], + make_sampling_params(), + "only for video", + ), + ], + ) + def test_forward_rejects_invalid_public_requests( + self, + make_cosmos3_pipeline, + prompt, + sampling_params, + message, + ) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + + with pytest.raises(ValueError, match=message): + pipeline.forward(SimpleNamespace(prompts=prompt, sampling_params=sampling_params)) diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_sound_tokenizer.py b/tests/diffusion/models/cosmos3/test_cosmos3_sound_tokenizer.py new file mode 100644 index 00000000000..47664c59e77 --- /dev/null +++ b/tests/diffusion/models/cosmos3/test_cosmos3_sound_tokenizer.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + +DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME = "diffusion_pytorch_model.safetensors" + + +class _FakeAVAEAudioTokenizer: + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.sample_rate = int(kwargs["sample_rate"]) + self.audio_channels = int(kwargs["audio_channels"]) + self.latent_ch = int(kwargs["io_channels"]) + self.temporal_compression_factor = int(kwargs["hop_size"]) + + def get_latent_num_samples(self, num_audio_samples: int) -> int: + return int(num_audio_samples) // self.temporal_compression_factor + + def get_audio_num_samples(self, num_latent_samples: int) -> int: + return int(num_latent_samples) * self.temporal_compression_factor + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + return torch.zeros(latents.shape[0], self.audio_channels, 8) + + +def _write_component(root: Path, config: dict | None = None, checkpoint_name: str | None = None) -> Path: + tokenizer_dir = root / "sound_tokenizer" + tokenizer_dir.mkdir(parents=True) + if checkpoint_name: + (tokenizer_dir / checkpoint_name).write_bytes(b"stub") + (tokenizer_dir / "config.json").write_text(json.dumps(config or {}), encoding="utf-8") + return tokenizer_dir + + +def _patch_fake_avae(monkeypatch: pytest.MonkeyPatch, created: dict) -> None: + from vllm_omni.diffusion.models.cosmos3 import sound_tokenizer + + class FakeAVAE(_FakeAVAEAudioTokenizer): + def __init__(self, **kwargs) -> None: + created.update(kwargs) + super().__init__(**kwargs) + + monkeypatch.setattr(sound_tokenizer, "Cosmos3AVAEAudioTokenizer", FakeAVAE) + monkeypatch.setattr(sound_tokenizer, "get_local_device", lambda: torch.device("cpu")) + + +def test_from_config_loads_local_diffusers_component(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + from vllm_omni.diffusion.models.cosmos3 import sound_tokenizer + + model_dir = tmp_path / "model" + tokenizer_dir = _write_component(model_dir, checkpoint_name=DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME) + created = {} + _patch_fake_avae(monkeypatch, created) + + tokenizer = sound_tokenizer.Cosmos3SoundTokenizer.from_config( + SimpleNamespace( + model=str(model_dir), + custom_pipeline_args={"sound_sample_rate": 32000, "sound_hop_size": 800, "sound_dim": 3}, + dtype=torch.float32, + ) + ) + + assert created["checkpoint_path"] == str(tokenizer_dir / DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME) + assert created["config_path"] == str(tokenizer_dir / "config.json") + assert (tokenizer.sample_rate, tokenizer.latent_ch, tokenizer.hop_size) == (32000, 3, 800) + + +def test_from_config_downloads_component_from_hf_repo(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + import huggingface_hub + + from vllm_omni.diffusion.models.cosmos3 import sound_tokenizer + + cache_dir = tmp_path / "hf" + _write_component(cache_dir, checkpoint_name=DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME) + calls = [] + created = {} + _patch_fake_avae(monkeypatch, created) + + def fake_snapshot_download(repo_id: str, *, revision: str | None, allow_patterns: list[str]) -> str: + calls.append((repo_id, revision, allow_patterns)) + return str(cache_dir) + + monkeypatch.setattr(huggingface_hub, "snapshot_download", fake_snapshot_download) + + sound_tokenizer.Cosmos3SoundTokenizer.from_config( + SimpleNamespace( + model="nvidia/cosmos3", + revision="test-rev", + custom_pipeline_args={"sound_sample_rate": 32000, "sound_hop_size": 800, "sound_dim": 3}, + dtype=torch.float32, + ) + ) + + assert created["checkpoint_path"].endswith(DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME) + assert calls == [ + ( + "nvidia/cosmos3", + "test-rev", + ["sound_tokenizer/config.json", f"sound_tokenizer/{DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME}"], + ) + ] + + +@pytest.mark.parametrize( + ("checkpoint_name", "message"), + [ + (None, "no AVAE sound tokenizer checkpoint"), + ("model.safetensors", DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME), + ], +) +def test_default_component_requires_diffusers_checkpoint_name(tmp_path, checkpoint_name, message) -> None: + from vllm_omni.diffusion.models.cosmos3 import sound_tokenizer + + model_dir = tmp_path / "model" + _write_component(model_dir, checkpoint_name=checkpoint_name) + + with pytest.raises(ValueError, match=message): + sound_tokenizer.Cosmos3SoundTokenizer.from_config( + SimpleNamespace(model=str(model_dir), custom_pipeline_args={}, dtype=torch.float32) + ) + + +def test_component_config_precedence_and_conflict_detection(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + from vllm_omni.diffusion.models.cosmos3 import sound_tokenizer + + component_config = { + "sampling_rate": 48000, + "dec_out_channels": 2, + "vocoder_input_dim": 64, + "hop_size": 1920, + } + model_dir = tmp_path / "model" + _write_component(model_dir, component_config, DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME) + created = {} + _patch_fake_avae(monkeypatch, created) + + tokenizer = sound_tokenizer.Cosmos3SoundTokenizer.from_config( + SimpleNamespace( + model=str(model_dir), + custom_pipeline_args={ + "sound_normalize_latents": True, + "sound_normalization_type": "tanh", + "sound_tanh_input_scale": 2.0, + }, + model_config={ + "sound_tokenizer": { + "sample_rate": 32000, + "audio_channels": 1, + "io_channels": 3, + "hop_size": 800, + "normalize_latents": False, + "normalization_type": "none", + } + }, + dtype=torch.float32, + ) + ) + + assert (created["sample_rate"], created["audio_channels"], created["io_channels"], created["hop_size"]) == ( + 48000, + 2, + 64, + 1920, + ) + assert (created["normalize_latents"], created["normalization_type"], created["tanh_input_scale"]) == ( + True, + "tanh", + 2.0, + ) + assert (tokenizer.sample_rate, tokenizer.latent_ch, tokenizer.hop_size) == (48000, 64, 1920) + + with pytest.raises(ValueError, match=r"sample_rate.*48000.*32000"): + sound_tokenizer.Cosmos3SoundTokenizer.from_config( + SimpleNamespace( + model=str(model_dir), + custom_pipeline_args={"sound_sample_rate": 32000}, + dtype=torch.float32, + ) + ) + + +def test_avae_uses_diffusers_decoder_state_dict_layout(tmp_path) -> None: + from safetensors.torch import save_file + + from vllm_omni.diffusion.models.cosmos3.audio_tokenizer import avae + + config = { + "sampling_rate": 8000, + "hop_size": 2, + "dec_dim": 4, + "dec_c_mults": [1], + "dec_strides": [2], + "dec_out_channels": 1, + "vocoder_input_dim": 2, + "normalization_type": "none", + } + checkpoint_path = tmp_path / DIFFUSERS_SOUND_TOKENIZER_CHECKPOINT_NAME + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config), encoding="utf-8") + + decoder = avae.OobleckDecoder(4, 2, 1, [2], [1]) + save_file({f"decoder.{key}": value for key, value in decoder.state_dict().items()}, str(checkpoint_path)) + + tokenizer = avae.Cosmos3AVAEAudioTokenizer( + checkpoint_path=checkpoint_path, + config_path=config_path, + dtype=torch.float32, + device="cpu", + ) + + keys = set(tokenizer.state_dict()) + assert {"decoder.conv1.weight_g", "decoder.block.0.conv_t1.weight_g", "decoder.conv2.weight_g"} <= keys + assert not any(key.startswith(("decoder.layers.", "model.decoder.")) for key in keys) + assert tokenizer.decode(torch.zeros(1, 2, 3)).shape == (1, 1, 6) + with pytest.raises(NotImplementedError, match="decoder-only"): + tokenizer.encode(torch.zeros(1, 1, 6)) diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py b/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py new file mode 100644 index 00000000000..d2f22b81760 --- /dev/null +++ b/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def _tiny_cosmos3_config(**overrides): + config = { + "hidden_size": 8, + "num_hidden_layers": 0, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 4, + "intermediate_size": 16, + "vocab_size": 32, + "latent_patch_size": 1, + "latent_channel": 2, + "rope_scaling": {"mrope_section": [1, 1, 0]}, + } + config.update(overrides) + return config + + +def test_mrope_position_ids_cover_text_video_sound_and_action() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_action, + compute_mrope_position_ids_sound, + compute_mrope_position_ids_text, + compute_mrope_position_ids_vision, + ) + + text_ids, text_offset = compute_mrope_position_ids_text(num_tokens=3, temporal_offset=5) + assert text_ids.tolist() == [[5, 6, 7], [5, 6, 7], [5, 6, 7]] + assert text_offset == 8 + + vision_ids, vision_offset = compute_mrope_position_ids_vision(2, 2, 3, temporal_offset=10, fps=None) + assert vision_ids.shape == (3, 12) + assert vision_ids[0].tolist() == [10] * 6 + [11] * 6 + assert vision_offset == 12 + + modulated_ids, modulated_offset = compute_mrope_position_ids_vision( + 2, + 1, + 1, + temporal_offset=10, + fps=12.0, + base_fps=24.0, + temporal_compression_factor=4, + ) + torch.testing.assert_close(modulated_ids[0], torch.tensor([10.0, 12.0])) + assert modulated_offset == 13 + + sound_ids, sound_offset = compute_mrope_position_ids_sound(3, temporal_offset=10, sound_latent_fps=25.0) + torch.testing.assert_close(sound_ids[0], torch.tensor([10.0, 10.96, 11.92])) + assert sound_offset == 12 + + action_ids, action_offset = compute_mrope_position_ids_action(3, temporal_offset=10, action_fps=None) + assert action_ids.tolist() == [[11, 12, 13], [0, 0, 0], [0, 0, 0]] + assert action_offset == 14 + + +@pytest.mark.parametrize( + ("key", "value"), + [ + ("qk_norm_for_diffusion", False), + ("qk_norm_for_text", False), + ("position_embedding_type", "rotary"), + ("unified_3d_mrope_reset_spatial_ids", False), + ("joint_attn_implementation", "one_way"), + ], +) +def test_validate_supported_config_rejects_unsupported_flags(key: str, value) -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + with pytest.raises(ValueError, match=f"{key}="): + Cosmos3VFMTransformer._validate_supported_config({key: value}) + Cosmos3VFMTransformer._validate_supported_config({}) + Cosmos3VFMTransformer._validate_supported_config(None) + + +def test_transformer_sharding_offload_and_patch_round_trip_contracts() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = nn.Module() + model.language_model.layers = nn.ModuleList([nn.Linear(2, 2) for _ in range(2)]) + model.gen_layers = nn.ModuleList([nn.Linear(2, 2)]) + model.norm_moe_gen = nn.LayerNorm(2) + + matched = [ + name + for name, module in model.named_modules() + if any(condition(name, module) for condition in model._hsdp_shard_conditions) + ] + assert matched == ["language_model.layers.0", "language_model.layers.1", "gen_layers.0"] + assert Cosmos3VFMTransformer._layerwise_offload_blocks_attrs == ["gen_layers"] + assert Cosmos3VFMTransformer._repeated_blocks == ["Cosmos3GenDecoderLayer"] + + model.latent_patch_size = 2 + model.latent_channel_size = 3 + latents = torch.arange(1 * 3 * 1 * 3 * 5, dtype=torch.float32).reshape(1, 3, 1, 3, 5) + torch.testing.assert_close(model.unpatchify(model.patchify(latents, t=1, h=3, w=5), t=1, h=3, w=5), latents) + + +def test_forward_returns_video_prediction(monkeypatch: pytest.MonkeyPatch) -> None: + from vllm_omni.diffusion.models.cosmos3 import transformer_cosmos3 + + monkeypatch.setattr(transformer_cosmos3, "_get_ulysses_state", lambda: (1, 0, None)) + + output = transformer_cosmos3.Cosmos3VFMTransformer( + SimpleNamespace(tf_model_config=_tiny_cosmos3_config(), dtype=torch.float32) + )( + hidden_states=torch.zeros(1, 2, 1, 2, 2), + timestep=torch.tensor([1.0]), + text_ids=torch.tensor([[1, 2]], dtype=torch.long), + text_mask=torch.ones(1, 2, dtype=torch.long), + video_shape=(1, 2, 2), + fps=24.0, + sound_latents=torch.zeros(1, 3, 4), + ) + + assert tuple(output.shape) == (1, 2, 1, 2, 2) + + +def test_sound_and_action_modules_follow_config() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + tiny = _tiny_cosmos3_config() + no_modal = Cosmos3VFMTransformer(SimpleNamespace(tf_model_config=tiny, dtype=torch.float32)) + with_sound = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "sound_gen": True}, + model_config={"sound_tokenizer": {"io_channels": 5, "sample_rate": 32000, "hop_size": 800}}, + custom_pipeline_args={}, + dtype=torch.float32, + ) + ) + with_action = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "action_gen": True, "max_action_dim": 6, "num_embodiment_domains": 9}, + dtype=torch.float32, + ) + ) + + assert no_modal.sound_gen is False + assert no_modal.action_gen is False + assert not hasattr(no_modal, "audio_proj_in") + assert not hasattr(no_modal, "action_proj_in") + assert with_sound.sound_dim == 5 + assert with_sound.sound_latent_fps == 40.0 + assert with_sound.audio_proj_in.in_features == 5 + assert with_action.action_dim == 6 + assert with_action.action_proj_in.num_domains == 9 + + +def test_sound_and_action_pack_unpack_validate_shapes() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.sound_dim = 3 + model.action_dim = 3 + + sound = torch.arange(2 * 3 * 4, dtype=torch.float32).reshape(2, 3, 4) + action = torch.arange(2 * 4 * 3, dtype=torch.float32).reshape(2, 4, 3) + torch.testing.assert_close(model.unpack_sound(model.pack_sound(sound)), sound) + torch.testing.assert_close(model.unpack_action(model.pack_action(action)), action) + + with pytest.raises(ValueError, match="channel mismatch"): + model.pack_sound(torch.zeros(1, 4, 2)) + with pytest.raises(ValueError, match="dimension mismatch"): + model.pack_action(torch.zeros(1, 2, 4)) + + +@pytest.mark.parametrize( + ("config", "extra_kwargs", "expected_shapes"), + [ + ( + _tiny_cosmos3_config(sound_gen=True, sound_dim=3, sound_latent_fps=24.0), + {"sound_latents": torch.zeros(1, 3, 4)}, + [(1, 2, 1, 2, 2), (1, 3, 4)], + ), + ( + _tiny_cosmos3_config(action_gen=True, max_action_dim=3, num_embodiment_domains=4), + {"action_latents": torch.zeros(1, 5, 3), "action_domain_ids": torch.tensor([2])}, + [(1, 2, 1, 2, 2), (1, 5, 3)], + ), + ], +) +def test_forward_returns_video_plus_optional_modality_predictions(config, extra_kwargs, expected_shapes) -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + output = Cosmos3VFMTransformer(SimpleNamespace(tf_model_config=config, dtype=torch.float32))( + hidden_states=torch.zeros(1, 2, 1, 2, 2), + timestep=torch.tensor([1.0]), + text_ids=torch.tensor([[1, 2]], dtype=torch.long), + text_mask=torch.ones(1, 2, dtype=torch.long), + video_shape=(1, 2, 2), + fps=24.0, + action_noisy_mask=torch.ones(1, 5, 1), + **extra_kwargs, + ) + + assert isinstance(output, tuple) + assert [tuple(tensor.shape) for tensor in output] == expected_shapes + + +def test_forward_with_sound_ulysses_error_mentions_combined_sequence(monkeypatch: pytest.MonkeyPatch) -> None: + import vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 as cosmos3_module + + model = cosmos3_module.Cosmos3VFMTransformer( + SimpleNamespace(tf_model_config=_tiny_cosmos3_config(sound_gen=True, sound_dim=3), dtype=torch.float32) + ) + monkeypatch.setattr(cosmos3_module, "_get_ulysses_state", lambda: (2, 0, None)) + + with pytest.raises(ValueError, match=r"GEN sequence length \(3 = video tokens 2 \+ sound tokens 1\)"): + model( + hidden_states=torch.zeros(1, 2, 1, 1, 2), + timestep=torch.tensor([1.0]), + text_ids=torch.tensor([[1, 2]], dtype=torch.long), + text_mask=torch.ones(1, 2, dtype=torch.long), + video_shape=(1, 1, 2), + fps=24.0, + sound_latents=torch.zeros(1, 3, 1), + ) + + +def test_compute_rope_freqs_places_text_video_action_and_sound_positions() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + class FakeRotary: + def __init__(self) -> None: + self.position_ids: list[torch.Tensor] = [] + + def __call__(self, x, position_ids): + del x + self.position_ids.append(position_ids.detach().cpu()) + batch, seq = position_ids.shape[1], position_ids.shape[2] + return torch.zeros(batch, seq, 4), torch.ones(batch, seq, 4) + + rotary = FakeRotary() + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = SimpleNamespace(rotary_emb=rotary) + model.temporal_modality_margin = 100 + model.base_fps = 24.0 + model.temporal_compression_factor = 4 + model.temporal_compression_factor_sound = 1 + model.sound_latent_fps = 25.0 + model.enable_fps_modulation = False + + freqs_und, freqs_gen = model._compute_rope_freqs( + text_mask=torch.tensor([[1, 1, 0], [1, 0, 0]], dtype=torch.long), + t=2, + hp=1, + wp=1, + fps=24.0, + device=torch.device("cpu"), + dtype=torch.float32, + ) + text_pos, vision_pos = rotary.position_ids + assert text_pos[:, 0, :].tolist() == [[0, 1, 0], [0, 1, 0], [0, 1, 0]] + assert vision_pos[0, 0].tolist() == [102, 103] + assert freqs_und[0].shape == (2, 3, 1, 4) + assert freqs_gen[0].shape == (2, 2, 1, 4) + + rotary.position_ids.clear() + model._compute_rope_freqs( + text_mask=torch.tensor([[1, 1]], dtype=torch.long), + t=2, + hp=1, + wp=1, + fps=24.0, + device=torch.device("cpu"), + dtype=torch.float32, + t_action=2, + action_start_frame_offset=1, + t_sound=1, + ) + + _, gen_pos = rotary.position_ids + assert gen_pos.shape == (3, 1, 5) + assert gen_pos[0, 0].tolist() == [102, 103, 103, 104, 102] diff --git a/tests/diffusion/test_diffusion_engine.py b/tests/diffusion/test_diffusion_engine.py index 2d44f6a6815..eb862a91ae0 100644 --- a/tests/diffusion/test_diffusion_engine.py +++ b/tests/diffusion/test_diffusion_engine.py @@ -10,6 +10,9 @@ from typing import Any import pytest +import torch + +from vllm_omni.diffusion.diffusion_engine import _move_tensor_tree_to_cpu @dataclass @@ -63,6 +66,80 @@ def update_from_output(self, sched_output, runner_output): return [req.request_id for req in sched_output.scheduled_new_reqs] +@pytest.mark.core_model +@pytest.mark.diffusion +@pytest.mark.cpu +def test_move_tensor_tree_keeps_cpu_tensor_identity() -> None: + tensor = torch.arange(8, dtype=torch.float32) + + moved = _move_tensor_tree_to_cpu(tensor) + + assert moved is tensor + + +@pytest.mark.core_model +@pytest.mark.diffusion +@pytest.mark.cpu +def test_move_tensor_tree_preserves_nested_structure_without_mutating_input() -> None: + tensor = torch.arange(4, dtype=torch.float32) + nested_tensor = torch.arange(6, dtype=torch.float32).reshape(2, 3) + sentinel = object() + payload = { + "tensor": tensor, + "list": [nested_tensor, sentinel], + "tuple": ({"inner": tensor}, "metadata"), + "scalar": 3, + } + + moved = _move_tensor_tree_to_cpu(payload) + + assert moved is not payload + assert set(moved) == {"tensor", "list", "tuple", "scalar"} + assert moved["list"] is not payload["list"] + assert moved["tuple"] is not payload["tuple"] + assert moved["tuple"][0] is not payload["tuple"][0] + assert moved["tensor"] is tensor + assert moved["list"][0] is nested_tensor + assert moved["list"][1] is sentinel + assert moved["tuple"][0]["inner"] is tensor + assert moved["tuple"][1] == "metadata" + assert moved["scalar"] == 3 + assert payload["list"][0] is nested_tensor + assert payload["list"][1] is sentinel + assert payload["tuple"][0]["inner"] is tensor + assert payload["tuple"][1] == "metadata" + + +@pytest.mark.core_model +@pytest.mark.diffusion +@pytest.mark.cpu +def test_move_tensor_tree_returns_non_tensor_values_unchanged() -> None: + value = object() + + moved = _move_tensor_tree_to_cpu(value) + + assert moved is value + + +@pytest.mark.diffusion +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_move_tensor_tree_moves_nested_cuda_tensors_to_cpu() -> None: + tensor = torch.arange(8, dtype=torch.float32, device="cuda") + other = torch.arange(4, dtype=torch.int64, device="cuda") + payload = {"tensor": tensor, "items": [other, ("keep", tensor)]} + + moved = _move_tensor_tree_to_cpu(payload) + + assert moved["tensor"].device.type == "cpu" + assert moved["items"][0].device.type == "cpu" + assert moved["items"][1][1].device.type == "cpu" + torch.testing.assert_close(moved["tensor"], tensor.cpu()) + torch.testing.assert_close(moved["items"][0], other.cpu()) + torch.testing.assert_close(moved["items"][1][1], tensor.cpu()) + assert moved["items"][1][0] == "keep" + + @pytest.mark.asyncio async def test_async_add_req_and_wait_for_response(): from vllm_omni.diffusion.diffusion_engine import DiffusionEngine diff --git a/tests/diffusion/test_diffusion_ipc.py b/tests/diffusion/test_diffusion_ipc.py new file mode 100644 index 00000000000..b7995e51601 --- /dev/null +++ b/tests/diffusion/test_diffusion_ipc.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib + +import pytest +import torch + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.ipc import ( + _SHM_TENSOR_THRESHOLD, + _pack_value_if_large, + _unpack_if_shm_handle, + pack_diffusion_output_shm, + unpack_diffusion_output_shm, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +def _large_numel(dtype: torch.dtype) -> int: + return (_SHM_TENSOR_THRESHOLD // torch.empty((), dtype=dtype).element_size()) + 1 + + +def _cleanup_shm_handle(value: object) -> None: + if isinstance(value, dict) and value.get("__tensor_shm__"): + with contextlib.suppress(FileNotFoundError): + _unpack_if_shm_handle(value) + + +def test_diffusion_output_dict_tensors_round_trip_through_shm() -> None: + image = torch.arange(300_000, dtype=torch.float32) + video = torch.arange(300_000, dtype=torch.float32) * 2 + output = DiffusionOutput(output={"image": image, "video": video, "metadata": {"keep": "inline"}}) + + pack_diffusion_output_shm(output) + + assert output.output["image"]["__tensor_shm__"] is True + assert output.output["video"]["__tensor_shm__"] is True + assert output.output["metadata"] == {"keep": "inline"} + + unpack_diffusion_output_shm(output) + + torch.testing.assert_close(output.output["image"], image) + torch.testing.assert_close(output.output["video"], video) + assert output.output["metadata"] == {"keep": "inline"} + + +def test_pack_value_keeps_tensor_at_threshold_inline() -> None: + tensor = torch.arange( + _SHM_TENSOR_THRESHOLD // torch.empty((), dtype=torch.float32).element_size(), + dtype=torch.float32, + ) + + packed = _pack_value_if_large(tensor) + + assert packed is tensor + + +def test_pack_value_packs_large_tensor_and_round_trips() -> None: + tensor = torch.arange(_large_numel(torch.float32), dtype=torch.float32) + packed = _pack_value_if_large(tensor) + + try: + assert isinstance(packed, dict) + assert packed["__tensor_shm__"] is True + assert packed["shape"] == [tensor.numel()] + assert packed["torch_dtype"] == "torch.float32" + + unpacked = _unpack_if_shm_handle(packed) + assert isinstance(unpacked, torch.Tensor) + torch.testing.assert_close(unpacked, tensor) + finally: + _cleanup_shm_handle(packed) + + +def test_pack_value_recurses_nested_dicts_without_mutating_other_values() -> None: + large = torch.arange(_large_numel(torch.float32), dtype=torch.float32) + small = torch.arange(8, dtype=torch.float32) + list_tensor = torch.arange(_large_numel(torch.float32), dtype=torch.float32) + payload = { + "media": { + "large": large, + "small": small, + }, + "list_value": [list_tensor], + "metadata": {"prompt": "keep inline"}, + } + + packed = _pack_value_if_large(payload) + + try: + assert packed is not payload + assert packed["media"] is not payload["media"] + assert packed["media"]["large"]["__tensor_shm__"] is True + assert packed["media"]["small"] is small + assert packed["list_value"] is payload["list_value"] + assert packed["list_value"][0] is list_tensor + assert packed["metadata"] == {"prompt": "keep inline"} + + unpacked_large = _unpack_if_shm_handle(packed["media"]["large"]) + assert isinstance(unpacked_large, torch.Tensor) + torch.testing.assert_close(unpacked_large, large) + finally: + handle = packed.get("media", {}).get("large") if isinstance(packed, dict) else None + _cleanup_shm_handle(handle) + + +def test_pack_value_preserves_dtype_shape_and_values_for_bfloat16() -> None: + tensor = torch.arange(_large_numel(torch.bfloat16), dtype=torch.float32).to(torch.bfloat16).reshape(1, -1) + packed = _pack_value_if_large(tensor) + + try: + assert isinstance(packed, dict) + assert packed["__tensor_shm__"] is True + assert packed["shape"] == list(tensor.shape) + assert packed["torch_dtype"] == "torch.bfloat16" + assert packed["numpy_dtype"] == "float32" + + unpacked = _unpack_if_shm_handle(packed) + assert isinstance(unpacked, torch.Tensor) + assert unpacked.dtype == torch.bfloat16 + torch.testing.assert_close(unpacked, tensor) + finally: + _cleanup_shm_handle(packed) + + +def test_pack_value_packs_non_contiguous_large_tensor_values() -> None: + tensor = torch.arange(_large_numel(torch.float32) * 2, dtype=torch.float32).reshape(-1, 2)[:, 0] + assert not tensor.is_contiguous() + + packed = _pack_value_if_large(tensor) + + try: + assert isinstance(packed, dict) + assert packed["__tensor_shm__"] is True + assert packed["shape"] == list(tensor.shape) + + unpacked = _unpack_if_shm_handle(packed) + assert isinstance(unpacked, torch.Tensor) + torch.testing.assert_close(unpacked, tensor) + finally: + _cleanup_shm_handle(packed) diff --git a/tests/e2e/accuracy/test_cosmos3_similarity.py b/tests/e2e/accuracy/test_cosmos3_similarity.py new file mode 100644 index 00000000000..ff2350096e0 --- /dev/null +++ b/tests/e2e/accuracy/test_cosmos3_similarity.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import base64 +import io +import json +import os +from pathlib import Path + +import pytest +import requests +import torch +from PIL import Image + +from tests.e2e.accuracy.helpers import model_output_dir +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer + +pytestmark = [pytest.mark.full_model, pytest.mark.diffusion] + +MODEL_ENV_VAR = "VLLM_TEST_COSMOS3_MODEL" +MODEL_ID = "cosmos3" +PROMPT = "A small warehouse robot moves a blue box across a clean floor." +NEGATIVE_PROMPT = "blurry, distorted, low quality" +SEED = 42 +WIDTH = HEIGHT = 256 +NUM_INFERENCE_STEPS = 2 + + +def _model_name() -> str: + model = os.environ.get(MODEL_ENV_VAR) + if not model: + pytest.skip(f"Set {MODEL_ENV_VAR} to run Cosmos3 full-model smoke tests.") + if not torch.cuda.is_available(): + pytest.skip("Cosmos3 full-model smoke tests require CUDA.") + return model + + +def _server_args() -> list[str]: + return [ + "--num-gpus", + "1", + "--model-class-name", + "Cosmos3OmniDiffusersPipeline", + "--stage-init-timeout", + "900", + "--init-timeout", + "1200", + ] + + +def _image_data_url(image: Image.Image) -> str: + buf = io.BytesIO() + image.save(buf, format="PNG") + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('ascii')}" + + +@pytest.mark.benchmark +@hardware_test(res={"cuda": "H100"}, num_cards=1) +def test_cosmos3_t2i_serving_smoke(accuracy_artifact_root: Path) -> None: + output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID) + with OmniServer(_model_name(), _server_args(), use_omni=True) as server: + response = requests.post( + f"http://{server.host}:{server.port}/v1/images/generations", + json={ + "model": server.model, + "prompt": PROMPT, + "negative_prompt": NEGATIVE_PROMPT, + "size": f"{WIDTH}x{HEIGHT}", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": NUM_INFERENCE_STEPS, + "guidance_scale": 1.0, + "seed": SEED, + }, + timeout=1800, + ) + + response.raise_for_status() + payload = response.json() + assert len(payload["data"]) == 1 + image = Image.open(io.BytesIO(base64.b64decode(payload["data"][0]["b64_json"]))).convert("RGB") + image.save(output_dir / "cosmos3_t2i.png") + assert image.size == (WIDTH, HEIGHT) + + +@pytest.mark.parametrize( + ("name", "prompt", "num_frames", "image_reference"), + [ + ("t2v", PROMPT, "1", None), + ( + "i2v", + "The blue rectangle moves slowly forward.", + "5", + Image.new("RGB", (96, 64), color=(40, 80, 160)), + ), + ], +) +@pytest.mark.benchmark +@hardware_test(res={"cuda": "H100"}, num_cards=1) +def test_cosmos3_video_serving_smoke( + accuracy_artifact_root: Path, + name: str, + prompt: str, + num_frames: str, + image_reference: Image.Image | None, +) -> None: + output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID) + data = { + "model": "", + "prompt": prompt, + "negative_prompt": NEGATIVE_PROMPT, + "size": f"{WIDTH}x{HEIGHT}", + "num_frames": num_frames, + "fps": "1", + "num_inference_steps": str(NUM_INFERENCE_STEPS), + "guidance_scale": "1.0", + "seed": str(SEED), + } + if image_reference is not None: + data["image_reference"] = json.dumps({"image_url": _image_data_url(image_reference)}) + + with OmniServer(_model_name(), _server_args(), use_omni=True) as server: + data["model"] = server.model + response = requests.post(f"http://{server.host}:{server.port}/v1/videos/sync", data=data, timeout=1800) + + response.raise_for_status() + assert response.headers["content-type"].startswith("video/mp4") + assert response.content + (output_dir / f"cosmos3_{name}.mp4").write_bytes(response.content) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 2556ac71d34..f85b43721cf 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -584,6 +584,7 @@ def test_generate_single_image(test_client): img_bytes = base64.b64decode(data["data"][0]["b64_json"]) img = Image.open(io.BytesIO(img_bytes)) assert img.size == (64, 64) # Our mock returns 64x64 images + assert test_client.app.state.engine_client.captured_prompt["modalities"] == ["image"] def test_generate_images_async_omni_sampling_params(async_omni_test_client): diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index a29f4493c28..dcff3054b38 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -243,6 +243,7 @@ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs): _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) engine = test_client.app.state.openai_serving_video._engine_client + assert engine.captured_prompt["modalities"] == ["video"] captured = engine.captured_sampling_params_list[0] assert captured.num_outputs_per_prompt == 1 assert captured.width == 640 @@ -398,6 +399,8 @@ def test_sampling_params_pass_through(test_client, mocker: MockerFixture): "true_cfg_scale": "4.0", "boundary_ratio": "0.7", "flow_shift": "0.25", + "generate_sound": "true", + "sound_duration": "2.5", }, ) @@ -412,6 +415,8 @@ def test_sampling_params_pass_through(test_client, mocker: MockerFixture): assert captured.true_cfg_scale == 4.0 assert captured.boundary_ratio == 0.7 assert captured.extra_args["flow_shift"] == 0.25 + assert captured.extra_args["generate_sound"] is True + assert captured.extra_args["sound_duration"] == 2.5 def test_frame_interpolation_params_pass_to_diffusion_sampling_params(test_client, mocker: MockerFixture): @@ -622,6 +627,115 @@ async def _generate(prompt, request_id, sampling_params_list): assert completed["stage_durations"] == {"diffuse": 2.5, "vae.decode": 0.3} assert completed["peak_memory_mb"] == 4096.5 + assert completed["action"] is None + + +def test_video_generation_response_exposes_action_payload(mocker: MockerFixture): + engine = FakeAsyncOmni() + handler = OmniOpenAIServingVideo.for_diffusion( + diffusion_engine=engine, + model_name="Cosmos3-8B-UVA", + ) + + async def _generate(prompt, request_id, sampling_params_list): + del prompt, request_id, sampling_params_list + import numpy as np + + yield MockVideoResult( + [object()], + custom_output={ + "action": np.array([[[1.5, 2.5], [3.5, 4.5]]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + engine.generate = _generate + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video.encode_video_base64", + return_value="encoded-video", + ) + + response = asyncio.run( + handler.generate_videos( + VideoGenerationRequest(prompt="predict actions"), + "action-json", + ) + ) + + action = response.data[0].action + assert action is not None + assert action.data == [[1.5, 2.5], [3.5, 4.5]] + assert action.shape == [2, 2] + assert action.dtype == "float32" + assert action.raw_action_dim == 2 + assert action.action_mode == "policy" + assert action.domain_id == 7 + assert response.model_dump(mode="json")["data"][0]["action"]["data"] == [[1.5, 2.5], [3.5, 4.5]] + + +def test_video_job_persists_action_metadata(test_client, mocker: MockerFixture): + engine = test_client.app.state.openai_serving_video._engine_client + + async def _generate(prompt, request_id, sampling_params_list): + import numpy as np + + engine.captured_prompt = prompt + engine.captured_sampling_params_list = sampling_params_list + yield MockVideoResult( + [object()], + custom_output={ + "action": np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + engine.generate = _generate + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes", + return_value=b"fake-video", + ) + + response = test_client.post("/v1/videos", data={"prompt": "profile me"}) + assert response.status_code == 200 + video_id = response.json()["id"] + completed = _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + + expected_action = { + "data": [[1.0, 2.0], [3.0, 4.0]], + "shape": [2, 2], + "dtype": "float32", + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + } + assert completed["action"] == expected_action + + listed = test_client.get("/v1/videos").json() + assert listed["data"][0]["action"] == expected_action + + +def test_action_extraction_accepts_unbatched_action(): + import numpy as np + + result = MockVideoResult( + [object()], + custom_output={ + "action": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + actions = OmniOpenAIServingVideo._extract_action_outputs(result, expected_count=1) + + assert actions[0] is not None + assert actions[0].data == [[1.0, 2.0], [3.0, 4.0]] + assert actions[0].shape == [2, 2] def test_missing_handler_returns_503(): @@ -755,6 +869,9 @@ def test_invalid_uploaded_input_reference_returns_400(test_client): def test_video_request_validation(): req = VideoGenerationRequest(prompt="test") assert req.prompt == "test" + assert req.generate_sound is False + assert req.sound_duration is None + assert VideoGenerationRequest(prompt="test", generate_sound=True, sound_duration=1.5).generate_sound is True with pytest.raises(ValueError): VideoGenerationRequest(prompt="test", size="invalid") @@ -767,6 +884,8 @@ def test_video_request_validation(): VideoGenerationRequest(prompt="test", frame_interpolation_exp=0) with pytest.raises(ValueError): VideoGenerationRequest(prompt="test", frame_interpolation_scale=0) + with pytest.raises(ValueError): + VideoGenerationRequest(prompt="test", sound_duration=0) def test_list_videos_supports_order_after_and_limit(test_client, mocker: MockerFixture): @@ -1063,6 +1182,8 @@ def test_sync_t2v_returns_video_bytes(test_client, mocker: MockerFixture): assert float(response.headers["x-inference-time-s"]) >= 0 assert json.loads(response.headers["x-stage-durations"]) == {} assert float(response.headers["x-peak-memory-mb"]) == 0.0 + engine = test_client.app.state.openai_serving_video._engine_client + assert engine.captured_prompt["modalities"] == ["video"] def test_sync_t2v_returns_profiler_headers(test_client, mocker: MockerFixture): diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py index 3612020d4fd..5aa79e0580a 100644 --- a/tests/entrypoints/test_omni_entrypoints.py +++ b/tests/entrypoints/test_omni_entrypoints.py @@ -851,7 +851,10 @@ def _enqueue_stage_error( """Enqueue a stage error output, optionally killing the engine.""" if kill_engine: engine._alive = False - engine_output = OmniRequestOutput.from_error(msg["request_id"], error_text) + engine_output = OmniRequestOutput.from_error( + msg["request_id"], + error_text, + ) engine_output.payload = "" engine.output_q.put_nowait( OutputMessage( diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py index 4a8f48de6fa..ff640b410ab 100644 --- a/vllm_omni/config/pipeline_registry.py +++ b/vllm_omni/config/pipeline_registry.py @@ -33,6 +33,10 @@ # --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) --- _OMNI_PIPELINES: dict[str, tuple[str, str]] = { # model_type -> (module_path, variable_name) + "cosmos3_omni": ( + "vllm_omni.diffusion.models.cosmos3.pipeline", + "COSMOS3_PIPELINE", + ), "qwen2_5_omni": ( "vllm_omni.model_executor.models.qwen2_5_omni.pipeline", "QWEN2_5_OMNI_PIPELINE", diff --git a/vllm_omni/deploy/cosmos3_omni.yaml b/vllm_omni/deploy/cosmos3_omni.yaml new file mode 100644 index 00000000000..2f3ed85a797 --- /dev/null +++ b/vllm_omni/deploy/cosmos3_omni.yaml @@ -0,0 +1,14 @@ +# Cosmos3 single-stage diffusion deploy config. +# +# This config is auto-loaded for Diffusers repos whose model_index.json has +# _class_name: Cosmos3OmniDiffusersPipeline. Pass --deploy-config only for +# local overrides such as disabling guardrails. + +async_chunk: false +trust_remote_code: true + +stages: + - stage_id: 0 + max_num_seqs: 1 + enforce_eager: true + model_class_name: Cosmos3OmniDiffusersPipeline diff --git a/vllm_omni/diffusion/attention/backends/cudnn_attn.py b/vllm_omni/diffusion/attention/backends/cudnn_attn.py index f27fe18706f..44026c56910 100644 --- a/vllm_omni/diffusion/attention/backends/cudnn_attn.py +++ b/vllm_omni/diffusion/attention/backends/cudnn_attn.py @@ -51,6 +51,8 @@ def __init__( ) -> None: self.causal = causal self.softmax_scale = softmax_scale + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads def forward_cuda( self, @@ -84,6 +86,7 @@ def forward_cuda( dropout_p=0.0, is_causal=self.causal, scale=self.softmax_scale, + enable_gqa=self.num_heads != self.num_kv_heads, ) except RuntimeError as e: if "No available kernel" not in str(e): diff --git a/vllm_omni/diffusion/attention/backends/sdpa.py b/vllm_omni/diffusion/attention/backends/sdpa.py index ab71e753b25..c650313698d 100644 --- a/vllm_omni/diffusion/attention/backends/sdpa.py +++ b/vllm_omni/diffusion/attention/backends/sdpa.py @@ -91,6 +91,8 @@ def __init__( self.softmax_scale = softmax_scale if backend_kwargs: logger.warning("SDPAImpl ignoring backend_kwargs: %s", list(backend_kwargs.keys())) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads def _forward_impl( self, @@ -115,6 +117,7 @@ def _forward_impl( dropout_p=0.0, is_causal=self.causal, scale=self.softmax_scale, + enable_gqa=self.num_heads != self.num_kv_heads, ) out = output.permute(0, 2, 1, 3) return out diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 7c27c31eb0b..7c13df6ef92 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -1568,6 +1568,76 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_cosmos3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Cosmos3 (T2V and I2V). + + Cosmos3 has a dual-pathway architecture (UND + GEN) but only the GEN + pathway (``gen_layers``) runs at every denoising step. The UND pathway + computes once and its K/V are cached by the pipeline itself; no cache-dit + needed there. We wrap only ``gen_layers`` via ``BlockAdapter``. + + Args: + pipeline: The Cosmos3 pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + logger.info( + f"Enabling cache-dit on Cosmos3 gen_layers: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[pipeline.transformer.gen_layers], + # Cosmos3 GEN blocks return only hidden_states. Per-layer UND K/V + # conditioning uses the transformer's cache-dit fallback path. + forward_pattern=[ForwardPattern.Pattern_3], + params_modifiers=[ + ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ), + ], + check_forward_pattern=False, + has_separate_cfg=True, + ), + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + # Register custom cache-dit enablers after function definitions CUSTOM_DIT_ENABLERS.update( { @@ -1594,6 +1664,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "HunyuanVideo15Pipeline": enable_cache_for_hunyuan_video_15, "HunyuanVideo15I2VPipeline": enable_cache_for_hunyuan_video_15, "HeliosPipeline": enable_cache_for_helios, + "Cosmos3OmniDiffusersPipeline": enable_cache_for_cosmos3, } ) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 8cb0d52967a..8d53eb6fddd 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -83,6 +83,18 @@ def supports_audio_output(model_class_name: str) -> bool: return bool(getattr(model_cls, "support_audio_output", False)) +def _move_tensor_tree_to_cpu(value: object) -> object: + if isinstance(value, torch.Tensor): + return value.cpu() if value.device.type != "cpu" else value + if isinstance(value, dict): + return {key: _move_tensor_tree_to_cpu(item) for key, item in value.items()} + if isinstance(value, list): + return [_move_tensor_tree_to_cpu(item) for item in value] + if isinstance(value, tuple): + return tuple(_move_tensor_tree_to_cpu(item) for item in value) + return value + + def get_extra_body_params(model_class_name: str) -> frozenset[str]: """Return the set of extra_body keys accepted by a pipeline. @@ -230,12 +242,8 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: # post-processing to avoid device OOM — model weights may still # reside on the device and leave no headroom for intermediates. output_data = output.output - if ( - self.od_config.enable_cpu_offload - and isinstance(output_data, torch.Tensor) - and output_data.device.type != "cpu" - ): - output_data = output_data.cpu() + if self.od_config.enable_cpu_offload: + output_data = _move_tensor_tree_to_cpu(output_data) postprocess_start_time = time.perf_counter() if self.post_process_func is not None: @@ -256,7 +264,10 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: custom_output.update(outputs.get("custom_output") or {}) model_audio_sample_rate = outputs.get("audio_sample_rate") model_fps = outputs.get("fps") - outputs = outputs.get("video", outputs) + if "image" in outputs: + outputs = outputs["image"] + elif "video" in outputs: + outputs = outputs["video"] postprocess_time = time.perf_counter() - postprocess_start_time logger.debug("Post-processing completed in %.4f seconds", postprocess_time) diff --git a/vllm_omni/diffusion/ipc.py b/vllm_omni/diffusion/ipc.py index 6a96533fd40..d4989da3d9e 100644 --- a/vllm_omni/diffusion/ipc.py +++ b/vllm_omni/diffusion/ipc.py @@ -85,16 +85,26 @@ def _pack_tensor_if_large(val: torch.Tensor) -> torch.Tensor | dict: return val +def _pack_value_if_large(val: object) -> object: + if isinstance(val, torch.Tensor): + return _pack_tensor_if_large(val) + if isinstance(val, dict): + return {key: _pack_value_if_large(value) for key, value in val.items()} + return val + + def _unpack_if_shm_handle(val: object) -> object: """Reconstruct a tensor from an SHM handle dict, or return as-is.""" if isinstance(val, dict) and val.get("__tensor_shm__"): return _tensor_from_shm(val) + if isinstance(val, dict): + return {key: _unpack_if_shm_handle(value) for key, value in val.items()} return val def _pack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput: - if output.output is not None and isinstance(output.output, torch.Tensor): - output.output = _pack_tensor_if_large(output.output) + if output.output is not None: + output.output = _pack_value_if_large(output.output) if output.trajectory_latents is not None and isinstance(output.trajectory_latents, torch.Tensor): output.trajectory_latents = _pack_tensor_if_large(output.trajectory_latents) if output.trajectory_timesteps is not None and isinstance(output.trajectory_timesteps, torch.Tensor): diff --git a/vllm_omni/diffusion/models/cosmos3/__init__.py b/vllm_omni/diffusion/models/cosmos3/__init__.py new file mode 100644 index 00000000000..6df062b5c0d --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + get_cosmos3_post_process_func, + get_cosmos3_pre_process_func, +) +from .transformer_cosmos3 import Cosmos3VFMTransformer + +__all__ = [ + "Cosmos3OmniDiffusersPipeline", + "get_cosmos3_post_process_func", + "get_cosmos3_pre_process_func", + "Cosmos3VFMTransformer", +] diff --git a/vllm_omni/diffusion/models/cosmos3/action.py b/vllm_omni/diffusion/models/cosmos3/action.py new file mode 100644 index 00000000000..e2572bbb733 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/action.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Action-token helpers for Cosmos3 UVA/action generation.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +ACTION_MODE_POLICY = "policy" +ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" +ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" +ACTION_MODES = { + ACTION_MODE_POLICY, + ACTION_MODE_FORWARD_DYNAMICS, + ACTION_MODE_INVERSE_DYNAMICS, +} + + +EMBODIMENT_TO_DOMAIN_ID: dict[str, int] = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + + +VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { + "256": { + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (320, 192), + "9,16": (192, 320), + }, + "480": { + "1,1": (640, 640), + "4,3": (736, 544), + "3,4": (544, 736), + "16,9": (832, 480), + "9,16": (480, 832), + }, + "704": { + "1,1": (960, 960), + "4,3": (1088, 832), + "3,4": (832, 1088), + "16,9": (1280, 704), + "9,16": (704, 1280), + }, + "720": { + "1,1": (960, 960), + "4,3": (1104, 832), + "3,4": (832, 1104), + "16,9": (1280, 720), + "9,16": (720, 1280), + }, +} + + +def normalize_action_mode(mode: Any) -> str | None: + if mode is None: + return None + normalized = str(mode).strip().lower() + if not normalized: + return None + if normalized not in ACTION_MODES: + raise ValueError(f"Unsupported Cosmos3 action_mode={mode!r}; expected one of {sorted(ACTION_MODES)}.") + return normalized + + +def resolve_domain_id( + *, + domain_id: Any = None, + domain_name: Any = None, + require_explicit: bool = False, +) -> int: + if domain_id is not None: + resolved = int(domain_id) + if resolved < 0: + raise ValueError(f"Cosmos3 domain_id must be non-negative, got {resolved}.") + return resolved + + if domain_name is None or str(domain_name).strip() == "": + if require_explicit: + raise ValueError( + "Cosmos3 action generation requires extra_args['domain_id'] or non-empty extra_args['domain_name']." + ) + return 0 + + key = str(domain_name).strip().lower() + if key not in EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(EMBODIMENT_TO_DOMAIN_ID)} or pass domain_id directly." + ) + return EMBODIMENT_TO_DOMAIN_ID[key] + + +def action_condition_indexes(mode: str, action_length: int) -> list[int]: + mode = normalize_action_mode(mode) + if mode == ACTION_MODE_FORWARD_DYNAMICS: + return list(range(action_length)) + if mode in {ACTION_MODE_POLICY, ACTION_MODE_INVERSE_DYNAMICS}: + return [] + raise AssertionError(f"Unexpected action mode: {mode!r}") + + +def vision_condition_indexes(mode: str, video_length: int, temporal_compression_factor: int) -> list[int]: + mode = normalize_action_mode(mode) + latent_frames = (video_length - 1) // temporal_compression_factor + 1 + if mode in {ACTION_MODE_POLICY, ACTION_MODE_FORWARD_DYNAMICS}: + return [0] + if mode == ACTION_MODE_INVERSE_DYNAMICS: + return list(range(latent_frames)) + raise AssertionError(f"Unexpected action mode: {mode!r}") + + +def action_start_frame_offset(mode: str, action_length: int, video_length: int) -> int: + del mode + if action_length == video_length - 1: + return 1 + if action_length == video_length: + return 0 + raise ValueError( + "Cosmos3 action_chunk_size must equal num_frames - 1 or num_frames; " + f"got action_chunk_size={action_length}, num_frames={video_length}." + ) + + +def build_action_condition_mask( + mode: str, + action_length: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.zeros(1, action_length, 1, device=device, dtype=dtype) + for idx in action_condition_indexes(mode, action_length): + mask[:, idx, :] = 1.0 + return mask + + +def build_vision_condition_mask( + mode: str, + video_length: int, + temporal_compression_factor: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + latent_frames = (video_length - 1) // temporal_compression_factor + 1 + mask = torch.zeros(1, 1, latent_frames, 1, 1, device=device, dtype=dtype) + for idx in vision_condition_indexes(mode, video_length, temporal_compression_factor): + mask[:, :, idx, :, :] = 1.0 + return mask + + +def pad_action_to_dim(action: torch.Tensor, action_dim: int) -> torch.Tensor: + if action.shape[-1] > action_dim: + raise ValueError(f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}.") + if action.shape[-1] == action_dim: + return action + padding = torch.zeros(*action.shape[:-1], action_dim - action.shape[-1], dtype=action.dtype, device=action.device) + return torch.cat([action, padding], dim=-1) + + +def load_action_tensor(action: Any = None, action_path: str | Path | None = None) -> torch.Tensor: + if action is None and action_path is None: + raise ValueError( + "Cosmos3 forward_dynamics action mode requires extra_args['action'] or extra_args['action_path']." + ) + if action is None: + action = json.loads(Path(str(action_path)).read_text()) + if isinstance(action, torch.Tensor): + tensor = action.detach().to(dtype=torch.float32) + else: + tensor = torch.as_tensor(np.asarray(action), dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2: + raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.") + return tensor + + +def find_closest_target_size(h: int, w: int, resolution: str | int) -> tuple[int, int]: + key = str(resolution) + if key not in VIDEO_RES_SIZE_INFO: + raise ValueError( + f"Unknown Cosmos3 action resolution={resolution!r}; expected one of {sorted(VIDEO_RES_SIZE_INFO)}." + ) + input_ratio = h / w + best_size = None + best_diff = float("inf") + for cand_w, cand_h in VIDEO_RES_SIZE_INFO[key].values(): + diff = abs(input_ratio - cand_h / cand_w) + if diff < best_diff: + best_diff = diff + best_size = (cand_w, cand_h) + assert best_size is not None + return best_size diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py new file mode 100644 index 00000000000..cfb794705ba --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .avae import Cosmos3AVAEAudioTokenizer + +__all__ = ["Cosmos3AVAEAudioTokenizer"] diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py new file mode 100644 index 00000000000..4ddb8d41527 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Diffusers-format AVAE audio tokenizer used by Cosmos3 sound generation.""" + +from __future__ import annotations + +import json +import math +from pathlib import Path +from typing import Any + +import torch +from torch import nn +from torch.nn.utils import weight_norm +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.progress_bar import _is_rank_zero + +logger = init_logger(__name__) + + +def _default_avae_config( + *, + sample_rate: int, + audio_channels: int, + io_channels: int, + hop_size: int, +) -> dict[str, Any]: + return { + "sampling_rate": sample_rate, + "hop_size": hop_size, + "dec_dim": 320, + "dec_c_mults": [1, 2, 4, 8, 16], + "dec_strides": [2, 4, 5, 6, 8], + "dec_out_channels": audio_channels, + "vocoder_input_dim": io_channels, + "normalization_type": "none", + "normalize_latents": False, + "tanh_input_scale": 1.5, + "tanh_output_scale": 3.5, + "tanh_clamp": 0.995, + } + + +def _config_get(config: dict[str, Any], *keys: str, default: Any = None) -> Any: + for key in keys: + value = config.get(key) + if value is not None: + return value + return default + + +def _load_config( + config_path: str | Path | None, + *, + sample_rate: int, + audio_channels: int, + io_channels: int, + hop_size: int, +) -> dict[str, Any]: + if config_path: + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + if not isinstance(config, dict): + raise TypeError(f"Cosmos3 AVAE config must be a JSON object, got {type(config)!r}.") + return config + return _default_avae_config( + sample_rate=sample_rate, + audio_channels=audio_channels, + io_channels=io_channels, + hop_size=hop_size, + ) + + +def _load_checkpoint(path: str | Path, map_location: torch.device | str) -> dict[str, torch.Tensor]: + path = Path(path) + if path.suffix == ".safetensors": + try: + from safetensors.torch import load_file + except ImportError as exc: + raise ImportError("Loading AVAE .safetensors checkpoints requires safetensors.") from exc + checkpoint = load_file(str(path), device=str(map_location)) + else: + checkpoint = torch.load(path, map_location=map_location) + + if not isinstance(checkpoint, dict): + raise TypeError(f"AVAE checkpoint must be a flat state dict, got {type(checkpoint)!r}.") + if not all(isinstance(value, torch.Tensor) for value in checkpoint.values()): + raise TypeError("AVAE checkpoint must be a flat tensor state dict.") + return checkpoint + + +def _validate_diffusers_state_dict(state_dict: dict[str, torch.Tensor]) -> None: + if not state_dict: + raise RuntimeError("AVAE checkpoint is empty.") + + if not any(key.startswith("decoder.") for key in state_dict): + raise RuntimeError("Cosmos3 AVAE checkpoint must contain diffusers-format decoder.* keys.") + + +class Snake1d(nn.Module): + """One-dimensional Snake activation matching diffusers' Oobleck layout.""" + + def __init__(self, hidden_dim: int, logscale: bool = True) -> None: + super().__init__() + self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.logscale = logscale + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + shape = hidden_states.shape + alpha = torch.exp(self.alpha) if self.logscale else self.alpha + beta = torch.exp(self.beta) if self.logscale else self.beta + hidden_states = hidden_states.reshape(shape[0], shape[1], -1) + hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) + return hidden_states.reshape(shape) + + +class OobleckResidualUnit(nn.Module): + """Residual unit used by the diffusers Oobleck decoder.""" + + def __init__(self, dimension: int = 16, dilation: int = 1) -> None: + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.snake1 = Snake1d(dimension) + self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) + self.snake2 = Snake1d(dimension) + self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + output_tensor = self.conv1(self.snake1(hidden_state)) + output_tensor = self.conv2(self.snake2(output_tensor)) + padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 + if padding > 0: + hidden_state = hidden_state[..., padding:-padding] + return hidden_state + output_tensor + + +class OobleckDecoderBlock(nn.Module): + """Decoder block used by the diffusers Oobleck decoder.""" + + def __init__(self, input_dim: int, output_dim: int, stride: int = 1, output_padding: int = 0) -> None: + super().__init__() + self.snake1 = Snake1d(input_dim) + self.conv_t1 = weight_norm( + nn.ConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=output_padding, + ) + ) + self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv_t1(hidden_state) + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + return self.res_unit3(hidden_state) + + +class OobleckDecoder(nn.Module): + """Diffusers-compatible Oobleck decoder for Cosmos3 AVAE latents.""" + + def __init__( + self, + channels: int, + input_channels: int, + audio_channels: int, + upsampling_ratios: list[int], + channel_multiples: list[int], + ) -> None: + super().__init__() + strides = upsampling_ratios + channel_multiples = [1] + channel_multiples + + self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) + + block = [] + for stride_index, stride in enumerate(strides): + block.append( + OobleckDecoderBlock( + input_dim=channels * channel_multiples[len(strides) - stride_index], + output_dim=channels * channel_multiples[len(strides) - stride_index - 1], + stride=stride, + output_padding=stride % 2, + ) + ) + self.block = nn.ModuleList(block) + self.snake1 = Snake1d(channels) + self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv1(hidden_state) + for layer in self.block: + hidden_state = layer(hidden_state) + hidden_state = self.snake1(hidden_state) + return self.conv2(hidden_state) + + +class Cosmos3AVAEAudioTokenizer(nn.Module): + """Decoder-only AVAE tokenizer for Cosmos3 audio latents.""" + + def __init__( + self, + *, + checkpoint_path: str | Path, + config_path: str | Path | None = None, + sample_rate: int = 48000, + audio_channels: int = 2, + io_channels: int = 64, + hop_size: int = 1920, + normalize_latents: bool = False, + normalization_type: str = "none", + tanh_input_scale: float = 1.5, + tanh_output_scale: float = 3.5, + tanh_clamp: float = 0.995, + dtype: torch.dtype = torch.bfloat16, + device: torch.device | str = "cuda", + ) -> None: + super().__init__() + self.dtype = dtype + self.device = torch.device(device) + + config = _load_config( + config_path, + sample_rate=sample_rate, + audio_channels=audio_channels, + io_channels=io_channels, + hop_size=hop_size, + ) + self.sample_rate = int(_config_get(config, "sampling_rate", "sample_rate", default=sample_rate)) + self.audio_channels = int( + _config_get( + config, + "dec_out_channels", + "audio_channels", + default=2 if bool(config.get("stereo", audio_channels == 2)) else 1, + ) + ) + self.latent_ch = int(_config_get(config, "vocoder_input_dim", "io_channels", "latent_ch", default=io_channels)) + dec_strides = [int(stride) for stride in _config_get(config, "dec_strides", default=[2, 4, 5, 6, 8])] + self.hop_size = int( + _config_get(config, "hop_size", default=math.prod(dec_strides) if dec_strides else hop_size) + ) + dec_stride_product = math.prod(dec_strides) + if dec_stride_product != self.hop_size: + raise ValueError( + "Cosmos3 AVAE config dec_strides product must equal hop_size " + f"for correct latent/audio duration math: product={dec_stride_product}, hop_size={self.hop_size}." + ) + + normalization_type = str(_config_get(config, "normalization_type", default=normalization_type)) + normalize_latents = bool(_config_get(config, "normalize_latents", default=normalize_latents)) + if normalization_type == "none" and normalize_latents: + normalization_type = "tanh" + self.normalization_type = normalization_type + self.tanh_input_scale = float(_config_get(config, "tanh_input_scale", default=tanh_input_scale)) + self.tanh_output_scale = float(_config_get(config, "tanh_output_scale", default=tanh_output_scale)) + self.tanh_clamp = float(_config_get(config, "tanh_clamp", default=tanh_clamp)) + + self.decoder = OobleckDecoder( + channels=int(_config_get(config, "dec_dim", default=320)), + input_channels=self.latent_ch, + audio_channels=self.audio_channels, + upsampling_ratios=list(reversed(dec_strides)), + channel_multiples=list(_config_get(config, "dec_c_mults", default=[1, 2, 4, 8, 16])), + ) + state_dict = _load_checkpoint(checkpoint_path, self.device) + _validate_diffusers_state_dict(state_dict) + + # The checkpoint also contains encoder weights, which we do not support here, hence strict=False + self.load_state_dict(state_dict, strict=False) + + self.eval() + for param in self.parameters(): + param.requires_grad = False + self.to(device=self.device, dtype=self.dtype) + if _is_rank_zero(): + logger.info("Loaded diffusers-format Cosmos3 AVAE checkpoint from %s", checkpoint_path) + + @property + def temporal_compression_factor(self) -> int: + return self.hop_size + + def get_latent_num_samples(self, num_audio_samples: int) -> int: + return int(num_audio_samples) // self.temporal_compression_factor + + def get_audio_num_samples(self, num_latent_samples: int) -> int: + return int(num_latent_samples) * self.temporal_compression_factor + + def _denormalize_latent(self, latent: torch.Tensor) -> torch.Tensor: + if self.normalization_type == "tanh": + in_dtype = latent.dtype + latent = torch.clamp( + latent.float() / self.tanh_output_scale, + -self.tanh_clamp, + self.tanh_clamp, + ) + return (torch.atanh(latent) * self.tanh_input_scale).to(in_dtype) + if self.normalization_type != "none": + raise ValueError(f"Unsupported AVAE normalization_type={self.normalization_type!r}.") + return latent + + @torch.no_grad() + def encode(self, audio: torch.Tensor, force_pad: bool = False) -> torch.Tensor: + del audio, force_pad + raise NotImplementedError("Cosmos3AVAEAudioTokenizer is decoder-only for diffusers-format sound_tokenizer/.") + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + in_dtype = latent.dtype + squeeze = latent.ndim == 2 + if squeeze: + latent = latent.unsqueeze(0) + z = self._denormalize_latent(latent.to(self.device)).to(self.dtype) + audio = self.decoder(z).clamp(-1.0, 1.0).to(in_dtype) + return audio.squeeze(0) if squeeze else audio diff --git a/vllm_omni/diffusion/models/cosmos3/guardrails.py b/vllm_omni/diffusion/models/cosmos3/guardrails.py new file mode 100644 index 00000000000..71525a6272e --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/guardrails.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 guardrail hooks for vllm-omni. + +Thin adapter around the ``cosmos_guardrail`` package's ``CosmosSafetyChecker`` +(Blocklist + Qwen3Guard for text, RetinaFace face-blur for video). + +Enable via custom_pipeline_args or the test script: + python test_cosmos3.py --model ... +Disable explicitly with ``--no-guardrails``. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.progress_bar import _is_rank_zero + +logger = init_logger(__name__) + + +try: + from cosmos_guardrail import CosmosSafetyChecker + + _COSMOS_GUARDRAIL_AVAILABLE = True +except ImportError: + _COSMOS_GUARDRAIL_AVAILABLE = False + + class CosmosSafetyChecker: # type: ignore[no-redef] + # Raised at runtime (not import time) so guardrail-less inference + # continues to work when ``cosmos_guardrail`` is not installed and + # ``model_config["guardrails"]`` is False. + def __init__(self, *args, **kwargs): + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement]" + "(https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + f"Please install cosmos-guardrail package to enable safety checks." + ) + + +TextGuardrailFn = Callable[[str], None] +VideoGuardrailFn = Callable[[np.ndarray], np.ndarray] + +_text_guardrail: TextGuardrailFn | None = None +_video_guardrail: VideoGuardrailFn | None = None + + +# --------------------------------------------------------------------------- +# Default guardrail builders +# --------------------------------------------------------------------------- +def _nn_models(runner: Any) -> list[torch.nn.Module]: + return [m for m in runner.models if isinstance(m, torch.nn.Module)] + + +def _build_text_guardrail(checker: Any) -> TextGuardrailFn: + def text_guardrail(prompt: str) -> None: + if not checker.check_text_safety(prompt): + # CosmosSafetyChecker logs the specific reason at CRITICAL. + raise ValueError("Guardrail blocked prompt") + + return text_guardrail + + +def _build_video_guardrail(checker: Any, offload_to_cpu: bool) -> VideoGuardrailFn: + video_models = _nn_models(checker.video_guardrail) + compute_device = "cuda" + + def video_guardrail(frames: np.ndarray) -> np.ndarray: + if offload_to_cpu: + for m in video_models: + m.to(compute_device) + try: + result = checker.check_video_safety(frames) + finally: + if offload_to_cpu: + for m in video_models: + m.to("cpu") + # ``check_video_safety`` returns ``None`` when the content safety + # filter blocks the frames. The face-blur postprocessor (the only + # video module enabled by default) does not block, so in practice + # ``result`` is always an ndarray here. + return result if result is not None else frames + + return video_guardrail + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- +def _init_default_guardrails(offload_to_cpu: bool = False) -> None: + global _text_guardrail, _video_guardrail + if _text_guardrail is not None: + return + if _is_rank_zero(): + logger.info("Initializing Cosmos3 guardrails (offload_to_cpu=%s)...", offload_to_cpu) + + # Instantiation raises ValueError when ``cosmos_guardrail`` is not + # installed - this is the right moment to fail loudly because the + # caller has opted in to guardrails. + checker = CosmosSafetyChecker() + + # Place text models on their resting device permanently. Video models + # idle on CPU when offload is on and move to GPU per-call (handled in + # the video guardrail closure). + idle_device = "cpu" if offload_to_cpu else "cuda" + for m in _nn_models(checker.text_guardrail): + m.to(idle_device) + for m in _nn_models(checker.video_guardrail): + m.to(idle_device) + + _text_guardrail = _build_text_guardrail(checker) + _video_guardrail = _build_video_guardrail(checker, offload_to_cpu) + if _is_rank_zero(): + logger.info("Cosmos3 guardrails initialized.") + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- +def ensure_initialized(od_config: Any) -> None: + if not is_guardrails_enabled(od_config): + return + cfg = getattr(od_config, "model_config", None) or {} + _init_default_guardrails(offload_to_cpu=bool(cfg.get("offload_guardrail_models", False))) + + +def check_text_safety(prompt: str) -> None: + if _text_guardrail is not None: + _text_guardrail(prompt) + + +def check_video_safety(video_tensor: torch.Tensor) -> torch.Tensor: + if _video_guardrail is None: + return video_tensor + + v = video_tensor.detach().cpu().float() + if v.dim() == 5: + v = v[0] + v = v.clamp(-1, 1) * 0.5 + 0.5 + frames_np = (v.permute(1, 2, 3, 0).numpy() * 255).round().astype(np.uint8) + + frames_np = _video_guardrail(frames_np) + + # Convert back to [-1, 1] to match the VAE output range. + result = torch.from_numpy(frames_np.copy()).float() / 127.5 - 1.0 + result = result.permute(3, 0, 1, 2) + if video_tensor.dim() == 5: + result = result.unsqueeze(0) + return result.to(video_tensor.device) + + +def is_guardrails_enabled(od_config: Any, sampling_params: Any = None) -> bool: + """Resolve the active guardrail gate. + + Server-level ``od_config.model_config["guardrails"]`` decides whether the + guardrail models are loaded at all (eager load at pipeline build time). + When that is False, no per-request override can turn checks back on, + because the singletons in this module are never populated. + + When the server gate is on, ``sampling_params.extra_args["guardrails"]`` + may override on a per-request basis: ``False`` skips the check for that + request, anything else (or missing) keeps the default behavior. + """ + cfg = getattr(od_config, "model_config", None) or {} + if not bool(cfg.get("guardrails", True)): + return False + if sampling_params is not None: + extra = getattr(sampling_params, "extra_args", None) or {} + per_request = extra.get("guardrails") + if per_request is not None: + return bool(per_request) + return True diff --git a/vllm_omni/diffusion/models/cosmos3/pipeline.py b/vllm_omni/diffusion/models/cosmos3/pipeline.py new file mode 100644 index 00000000000..23bd47f0a3d --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/pipeline.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 deploy-schema topology.""" + +from vllm_omni.config.stage_config import ( + PipelineConfig, + StageExecutionType, + StagePipelineConfig, +) + +COSMOS3_PIPELINE = PipelineConfig( + model_type="cosmos3_omni", + model_arch="Cosmos3ForConditionalGeneration", + hf_architectures=("Cosmos3ForConditionalGeneration",), + diffusers_class_name="Cosmos3OmniDiffusersPipeline", + stages=( + StagePipelineConfig( + stage_id=0, + model_stage="diffusion", + execution_type=StageExecutionType.DIFFUSION, + input_sources=(), + final_output=True, + final_output_type="image", + ), + ), +) diff --git a/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py b/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py new file mode 100644 index 00000000000..45ef135086c --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py @@ -0,0 +1,1834 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 text/image-to-video and text-to-image pipeline for vllm-omni. + +Single pipeline class supports T2V, I2V, and T2I; the mode is selected at +runtime by: + +* ``prompt["modalities"]`` contains ``"image"``: **T2I** (text-to-image). +* ``prompt["modalities"]`` contains ``"video"`` or is omitted: **T2V** + (text-to-video). +* ``multi_modal_data['image']`` present on the prompt: **I2V** + (handled by :func:`get_cosmos3_pre_process_func`) + +""" + +from __future__ import annotations + +import math +import os +import time +from collections.abc import Iterable +from typing import Any, ClassVar + +import numpy as np +import PIL.Image +import torch +from diffusers import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from torch import nn +from transformers import AutoTokenizer +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import DistributedAutoencoderKLWan +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .action import ( + ACTION_MODE_FORWARD_DYNAMICS, + ACTION_MODE_INVERSE_DYNAMICS, + ACTION_MODE_POLICY, + action_start_frame_offset, + build_action_condition_mask, + build_vision_condition_mask, + find_closest_target_size, + load_action_tensor, + normalize_action_mode, + pad_action_to_dim, + resolve_domain_id, +) +from .transformer_cosmos3 import Cosmos3VFMTransformer + +logger = init_logger(__name__) + +COSMOS3_DURATION_TEMPLATE = "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS." +COSMOS3_RESOLUTION_TEMPLATE = "This video is of {height}x{width} resolution." +COSMOS3_IMAGE_RESOLUTION_TEMPLATE = "This image is of {height}x{width} resolution." +COSMOS3_INVERSE_DURATION_TEMPLATE = "The video is not {duration:.1f} seconds long and is not of {fps:.0f} FPS." +COSMOS3_INVERSE_RESOLUTION_TEMPLATE = "This video is not of {height}x{width} resolution." +COSMOS3_INVERSE_IMAGE_RESOLUTION_TEMPLATE = "This image is not of {height}x{width} resolution." +COSMOS3_SYSTEM_PROMPT = "You are a helpful assistant who will generate videos from a given prompt." +COSMOS3_T2I_SYSTEM_PROMPT = "You are a helpful assistant who will generate images from a given prompt." + + +# --------------------------------------------------------------------------- +# Post-process function (registered in registry.py) +# --------------------------------------------------------------------------- +def get_cosmos3_pre_process_func(od_config: OmniDiffusionConfig): + """Pre-process function for both T2V and I2V. + + For T2V (no image in ``multi_modal_data``), the request is returned + unchanged after the optional guardrails check. For I2V (image present), + the conditioning image is loaded, aspect-resized + center-cropped, and + stored back on the prompt as ``additional_information.preprocessed_image``. + """ + from .guardrails import check_text_safety, ensure_initialized, is_guardrails_enabled + + video_processor = VideoProcessor(vae_scale_factor=16) + # Eager-load guardrail models at pipeline build time when the server-level + # gate is on. Per-request overrides only decide whether the loaded models + # are *invoked* — they cannot turn checks on without a server-side preload. + if is_guardrails_enabled(od_config): + ensure_initialized(od_config) + + def _extra_args(request: OmniDiffusionRequest) -> dict[str, Any]: + extra = getattr(getattr(request, "sampling_params", None), "extra_args", None) + return extra if isinstance(extra, dict) else {} + + def _request_action_mode(request: OmniDiffusionRequest) -> str | None: + return normalize_action_mode(_extra_args(request).get("action_mode")) + + def _set_action_size_from_image(request: OmniDiffusionRequest, image: PIL.Image.Image) -> tuple[int, int]: + sp = request.sampling_params + if sp.height is not None and sp.width is not None: + return int(sp.height), int(sp.width) + + extra = _extra_args(request) + resolution = extra.get("resolution", extra.get("image_size", 480)) + target_w, target_h = find_closest_target_size(image.height, image.width, resolution) + if sp.height is None: + sp.height = target_h + if sp.width is None: + sp.width = target_w + return int(sp.height), int(sp.width) + + def _pil_to_rgb(value: Any) -> PIL.Image.Image: + if isinstance(value, str): + return PIL.Image.open(value).convert("RGB") + if isinstance(value, PIL.Image.Image): + return value.convert("RGB") + raise TypeError(f"Cosmos3 action preprocessing expected PIL image or image path, got {type(value)!r}.") + + def _resize_and_pad_action_image(image: PIL.Image.Image, target_h: int, target_w: int) -> PIL.Image.Image: + scale = min(target_w / image.width, target_h / image.height, 1.0) + resize_w = max(1, int(scale * image.width + 0.5)) + resize_h = max(1, int(scale * image.height + 0.5)) + if (resize_w, resize_h) != image.size: + image = image.resize((resize_w, resize_h), PIL.Image.Resampling.BICUBIC) + + array = np.asarray(image) + pad_h = target_h - resize_h + pad_w = target_w - resize_w + if pad_h < 0 or pad_w < 0: + raise ValueError( + f"Cosmos3 action image resize exceeded target size: resized={(resize_h, resize_w)}, " + f"target={(target_h, target_w)}." + ) + if pad_h == 0 and pad_w == 0: + return image + pad_mode = "reflect" if pad_h < resize_h and pad_w < resize_w else "edge" + padded = np.pad(array, ((0, pad_h), (0, pad_w), (0, 0)), mode=pad_mode) + return PIL.Image.fromarray(padded) + + def _preprocess_action_image(image: PIL.Image.Image, target_h: int, target_w: int) -> torch.Tensor: + image = _resize_and_pad_action_image(image, target_h, target_w) + return video_processor.preprocess(image, height=target_h, width=target_w) + + def _preprocess_action_video(frames: list[Any], target_h: int, target_w: int) -> torch.Tensor: + if not frames: + raise ValueError("Cosmos3 action video input must contain at least one frame.") + processed = [_preprocess_action_image(_pil_to_rgb(frame), target_h, target_w).squeeze(0) for frame in frames] + return torch.stack(processed, dim=1).unsqueeze(0).contiguous() + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + action_mode = _request_action_mode(request) + if is_guardrails_enabled(od_config, request.sampling_params): + for prompt in request.prompts: + text = prompt if isinstance(prompt, str) else prompt.get("prompt", "") + check_text_safety(text) + + for i, prompt in enumerate(request.prompts): + if isinstance(prompt, str): + continue + multi_modal_data = prompt.get("multi_modal_data", {}) or {} + raw_image = multi_modal_data.get("image") + raw_video = multi_modal_data.get("video") + if raw_image is None and not (action_mode is not None and raw_video is not None): + continue + + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + if not isinstance(raw_video, list) or not raw_video: + raise TypeError("Cosmos3 action video input must be a non-empty list of PIL images or image paths.") + image = _pil_to_rgb(raw_video[0]) + else: + image = _pil_to_rgb(raw_image) + + # Auto-calculate H/W from aspect ratio (720p max area) + if request.sampling_params.height is None or request.sampling_params.width is None: + if action_mode is not None: + _set_action_size_from_image(request, image) + else: + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + mod_value = 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + target_w = request.sampling_params.width + target_h = request.sampling_params.height + if action_mode is not None: + prompt["additional_information"]["preprocessed_image"] = _preprocess_action_image( + image, + int(target_h), + int(target_w), + ) + else: + scale = max(target_w / image.width, target_h / image.height) + resize_w = int(np.ceil(scale * image.width)) + resize_h = int(np.ceil(scale * image.height)) + image = image.resize((resize_w, resize_h), PIL.Image.Resampling.LANCZOS) + left = (resize_w - target_w) // 2 + top = (resize_h - target_h) // 2 + image = image.crop((left, top, left + target_w, top + target_h)) + + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=target_h, width=target_w + ) + if action_mode is not None and raw_video is not None: + if not isinstance(raw_video, list): + raise TypeError("Cosmos3 action video input must be a list of PIL images or image paths.") + prompt["additional_information"]["preprocessed_video"] = _preprocess_action_video( + raw_video, + int(target_h), + int(target_w), + ) + request.prompts[i] = prompt + + return request + + return pre_process_func + + +def get_cosmos3_post_process_func(od_config: OmniDiffusionConfig): + from .guardrails import check_video_safety, is_guardrails_enabled + + video_processor = VideoProcessor(vae_scale_factor=16) + + def _sampling_param(sampling_params, key: str, default=None): + extra = getattr(sampling_params, "extra_args", None) + if isinstance(extra, dict) and extra.get(key) is not None: + return extra[key] + value = getattr(sampling_params, key, None) + return default if value is None else value + + def _resolve_output_fps(sampling_params): + fps = ( + _sampling_param(sampling_params, "resolved_frame_rate") + or _sampling_param(sampling_params, "frame_rate") + or _sampling_param(sampling_params, "fps") + or 24.0 + ) + try: + fps_value = float(fps) + except (TypeError, ValueError): + fps_value = 24.0 + if fps_value <= 0: + fps_value = 24.0 + return int(fps_value) if fps_value.is_integer() else fps_value + + def post_process_func( + output: torch.Tensor | dict[str, torch.Tensor] | tuple, + output_type: str = "np", + sampling_params=None, + ): + if output_type == "latent": + return output + + audio = None + audio_sample_rate = None + if isinstance(output, dict): + if "image" in output and "video" in output: + raise ValueError("Cosmos3 output cannot contain both image and video payloads.") + if "image" in output: + video = output["image"] + elif "video" in output: + video = output["video"] + else: + raise ValueError("Cosmos3 postprocess expected an 'image' or 'video' output payload.") + audio = output.get("audio") + audio_sample_rate = output.get("audio_sample_rate") + elif isinstance(output, tuple): + if len(output) == 3: + video, audio, audio_sample_rate = output + elif len(output) == 2: + video, audio = output + else: + raise ValueError( + "Cosmos3 postprocess expects output tensor, output dict, or (video, audio[, sample_rate]) tuple." + ) + else: + video = output + + if isinstance(output, dict) and "image" in output: + if audio is not None: + raise ValueError("Cosmos3 text-to-image postprocess does not support audio output.") + if video.ndim != 5 or video.shape[2] != 1: + raise ValueError( + "Cosmos3 text-to-image postprocess expects decoded output " + f"with shape [B, C, 1, H, W], got {tuple(video.shape)}." + ) + image = video.squeeze(2) # [B, 3, H, W] + if is_guardrails_enabled(od_config, sampling_params): + # check_video_safety expects a 5D tensor; re-add T axis. + checked = check_video_safety(image.unsqueeze(2)) + image = checked.squeeze(2) + return video_processor.postprocess(image, output_type="pil") + if is_guardrails_enabled(od_config, sampling_params): + video = check_video_safety(video) + result = {"video": video_processor.postprocess_video(video, output_type=output_type)} + if audio is None: + return result + if isinstance(audio, torch.Tensor): + audio = audio.detach().cpu() + result["audio"] = audio + result["fps"] = _resolve_output_fps(sampling_params) + if audio_sample_rate is not None: + result["audio_sample_rate"] = int(audio_sample_rate) + return result + + return post_process_func + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- +class Cosmos3OmniDiffusersPipeline( + nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin +): + """Cosmos3 text/image-to-video / text-to-image pipeline. + + Architecture: Mixture-of-Transformers with Qwen3-VL backbone. + - Understanding pathway: causal self-attention on text (runs once, K/V cached) + - Generation pathway: cross-attention on noisy visual latents (runs each step) + + Supports T2V, I2V, and T2I from the same class. Mode is selected at + runtime: + + * **T2I** when ``prompt["modalities"]`` contains ``"image"``. Latent + T-dim is forced to 1, T2I-specific scheduler defaults are applied (50 steps, + flow_shift=3.0, guidance_interval=[400, 1000]), the duration + template is suppressed, and post-process emits PIL images. + * **I2V** when the request supplies a preprocessed image via + ``multi_modal_data['image']`` (handled by + :func:`get_cosmos3_pre_process_func`) and the requested output modality + is not image. + Frame 0 of the initial latent is set to the VAE-encoded conditioning + image, frame-0 noise predictions are masked to zero, and the clean + image latent is re-injected at frame 0 after each scheduler step. + * **T2V** otherwise (default video generation). + """ + + support_image_input: ClassVar[bool] = True + color_format: ClassVar[str] = "RGB" + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ) -> None: + super().__init__() + if od_config.enable_cpu_offload: + raise ValueError( + "Cosmos3 has no separate text encoder, so CPU offloading " + "(transformer↔encoder swapping) is not supported. " + "Use --enable-layerwise-offload instead." + ) + self.od_config = od_config + self.device = get_local_device() + self.dtype = od_config.dtype + + model_path = od_config.model + local_files_only = os.path.exists(model_path) + + # --- Tokenizer --- + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder="text_tokenizer", + local_files_only=local_files_only, + ) + + # --- VAE --- + self.vae = DistributedAutoencoderKLWan.from_pretrained( + model_path, + subfolder="vae", + torch_dtype=self.dtype, + local_files_only=local_files_only, + ).to(self.device) + + if not hasattr(self.vae.config, "scale_factor_temporal"): + raise ValueError( + "Cosmos3 Diffusers VAE config must define scale_factor_temporal " + "so transformer mRoPE temporal positions can be computed correctly." + ) + self.vae_scale_factor_temporal = int(self.vae.config.scale_factor_temporal) + self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 16) + + # --- Transformer (weights loaded later via weights_sources) --- + self.transformer = Cosmos3VFMTransformer( + od_config=od_config, + temporal_compression_factor=self.vae_scale_factor_temporal, + ) + + # --- Scheduler --- + # Load from checkpoint to preserve solver_order, timestep_spacing, + # beta_schedule, sigma bounds, flow_shift, etc. Only override + # flow_shift when explicitly requested by the user. + self.scheduler = UniPCMultistepScheduler.from_pretrained( + model_path, + subfolder="scheduler", + local_files_only=local_files_only, + ) + if od_config.flow_shift is not None: + self.scheduler = UniPCMultistepScheduler.from_config(self.scheduler.config, flow_shift=od_config.flow_shift) + + # --- Video processor for post-decode --- + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # --- Weight sources for DiffusersPipelineLoader --- + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model_path, + subfolder=None, + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + allow_patterns_overrides=["transformer/*.safetensors"], + ), + ] + + # Snapshot the loaded scheduler config so we can rebuild the + # scheduler at request time when a per-request flow_shift override + # is supplied (T2I uses shift=3.0; T2V/I2V use the engine default). + self._base_scheduler_config = self.scheduler.config + self._engine_init_flow_shift = float(getattr(self.scheduler.config, "flow_shift", 1.0) or 1.0) + self._current_flow_shift = self._engine_init_flow_shift + + self._guidance_scale = None + self._num_timesteps = None + self._sound_tokenizer = None + if getattr(self.transformer, "sound_gen", False): + self._get_sound_tokenizer() + + self.setup_diffusion_pipeline_profiler( + enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler + ) + + # -- Weight loading -------------------------------------------------------- + + @staticmethod + def _remap_ckpt_key(key: str) -> str | None: + """Remap a Diffusers transformer key to the model parameter namespace. + + Checkpoint keys arrive with a synthetic ``transformer.`` prefix from + ``weights_sources``. The source checkpoint itself uses the prefixless + Diffusers transformer namespace: top-level projections plus Qwen3-VL + backbone keys. UND and GEN components share each layer in the source + and are split into separate module lists here. Some sources wrap the + transformer namespace under ``model.``; that wrapper is structural and + is stripped before applying the Cosmos3 leaf-name remap. + + Returns the remapped name under ``transformer.``, or None to skip. + """ + k = key + # Strip the weights_sources prefix + if k.startswith("transformer."): + k = k[len("transformer.") :] + if k.startswith("model."): + k = k[len("model.") :] + + # Top-level generation components. + if k.startswith( + ( + "proj_in.", + "proj_out.", + "time_embedder.", + "sound2llm.", + "llm2sound.", + "action_proj_in.", + "action_proj_out.", + ) + ): + return f"transformer.{k}" + if k in ("sound_modality_embed", "sound_modality_embed.weight"): + return "transformer.sound_modality_embed" + if k in ("action_modality_embed", "action_modality_embed.weight"): + return "transformer.action_modality_embed" + if k.startswith("action_pos_embed."): + return None + + # Skip lm_head + if k.startswith("lm_head."): + return None + + # embed_tokens / norm -> language_model.* + if k.startswith("embed_tokens."): + return f"transformer.language_model.{k}" + if k.startswith("norm."): + return f"transformer.language_model.{k}" + + # norm_moe_gen -> top level + if k.startswith("norm_moe_gen."): + return f"transformer.{k}" + + if not k.startswith("layers."): + return None + + parts = k.split(".", 2) # ['layers', '{i}', '{rest}'] + if len(parts) != 3: + return None + layer_idx = parts[1] + rest = parts[2] + + und_lp = f"transformer.language_model.layers.{layer_idx}" + gen_lp = f"transformer.gen_layers.{layer_idx}" + + _LAYER_MAP = { + # UND attention + "self_attn.to_q.": f"{und_lp}.self_attn.to_q.", + "self_attn.to_k.": f"{und_lp}.self_attn.to_k.", + "self_attn.to_v.": f"{und_lp}.self_attn.to_v.", + "self_attn.to_out.": f"{und_lp}.self_attn.to_out.", + "self_attn.norm_q.": f"{und_lp}.self_attn.norm_q.", + "self_attn.norm_k.": f"{und_lp}.self_attn.norm_k.", + # GEN attention + "self_attn.add_q_proj.": f"{gen_lp}.cross_attention.to_q.", + "self_attn.add_k_proj.": f"{gen_lp}.cross_attention.to_k.", + "self_attn.add_v_proj.": f"{gen_lp}.cross_attention.to_v.", + "self_attn.to_add_out.": f"{gen_lp}.cross_attention.to_out.", + "self_attn.norm_added_q.": f"{gen_lp}.cross_attention.norm_q.", + "self_attn.norm_added_k.": f"{gen_lp}.cross_attention.norm_k.", + # Norms + "input_layernorm.": f"{und_lp}.input_layernorm.", + "post_attention_layernorm.": f"{und_lp}.post_attention_layernorm.", + "input_layernorm_moe_gen.": f"{gen_lp}.input_layernorm.", + "post_attention_layernorm_moe_gen.": f"{gen_lp}.post_attention_layernorm.", + # UND MLP + "mlp.gate_proj.": f"{und_lp}.mlp.gate_proj.", + "mlp.up_proj.": f"{und_lp}.mlp.up_proj.", + "mlp.down_proj.": f"{und_lp}.mlp.down_proj.", + # GEN MLP + "mlp_moe_gen.gate_proj.": f"{gen_lp}.mlp.gate_proj.", + "mlp_moe_gen.up_proj.": f"{gen_lp}.mlp.up_proj.", + "mlp_moe_gen.down_proj.": f"{gen_lp}.mlp.down_proj.", + } + + for pattern, replacement in _LAYER_MAP.items(): + if rest.startswith(pattern): + suffix = rest[len(pattern) :] + return replacement + suffix + + return None + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Stream-remap checkpoint weights and load via AutoWeightsLoader. + + Handles quantization, TP-aware weight_loader, and buffer loading. + Returns the set of loaded parameter names for strict validation. + """ + state = self.state_dict() + allowed = set(state.keys()) + tp_aware = {n for n, p in self.named_parameters() if hasattr(p, "weight_loader")} + + def _remapped_weights() -> Iterable[tuple[str, torch.Tensor]]: + total = kept = 0 + for name, tensor in weights: + total += 1 + remapped = self._remap_ckpt_key(name) + if remapped is not None and (remapped in allowed or remapped in tp_aware): + kept += 1 + yield remapped, tensor + if _is_rank_zero(): + logger.info( + "Cosmos3 weight remap: kept %d/%d tensors", + kept, + total, + ) + + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(_remapped_weights()) + self.transformer.post_load_weights() + self.transformer.eval() + if getattr(self.transformer, "sound_gen", False): + sound_markers = ("audio_proj_in.", "audio_proj_out.", "audio_modality_embed") + missing = [marker.rstrip(".") for marker in sound_markers if not any(marker in name for name in loaded)] + if missing: + raise ValueError( + "Cosmos3 transformer config enables sound generation, but " + f"the checkpoint is missing sound weights for {missing}. " + "Use a sound-capable transformer checkpoint." + ) + if getattr(self.transformer, "action_gen", False): + action_markers = ("action_proj_in.", "action_proj_out.", "action_modality_embed") + missing = [marker.rstrip(".") for marker in action_markers if not any(marker in name for name in loaded)] + if missing: + raise ValueError( + "Cosmos3 transformer config enables action generation, but " + f"the checkpoint is missing action weights for {missing}. " + "Use an action-capable transformer checkpoint." + ) + return loaded + + def predict_noise(self, **kwargs) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Override CFGParallelMixin.predict_noise for Cosmos3. + + The transformer returns the raw prediction: video-only as a tensor, + or a tuple in video, action, sound order for multimodal generation. + """ + return self.transformer(**kwargs) + + @staticmethod + def _cfg_parallel_active() -> bool: + try: + return get_classifier_free_guidance_world_size() > 1 + except Exception: + return False + + @staticmethod + def _get_sp_param(sp, key: str, default=None): + """Read a runtime control from sampling params. + + Order of precedence: + 1. ``sp.extra_args[key]`` - preferred path; the OpenAI image/video + endpoints surface custom controls here (see e.g. + ``serving_video.py`` writing ``extra_args['flow_shift']``). + 2. direct attribute on ``sp`` - backward compat for callers that + set attributes directly. + 3. ``default``. + + Skipping this helper would cause API-driven overrides like + ``request.flow_shift`` (forwarded as ``extra_args['flow_shift']``) to + be silently ignored. + """ + extra = getattr(sp, "extra_args", None) + if isinstance(extra, dict) and extra.get(key) is not None: + return extra[key] + val = getattr(sp, key, None) + if val is not None: + return val + return default + + @staticmethod + def _truthy(value) -> bool: + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + @classmethod + def _get_prompt_param(cls, prompt_data, key: str, default=None): + if not isinstance(prompt_data, dict): + return default + if prompt_data.get(key) is not None: + return prompt_data[key] + additional = prompt_data.get("additional_information") + if isinstance(additional, dict) and additional.get(key) is not None: + return additional[key] + return default + + @classmethod + def _is_sound_request(cls, prompt_data, sp) -> bool: + keys = ( + "sound_gen", + "generate_sound", + "enable_sound_generation", + "return_audio", + "output_audio", + "generate_audio", + ) + for key in keys: + if cls._truthy(cls._get_prompt_param(prompt_data, key, None)): + return True + if cls._truthy(cls._get_sp_param(sp, key, None)): + return True + return False + + @classmethod + def _get_action_mode(cls, prompt_data, sp) -> str | None: + return normalize_action_mode( + cls._get_sp_param(sp, "action_mode", cls._get_prompt_param(prompt_data, "action_mode", None)) + ) + + def _get_sound_tokenizer(self): + if not hasattr(self, "_sound_tokenizer"): + self._sound_tokenizer = None + if self._sound_tokenizer is None: + from .sound_tokenizer import Cosmos3SoundTokenizer + + self._sound_tokenizer = Cosmos3SoundTokenizer.from_config(self.od_config) + return self._sound_tokenizer + + @staticmethod + def _is_t2i_request(req: OmniDiffusionRequest) -> bool: + """Detect text-to-image mode from request-level prompt modalities.""" + if not req.prompts: + return False + first_prompt = req.prompts[0] + modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else [] + if modalities is None: + modalities = [] + if isinstance(modalities, str): + modalities = [modalities] + if "image" in modalities and "video" in modalities: + raise ValueError("Cosmos3 prompt modalities cannot request both image and video output.") + + accepted_modalities = ["image", "video", "text", "audio"] + if any([x not in accepted_modalities for x in modalities]): + raise ValueError(f"Incorrect modality value in {modalities}, expected one of {accepted_modalities}.") + return "image" in modalities + + def _set_flow_shift(self, target_shift: float) -> None: + """Set the UniPC ``flow_shift`` to a concrete target value. + + The scheduler is rebuilt from the saved base config if + the target differs from the current shift. Tracking + ``self._current_flow_shift`` explicitly is required because the + previous mode may have rebuilt the scheduler - we cannot rely on + ``self.scheduler.config.flow_shift`` reflecting the last requested + target if a rebuild was skipped via the equality check. + """ + target = float(target_shift) + if target == float(self._current_flow_shift): + return + self.scheduler = UniPCMultistepScheduler.from_config(self._base_scheduler_config, flow_shift=target) + self._current_flow_shift = target + + def _set_scheduler_timesteps(self, num_inference_steps: int) -> None: + for name, value in vars(self.scheduler).items(): + if isinstance(value, torch.Tensor) and value.device.type != "cpu": + setattr(self.scheduler, name, value.cpu()) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + # -- Prompt formatting ----------------------------------------------------- + + @staticmethod + def _apply_metadata_templates( + prompt: str, + num_frames: int, + frame_rate: float, + height: int, + width: int, + duration_template: str | None = COSMOS3_DURATION_TEMPLATE, + resolution_template: str | None = COSMOS3_RESOLUTION_TEMPLATE, + force_duration_template: bool = False, + ) -> str: + """ + Append duration and resolution metadata to a prompt. + """ + parts: list[str] = [] + head = prompt.rstrip(".").strip() + if head: + parts.append(head) + if duration_template is not None and (num_frames > 1 or force_duration_template): + duration = num_frames / frame_rate + parts.append(duration_template.format(duration=duration, fps=frame_rate).rstrip(".")) + if resolution_template is not None: + parts.append(resolution_template.format(height=height, width=width).rstrip(".")) + if not parts: + return "" + return ". ".join(parts) + "." + + # -- Tokenization -------------------------------------------------------- + + def _tokenize_prompt( + self, + text: str, + max_sequence_length: int, + use_system_prompt: bool = False, + system_prompt: str | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Tokenize a prompt using the Qwen2 chat template. + + Returns (input_ids, attention_mask) as [1, S] tensors on device. + """ + conversations = [] + if use_system_prompt: + conversations.append( + { + "role": "system", + "content": system_prompt or COSMOS3_SYSTEM_PROMPT, + } + ) + conversations.append({"role": "user", "content": text}) + + token_ids = self._normalize_token_ids( + self.tokenizer.apply_chat_template(conversations, tokenize=True, add_generation_prompt=True) + ) + original_token_count = len(token_ids) + if original_token_count > max_sequence_length and _is_rank_zero(): + logger.warning( + "Cosmos3 prompt token_ids shortened to max_sequence_length: " + "original_token_count=%d, max_sequence_length=%d, removed_token_count=%d", + original_token_count, + max_sequence_length, + original_token_count - max_sequence_length, + ) + token_ids = token_ids[:max_sequence_length] + token_ids.append(self.tokenizer.eos_token_id) # 151645 + token_ids.append(self.tokenizer.convert_tokens_to_ids("<|vision_start|>")) # 151652 + seq_len = len(token_ids) + + pad_len = max_sequence_length - seq_len + attention_mask = [1] * seq_len + [0] * pad_len + token_ids = token_ids + [self.tokenizer.pad_token_id or 0] * pad_len + + input_ids = torch.tensor([token_ids], dtype=torch.long, device=self.device) + attention_mask = torch.tensor([attention_mask], dtype=torch.long, device=self.device) + return input_ids, attention_mask + + @staticmethod + def _normalize_token_ids(tokenized_output: object) -> list[int]: + """Normalize tokenizer outputs into a flat ``list[int]``. + + Different Transformers/tokenizers versions can return ``list[int]``, + a mapping/BatchEncoding with ``input_ids``, tensors, or + ``tokenizers.Encoding`` objects from ``apply_chat_template``. + """ + token_ids = tokenized_output + while True: + if isinstance(token_ids, dict) and "input_ids" in token_ids: + token_ids = token_ids["input_ids"] + elif hasattr(token_ids, "input_ids"): + token_ids = token_ids.input_ids + elif hasattr(token_ids, "ids"): + token_ids = token_ids.ids + elif hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + elif isinstance(token_ids, tuple): + token_ids = list(token_ids) + elif isinstance(token_ids, list) and len(token_ids) == 1: + first = token_ids[0] + if isinstance(first, list | tuple): + token_ids = list(first) + elif hasattr(first, "ids") or hasattr(first, "input_ids"): + token_ids = first + elif hasattr(first, "tolist"): + first_list = first.tolist() + if isinstance(first_list, list | tuple): + token_ids = list(first_list) + else: + break + else: + break + else: + break + + if not isinstance(token_ids, list): + raise TypeError( + "Cosmos3 tokenizer must return token IDs as a list-like value; " + f"got {type(token_ids).__name__}: {token_ids!r}" + ) + + normalized_ids = [] + for idx, token_id in enumerate(token_ids): + if hasattr(token_id, "item"): + token_id = token_id.item() + try: + normalized_ids.append(int(token_id)) + except (TypeError, ValueError) as exc: + raise TypeError( + "Cosmos3 tokenizer returned a non-integer token at " + f"index {idx}: {type(token_id).__name__}: {token_id!r}" + ) from exc + return normalized_ids + + # -- Latent preparation -------------------------------------------------- + + def _prepare_latents( + self, + height: int, + width: int, + num_frames: int, + generator: torch.Generator, + ) -> torch.Tensor: + num_channels_latents = self.transformer.latent_channel_size + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + 1, + num_channels_latents, + num_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + return randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype) + + def _prepare_sound_latents( + self, + target_audio_samples: int, + generator: torch.Generator, + ) -> tuple[torch.Tensor, int]: + sound_tokenizer = self._get_sound_tokenizer() + hop_size = int( + getattr(sound_tokenizer, "hop_size", None) or getattr(sound_tokenizer, "temporal_compression_factor") + ) + latent_frames = max(1, math.ceil(max(1, int(target_audio_samples)) / hop_size)) + sound_dim = int(getattr(sound_tokenizer, "latent_ch", 64)) + transformer_sound_dim = int(getattr(self.transformer, "sound_dim", sound_dim)) + if sound_dim != transformer_sound_dim: + raise ValueError( + "Cosmos3 sound tokenizer latent channels do not match transformer " + f"sound_dim: tokenizer={sound_dim}, transformer={transformer_sound_dim}." + ) + latents = randn_tensor( + (1, sound_dim, latent_frames), + generator=generator, + device=self.device, + dtype=self.dtype, + ) + return latents, latent_frames + + def _resolve_sound_target_samples( + self, + sp, + num_frames: int, + frame_rate: float, + ) -> tuple[int, float, int]: + sound_tokenizer = self._get_sound_tokenizer() + duration = self._get_sp_param(sp, "sound_duration", None) + if duration is None: + duration = self._get_sp_param(sp, "audio_duration", None) + if duration is None: + duration = num_frames / frame_rate + duration = max(float(duration), 1.0 / max(float(frame_rate), 1.0)) + sample_rate = int(getattr(sound_tokenizer, "sample_rate", 48000)) + return max(1, int(round(duration * sample_rate))), duration, sample_rate + + # -- VAE decode ---------------------------------------------------------- + + def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.to(self.vae.dtype) + + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + if not hasattr(self, "_latents_mean"): + self._latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(self.device, self.vae.dtype) + ) + self._latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(self.device, self.vae.dtype) + ) + latents = (latents * self._latents_std) + self._latents_mean + else: + scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) + latents = latents / scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + return video + + def _decode_sound_latents( + self, + sound_latents: torch.Tensor, + target_audio_samples: int, + ) -> torch.Tensor: + sound_tokenizer = self._get_sound_tokenizer() + audio = sound_tokenizer.decode(sound_latents.to(self.dtype)) + if audio.shape[-1] > target_audio_samples: + audio = audio[..., :target_audio_samples] + elif audio.shape[-1] < target_audio_samples: + audio = torch.nn.functional.pad(audio, (0, target_audio_samples - audio.shape[-1])) + return audio.detach().cpu() + + # -- Prompt formatting + tokenization (shared by T2V and I2V) ------------ + + def _format_and_tokenize_prompts( + self, + prompt: str, + negative_prompt: str, + num_frames: int, + frame_rate: float, + height: int, + width: int, + max_sequence_length: int, + sp, + use_system_prompt: bool = False, + is_t2i: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Format prompts with metadata templates and tokenize. + + Returns (cond_ids, cond_mask, uncond_ids, uncond_mask). + + For T2I (``is_t2i=True``) the duration template is suppressed (no FPS + or duration concept for a single image) and the image-flavored + resolution template is used. + """ + # Route cosmos3-specific controls through ``_get_sp_param`` so they + # are picked up from ``extra_args`` (OpenAI endpoint path) as well + # as from direct attributes. + use_duration_template = bool(self._get_sp_param(sp, "use_duration_template", True)) and not is_t2i + dur_tmpl = COSMOS3_DURATION_TEMPLATE if use_duration_template else None + if bool(self._get_sp_param(sp, "use_resolution_template", True)): + res_tmpl = COSMOS3_IMAGE_RESOLUTION_TEMPLATE if is_t2i else COSMOS3_RESOLUTION_TEMPLATE + else: + res_tmpl = None + prompt = self._apply_metadata_templates( + prompt, + num_frames, + frame_rate, + height, + width, + duration_template=dur_tmpl, + resolution_template=res_tmpl, + ) + if _is_rank_zero(): + logger.info("Final prompt: '%s'", prompt) + + # Negative prompt: inverse templates ("not {duration}...", "not {height}x{width}..."). + # Applied whenever the matching positive template is enabled; an empty + # negative_prompt yields output that starts with the template, not a dot. + inv_dur = COSMOS3_INVERSE_DURATION_TEMPLATE if dur_tmpl else None + if res_tmpl is None: + inv_res = None + elif is_t2i: + inv_res = COSMOS3_INVERSE_IMAGE_RESOLUTION_TEMPLATE + else: + inv_res = COSMOS3_INVERSE_RESOLUTION_TEMPLATE + negative_prompt = self._apply_metadata_templates( + negative_prompt, + num_frames, + frame_rate, + height, + width, + duration_template=inv_dur, + resolution_template=inv_res, + force_duration_template=True, + ) + + default_sys_prompt = COSMOS3_T2I_SYSTEM_PROMPT if is_t2i else COSMOS3_SYSTEM_PROMPT + sys_prompt = self._get_sp_param(sp, "system_prompt", default_sys_prompt) or default_sys_prompt + cond_ids, cond_mask = self._tokenize_prompt( + prompt, max_sequence_length, use_system_prompt, system_prompt=sys_prompt + ) + uncond_ids, uncond_mask = self._tokenize_prompt( + negative_prompt, max_sequence_length, use_system_prompt, system_prompt=sys_prompt + ) + return cond_ids, cond_mask, uncond_ids, uncond_mask + + # -- I2V latent preparation --------------------------------------------- + + def _encode_conditioning_video( + self, + image_tensor: torch.Tensor, + num_frames: int, + height: int, + width: int, + ) -> torch.Tensor: + """VAE-encode a conditioning image as a full-length video. + + The WAN VAE has temporal compression (factor 4), so encoding a single + frame produces degenerate temporal features. We fill the entire + pixel-space video with the conditioning image (repeating it across all + frames) so the temporal encoder sees plausible content everywhere. + The caller keeps only the conditioned latent frame(s) and replaces + the rest with noise. + """ + # image_tensor: [1, 3, H, W] -> [1, 3, num_frames, H, W] + video = image_tensor.unsqueeze(2).expand(-1, -1, num_frames, -1, -1).contiguous() + video = video.to(device=self.device, dtype=self.vae.dtype) + + latent = self.vae.encode(video).latent_dist.mode() + + # Normalize (inverse of _decode_latents denormalization) + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + latent = (latent - latents_mean) / latents_std + else: + scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) + latent = latent * scaling_factor + + return latent.to(self.dtype) + + def _encode_video_tensor(self, video_tensor: torch.Tensor) -> torch.Tensor: + """VAE-encode a preprocessed pixel video [1, 3, T, H, W].""" + if video_tensor.ndim == 4: + video_tensor = video_tensor.unsqueeze(0) + if video_tensor.ndim != 5: + raise ValueError(f"Cosmos3 video tensor must have shape [1, 3, T, H, W], got {tuple(video_tensor.shape)}.") + if video_tensor.shape[0] != 1 or video_tensor.shape[1] != 3: + raise ValueError(f"Cosmos3 video tensor must have shape [1, 3, T, H, W], got {tuple(video_tensor.shape)}.") + + video = video_tensor.to(device=self.device, dtype=self.vae.dtype) + latent = self.vae.encode(video).latent_dist.mode() + + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + latent = (latent - latents_mean) / latents_std + else: + scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) + latent = latent * scaling_factor + + return latent.to(self.dtype) + + def _prepare_latents_i2v( + self, + image_tensor: torch.Tensor, + height: int, + width: int, + num_frames: int, + generator: torch.Generator, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepare initial latents with frame 0 conditioned on the input image. + + Returns: + latents: [1, C, T_lat, H_lat, W_lat] with frame 0 = image, rest = noise + velocity_mask: [1, 1, T_lat, 1, 1] with frame 0 = 0, rest = 1 + image_latent: [1, C, 1, H_lat, W_lat] clean frame 0 for re-injection + """ + C = self.transformer.latent_channel_size + T_lat = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + H_lat = height // self.vae_scale_factor_spatial + W_lat = width // self.vae_scale_factor_spatial + + noise = randn_tensor( + (1, C, T_lat, H_lat, W_lat), + generator=generator, + device=self.device, + dtype=self.dtype, + ) + + cond_latent = self._encode_conditioning_video(image_tensor, num_frames, height, width) + image_latent = cond_latent[:, :, 0:1, :, :] + + condition_mask = torch.zeros(1, 1, T_lat, 1, 1, device=self.device, dtype=self.dtype) + condition_mask[:, :, 0, :, :] = 1.0 + latents = condition_mask * cond_latent + (1.0 - condition_mask) * noise + velocity_mask = 1.0 - condition_mask + return latents, velocity_mask, image_latent + + def _prepare_latents_action_video( + self, + video_tensor: torch.Tensor, + mode: str, + height: int, + width: int, + num_frames: int, + generator: torch.Generator, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepare video latents for action modes with mode-specific conditioning.""" + del height, width + C = self.transformer.latent_channel_size + T_lat = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + H_lat = video_tensor.shape[-2] // self.vae_scale_factor_spatial + W_lat = video_tensor.shape[-1] // self.vae_scale_factor_spatial + + noise = randn_tensor( + (1, C, T_lat, H_lat, W_lat), + generator=generator, + device=self.device, + dtype=self.dtype, + ) + cond_latent = self._encode_video_tensor(video_tensor) + if cond_latent.shape[2:] != noise.shape[2:]: + raise ValueError( + "Cosmos3 action video latent shape mismatch: " + f"encoded={tuple(cond_latent.shape)}, expected={tuple(noise.shape)}." + ) + condition_mask = build_vision_condition_mask( + mode, + num_frames, + self.vae_scale_factor_temporal, + device=self.device, + dtype=self.dtype, + ) + latents = condition_mask * cond_latent + (1.0 - condition_mask) * noise + velocity_mask = 1.0 - condition_mask + return latents, velocity_mask, cond_latent + + def _prepare_action_latents( + self, + *, + mode: str, + action_chunk_size: int, + raw_action_dim: int | None, + generator: torch.Generator, + sp, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + action_dim = int(getattr(self.transformer, "action_dim", 64)) + if mode == ACTION_MODE_FORWARD_DYNAMICS: + action = load_action_tensor( + self._get_sp_param(sp, "action", None), + self._get_sp_param(sp, "action_path", None), + ) + if action.shape[0] < action_chunk_size: + pad = action[-1:].repeat(action_chunk_size - action.shape[0], 1) + action = torch.cat([action, pad], dim=0) + elif action.shape[0] > action_chunk_size: + action = action[:action_chunk_size] + if raw_action_dim is None: + raw_action_dim = int(action.shape[-1]) + clean_action = pad_action_to_dim(action, action_dim) + else: + if raw_action_dim is None: + raise ValueError( + "Cosmos3 action_mode='policy' and 'inverse_dynamics' require extra_args['raw_action_dim']." + ) + clean_action = torch.zeros(action_chunk_size, action_dim, dtype=torch.float32) + + raw_action_dim = int(raw_action_dim) + if raw_action_dim <= 0 or raw_action_dim > action_dim: + raise ValueError(f"Cosmos3 raw_action_dim must be in [1, {action_dim}], got {raw_action_dim}.") + + clean_action = clean_action.to(device=self.device, dtype=self.dtype).unsqueeze(0) + condition_mask = build_action_condition_mask( + mode, + action_chunk_size, + device=self.device, + dtype=self.dtype, + ) + noise = randn_tensor( + (1, action_chunk_size, action_dim), + generator=generator, + device=self.device, + dtype=self.dtype, + ) + noise[:, :, raw_action_dim:] = 0 + clean_action[:, :, raw_action_dim:] = 0 + action_latents = condition_mask * clean_action + (1.0 - condition_mask) * noise + action_velocity_mask = 1.0 - condition_mask + return action_latents, action_velocity_mask, clean_action, raw_action_dim + + # -- Denoising loop (shared by T2V and I2V) ----------------------------- + + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + cond_ids: torch.Tensor, + cond_mask: torch.Tensor, + uncond_ids: torch.Tensor, + uncond_mask: torch.Tensor, + guidance_scale: float, + shared_kwargs: dict, + *, + action_latents: torch.Tensor | None = None, + action_velocity_mask: torch.Tensor | None = None, + action_condition_latents: torch.Tensor | None = None, + sound_latents: torch.Tensor | None = None, + velocity_mask: torch.Tensor | None = None, + image_latent: torch.Tensor | None = None, + condition_latents: torch.Tensor | None = None, + guidance_interval: tuple[float, float] | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Denoising loop with 3-mode CFG support (parallel, sequential, none). + + Cosmos3's UND pathway is text-dependent, so CFG needs separate K/V + caches for conditional and unconditional text. + + Two modes: + 1. CFG parallel (multi-GPU): each rank handles one condition via + predict_noise_maybe_with_cfg; caching is rank-local. + 2. Sequential CFG (single-GPU or cfg_size=1): two separate + forward passes with explicit cache swapping. We cannot + batch B=2 because different text lengths would cause the + shorter branch to attend to padding in cross-attention. + + I2V conditioning (when both arguments are supplied): + * ``velocity_mask`` zeros frame-0 noise predictions before stepping. + * ``image_latent`` is re-injected into frame 0 after each scheduler + step, since UniPC's predictor-corrector update rescales the + sample (sigma-dependent), so even zero velocity does not preserve + frame 0. + + ``guidance_interval`` (T2I) restricts CFG to + timesteps inside the closed interval ``[lo, hi]``. The interval is + compared against the raw scheduler timestep value; works for both + the [0, 1000] discrete scale and normalized flow-matching scales. + Outside the interval the cond/uncond delta is zeroed so all ranks + continue to execute identical control flow (CFG-Parallel safe). + """ + do_cfg = guidance_scale > 1.0 + cfg_parallel = self._cfg_parallel_active() and do_cfg + self.transformer.reset_cache() + + def _cfg_active_at(t: torch.Tensor) -> bool: + if guidance_interval is None: + return True + t_scalar = float(t.item()) if torch.is_tensor(t) else float(t) + lo, hi = guidance_interval + return lo <= t_scalar <= hi + + def _pack_joint( + video_tensor: torch.Tensor, + action_tensor: torch.Tensor | None = None, + sound_tensor: torch.Tensor | None = None, + ): + batch = video_tensor.shape[0] + tensors = [video_tensor] + if action_tensor is not None: + tensors.append(action_tensor) + if sound_tensor is not None: + tensors.append(sound_tensor) + flats = [tensor.reshape(batch, -1) for tensor in tensors] + return torch.cat(flats, dim=1), [tensor.shape for tensor in tensors], [flat.shape[1] for flat in flats] + + def _unpack_joint( + packed: torch.Tensor, + shapes: list[torch.Size], + numels: list[int], + ) -> tuple[torch.Tensor, ...]: + outputs = [] + offset = 0 + for shape, numel in zip(shapes, numels, strict=True): + outputs.append(packed[:, offset : offset + numel].reshape(shape)) + offset += numel + return tuple(outputs) + + def _split_noise_pred( + noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + has_action = action_latents is not None + has_sound = sound_latents is not None + if not has_action and not has_sound: + if isinstance(noise_pred, tuple): + raise ValueError("Cosmos3 video-only diffusion received tuple predictions.") + return noise_pred, None, None + if not isinstance(noise_pred, tuple): + raise ValueError("Cosmos3 multimodal diffusion expects transformer predictions as a tuple.") + expected = 1 + int(has_action) + int(has_sound) + if len(noise_pred) != expected: + raise ValueError( + f"Cosmos3 multimodal diffusion expected {expected} predictions, got {len(noise_pred)}." + ) + video_pred = noise_pred[0] + idx = 1 + action_pred = noise_pred[idx] if has_action else None + if has_action: + idx += 1 + sound_pred = noise_pred[idx] if has_sound else None + return video_pred, action_pred, sound_pred + + def _step( + noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + t: torch.Tensor, + latents: torch.Tensor, + action_latents: torch.Tensor | None, + sound_latents: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + video_pred, action_pred, sound_pred = _split_noise_pred(noise_pred) + if velocity_mask is not None: + video_pred = video_pred * velocity_mask + if action_pred is not None and action_velocity_mask is not None: + action_pred = action_pred * action_velocity_mask + if action_latents is None and sound_latents is None: + latents = self.scheduler.step(video_pred, t, latents, return_dict=False)[0] + else: + packed_noise, shapes, numels = _pack_joint(video_pred, action_pred, sound_pred) + packed_latents, _, _ = _pack_joint(latents, action_latents, sound_latents) + packed_next = self.scheduler.step(packed_noise, t, packed_latents, return_dict=False)[0] + unpacked = _unpack_joint(packed_next, shapes, numels) + latents = unpacked[0] + idx = 1 + if action_latents is not None: + action_latents = unpacked[idx] + idx += 1 + if sound_latents is not None: + sound_latents = unpacked[idx] + if condition_latents is not None and velocity_mask is not None: + latents = velocity_mask * latents + (1.0 - velocity_mask) * condition_latents + elif image_latent is not None: + latents[:, :, 0:1, :, :] = image_latent + if action_latents is not None and action_condition_latents is not None and action_velocity_mask is not None: + action_latents = ( + action_velocity_mask * action_latents + (1.0 - action_velocity_mask) * action_condition_latents + ) + outputs = [latents] + if action_latents is not None: + outputs.append(action_latents) + if sound_latents is not None: + outputs.append(sound_latents) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: + nonlocal latents, action_latents, sound_latents + if action_latents is None and sound_latents is None: + assert isinstance(step_out, torch.Tensor) + latents = step_out + return + if not isinstance(step_out, tuple): + raise ValueError("Cosmos3 multimodal diffusion step returned a non-tuple result.") + latents = step_out[0] + idx = 1 + if action_latents is not None: + action_latents = step_out[idx] + idx += 1 + if sound_latents is not None: + sound_latents = step_out[idx] + + if cfg_parallel: + for t in self.progress_bar(timesteps): + timestep = t.unsqueeze(0) + # Out-of-interval steps run with effective scale 1.0 so the + # combined output equals the cond branch (uncond is dropped). + # All ranks still execute both branches; no CFG-Parallel + # divergence. + step_scale = guidance_scale if _cfg_active_at(t) else 1.0 + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=step_scale, + positive_kwargs=dict( + hidden_states=latents, + timestep=timestep, + text_ids=cond_ids, + text_mask=cond_mask, + action_latents=action_latents, + sound_latents=sound_latents, + **shared_kwargs, + ), + negative_kwargs=dict( + hidden_states=latents, + timestep=timestep, + text_ids=uncond_ids, + text_mask=uncond_mask, + action_latents=action_latents, + sound_latents=sound_latents, + **shared_kwargs, + ), + cfg_normalize=False, + ) + _assign_step_out(_step(noise_pred, t, latents, action_latents, sound_latents)) + + elif do_cfg: + cond_cache: tuple = (None, None) + uncond_cache: tuple = (None, None) + + for t in self.progress_bar(timesteps): + timestep = t.unsqueeze(0) + cfg_active = _cfg_active_at(t) + + self.transformer.cached_kv, self.transformer.cached_freqs_gen = cond_cache + noise_cond = self.transformer( + hidden_states=latents, + timestep=timestep, + text_ids=cond_ids, + text_mask=cond_mask, + action_latents=action_latents, + sound_latents=sound_latents, + **shared_kwargs, + ) + if cond_cache[0] is None: + cond_cache = (self.transformer.cached_kv, self.transformer.cached_freqs_gen) + + if cfg_active: + self.transformer.cached_kv, self.transformer.cached_freqs_gen = uncond_cache + noise_uncond = self.transformer( + hidden_states=latents, + timestep=timestep, + text_ids=uncond_ids, + text_mask=uncond_mask, + action_latents=action_latents, + sound_latents=sound_latents, + **shared_kwargs, + ) + if uncond_cache[0] is None: + uncond_cache = (self.transformer.cached_kv, self.transformer.cached_freqs_gen) + noise_pred = self.combine_cfg_noise(noise_cond, noise_uncond, guidance_scale, cfg_normalize=False) + else: + # Skip uncond forward entirely outside the interval; this + # is correctness-preserving (CFG with scale=1 reduces to + # the cond branch) and gives a free speedup for T2I. + noise_pred = noise_cond + + _assign_step_out(_step(noise_pred, t, latents, action_latents, sound_latents)) + + else: + for t in self.progress_bar(timesteps): + timestep = t.unsqueeze(0) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep, + text_ids=cond_ids, + text_mask=cond_mask, + action_latents=action_latents, + sound_latents=sound_latents, + **shared_kwargs, + ) + _assign_step_out(_step(noise_pred, t, latents, action_latents, sound_latents)) + + outputs = [latents] + if action_latents is not None: + outputs.append(action_latents) + if sound_latents is not None: + outputs.append(sound_latents) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + # -- Forward (main generation entry point) ------------------------------- + + def forward( + self, + req: OmniDiffusionRequest, + ) -> DiffusionOutput: + pipeline_start = time.time() + + # --- Parse request --- + if len(req.prompts) > 1: + raise ValueError("Cosmos3OmniDiffusersPipeline currently supports a single prompt per request.") + + prompt_data = req.prompts[0] + if isinstance(prompt_data, str): + prompt = prompt_data + negative_prompt = None + image_tensor = None + action_video_tensor = None + else: + prompt = prompt_data.get("prompt", "") + negative_prompt = prompt_data.get("negative_prompt") + additional_info = prompt_data.get("additional_information", {}) or {} + image_tensor = additional_info.get("preprocessed_image") + action_video_tensor = additional_info.get("preprocessed_video") + + sp = req.sampling_params + is_t2i = self._is_t2i_request(req) + sound_enabled = self._is_sound_request(prompt_data, sp) + action_mode = self._get_action_mode(prompt_data, sp) + action_enabled = action_mode is not None + if action_enabled and is_t2i: + raise ValueError("Cosmos3 action generation is supported only for video outputs.") + if action_enabled and sound_enabled: + raise ValueError("Cosmos3 action+sound joint generation is not supported in this phase.") + if action_enabled and not getattr(self.transformer, "action_gen", False): + raise ValueError( + "Cosmos3 action generation was requested, but the transformer was " + "initialized without action modules. Check that the checkpoint config " + "enables action_gen and includes action weights." + ) + if sound_enabled and is_t2i: + raise ValueError( + "Cosmos3 sound generation is supported only for video outputs in " + "this phase; text-to-image with sound is unsupported." + ) + if sound_enabled and not getattr(self.transformer, "sound_gen", False): + raise ValueError( + "Cosmos3 sound generation was requested, but the transformer was " + "initialized without sound modules. Check that the checkpoint config " + "enables sound_gen or defines sound_dim and includes sound weights." + ) + if negative_prompt is None: + negative_prompt = "" + + # T2I and T2V share the same model + forward path; only defaults + # differ: + # T2I: 1024x1024, 50 steps, shift=3.0, guidance_interval=[400, 1000] + # T2V: 720x1280, 35 steps, shift=engine-init, no interval + if is_t2i: + height = sp.height or 1024 + width = sp.width or 1024 + num_frames = 1 + num_inference_steps = sp.num_inference_steps or 50 + guidance_scale = sp.guidance_scale if sp.guidance_scale else 7.0 + default_flow_shift = 3.0 + default_guidance_interval: tuple[float, float] | None = (400.0, 1000.0) + batch_size = max(1, int(getattr(sp, "num_outputs_per_prompt", None) or 1)) + else: + height = sp.height or 720 + width = sp.width or 1280 + num_frames = sp.num_frames or 189 + num_inference_steps = sp.num_inference_steps or 35 + guidance_scale = sp.guidance_scale if sp.guidance_scale else 6.0 + # Fall back to the engine-init shift, NOT None: passing None + # to ``_set_flow_shift`` would leak a prior T2I rebuild + # (shift=3.0) into a subsequent video request. + default_flow_shift = self._engine_init_flow_shift + default_guidance_interval = None + batch_size = 1 # Existing video pipeline assumes B=1. + + if action_enabled: + action_chunk_param = self._get_sp_param(sp, "action_chunk_size", None) + if action_chunk_param is not None: + action_chunk_size = int(action_chunk_param) + if sp.num_frames is None: + num_frames = action_chunk_size + 1 + elif sp.num_frames is None: + action_chunk_size = 16 + num_frames = action_chunk_size + 1 + else: + action_chunk_size = int(num_frames) - 1 + if action_chunk_size <= 0: + raise ValueError(f"Cosmos3 action_chunk_size must be positive, got {action_chunk_size}.") + if num_frames not in (action_chunk_size, action_chunk_size + 1): + raise ValueError( + "Cosmos3 action requests require num_frames to equal action_chunk_size " + f"or action_chunk_size + 1; got num_frames={num_frames}, action_chunk_size={action_chunk_size}." + ) + num_inference_steps = sp.num_inference_steps or 30 + guidance_scale = sp.guidance_scale if sp.guidance_scale is not None else 1.0 + default_flow_shift = 5.0 + + domain_id = None + if action_enabled: + domain_id = resolve_domain_id( + domain_id=self._get_sp_param(sp, "domain_id", None), + domain_name=self._get_sp_param(sp, "domain_name", None), + require_explicit=True, + ) + + # Runtime controls: prefer ``extra_args`` (OpenAI endpoints write + # there) over direct attrs. + flow_shift_target = float(self._get_sp_param(sp, "flow_shift", default_flow_shift)) + guidance_interval = self._get_sp_param(sp, "guidance_interval", default_guidance_interval) + + frame_rate = self._get_sp_param(sp, "resolved_frame_rate") or self._get_sp_param(sp, "frame_rate") or 24.0 + max_sequence_length = self._get_sp_param(sp, "max_sequence_length", 512) or 512 + use_system_prompt = bool(self._get_sp_param(sp, "use_system_prompt", False)) + + if action_enabled and action_video_tensor is None: + extra_action_video = self._get_sp_param(sp, "action_video", None) + if isinstance(extra_action_video, torch.Tensor): + action_video_tensor = extra_action_video + if action_enabled and isinstance(action_video_tensor, torch.Tensor): + if action_video_tensor.ndim == 4: + action_video_tensor = action_video_tensor.unsqueeze(0) + if action_video_tensor.ndim != 5: + raise ValueError( + "Cosmos3 extra_args['action_video'] must have shape [1, 3, T, H, W] " + f"or [3, T, H, W], got {tuple(action_video_tensor.shape)}." + ) + if sp.height is None: + height = int(action_video_tensor.shape[-2]) + if sp.width is None: + width = int(action_video_tensor.shape[-1]) + + self._guidance_scale = guidance_scale + self._num_timesteps = num_inference_steps + + # Always resolve to a concrete target shift for this request, then + # update the scheduler. This is what guarantees mode-to-mode + # transitions restore the right schedule (no T2I to T2V leak). + self._set_flow_shift(flow_shift_target) + + generator = sp.generator + if generator is None: + seed = sp.seed if sp.seed is not None else 42 + generator = torch.Generator(device=self.device).manual_seed(seed) + + # --- Format prompts & tokenize (B=1; reused across loop iterations + # for T2I num_outputs_per_prompt > 1) --- + cond_ids, cond_mask, uncond_ids, uncond_mask = self._format_and_tokenize_prompts( + prompt, + negative_prompt, + num_frames, + frame_rate, + height, + width, + max_sequence_length, + sp, + use_system_prompt, + is_t2i=is_t2i, + ) + + # --- Prepare latents (T2I, T2V, or I2V) --- + # T2I shares _prepare_latents with T2V; the math collapses cleanly + # at num_frames=1 ((1-1)//4 + 1 = 1 latent frame). For T2I with + # ``num_outputs_per_prompt > 1`` we loop the diffusion below; + # batching B=N together would require expanding text K/V (UND + # pathway is text-only and cached) and is left as a future + # optimization. + action_latents = None + action_velocity_mask = None + action_condition_latents = None + raw_action_dim = None + action_offset = 1 + if action_enabled: + if action_video_tensor is not None and action_video_tensor.ndim == 4: + action_video_tensor = action_video_tensor.unsqueeze(0) + if action_video_tensor is not None and action_video_tensor.ndim != 5: + raise ValueError( + "Cosmos3 action video tensor must have shape [1, 3, T, H, W] " + f"or [3, T, H, W], got {tuple(action_video_tensor.shape)}." + ) + if action_video_tensor is not None and action_video_tensor.shape[2] < num_frames: + pad = action_video_tensor[:, :, -1:].repeat(1, 1, num_frames - action_video_tensor.shape[2], 1, 1) + action_video_tensor = torch.cat([action_video_tensor, pad], dim=2) + elif action_video_tensor is not None and action_video_tensor.shape[2] > num_frames: + action_video_tensor = action_video_tensor[:, :, :num_frames] + + if action_mode == ACTION_MODE_INVERSE_DYNAMICS and action_video_tensor is None: + raise ValueError("Cosmos3 inverse_dynamics action mode requires multi_modal_data['video'].") + if action_mode in {ACTION_MODE_POLICY, ACTION_MODE_FORWARD_DYNAMICS} and image_tensor is None: + if action_video_tensor is None: + raise ValueError( + f"Cosmos3 action_mode={action_mode!r} requires multi_modal_data['image'] " + "or multi_modal_data['video']." + ) + image_tensor = action_video_tensor[:, :, 0] + + raw_action_dim_param = self._get_sp_param(sp, "raw_action_dim", None) + raw_action_dim = int(raw_action_dim_param) if raw_action_dim_param is not None else None + action_prepared = self._prepare_action_latents( + mode=action_mode, + action_chunk_size=action_chunk_size, + raw_action_dim=raw_action_dim, + generator=generator, + sp=sp, + ) + action_latents, action_velocity_mask, action_condition_latents, raw_action_dim = action_prepared + action_offset = action_start_frame_offset(action_mode, action_chunk_size, num_frames) + + if action_enabled and action_video_tensor is not None: + latents, velocity_mask, condition_latents = self._prepare_latents_action_video( + action_video_tensor, + action_mode, + height, + width, + num_frames, + generator, + ) + image_latent = condition_latents[:, :, 0:1] + elif image_tensor is not None and not is_t2i: + latents, velocity_mask, image_latent = self._prepare_latents_i2v( + image_tensor, + height, + width, + num_frames, + generator, + ) + condition_latents = None + else: + latents = self._prepare_latents(height, width, num_frames, generator) + velocity_mask = None + image_latent = None + condition_latents = None + + sound_latents = None + target_audio_samples = None + sound_sample_rate = None + if sound_enabled: + target_audio_samples, _, sound_sample_rate = self._resolve_sound_target_samples(sp, num_frames, frame_rate) + sound_latents, _ = self._prepare_sound_latents(target_audio_samples, generator) + + T_latent = latents.shape[2] + H_latent = latents.shape[3] + W_latent = latents.shape[4] + video_shape = (T_latent, H_latent, W_latent) + + # --- Denoising loop --- + shared_kwargs = dict(video_shape=video_shape, fps=frame_rate) + if velocity_mask is not None: + shared_kwargs["noisy_frame_mask"] = velocity_mask + if action_enabled: + shared_kwargs.update( + action_domain_ids=torch.tensor([domain_id], dtype=torch.long, device=self.device), + action_noisy_mask=action_velocity_mask, + action_start_frame_offset=action_offset, + action_fps=float(self._get_sp_param(sp, "action_fps", frame_rate) or frame_rate), + ) + + def _run_diffusion(start_latents): + self._set_scheduler_timesteps(num_inference_steps) + return self.diffuse( + latents=start_latents, + timesteps=self.scheduler.timesteps, + cond_ids=cond_ids, + cond_mask=cond_mask, + uncond_ids=uncond_ids, + uncond_mask=uncond_mask, + guidance_scale=guidance_scale, + shared_kwargs=shared_kwargs, + action_latents=action_latents, + action_velocity_mask=action_velocity_mask, + action_condition_latents=action_condition_latents, + sound_latents=sound_latents, + velocity_mask=velocity_mask, + image_latent=image_latent, + condition_latents=condition_latents, + guidance_interval=guidance_interval, + ) + + if is_t2i and batch_size > 1: + # Generate N independent images by re-running the full diffusion + # loop with different noise seeds. The first sample reuses + # ``latents`` already drawn from ``generator``; subsequent + # samples draw fresh noise from the same generator (state + # advances per call), giving distinct outputs from a single + # user-provided seed. Batched B=N would be more efficient but + # requires expanding cached UND text K/V to match. + samples = [_run_diffusion(latents)] + for _ in range(batch_size - 1): + next_latents = self._prepare_latents(height, width, num_frames, generator) + samples.append(_run_diffusion(next_latents)) + latents = torch.cat(samples, dim=0) + else: + diffusion_output = _run_diffusion(latents) + if action_enabled and sound_enabled: + latents, action_latents, sound_latents = diffusion_output + elif action_enabled: + latents, action_latents = diffusion_output + elif sound_enabled: + latents, sound_latents = diffusion_output + else: + latents = diffusion_output + + # --- Decode --- + if _is_rank_zero(): + logger.info("Decoding video...") + decode_start = time.time() + video = self._decode_latents(latents) + if _is_rank_zero(): + logger.info("Video decoded in %.2fs", time.time() - decode_start) + logger.info("Total pipeline time: %.2fs", time.time() - pipeline_start) + + if sound_enabled: + if sound_latents is None or target_audio_samples is None or sound_sample_rate is None: + raise ValueError("Cosmos3 sound generation finished without sound latents.") + if _is_rank_zero(): + logger.info("Decoding sound...") + audio = self._decode_sound_latents(sound_latents, target_audio_samples) + return DiffusionOutput(output={"video": video, "audio": audio, "audio_sample_rate": sound_sample_rate}) + + if action_enabled: + if action_latents is None or raw_action_dim is None or domain_id is None: + raise ValueError("Cosmos3 action generation finished without action latents.") + action = action_latents[:, :, :raw_action_dim].detach().cpu() + return DiffusionOutput( + output={"video": video}, + custom_output={ + "action": action, + "raw_action_dim": raw_action_dim, + "action_mode": action_mode, + "domain_id": domain_id, + }, + ) + + return DiffusionOutput(output={"image": video} if is_t2i else {"video": video}) diff --git a/vllm_omni/diffusion/models/cosmos3/sound_tokenizer.py b/vllm_omni/diffusion/models/cosmos3/sound_tokenizer.py new file mode 100644 index 00000000000..281b7e1d9f0 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/sound_tokenizer.py @@ -0,0 +1,537 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 sound tokenizer integration.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.models.progress_bar import _is_rank_zero + +from .audio_tokenizer import Cosmos3AVAEAudioTokenizer + +logger = init_logger(__name__) + +DEFAULT_SOUND_SAMPLE_RATE = 48000 +DEFAULT_SOUND_CHANNELS = 2 +DEFAULT_SOUND_DIM = 64 +DEFAULT_SOUND_HOP_SIZE = 1920 +DEFAULT_SOUND_LATENT_FPS = DEFAULT_SOUND_SAMPLE_RATE / DEFAULT_SOUND_HOP_SIZE +DEFAULT_SOUND_NORMALIZE_LATENTS = False +DEFAULT_SOUND_NORMALIZATION_TYPE = "none" +DEFAULT_SOUND_TANH_INPUT_SCALE = 1.5 +DEFAULT_SOUND_TANH_OUTPUT_SCALE = 3.5 +DEFAULT_SOUND_TANH_CLAMP = 0.995 +SOUND_TOKENIZER_COMPONENT_NAME = "sound_tokenizer" +SOUND_TOKENIZER_CHECKPOINT_NAME = "diffusion_pytorch_model.safetensors" + + +def _pipeline_args(od_config: OmniDiffusionConfig) -> dict[str, Any]: + return dict(getattr(od_config, "custom_pipeline_args", None) or {}) + + +def _config_get(config: Any, key: str, default: Any = None) -> Any: + if config is None: + return default + if isinstance(config, dict): + return config.get(key, default) + if hasattr(config, "get"): + value = config.get(key, None) + return default if value is None else value + return getattr(config, key, default) + + +def _config_path_get(config: Any, *keys: str) -> Any: + value = config + for key in keys: + value = _config_get(value, key, None) + if value is None: + return None + return value + + +def _sound_tokenizer_config_from(config: Any) -> Any: + """Return nested ``sound_tokenizer`` config from Cosmos3 config shapes.""" + for path in ( + ("sound_tokenizer",), + ("model", "config", "sound_tokenizer"), + ("config", "sound_tokenizer"), + ("model_config", "sound_tokenizer"), + ): + value = _config_path_get(config, *path) + if value is not None: + return value + return None + + +def _nested_sound_tokenizer_configs(od_config: OmniDiffusionConfig | None) -> tuple[Any, ...]: + if od_config is None: + return () + configs = [] + for source in ( + getattr(od_config, "model_config", None), + getattr(od_config, "tf_model_config", None), + ): + config = _sound_tokenizer_config_from(source) + if config is not None: + configs.append(config) + return tuple(configs) + + +def _first_value_from_configs(configs: tuple[Any, ...], keys: tuple[str, ...]) -> Any: + for config in configs: + for key in keys: + value = _config_get(config, key, None) + if value is not None: + return value + return None + + +def _top_level_model_value(od_config: OmniDiffusionConfig | None, keys: tuple[str, ...]) -> Any: + if od_config is None: + return None + for source in ( + getattr(od_config, "model_config", None), + getattr(od_config, "tf_model_config", None), + ): + for key in keys: + for path in ((key,), ("model", "config", key), ("config", key), ("model_config", key)): + value = _config_path_get(source, *path) + if value is not None: + return value + return None + + +def _custom_arg_value(args: dict[str, Any], keys: tuple[str, ...]) -> Any: + for key in keys: + value = args.get(key) + if value is not None: + return value + return None + + +def _as_bool(value: Any) -> bool: + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +def _as_audio_channels(value: Any) -> int: + if isinstance(value, bool): + return 2 if value else 1 + if isinstance(value, str) and value.strip().lower() in { + "1", + "0", + "true", + "false", + "yes", + "no", + "on", + "off", + }: + return 2 if _as_bool(value) else 1 + return int(value) + + +def _resolve_model_file(path: Any, model_root: str | None) -> str | None: + if not path: + return None + path = str(path) + if "://" in path or os.path.isabs(path) or os.path.exists(path) or not model_root: + return path + return str(Path(model_root) / path) + + +def _load_sound_tokenizer_component_config(config_path: str | None) -> dict[str, Any]: + if not config_path: + return {} + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + if not isinstance(config, dict): + raise TypeError(f"Cosmos3 sound tokenizer config must be a JSON object, got {type(config)!r}.") + return config + + +def _component_audio_channels(config: dict[str, Any]) -> Any: + if config.get("dec_out_channels") is not None: + return config["dec_out_channels"] + if config.get("audio_channels") is not None: + return config["audio_channels"] + if config.get("stereo") is not None: + return 2 if _as_bool(config["stereo"]) else 1 + return None + + +def _component_arch_values(config: dict[str, Any]) -> dict[str, Any]: + values = { + "sample_rate": config.get("sampling_rate", config.get("sample_rate")), + "audio_channels": _component_audio_channels(config), + "io_channels": config.get("vocoder_input_dim", config.get("io_channels", config.get("latent_ch"))), + "hop_size": config.get("hop_size"), + } + return {key: value for key, value in values.items() if value is not None} + + +def _resolve_arch_value( + od_config: OmniDiffusionConfig, + args: dict[str, Any], + component_values: dict[str, Any], + *, + field: str, + custom_keys: tuple[str, ...], + nested_keys: tuple[str, ...], + top_level_keys: tuple[str, ...], + default: Any, + cast, +) -> Any: + custom_value = _custom_arg_value(args, custom_keys) + component_value = component_values.get(field) + if component_value is not None: + resolved = cast(component_value) + if custom_value is not None and cast(custom_value) != resolved: + raise ValueError( + "Conflicting Cosmos3 sound tokenizer architecture override for " + f"{field}: component config has {resolved!r}, custom args have {cast(custom_value)!r}." + ) + return resolved + + if custom_value is not None: + return cast(custom_value) + + nested_value = _first_value_from_configs(_nested_sound_tokenizer_configs(od_config), nested_keys) + if nested_value is not None: + return cast(nested_value) + + top_value = _top_level_model_value(od_config, top_level_keys) + if top_value is not None: + return cast(top_value) + + return cast(default) + + +def _resolve_normalization_value( + od_config: OmniDiffusionConfig, + args: dict[str, Any], + *, + name: str, + default: Any, + aliases: tuple[str, ...] = (), +) -> Any: + keys = (f"sound_{name}", name, *aliases) + custom_value = _custom_arg_value(args, keys) + if custom_value is not None: + return custom_value + nested_value = _first_value_from_configs(_nested_sound_tokenizer_configs(od_config), (name, *aliases)) + return default if nested_value is None else nested_value + + +def get_sound_config_value( + od_config: OmniDiffusionConfig, + name: str, + default: Any, + aliases: tuple[str, ...] = (), +) -> Any: + # Backward-compatible generic accessor. Prefer the more specific helpers + # below for Cosmos3 sound tokenizer fields so precedence stays explicit. + keys = (name, *aliases) + for config in ( + _pipeline_args(od_config), + getattr(od_config, "model_config", None), + getattr(od_config, "tf_model_config", None), + ): + if config is None: + continue + for key in keys: + if hasattr(config, "get"): + value = config.get(key, None) + else: + value = getattr(config, key, None) + if value is not None: + return value + return default + + +def get_sound_sample_rate(od_config: OmniDiffusionConfig) -> int: + args = _pipeline_args(od_config) + return _resolve_arch_value( + od_config, + args, + {}, + field="sample_rate", + custom_keys=("sound_sample_rate", "sample_rate"), + nested_keys=("sample_rate", "sampling_rate"), + top_level_keys=("sound_sample_rate", "sample_rate"), + default=DEFAULT_SOUND_SAMPLE_RATE, + cast=int, + ) + + +def get_sound_channels(od_config: OmniDiffusionConfig) -> int: + args = _pipeline_args(od_config) + return _resolve_arch_value( + od_config, + args, + {}, + field="audio_channels", + custom_keys=("sound_audio_channels", "audio_channels", "stereo"), + nested_keys=("audio_channels", "dec_out_channels", "stereo"), + top_level_keys=("sound_audio_channels", "audio_channels", "stereo"), + default=DEFAULT_SOUND_CHANNELS, + cast=_as_audio_channels, + ) + + +def get_sound_dim(od_config: OmniDiffusionConfig | None) -> int: + if od_config is None: + return DEFAULT_SOUND_DIM + args = _pipeline_args(od_config) + custom_value = _custom_arg_value(args, ("sound_dim", "io_channels", "latent_ch")) + if custom_value is not None: + return int(custom_value) + top_value = _top_level_model_value(od_config, ("sound_dim",)) + if top_value is not None: + return int(top_value) + nested_value = _first_value_from_configs( + _nested_sound_tokenizer_configs(od_config), + ("io_channels", "vocoder_input_dim", "latent_ch"), + ) + return int(DEFAULT_SOUND_DIM if nested_value is None else nested_value) + + +def get_sound_hop_size(od_config: OmniDiffusionConfig) -> int: + args = _pipeline_args(od_config) + return _resolve_arch_value( + od_config, + args, + {}, + field="hop_size", + custom_keys=("sound_hop_size", "hop_size"), + nested_keys=("hop_size",), + top_level_keys=("sound_hop_size", "hop_size"), + default=DEFAULT_SOUND_HOP_SIZE, + cast=int, + ) + + +def get_sound_latent_fps(od_config: OmniDiffusionConfig | None) -> float: + if od_config is None: + return DEFAULT_SOUND_LATENT_FPS + args = _pipeline_args(od_config) + custom_value = _custom_arg_value(args, ("sound_latent_fps",)) + if custom_value is not None: + return float(custom_value) + top_value = _top_level_model_value(od_config, ("sound_latent_fps",)) + if top_value is not None: + return float(top_value) + nested_configs = _nested_sound_tokenizer_configs(od_config) + nested_fps = _first_value_from_configs(nested_configs, ("sound_latent_fps", "latent_fps")) + if nested_fps is not None: + return float(nested_fps) + sample_rate = _first_value_from_configs(nested_configs, ("sample_rate", "sampling_rate")) + hop_size = _first_value_from_configs(nested_configs, ("hop_size",)) + if sample_rate is not None and hop_size is not None: + return float(sample_rate) / float(hop_size) + return float(DEFAULT_SOUND_LATENT_FPS) + + +class Cosmos3SoundTokenizer: + """Thin adapter around the local AVAE tokenizer implementation.""" + + def __init__(self, tokenizer: Any) -> None: + self.tokenizer = tokenizer + self.sample_rate = int(getattr(tokenizer, "sample_rate", DEFAULT_SOUND_SAMPLE_RATE)) + self.audio_channels = int(getattr(tokenizer, "audio_channels", DEFAULT_SOUND_CHANNELS)) + self.latent_ch = int(getattr(tokenizer, "latent_ch", DEFAULT_SOUND_DIM)) + self.hop_size = int(getattr(tokenizer, "temporal_compression_factor", DEFAULT_SOUND_HOP_SIZE)) + + @classmethod + def from_config(cls, od_config: OmniDiffusionConfig) -> Cosmos3SoundTokenizer: + args = _pipeline_args(od_config) + model_path = getattr(od_config, "model", None) + explicit_avae_path = ( + args.get("sound_tokenizer_path") + or args.get("avae_path") + or args.get("cosmos3_avae_path") + or os.environ.get("COSMOS3_SOUND_TOKENIZER_PATH") + ) + explicit_config_path = args.get("sound_tokenizer_config_path") or os.environ.get( + "COSMOS3_SOUND_TOKENIZER_CONFIG_PATH" + ) + + model_root = str(model_path) if model_path and os.path.isdir(model_path) else None + if model_root is None and model_path and not explicit_avae_path: + from huggingface_hub import snapshot_download + + model_root = snapshot_download( + repo_id=str(model_path), + revision=getattr(od_config, "revision", None), + allow_patterns=[ + f"{SOUND_TOKENIZER_COMPONENT_NAME}/config.json", + f"{SOUND_TOKENIZER_COMPONENT_NAME}/{SOUND_TOKENIZER_CHECKPOINT_NAME}", + ], + ) + + if explicit_avae_path: + avae_path = _resolve_model_file(explicit_avae_path, model_root) + else: + tokenizer_dir = Path(model_root) / SOUND_TOKENIZER_COMPONENT_NAME if model_root else None + candidate = tokenizer_dir / SOUND_TOKENIZER_CHECKPOINT_NAME if tokenizer_dir else None + avae_path = str(candidate) if candidate and candidate.exists() else None + + if not avae_path: + raise ValueError( + "Cosmos3 sound generation was requested, but no AVAE sound " + "tokenizer checkpoint was provided. Set " + "custom_pipeline_args['sound_tokenizer_path'] or " + "COSMOS3_SOUND_TOKENIZER_PATH, or include " + f"{SOUND_TOKENIZER_COMPONENT_NAME}/{SOUND_TOKENIZER_CHECKPOINT_NAME} under the model path." + ) + + config_path = _resolve_model_file(explicit_config_path, model_root) + if config_path is None and model_root: + candidate = Path(model_root) / SOUND_TOKENIZER_COMPONENT_NAME / "config.json" + config_path = str(candidate) if candidate.exists() else None + component_config = _load_sound_tokenizer_component_config(config_path) + component_values = _component_arch_values(component_config) + + sample_rate = _resolve_arch_value( + od_config, + args, + component_values, + field="sample_rate", + custom_keys=("sound_sample_rate", "sample_rate"), + nested_keys=("sample_rate", "sampling_rate"), + top_level_keys=("sound_sample_rate", "sample_rate"), + default=DEFAULT_SOUND_SAMPLE_RATE, + cast=int, + ) + audio_channels = _resolve_arch_value( + od_config, + args, + component_values, + field="audio_channels", + custom_keys=("sound_audio_channels", "audio_channels", "stereo"), + nested_keys=("audio_channels", "dec_out_channels", "stereo"), + top_level_keys=("sound_audio_channels", "audio_channels", "stereo"), + default=DEFAULT_SOUND_CHANNELS, + cast=_as_audio_channels, + ) + sound_dim = _resolve_arch_value( + od_config, + args, + component_values, + field="io_channels", + custom_keys=("sound_dim", "io_channels", "latent_ch"), + nested_keys=("io_channels", "vocoder_input_dim", "latent_ch"), + top_level_keys=("sound_dim",), + default=DEFAULT_SOUND_DIM, + cast=int, + ) + hop_size = _resolve_arch_value( + od_config, + args, + component_values, + field="hop_size", + custom_keys=("sound_hop_size", "hop_size"), + nested_keys=("hop_size",), + top_level_keys=("sound_hop_size", "hop_size"), + default=DEFAULT_SOUND_HOP_SIZE, + cast=int, + ) + normalize_latents = _as_bool( + _resolve_normalization_value( + od_config, + args, + name="normalize_latents", + default=DEFAULT_SOUND_NORMALIZE_LATENTS, + ) + ) + normalization_type = str( + _resolve_normalization_value( + od_config, + args, + name="normalization_type", + default=DEFAULT_SOUND_NORMALIZATION_TYPE, + ) + ) + tanh_input_scale = float( + _resolve_normalization_value( + od_config, + args, + name="tanh_input_scale", + default=DEFAULT_SOUND_TANH_INPUT_SCALE, + ) + ) + tanh_output_scale = float( + _resolve_normalization_value( + od_config, + args, + name="tanh_output_scale", + default=DEFAULT_SOUND_TANH_OUTPUT_SCALE, + ) + ) + tanh_clamp = float( + _resolve_normalization_value( + od_config, + args, + name="tanh_clamp", + default=DEFAULT_SOUND_TANH_CLAMP, + ) + ) + tokenizer = Cosmos3AVAEAudioTokenizer( + checkpoint_path=str(avae_path), + config_path=config_path, + sample_rate=sample_rate, + audio_channels=audio_channels, + io_channels=sound_dim, + hop_size=hop_size, + normalize_latents=normalize_latents, + normalization_type=normalization_type, + tanh_input_scale=tanh_input_scale, + tanh_output_scale=tanh_output_scale, + tanh_clamp=tanh_clamp, + dtype=getattr(od_config, "dtype", torch.bfloat16), + device=get_local_device(), + ) + if _is_rank_zero(): + logger.info( + "Loaded Cosmos3 AVAE sound tokenizer from %s (sr=%d, channels=%d, latent_ch=%d, hop=%d)", + avae_path, + sample_rate, + audio_channels, + sound_dim, + hop_size, + ) + return cls(tokenizer) + + def get_latent_num_samples(self, num_audio_samples: int) -> int: + return int(self.tokenizer.get_latent_num_samples(num_audio_samples)) + + def get_audio_num_samples(self, num_latent_samples: int) -> int: + return int(self.tokenizer.get_audio_num_samples(num_latent_samples)) + + @torch.no_grad() + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode sound latents. + + Args: + latents: ``[B, C, T]`` or ``[C, T]`` tensor. + + Returns: + ``[B, audio_channels, N]`` tensor for batched input, or + ``[audio_channels, N]`` for unbatched input. + """ + squeeze = latents.ndim == 2 + if squeeze: + latents = latents.unsqueeze(0) + audio = self.tokenizer.decode(latents) + audio = audio.clamp(-1.0, 1.0) + return audio.squeeze(0) if squeeze else audio diff --git a/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py b/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py new file mode 100644 index 00000000000..52a52f8d042 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py @@ -0,0 +1,1556 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 VFM Transformer for vllm-omni. + +Implements the Mixture-of-Transformers architecture with two pathways: +- Understanding (UND): causal self-attention on text tokens (Qwen3-VL backbone) +- Generation (GEN): cross-attention where visual Q attends to [K_und, K_gen] + +Ported from the TRT-LLM integration (tekit branch user/shreyasm/cosmos3). +""" + +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, +) + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention as FrameworkAttention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput +from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available +from vllm_omni.diffusion.layers.norm import RMSNorm + +logger = init_logger(__name__) + + +def _get_ulysses_state() -> tuple[int, int, dist.ProcessGroup | None]: + """Return (ulysses_size, ulysses_rank, ulysses_pg) from vllm-omni parallel state. + + Returns (1, 0, None) when sequence parallelism is not active. + """ + from vllm_omni.diffusion.distributed.parallel_state import ( + get_sp_group, + get_ulysses_parallel_rank, + get_ulysses_parallel_world_size, + ) + + size = get_ulysses_parallel_world_size() + if size <= 1: + return 1, 0, None + return size, get_ulysses_parallel_rank(), get_sp_group().ulysses_group + + +def _is_sp_active() -> bool: + """Check whether sequence parallelism is active in the current forward context. + + Follows the Bagel pattern: read ``forward_context.sp_active`` which returns + True when ``sequence_parallel_size > 1`` even without ``_sp_plan`` hooks. + """ + + if not is_forward_context_available(): + return False + return get_forward_context().sp_active + + +def _tf_config_get(config: Any, key: str, default: Any) -> Any: + """Read a value from TransformerConfig, dict, or simple namespace.""" + if config is None: + return default + if hasattr(config, "get"): + return config.get(key, default) + return getattr(config, key, default) + + +def _nested_get(value: Any, key: str) -> Any: + if isinstance(value, dict): + if key in value: + return value[key] + for child in value.values(): + found = _nested_get(child, key) + if found is not None: + return found + elif isinstance(value, list | tuple): + for child in value: + found = _nested_get(child, key) + if found is not None: + return found + return None + + +def _od_config_get(od_config: Any, key: str, default: Any = None) -> Any: + """Read Cosmos3 options from runtime, model, or transformer config.""" + if od_config is None: + return default + for attr in ("custom_pipeline_args", "model_config"): + source = getattr(od_config, attr, None) or {} + if isinstance(source, dict): + if key in source: + return source[key] + found = _nested_get(source, key) + if found is not None: + return found + tf_model_config = getattr(od_config, "tf_model_config", None) + if isinstance(tf_model_config, dict): + if key in tf_model_config: + return tf_model_config[key] + found = _nested_get(tf_model_config, key) + if found is not None: + return found + value = _tf_config_get(tf_model_config, key, None) + return default if value is None else value + + +def _as_bool(value: Any) -> bool: + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +class DomainAwareLinear(nn.Module): + """Linear projection with one weight/bias pair per action embodiment domain.""" + + def __init__( + self, + input_size: int, + output_size: int, + num_domains: int, + *, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self.input_size = int(input_size) + self.output_size = int(output_size) + self.num_domains = int(num_domains) + self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size, dtype=dtype) + self.bias = nn.Embedding(self.num_domains, self.output_size, dtype=dtype) + nn.init.xavier_uniform_(self.fc.weight) + nn.init.zeros_(self.bias.weight) + + def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: + if domain_id.ndim == 0: + domain_id = domain_id.unsqueeze(0) + domain_id = domain_id.to(device=x.device, dtype=torch.long).reshape(-1) + if x.shape[0] != domain_id.shape[0]: + raise ValueError( + "Cosmos3 action domain_id batch size must match action tokens: " + f"tokens={x.shape[0]}, domain_id={domain_id.shape[0]}." + ) + if torch.any((domain_id < 0) | (domain_id >= self.num_domains)): + raise ValueError(f"Cosmos3 action domain_id must be in [0, {self.num_domains}), got {domain_id.tolist()}.") + + weight = self.fc(domain_id).view(domain_id.shape[0], self.input_size, self.output_size) + bias = self.bias(domain_id).view(domain_id.shape[0], self.output_size) + if x.ndim == 2: + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias + if x.ndim == 3: + return torch.bmm(x, weight) + bias.unsqueeze(1) + raise ValueError(f"Cosmos3 DomainAwareLinear expected rank-2 or rank-3 input, got {tuple(x.shape)}.") + + +# --------------------------------------------------------------------------- +# Rotary Position Embeddings (mRoPE) +# --------------------------------------------------------------------------- +def compute_mrope_position_ids_text( + num_tokens: int, + temporal_offset: int, +) -> tuple[torch.Tensor, int]: + """Generate 3D mRoPE position IDs for text tokens. + + Text tokens: all three axes (t, h, w) share the same monotonically + increasing position IDs. + """ + ids = torch.arange(num_tokens, dtype=torch.long) + temporal_offset + mrope_ids = ids.unsqueeze(0).expand(3, -1).contiguous() + return mrope_ids, temporal_offset + num_tokens + + +def compute_mrope_position_ids_vision( + grid_t: int, + grid_h: int, + grid_w: int, + temporal_offset: int | float, + fps: float | None = None, + base_fps: float = 24.0, + temporal_compression_factor: int = 4, + base_temporal_compression_factor: int | None = None, + enable_fps_modulation: bool = True, +) -> tuple[torch.Tensor, int | float]: + """Generate 3D mRoPE position IDs for vision tokens. + + Creates a (t, h, w) position grid with spatial indices reset per segment + (Qwen3VL-style). Flattened in t-major order. + """ + fps_modulation = enable_fps_modulation and fps is not None + + if fps_modulation: + tps = fps / temporal_compression_factor + effective_base_tcf = ( + base_temporal_compression_factor + if base_temporal_compression_factor is not None + else temporal_compression_factor + ) + base_tps = base_fps / effective_base_tcf + frame_indices = torch.arange(grid_t, dtype=torch.float32) + t_index = (frame_indices / tps * base_tps + temporal_offset).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + else: + t_index = torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + int( + temporal_offset + ) + + h_index = torch.arange(grid_h, dtype=torch.long).view(1, -1, 1).expand(grid_t, -1, grid_w).flatten() + w_index = torch.arange(grid_w, dtype=torch.long).view(1, 1, -1).expand(grid_t, grid_h, -1).flatten() + + if fps_modulation: + mrope_ids = torch.stack([t_index, h_index.to(torch.float32), w_index.to(torch.float32)], dim=0) + else: + mrope_ids = torch.stack([t_index, h_index, w_index], dim=0) + + next_offset = math.floor(mrope_ids.max().item()) + 1 + return mrope_ids, next_offset + + +def compute_mrope_position_ids_sound( + grid_t: int, + temporal_offset: int | float, + sound_latent_fps: float, + base_fps: float = 24.0, + temporal_compression_factor_sound: int = 1, + enable_fps_modulation: bool = True, + base_temporal_compression_factor: int | None = None, +) -> tuple[torch.Tensor, int | float]: + """Generate mRoPE IDs for sound tokens as a (T, 1, 1) grid.""" + del base_temporal_compression_factor + return compute_mrope_position_ids_vision( + grid_t=grid_t, + grid_h=1, + grid_w=1, + temporal_offset=temporal_offset, + fps=sound_latent_fps, + base_fps=base_fps, + temporal_compression_factor=temporal_compression_factor_sound, + base_temporal_compression_factor=temporal_compression_factor_sound, + enable_fps_modulation=enable_fps_modulation, + ) + + +def compute_mrope_position_ids_action( + grid_t: int, + temporal_offset: int | float, + action_fps: float | None, + base_fps: float = 24.0, + base_temporal_compression_factor: int = 4, + enable_fps_modulation: bool = True, + start_frame_offset: int = 1, +) -> tuple[torch.Tensor, int | float]: + """Generate mRoPE IDs for action tokens as a frame-rate (T, 1, 1) grid.""" + return compute_mrope_position_ids_vision( + grid_t=grid_t, + grid_h=1, + grid_w=1, + temporal_offset=temporal_offset, + fps=action_fps, + base_fps=base_fps, + temporal_compression_factor=1, + base_temporal_compression_factor=base_temporal_compression_factor, + enable_fps_modulation=enable_fps_modulation, + start_frame_offset=start_frame_offset, + ) + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + """Multi-dimensional rotary position embedding for Qwen3-VL.""" + + def __init__( + self, + *, + head_dim: int, + rope_theta: float, + mrope_section: list[int], + ) -> None: + super().__init__() + self.head_dim = head_dim + self.rope_theta = rope_theta + self.mrope_section = mrope_section + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim) + ) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Reorganize from chunked [TTT...HHH...WWW] to interleaved [THTHW...].""" + freqs_t = freqs[0] + for dim, offset in enumerate((1, 2), start=1): + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1).to(x.device) + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# RoPE application (Qwen3/Llama style) +# --------------------------------------------------------------------------- +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Qwen3-style RoPE: (x * cos) + (rotate_half(x) * sin). + + Args: + q: [B, S, h, D] + k: [B, S, H_kv, D] + cos: [1, S, 1, D] or broadcastable + sin: [1, S, 1, D] or broadcastable + """ + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +# --------------------------------------------------------------------------- +# Timestep Embedder +# --------------------------------------------------------------------------- +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations.""" + + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + max_period: int = 10000, + target_dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + # Following diffusers naming pattern here for checkpoint compatibility. + self.linear_1 = nn.Linear(frequency_embedding_size, hidden_size, bias=True) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.frequency_embedding_size = frequency_embedding_size + self.hidden_size = hidden_size + + half = frequency_embedding_size // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=target_dtype) / half) + self.register_buffer("freqs", freqs, persistent=False) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + args = t[:, None] * self.freqs[None] + t_freq = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + return self.linear_2(self.act(self.linear_1(t_freq))) + + +# --------------------------------------------------------------------------- +# GatedMLP (replaces TRT-LLM GatedMLP) +# --------------------------------------------------------------------------- +class Cosmos3GatedMLP(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 12288, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + ) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# Attention Modules +# --------------------------------------------------------------------------- +class Cosmos3CausalAttention(nn.Module): + """Understanding pathway: causal self-attention on text tokens. + + Returns (output, K, V) where K/V are post-norm, post-RoPE for the + generation pathway's cross-attention. + """ + + def __init__( + self, + *, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.num_kv_heads = num_key_value_heads + self.head_dim = head_dim + + tp_size = get_tensor_model_parallel_world_size() + self.num_heads_local = self.num_heads // tp_size + self.num_kv_heads_local = self.num_kv_heads // tp_size + + self.to_q = ColumnParallelLinear( + hidden_size, + self.num_heads * self.head_dim, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_q", + ) + self.to_k = ColumnParallelLinear( + hidden_size, + self.num_kv_heads * self.head_dim, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_k", + ) + self.to_v = ColumnParallelLinear( + hidden_size, + self.num_kv_heads * self.head_dim, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_v", + ) + self.to_out = RowParallelLinear( + self.num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_out", + ) + + self.norm_q = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.norm_k = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + text_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, S, _ = hidden_states.shape + + q = self.to_q(hidden_states).view(B, S, self.num_heads_local, self.head_dim) + k = self.to_k(hidden_states).view(B, S, self.num_kv_heads_local, self.head_dim) + v = self.to_v(hidden_states).view(B, S, self.num_kv_heads_local, self.head_dim) + + # Per-head QK norm + q = F.rms_norm(q, (q.shape[-1],), self.norm_q.weight, self.norm_q.variance_epsilon) + k = F.rms_norm(k, (k.shape[-1],), self.norm_k.weight, self.norm_k.variance_epsilon) + + # Qwen3-style RoPE + q, k = _apply_rotary_pos_emb(q, k, freqs_cos, freqs_sin) + + # Transpose to [B, h, S, D] for SDPA + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + + if text_mask is not None: + causal = torch.tril(torch.ones(S, S, device=hidden_states.device, dtype=torch.bool)) + padding = text_mask[:, None, None, :].bool() # [B, 1, 1, S] + combined = causal[None, None] & padding # [B, 1, S, S] + out = F.scaled_dot_product_attention(q_t, k_t, v_t, attn_mask=combined, enable_gqa=True) + else: + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=True, enable_gqa=True) + + out = out.transpose(1, 2).contiguous().view(B, S, -1) + return self.to_out(out), k, v + + +class Cosmos3CrossAttention(nn.Module): + """Generation pathway: full attention where visual Q attends to all K/V. + + Dual-path implementation: + + * **Non-SP path** (single GPU): the framework ``Attention`` layer with + explicit ``cat([k_und, k_gen])`` concatenation. Text conditioning is + always present because K/V are physically concatenated. + + * **SP path** (Ulysses active): the framework ``Attention`` layer with + ``joint_key/joint_value`` in ``AttentionMetadata``. The Ulysses + wrapper head-slices the replicated UND K/V and performs all-to-all + on the sharded GEN Q/K/V so every query sees the full context. + """ + + def __init__( + self, + *, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.num_kv_heads = num_key_value_heads + self.head_dim = head_dim + + tp_size = get_tensor_model_parallel_world_size() + self.num_heads_local = self.num_heads // tp_size + self.num_kv_heads_local = self.num_kv_heads // tp_size + + self.to_q = ColumnParallelLinear( + hidden_size, + self.num_heads * self.head_dim, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_q", + ) + self.to_k = ColumnParallelLinear( + hidden_size, + self.num_kv_heads * self.head_dim, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_k", + ) + self.to_v = ColumnParallelLinear( + hidden_size, + self.num_kv_heads * self.head_dim, + bias=False, + gather_output=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_v", + ) + self.to_out = RowParallelLinear( + self.num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_out", + ) + + self.norm_q = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.norm_k = RMSNorm(self.head_dim, eps=rms_norm_eps) + + self.local_attn = FrameworkAttention( + num_heads=self.num_heads_local, + head_size=self.head_dim, + causal=False, + softmax_scale=1.0 / (self.head_dim**0.5), + num_kv_heads=self.num_kv_heads_local, + skip_sequence_parallel=True, + ) + + # Lazy-created on first SP forward so it captures the active SP context. + self._sp_attn: nn.Module | None = None + + def _get_sp_attn(self) -> nn.Module: + if self._sp_attn is None: + self._sp_attn = FrameworkAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + causal=False, + softmax_scale=1.0 / (self.head_dim**0.5), + num_kv_heads=self.num_kv_heads, + ) + return self._sp_attn + + # -- Non-SP path: explicit K/V concatenation + framework Attention -------- + + def _forward_local( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_und: torch.Tensor, + v_und: torch.Tensor, + ) -> torch.Tensor: + B, S_gen = q.shape[:2] + k_all = torch.cat([k_und, k], dim=1) + v_all = torch.cat([v_und, v], dim=1) + + out = self.local_attn(q, k_all, v_all) + return out.reshape(B, S_gen, -1) + + # -- SP path: framework Attention with joint_key/value ------------------- + + def _forward_sp( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_und: torch.Tensor, + v_und: torch.Tensor, + ) -> torch.Tensor: + B, S_gen = q.shape[:2] + + # Zero-length joint_query satisfies the Ulysses contract + # (joint_query, joint_key, joint_value must all be present) without + # adding text tokens to Q. joint_len=0 keeps post_attention on the + # standard reverse-all-to-all path (no joint-output splitting). + joint_q = q.new_empty(B, 0, self.num_heads_local, self.head_dim) + + attn_metadata = AttentionMetadata( + joint_query=joint_q, + joint_key=k_und, + joint_value=v_und, + joint_strategy="front", + ) + out = self._get_sp_attn()(q, k, v, attn_metadata) + return out.reshape(B, S_gen, -1) + + # -- Public forward: routes to the appropriate path ---------------------- + + def forward( + self, + hidden_states: torch.Tensor, + k_und: torch.Tensor, + v_und: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states: [B, S_gen_local, hidden_size] (may be sequence-sharded) + k_und: [B, S_und, H_kv_local, D] pre-computed UND keys (TP-sharded, post-norm/RoPE) + v_und: [B, S_und, H_kv_local, D] pre-computed UND values (TP-sharded) + freqs_cos: [B, S_gen_local, 1, D] + freqs_sin: [B, S_gen_local, 1, D] + """ + B, S_gen, _ = hidden_states.shape + + q = self.to_q(hidden_states).view(B, S_gen, self.num_heads_local, self.head_dim) + k = self.to_k(hidden_states).view(B, S_gen, self.num_kv_heads_local, self.head_dim) + v = self.to_v(hidden_states).view(B, S_gen, self.num_kv_heads_local, self.head_dim) + + # Per-head QK norm + q = F.rms_norm(q, (q.shape[-1],), self.norm_q.weight, self.norm_q.variance_epsilon) + k = F.rms_norm(k, (k.shape[-1],), self.norm_k.weight, self.norm_k.variance_epsilon) + + # Qwen3-style RoPE + q, k = _apply_rotary_pos_emb(q, k, freqs_cos, freqs_sin) + + if _is_sp_active(): + out = self._forward_sp(q, k, v, k_und, v_und) + else: + out = self._forward_local(q, k, v, k_und, v_und) + + return self.to_out(out) + + +# --------------------------------------------------------------------------- +# Decoder Layers +# --------------------------------------------------------------------------- +class Cosmos3UndDecoderLayer(nn.Module): + """Understanding pathway decoder layer: causal self-attention + MLP.""" + + def __init__( + self, + *, + hidden_size: int, + intermediate_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = Cosmos3CausalAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = Cosmos3GatedMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs: tuple[torch.Tensor, torch.Tensor], + text_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns (hidden_states, K, V) where K/V are for GEN cross-attention.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + cos, sin = freqs + attn_out, k, v = self.self_attn(hidden_states, cos, sin, text_mask) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + self.mlp(hidden_states) + + return hidden_states, k, v + + +class Cosmos3GenDecoderLayer(nn.Module): + """Generation pathway decoder layer: cross-attention (to UND K/V) + MLP.""" + + def __init__( + self, + *, + layer_idx: int | None = None, + hidden_size: int, + intermediate_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attention = Cosmos3CrossAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.cross_attention", + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = Cosmos3GatedMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + *, + k_und: torch.Tensor | None = None, + v_und: torch.Tensor | None = None, + freqs_cos: torch.Tensor | None = None, + freqs_sin: torch.Tensor | None = None, + cached_kv: list[tuple[torch.Tensor, torch.Tensor]] | None = None, + freqs_gen: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + if cached_kv is not None: + if self.layer_idx is None: + raise ValueError("Cosmos3 GEN layer requires layer_idx when cached_kv is provided.") + k_und, v_und = cached_kv[self.layer_idx] + if freqs_gen is not None: + freqs_cos, freqs_sin = freqs_gen + if k_und is None or v_und is None or freqs_cos is None or freqs_sin is None: + raise ValueError("Cosmos3 GEN layer requires k_und/v_und/freqs_cos/freqs_sin.") + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attention( + hidden_states, k_und=k_und, v_und=v_und, freqs_cos=freqs_cos, freqs_sin=freqs_sin + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + self.mlp(hidden_states) + + return hidden_states + + +# --------------------------------------------------------------------------- +# Language Model (Understanding pathway) +# --------------------------------------------------------------------------- +class Cosmos3LanguageModel(nn.Module): + """Understanding pathway: a standard causal LM that processes text tokens. + + Returns per-layer K/V tensors for the generation pathway's cross-attention. + The UND pathway is independent of the denoising step, so its K/V can be + computed once and reused across all sampling steps. + """ + + def __init__( + self, + *, + hidden_size: int, + intermediate_size: int, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + vocab_size: int, + rms_norm_eps: float, + rope_theta: float, + mrope_section: list[int], + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + self.rotary_emb = Qwen3VLTextRotaryEmbedding( + head_dim=head_dim, + rope_theta=rope_theta, + mrope_section=mrope_section, + ) + self.layers = nn.ModuleList( + [ + Cosmos3UndDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.layers.{i}", + ) + for i in range(num_hidden_layers) + ] + ) + # TODO: Not used right now, will be used in the future for prompt upsampler. + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + text_ids: torch.Tensor, + freqs: tuple[torch.Tensor, torch.Tensor], + ) -> list[tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + text_ids: [B, S] token IDs + freqs: (cos, sin) each [B, S, 1, D] + + Returns: + List of (K, V) per layer, each [B, S, H_kv, D]. + + No padding mask is applied: with right-padding + causal self-attention, + real query positions only attend to real keys, and the caller trims pad + K/V via ``max_real_len`` before the GEN cross-attention sees them. + """ + hidden = self.embed_tokens(text_ids) + + cached_kv: list[tuple[torch.Tensor, torch.Tensor]] = [] + for layer in self.layers: + hidden, k, v = layer(hidden, freqs) + cached_kv.append((k, v)) + + return cached_kv + + +# --------------------------------------------------------------------------- +# Main Transformer +# --------------------------------------------------------------------------- +class Cosmos3GenSPPrepare(nn.Module): + """Module boundary used by _sp_plan to shard GEN states and RoPE together.""" + + def forward( + self, + hidden_gen: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return hidden_gen, freqs_cos, freqs_sin + + +class Cosmos3VFMTransformer(nn.Module): + """Cosmos3 VFM Transformer: UND language model + GEN denoising layers. + + The UND pathway runs once per generation (K/V cached). The GEN pathway + runs at each denoising step. + + Layerwise offloading uses ``gen_layers`` as the block container. + + Sequence parallelism uses ``_sp_plan`` to shard/gather the GEN pathway at + module boundaries. ``Cosmos3CrossAttention`` checks + ``forward_context.sp_active`` at runtime and routes to the framework + ``Attention`` layer (with Ulysses all-to-all) or plain SDPA accordingly. + """ + + _repeated_blocks = ["Cosmos3GenDecoderLayer"] + + _layerwise_offload_blocks_attrs = ["gen_layers"] + + packed_modules_mapping = {} + + @staticmethod + def _is_transformer_block(name: str, module) -> bool: + return ("gen_layers" in name or "language_model.layers" in name) and name.split(".")[-1].isdigit() + + _hsdp_shard_conditions = [_is_transformer_block] + + _sp_plan = { + "gen_sp_prepare": { + 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), + 2: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), + }, + "gen_sp_gather": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + @staticmethod + def _validate_supported_config(model_config: Any) -> None: + """Fail loudly when a checkpoint requests an unsupported architecture.""" + expected_values = { + "qk_norm_for_diffusion": True, + "qk_norm_for_text": True, + "position_embedding_type": "unified_3d_mrope", + "unified_3d_mrope_reset_spatial_ids": True, + "joint_attn_implementation": "two_way", + } + for key, expected in expected_values.items(): + actual = _tf_config_get(model_config, key, expected) + if actual != expected: + raise ValueError(f"Unsupported Cosmos3 transformer config: {key}={actual!r}; expected {expected!r}.") + + def __init__( + self, + od_config: OmniDiffusionConfig, + *, + temporal_compression_factor: int | None = None, + ) -> None: + super().__init__() + model_config = od_config.tf_model_config + self._validate_supported_config(model_config) + rope_scaling = _tf_config_get(model_config, "rope_scaling", {}) or {} + + self.hidden_size = int(_tf_config_get(model_config, "hidden_size", 4096)) + self.num_hidden_layers = int(_tf_config_get(model_config, "num_hidden_layers", 36)) + self.num_attention_heads = int(_tf_config_get(model_config, "num_attention_heads", 32)) + self.num_key_value_heads = int(_tf_config_get(model_config, "num_key_value_heads", 8)) + self.head_dim = int(_tf_config_get(model_config, "head_dim", 128)) + self.intermediate_size = int(_tf_config_get(model_config, "intermediate_size", 12288)) + self.vocab_size = int(_tf_config_get(model_config, "vocab_size", 151936)) + self.rms_norm_eps = float(_tf_config_get(model_config, "rms_norm_eps", 1e-6)) + self.rope_theta = float(_tf_config_get(model_config, "rope_theta", 5_000_000)) + self.mrope_section = list(rope_scaling.get("mrope_section", [24, 20, 20])) + self.latent_patch_size = int(_tf_config_get(model_config, "latent_patch_size", 2)) + self.latent_channel_size = int(_tf_config_get(model_config, "latent_channel", 48)) + self.timestep_scale = float(_tf_config_get(model_config, "timestep_scale", 0.001)) + self.base_fps = float(_tf_config_get(model_config, "base_fps", 24.0)) + sound_gen_value = _od_config_get(od_config, "sound_gen", None) + sound_dim_value = _od_config_get(od_config, "sound_dim", None) + if sound_dim_value is None: + sound_dim_value = _od_config_get(od_config, "io_channels", None) + if sound_dim_value is None: + sound_dim_value = _od_config_get(od_config, "vocoder_input_dim", None) + if sound_dim_value is None: + sound_dim_value = _od_config_get(od_config, "latent_ch", None) + self.sound_gen = _as_bool(sound_gen_value) if sound_gen_value is not None else sound_dim_value is not None + from .sound_tokenizer import get_sound_dim, get_sound_latent_fps + + self.sound_dim = int(sound_dim_value if sound_dim_value is not None else get_sound_dim(od_config)) + action_gen_value = _od_config_get(od_config, "action_gen", None) + action_dim_value = _od_config_get(od_config, "action_dim", None) + if action_dim_value is None: + action_dim_value = _od_config_get(od_config, "max_action_dim", None) + self.action_gen = _as_bool(action_gen_value) if action_gen_value is not None else False + self.action_dim = int(action_dim_value if action_dim_value is not None else 64) + self.num_embodiment_domains = int(_od_config_get(od_config, "num_embodiment_domains", 32)) + self.sound_latent_fps = float(get_sound_latent_fps(od_config)) + if temporal_compression_factor is None: + temporal_compression_factor = _tf_config_get(model_config, "temporal_compression_factor", 4) + self.temporal_compression_factor = int(temporal_compression_factor) + self.temporal_compression_factor_sound = int( + _tf_config_get(model_config, "temporal_compression_factor_sound", 1) + ) + self.enable_fps_modulation = bool(_tf_config_get(model_config, "enable_fps_modulation", True)) + self.temporal_modality_margin = int( + _tf_config_get( + model_config, + "unified_3d_mrope_temporal_modality_margin", + 15000, + ) + ) + self.patch_latent_dim = (self.latent_patch_size**2) * self.latent_channel_size + + dtype = od_config.dtype + quant_config = getattr(od_config, "quantization_config", None) if od_config else None + + self.language_model = Cosmos3LanguageModel( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + vocab_size=self.vocab_size, + rms_norm_eps=self.rms_norm_eps, + rope_theta=self.rope_theta, + mrope_section=self.mrope_section, + quant_config=quant_config, + prefix="language_model", + ) + + # Video projection layers are small; not worth quantizing. + self.proj_in = nn.Linear(self.patch_latent_dim, self.hidden_size) + self.proj_out = nn.Linear(self.hidden_size, self.patch_latent_dim) + self.time_embedder = TimestepEmbedder(self.hidden_size, target_dtype=dtype) + if self.action_gen: + self.action_proj_in = DomainAwareLinear( + self.action_dim, + self.hidden_size, + self.num_embodiment_domains, + dtype=dtype, + ) + self.action_proj_out = DomainAwareLinear( + self.hidden_size, + self.action_dim, + self.num_embodiment_domains, + dtype=dtype, + ) + self.action_modality_embed = nn.Parameter(torch.zeros(self.hidden_size, dtype=dtype)) + if self.sound_gen: + self.audio_proj_in = nn.Linear(self.sound_dim, self.hidden_size) + self.audio_proj_out = nn.Linear(self.hidden_size, self.sound_dim) + self.audio_modality_embed = nn.Parameter(torch.zeros(self.hidden_size)) + + self.time_embedder = TimestepEmbedder(self.hidden_size, target_dtype=torch.bfloat16) + + self.gen_layers = nn.ModuleList( + [ + Cosmos3GenDecoderLayer( + layer_idx=i, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + rms_norm_eps=self.rms_norm_eps, + quant_config=quant_config, + prefix=f"gen_layers.{i}", + ) + for i in range(self.num_hidden_layers) + ] + ) + + self.norm_moe_gen = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) + self.gen_sp_prepare = Cosmos3GenSPPrepare() + self.gen_sp_gather = nn.Identity() + + # SDPA backend selection for torch.nn.attention.sdpa_kernel context. + # Default: allow all backends; override to restrict (e.g. FlashAttention only). + self.sdpa_backends = [ + torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + ] + + # Cached state (populated on first forward, reused across denoising steps) + self.cached_kv: list[tuple[torch.Tensor, torch.Tensor]] | None = None + self.cached_freqs_gen: tuple[torch.Tensor, torch.Tensor] | None = None + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + # -- Patchify / Unpatchify ----------------------------------------------- + + def _pad_to_patch_size(self, h: int, w: int) -> tuple[int, int, int, int]: + """Returns (hp, wp, H_padded, W_padded).""" + p = self.latent_patch_size + H_padded = ((h + p - 1) // p) * p + W_padded = ((w + p - 1) // p) * p + return H_padded // p, W_padded // p, H_padded, W_padded + + def patchify(self, latents: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + """[B, C, t, h, w] -> [B, t*hp*wp, p*p*C], padding h/w if needed.""" + B = latents.shape[0] + p = self.latent_patch_size + C = self.latent_channel_size + hp, wp, H_padded, W_padded = self._pad_to_patch_size(h, w) + + if H_padded != h or W_padded != w: + latents = F.pad(latents, (0, W_padded - w, 0, H_padded - h)) + + x = latents.reshape(B, C, t, hp, p, wp, p) + x = x.permute(0, 2, 3, 5, 4, 6, 1) # [B, t, hp, wp, p, p, C] + return x.reshape(B, t * hp * wp, p * p * C) + + def unpatchify(self, tokens: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + """[B, t*hp*wp, p*p*C] -> [B, C, t, h, w], cropping padding if needed.""" + B = tokens.shape[0] + p = self.latent_patch_size + C = self.latent_channel_size + hp, wp, H_padded, W_padded = self._pad_to_patch_size(h, w) + + x = tokens.reshape(B, t, hp, wp, p, p, C) + x = x.permute(0, 6, 1, 2, 4, 3, 5) # [B, C, t, hp, p, wp, p] + x = x.reshape(B, C, t, H_padded, W_padded) + + if H_padded != h or W_padded != w: + x = x[:, :, :, :h, :w] + return x + + def pack_sound(self, sound_latents: torch.Tensor) -> torch.Tensor: + """[B, C_sound, T_sound] -> [B, T_sound, C_sound].""" + if sound_latents.ndim != 3: + raise ValueError(f"Cosmos3 sound latents must have shape [B, C, T], got {tuple(sound_latents.shape)}.") + if sound_latents.shape[1] != self.sound_dim: + raise ValueError( + f"Cosmos3 sound latent channel mismatch: expected {self.sound_dim}, got {sound_latents.shape[1]}." + ) + return sound_latents.permute(0, 2, 1).contiguous() + + @staticmethod + def unpack_sound(tokens: torch.Tensor) -> torch.Tensor: + """[B, T_sound, C_sound] -> [B, C_sound, T_sound].""" + return tokens.permute(0, 2, 1).contiguous() + + def pack_action(self, action_latents: torch.Tensor) -> torch.Tensor: + """Validate and return action latents as [B, T_action, D_action] tokens.""" + if action_latents.ndim != 3: + raise ValueError(f"Cosmos3 action latents must have shape [B, T, D], got {tuple(action_latents.shape)}.") + if action_latents.shape[-1] != self.action_dim: + raise ValueError( + f"Cosmos3 action latent dimension mismatch: expected {self.action_dim}, got {action_latents.shape[-1]}." + ) + return action_latents.contiguous() + + @staticmethod + def unpack_action(tokens: torch.Tensor) -> torch.Tensor: + """Return [B, T_action, D_action] action predictions.""" + return tokens.contiguous() + + # -- RoPE computation ---------------------------------------------------- + + def _compute_rope_freqs( + self, + text_mask: torch.Tensor, + t: int, + hp: int, + wp: int, + fps: float | None, + device: torch.device, + dtype: torch.dtype, + t_action: int | None = None, + action_start_frame_offset: int = 1, + action_fps: float | None = None, + t_sound: int | None = None, + ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + """Compute mRoPE cos/sin for UND text and GEN media pathways.""" + B = text_mask.shape[0] + S_text = text_mask.shape[1] + text_lengths = text_mask.sum(dim=1).long() + effective_fps = fps if fps is not None and t > 1 else None + action_frames = int(t_action or 0) + sound_frames = int(t_sound or 0) + + text_pos_list = [] + gen_pos_list = [] + for b in range(B): + real_len = int(text_lengths[b].item()) + t_pos, t_offset = compute_mrope_position_ids_text(real_len, temporal_offset=0) + media_temporal_offset = t_offset + self.temporal_modality_margin + v_pos, _ = compute_mrope_position_ids_vision( + t, + hp, + wp, + temporal_offset=media_temporal_offset, + fps=effective_fps, + base_fps=self.base_fps, + temporal_compression_factor=self.temporal_compression_factor, + enable_fps_modulation=self.enable_fps_modulation, + ) + gen_positions = [v_pos] + if action_frames > 0: + a_pos, _ = compute_mrope_position_ids_action( + action_frames, + temporal_offset=media_temporal_offset, + action_fps=action_fps if action_fps is not None else fps, + base_fps=self.base_fps, + base_temporal_compression_factor=self.temporal_compression_factor, + enable_fps_modulation=self.enable_fps_modulation, + start_frame_offset=action_start_frame_offset, + ) + gen_positions.append(a_pos) + if sound_frames > 0: + s_pos, _ = compute_mrope_position_ids_sound( + sound_frames, + temporal_offset=media_temporal_offset, + sound_latent_fps=self.sound_latent_fps, + base_fps=self.base_fps, + temporal_compression_factor_sound=getattr(self, "temporal_compression_factor_sound", 1), + enable_fps_modulation=self.enable_fps_modulation, + ) + gen_positions.append(s_pos) + pos_dtype = gen_positions[0].dtype + for pos in gen_positions[1:]: + pos_dtype = torch.promote_types(pos_dtype, pos.dtype) + v_pos = torch.cat([pos.to(pos_dtype) for pos in gen_positions], dim=1) + if real_len < S_text: + t_pos = torch.cat( + [t_pos, torch.zeros(3, S_text - real_len, dtype=t_pos.dtype)], + dim=1, + ) + text_pos_list.append(t_pos) + gen_pos_list.append(v_pos) + + text_pos_ids = torch.stack(text_pos_list, dim=1).to(device) # [3, B, S_text] + gen_pos_ids = torch.stack(gen_pos_list, dim=1).to(device) # [3, B, S_gen] + + rotary_emb = self.language_model.rotary_emb + _dummy = torch.tensor([], dtype=dtype, device=device) + cos_und, sin_und = rotary_emb(_dummy, position_ids=text_pos_ids) + cos_gen, sin_gen = rotary_emb(_dummy, position_ids=gen_pos_ids) + + freqs_und = (cos_und.unsqueeze(2), sin_und.unsqueeze(2)) # (B, S, 1, D) + freqs_gen = (cos_gen.unsqueeze(2), sin_gen.unsqueeze(2)) + return freqs_und, freqs_gen + + # -- Cache management ---------------------------------------------------- + + def reset_cache(self) -> None: + self.cached_kv = None + self.cached_freqs_gen = None + + @staticmethod + def _validate_gen_sequence_parallel( + *, + s_gen: int, + s_video: int, + s_action: int, + s_sound: int, + has_action: bool, + has_sound: bool, + ulysses_size: int, + ) -> None: + if ulysses_size <= 1 or s_gen % ulysses_size == 0: + return + + detail_parts = [f"video tokens {s_video}"] + if has_action: + detail_parts.append(f"action tokens {s_action}") + if has_sound: + detail_parts.append(f"sound tokens {s_sound}") + detail = " = " + " + ".join(detail_parts) if len(detail_parts) > 1 else "" + adjust_detail = ( + "Adjust the spatial resolution, frame count, action chunk size, " + "sound duration, or sound latent FPS so the combined media sequence is a " + "multiple of ulysses_degree." + if has_action or has_sound + else ( + "Adjust the spatial resolution so that " + "t * ceil(h/patch) * ceil(w/patch) is a multiple " + "of ulysses_degree." + ) + ) + raise ValueError( + f"GEN sequence length ({s_gen}{detail}) must be divisible by " + f"ulysses_degree ({ulysses_size}). {adjust_detail}" + ) + + # -- Forward ------------------------------------------------------------- + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + text_ids: torch.Tensor, + text_mask: torch.Tensor, + video_shape: tuple[int, int, int], + fps: float | None = None, + action_latents: torch.Tensor | None = None, + action_domain_ids: torch.Tensor | None = None, + action_noisy_mask: torch.Tensor | None = None, + action_start_frame_offset: int = 1, + action_fps: float | None = None, + sound_latents: torch.Tensor | None = None, + noisy_frame_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Args: + hidden_states: [B, C, t, h, w] noisy latents + timestep: [B] diffusion timestep + text_ids: [B, S_text] tokenized text + text_mask: [B, S_text] attention mask (1=real, 0=pad) + video_shape: (t, h, w) in latent space + fps: video frame rate for temporal mRoPE modulation + action_latents: Optional [B, T_action, D_action] noisy action latents. + action_domain_ids: Optional [B] embodiment domain IDs for action projections. + action_noisy_mask: Optional [B, T_action, 1] mask where 1=noisy + action token and 0=clean conditioned token. + sound_latents: Optional [B, C_sound, T_sound] noisy sound latents. + noisy_frame_mask: Optional [B, 1, t, 1, 1] mask where 1=noisy (add + timestep embedding, predict velocity) and 0=conditioned (clean + context, skip timestep embedding). None means all frames noisy + (T2V mode). + + Returns: + [B, C, t, h, w] velocity prediction, or + tuple outputs in video, action, sound order when extra modalities are provided. + """ + t, h, w = video_shape + hp, wp, _, _ = self._pad_to_patch_size(h, w) + text_lengths = text_mask.sum(dim=1) + min_real_len = int(text_lengths.min().item()) + max_real_len = int(text_lengths.max().item()) + if min_real_len != max_real_len: + raise ValueError( + f"Cosmos3 requires identical real text lengths within a batch " + f"(got min={min_real_len}, max={max_real_len})." + ) + has_action = action_latents is not None + has_sound = sound_latents is not None + if has_action and not self.action_gen: + raise ValueError( + "Cosmos3 action generation was requested, but this transformer " + "was initialized without action modules. Check that the " + "transformer config enables action_gen." + ) + if has_sound and not self.sound_gen: + raise ValueError( + "Cosmos3 sound generation was requested, but this transformer " + "was initialized without sound modules. Check that the " + "transformer config enables sound_gen or defines sound_dim." + ) + + # Query Ulysses state at runtime + ulysses_size, _, _ = _get_ulysses_state() + + # Patchify latents and project to hidden space + hidden_video = self.proj_in(self.patchify(hidden_states, t, h, w)) + s_video = hidden_video.shape[1] + s_action = 0 + hidden_action = None + s_sound = 0 + hidden_sound = None + if action_latents is not None: + if action_latents.shape[0] != hidden_states.shape[0]: + raise ValueError( + "Cosmos3 action and video batch sizes must match: " + f"video={hidden_states.shape[0]}, action={action_latents.shape[0]}." + ) + if action_domain_ids is None: + action_domain_ids = torch.zeros(action_latents.shape[0], dtype=torch.long, device=action_latents.device) + hidden_action = self.action_proj_in(self.pack_action(action_latents), action_domain_ids) + hidden_action = hidden_action + self.action_modality_embed.to(hidden_action.dtype) + s_action = hidden_action.shape[1] + if sound_latents is not None: + if sound_latents.shape[0] != hidden_states.shape[0]: + raise ValueError( + "Cosmos3 sound and video batch sizes must match: " + f"video={hidden_states.shape[0]}, sound={sound_latents.shape[0]}." + ) + hidden_sound = self.audio_proj_in(self.pack_sound(sound_latents)) + hidden_sound = hidden_sound + self.audio_modality_embed.to(hidden_sound.dtype) + s_sound = hidden_sound.shape[1] + + # Timestep embedding (fp32 for precision). + # For I2V: only add to noisy tokens, not conditioned ones. + # Conditioned frames are clean context and should not receive + # the diffusion timestep signal. + with torch.autocast("cuda", enabled=True, dtype=torch.float32): + time_embed = self.time_embedder(timestep * self.timestep_scale) + time_embed = time_embed.to(hidden_states.dtype) + + if noisy_frame_mask is not None: + # Build per-token mask from per-frame mask. + # noisy_frame_mask: [B, 1, t, 1, 1] → token mask: [B, t*hp*wp, 1] + token_noisy_mask = ( + noisy_frame_mask[:, 0, :, 0, 0] # [B, t] + .unsqueeze(-1) # [B, t, 1] + .expand(-1, -1, hp * wp) # [B, t, hp*wp] + .reshape(hidden_video.shape[0], -1, 1) # [B, t*hp*wp, 1] + ) + hidden_video = hidden_video + time_embed.unsqueeze(1) * token_noisy_mask + else: + hidden_video = hidden_video + time_embed.unsqueeze(1) + + if hidden_action is not None: + if action_noisy_mask is None: + hidden_action = hidden_action + time_embed.unsqueeze(1) + else: + if action_noisy_mask.shape != (hidden_action.shape[0], hidden_action.shape[1], 1): + raise ValueError( + "Cosmos3 action_noisy_mask must have shape [B, T_action, 1], " + f"got {tuple(action_noisy_mask.shape)}." + ) + hidden_action = hidden_action + time_embed.unsqueeze(1) * action_noisy_mask.to(hidden_action.dtype) + + if hidden_sound is not None: + hidden_sound = hidden_sound + time_embed.unsqueeze(1) + hidden_parts = [hidden_video] + if hidden_action is not None: + hidden_parts.append(hidden_action) + if hidden_sound is not None: + hidden_parts.append(hidden_sound) + hidden_gen = torch.cat(hidden_parts, dim=1) + + with torch.nn.attention.sdpa_kernel(self.sdpa_backends, set_priority=True): + # Run UND pathway once and cache K/V (replicated across all ranks) + if self.cached_kv is None: + freqs_und, freqs_gen = self._compute_rope_freqs( + text_mask, + t, + hp, + wp, + fps, + hidden_states.device, + hidden_states.dtype, + t_action=s_action, + action_start_frame_offset=action_start_frame_offset, + action_fps=action_fps, + t_sound=s_sound, + ) + cached_kv_full = self.language_model(text_ids, freqs_und) + self.cached_freqs_gen = freqs_gen + + # Trim to real text length (remove padding). K/V stay replicated; + # the framework Attention layer head-slices them via joint_key/value. + self.cached_kv = [(k[:, :max_real_len], v[:, :max_real_len]) for k, v in cached_kv_full] + + # Run GEN layers. UND K/V (replicated) is passed to each layer; + # the Cosmos3CrossAttention forwards them as joint_key/value so the + # framework Attention handles the Ulysses head-slicing internally. + if self.cached_kv is None or self.cached_freqs_gen is None: + raise RuntimeError("Cosmos3 GEN cache was not initialized before running GEN layers.") + self._validate_gen_sequence_parallel( + s_gen=hidden_gen.shape[1], + s_video=s_video, + s_action=s_action, + s_sound=s_sound, + has_action=has_action, + has_sound=has_sound, + ulysses_size=ulysses_size, + ) + freqs_cos, freqs_sin = self.cached_freqs_gen + hidden_gen, freqs_cos, freqs_sin = self.gen_sp_prepare(hidden_gen, freqs_cos, freqs_sin) + freqs_gen = (freqs_cos, freqs_sin) + + if len(self.gen_layers) == len(self.cached_kv): + for layer, (k_und, v_und) in zip(self.gen_layers, self.cached_kv, strict=True): + hidden_gen = layer( + hidden_gen, + k_und=k_und, + v_und=v_und, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + # Cache-dit's block wrapper may return a tuple; unwrap it. + if isinstance(hidden_gen, tuple): + hidden_gen = hidden_gen[0] + else: + # Cache-dit patches gen_layers to a grouped wrapper. + for layer in self.gen_layers: + hidden_gen = layer( + hidden_gen, + cached_kv=self.cached_kv, + freqs_gen=freqs_gen, + ) + if isinstance(hidden_gen, tuple): + hidden_gen = hidden_gen[0] + + hidden_gen = self.gen_sp_gather(hidden_gen) + + # Final norm and project back to latent space + hidden_gen = self.norm_moe_gen(hidden_gen) + if not has_action and not has_sound: + return self.unpatchify(self.proj_out(hidden_gen), t, h, w) + + split_sizes = [s_video] + if has_action: + split_sizes.append(s_action) + if has_sound: + split_sizes.append(s_sound) + split_hidden = hidden_gen.split(split_sizes, dim=1) + hidden_video = split_hidden[0] + video_pred = self.unpatchify(self.proj_out(hidden_video), t, h, w) + outputs: list[torch.Tensor] = [video_pred] + split_idx = 1 + if has_action: + hidden_action = split_hidden[split_idx] + split_idx += 1 + assert action_domain_ids is not None + outputs.append(self.unpack_action(self.action_proj_out(hidden_action, action_domain_ids))) + if has_sound: + hidden_sound = split_hidden[split_idx] + outputs.append(self.unpack_sound(self.audio_proj_out(hidden_sound))) + return tuple(outputs) + + def post_load_weights(self) -> None: + """Post-load processing: ensure correct dtypes.""" + self.time_embedder.to(torch.float32) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index fe7d00c77dd..cd6efdb5c65 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -261,6 +261,11 @@ "pipeline_omnivoice", "OmniVoicePipeline", ), + "Cosmos3OmniDiffusersPipeline": ( + "cosmos3", + "pipeline_cosmos3", + "Cosmos3OmniDiffusersPipeline", + ), "DiffusersAdapterPipeline": ( "diffusers_adapter", "pipeline_diffusers_adapter", @@ -493,6 +498,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "OmniVoicePipeline": "get_omnivoice_post_process_func", "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func", "SenseNovaU1Pipeline": "get_sensenova_u1_post_process_func", + "Cosmos3OmniDiffusersPipeline": "get_cosmos3_post_process_func", "HiDreamImagePipeline": "get_hidream_image_post_process_func", } @@ -517,6 +523,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_pre_process_func", "HunyuanImage3ForCausalMM": "get_hunyuan_image_3_pre_process_func", "MagiHumanPipeline": "get_magi_human_pre_process_func", + "Cosmos3OmniDiffusersPipeline": "get_cosmos3_pre_process_func", } diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 4986eae63c9..05c6345277a 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1929,6 +1929,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "max_num_seqs": kwargs.get("max_num_seqs") or 1, "parallel_config": parallel_config, "model_class_name": kwargs.get("model_class_name", None), + "model_config": kwargs.get("model_config", None), "additional_config": kwargs.get("additional_config", None), "step_execution": kwargs.get("step_execution", False), "vae_use_slicing": kwargs.get("vae_use_slicing", False), diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index d8d59285d93..ac8a7ad5449 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1559,7 +1559,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) return ImageGenerationResponse(created=int(time.time()), data=image_data) # Build params - pass through user values directly - prompt: OmniTextPrompt = {"prompt": request.prompt} + prompt: OmniTextPrompt = {"prompt": request.prompt, "modalities": ["image"]} if request.negative_prompt is not None: prompt["negative_prompt"] = request.negative_prompt gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n) @@ -1744,7 +1744,7 @@ async def edit_images( try: # 2. Build prompt & images params cot_output = None - prompt: OmniTextPrompt = {"prompt": prompt} + prompt: OmniTextPrompt = {"prompt": prompt, "modalities": ["image"]} if negative_prompt is not None: prompt["negative_prompt"] = negative_prompt input_images_list = [] @@ -2435,7 +2435,7 @@ async def _run_video_generation_job( started_at = time.perf_counter() output_path = None try: - video_bytes, stage_durations, peak_memory_mb = await handler.generate_video_bytes( + video_bytes, stage_durations, peak_memory_mb, action = await handler.generate_video_bytes( request, video_id, reference_image=reference_image ) @@ -2453,6 +2453,7 @@ async def _run_video_generation_job( "inference_time_s": time.perf_counter() - started_at, "stage_durations": stage_durations, "peak_memory_mb": peak_memory_mb, + "action": action, }, ) except (EngineGenerateError, EngineDeadError) as exc: @@ -2517,6 +2518,8 @@ async def _parse_video_form( flow_shift: float | None = Form(default=None), true_cfg_scale: float | None = Form(default=None), seed: int | None = Form(default=None), + generate_sound: bool | None = Form(default=None), + sound_duration: float | None = Form(default=None, gt=0.0), negative_prompt: str | None = Form(default=None), enable_frame_interpolation: bool | None = Form(default=None), frame_interpolation_exp: int | None = Form(default=None, ge=1), @@ -2557,6 +2560,8 @@ async def _parse_video_form( "flow_shift": flow_shift, "true_cfg_scale": true_cfg_scale, "seed": seed, + "generate_sound": generate_sound, + "sound_duration": sound_duration, "negative_prompt": negative_prompt, "enable_frame_interpolation": enable_frame_interpolation, "frame_interpolation_exp": frame_interpolation_exp, @@ -2660,7 +2665,7 @@ async def create_video_sync( raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id) started_at = time.perf_counter() try: - video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for( + video_bytes, stage_durations, peak_memory_mb, _action = await asyncio.wait_for( handler.generate_video_bytes(request, request_id, reference_image=reference_image), timeout=VIDEO_SYNC_TIMEOUT_S, ) diff --git a/vllm_omni/entrypoints/openai/protocol/__init__.py b/vllm_omni/entrypoints/openai/protocol/__init__.py index c68f6f59879..0d8ddd82d90 100644 --- a/vllm_omni/entrypoints/openai/protocol/__init__.py +++ b/vllm_omni/entrypoints/openai/protocol/__init__.py @@ -13,6 +13,7 @@ ResponseFormat, ) from vllm_omni.entrypoints.openai.protocol.videos import ( + VideoAction, VideoData, VideoGenerationRequest, VideoGenerationResponse, @@ -27,6 +28,7 @@ "ImageGenerationRequest", "ImageGenerationResponse", "ResponseFormat", + "VideoAction", "VideoData", "VideoGenerationRequest", "VideoGenerationResponse", diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index d46c8d43d6b..ec5ab14e8d8 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -149,6 +149,15 @@ class VideoGenerationRequest(BaseModel): description="True CFG scale (model-specific parameter, may be ignored if not supported)", ) seed: int | None = Field(default=None, description="Random seed for reproducibility") + generate_sound: bool = Field( + default=False, + description="Request model-generated audio for video models that support sound generation.", + ) + sound_duration: float | None = Field( + default=None, + gt=0.0, + description="Duration in seconds for model-generated audio. Defaults to the generated video duration.", + ) # vllm-omni extensions for post-generation frame interpolation. enable_frame_interpolation: bool = Field( @@ -211,12 +220,24 @@ def resolve_video_params(self) -> VideoParams: return vp +class VideoAction(BaseModel): + """Generated action sequence returned by action-capable video models.""" + + data: list[Any] = Field(..., description="JSON-serializable nested action values") + shape: list[int] = Field(..., description="Shape of the returned action data") + dtype: str | None = Field(default=None, description="Source action dtype, if available") + raw_action_dim: int | None = Field(default=None, description="Raw action dimension requested by the model") + action_mode: str | None = Field(default=None, description="Action generation mode") + domain_id: int | None = Field(default=None, description="Action embodiment domain id") + + class VideoData(BaseModel): """Single generated video data.""" b64_json: str | None = Field(default=None, description="Base64-encoded MP4 video") url: str | None = Field(default=None, description="Video URL (not implemented)") revised_prompt: str | None = Field(default=None, description="Revised prompt (OpenAI compatibility, always null)") + action: VideoAction | None = Field(default=None, description="Generated action sequence metadata, if any") class VideoGenerationResponse(BaseModel): @@ -289,6 +310,7 @@ class VideoResponse(BaseModel): default=0.0, description="Peak device memory usage in MB reported by the diffusion pipeline.", ) + action: VideoAction | None = Field(default=None, description="Generated action sequence metadata, if any") @property def file_extension(self) -> str: diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 06ccfcd70fe..2a8676a4145 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -2521,6 +2521,7 @@ def _prepare_diffusion_image_request( gen_prompt: OmniTextPrompt = { "prompt": prompt, "negative_prompt": negative_prompt, + "modalities": ["image"], } if pil_images: if len(pil_images) == 1: @@ -2831,6 +2832,7 @@ async def _create_diffusion_chat_completion( gen_prompt: OmniTextPrompt = { "prompt": prompt, "negative_prompt": negative_prompt, + "modalities": ["image"], } gen_params = OmniDiffusionSamplingParams( height=height, diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index 043ccd98322..9b173f9c0d1 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -16,6 +16,7 @@ from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.protocol.videos import ( + VideoAction, VideoData, VideoGenerationRequest, VideoGenerationResponse, @@ -44,6 +45,7 @@ class VideoGenerationArtifacts: videos: list[Any] audios: list[Any | None] + actions: list[VideoAction | None] audio_sample_rate: int output_fps: int stage_durations: dict[str, float] @@ -96,7 +98,7 @@ async def _run_and_extract( reference_image: ReferenceImage | None = None, ) -> VideoGenerationArtifacts: """Run the generation pipeline and extract video/audio/profiler outputs.""" - prompt: OmniTextPrompt = OmniTextPrompt(prompt=request.prompt) + prompt: OmniTextPrompt = OmniTextPrompt(prompt=request.prompt, modalities=["video"]) if request.negative_prompt is not None: prompt["negative_prompt"] = request.negative_prompt @@ -148,6 +150,10 @@ async def _run_and_extract( ) if "flow_shift" in provided_fields and request.flow_shift is not None: gen_params.extra_args["flow_shift"] = request.flow_shift + if "generate_sound" in provided_fields: + gen_params.extra_args["generate_sound"] = request.generate_sound + if "sound_duration" in provided_fields and request.sound_duration is not None: + gen_params.extra_args["sound_duration"] = request.sound_duration # Apply model-specific extra parameters if request.extra_params is not None: @@ -173,11 +179,13 @@ async def _run_and_extract( result = await self._run_generation(prompt, gen_params, reference_id) videos = self._extract_video_outputs(result) audios = self._extract_audio_outputs(result, expected_count=len(videos)) + actions = self._extract_action_outputs(result, expected_count=len(videos)) audio_sample_rate = self._resolve_audio_sample_rate(result) output_fps = (vp.fps or self._resolve_fps(result) or 24) * self._resolve_video_fps_multiplier(result) return VideoGenerationArtifacts( videos=videos, audios=audios, + actions=actions, audio_sample_rate=audio_sample_rate, output_fps=output_fps, stage_durations=self._extract_stage_durations(result), @@ -211,7 +219,8 @@ async def generate_videos( audio_sample_rate=artifacts.audio_sample_rate, video_codec_options=video_codec_options, ) - ) + ), + action=artifacts.actions[idx], ) for idx, video in enumerate(artifacts.videos) ] @@ -230,7 +239,7 @@ async def generate_video_bytes( reference_id: str, *, reference_image: ReferenceImage | None = None, - ) -> tuple[bytes, dict[str, float], float]: + ) -> tuple[bytes, dict[str, float], float, VideoAction | None]: """Generate a video and return raw MP4 bytes, bypassing base64 encoding.""" artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image) if len(artifacts.videos) > 1: @@ -255,22 +264,15 @@ async def generate_video_bytes( ) _t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000 logger.info("Video response encoding (MP4 bytes): %.2f ms", _t_encode_ms) - return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb + return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb, artifacts.actions[0] @staticmethod def _resolve_video_fps_multiplier(result: Any) -> int: - custom_output = getattr(result, "custom_output", None) + custom_output = OmniOpenAIServingVideo._extract_custom_output(result) if isinstance(custom_output, dict): multiplier = custom_output.get("video_fps_multiplier") if multiplier is not None: return int(multiplier) - request_output = getattr(result, "request_output", None) - if request_output is not None: - custom_output = getattr(request_output, "custom_output", None) - if isinstance(custom_output, dict): - multiplier = custom_output.get("video_fps_multiplier") - if multiplier is not None: - return int(multiplier) return 1 def _resolve_default_sampling_params(self) -> OmniDiffusionSamplingParams: @@ -443,6 +445,132 @@ def _resolve_audio_sample_rate(self, result: Any) -> int: return 24000 + @classmethod + def _extract_action_outputs(cls, result: Any, expected_count: int) -> list[VideoAction | None]: + custom_output = cls._extract_custom_output(result) + if not custom_output or "action" not in custom_output: + return [None] * expected_count + + action_items = cls._split_action_payload(custom_output["action"], expected_count) + return [ + cls._make_video_action(action_item, custom_output) if action_item is not None else None + for action_item in action_items + ] + + @staticmethod + def _extract_custom_output(result: Any) -> dict[str, Any]: + custom_output = getattr(result, "custom_output", None) + if isinstance(custom_output, dict): + return custom_output + + request_output = getattr(result, "request_output", None) + if isinstance(request_output, dict): + custom_output = request_output.get("custom_output") + if custom_output is None: + custom_output = request_output.get("_custom_output") + elif request_output is not None: + custom_output = getattr(request_output, "custom_output", None) + if custom_output is None: + custom_output = getattr(request_output, "_custom_output", None) + + return custom_output if isinstance(custom_output, dict) else {} + + @classmethod + def _split_action_payload(cls, action: Any, expected_count: int) -> list[Any | None]: + if expected_count <= 0: + return [] + + shape = cls._shape_of(action) + if len(shape) >= 3: + count = min(shape[0], expected_count) + actions = [cls._index_action(action, i) for i in range(count)] + actions.extend([None] * (expected_count - count)) + return actions + + return [action] + [None] * (expected_count - 1) + + @classmethod + def _make_video_action(cls, action: Any, custom_output: dict[str, Any]) -> VideoAction: + data = cls._to_jsonable(action) + if not isinstance(data, list): + data = [data] + + action_mode = custom_output.get("action_mode") + return VideoAction( + data=data, + shape=cls._shape_of(action), + dtype=cls._dtype_of(action), + raw_action_dim=cls._coerce_optional_int(custom_output.get("raw_action_dim")), + action_mode=str(action_mode) if action_mode is not None else None, + domain_id=cls._coerce_optional_int(custom_output.get("domain_id")), + ) + + @staticmethod + def _index_action(action: Any, index: int) -> Any: + try: + return action[index] + except (IndexError, KeyError, TypeError): + return None + + @classmethod + def _to_jsonable(cls, value: Any) -> Any: + if hasattr(value, "detach"): + value = value.detach() + if hasattr(value, "cpu"): + value = value.cpu() + if hasattr(value, "tolist"): + return cls._to_jsonable(value.tolist()) + if isinstance(value, (list, tuple)): + return [cls._to_jsonable(item) for item in value] + if hasattr(value, "item"): + try: + return value.item() + except (TypeError, ValueError): + pass + return value + + @classmethod + def _shape_of(cls, value: Any) -> list[int]: + shape = getattr(value, "shape", None) + if shape is not None: + try: + return [int(dim) for dim in shape] + except (TypeError, ValueError): + pass + if isinstance(value, (list, tuple)): + if not value: + return [0] + return [len(value)] + cls._shape_of(value[0]) + return [] + + @staticmethod + def _dtype_of(value: Any) -> str | None: + dtype = getattr(value, "dtype", None) + return str(dtype) if dtype is not None else None + + @staticmethod + def _coerce_optional_int(value: Any) -> int | None: + if value is None: + return None + try: + value = value.item() if hasattr(value, "item") else value + return int(value) + except (TypeError, ValueError): + return None + + def _resolve_audio_sample_rate(self, result: Any) -> int: + result_sample_rate = self._extract_audio_sample_rate_from_result(result) + if result_sample_rate is not None: + return result_sample_rate + + model_config = getattr(self._engine_client, "model_config", None) + hf_config = getattr(model_config, "hf_config", None) + config_sample_rate = self._extract_audio_sample_rate_from_config(hf_config) + if config_sample_rate is not None: + return config_sample_rate + + return 24000 + @staticmethod def _resolve_fps(result: Any) -> int | None: """Extract fps from multimodal_output if the model reported it.""" diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 86a0d818268..cdb262d0239 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -36,6 +36,7 @@ def _warn_deprecated_explicit_keys(kwargs: dict[str, Any]) -> None: _DIFFUSERS_CLASS_TO_CONFIG: dict[str, str] = { + "Cosmos3OmniDiffusersPipeline": "cosmos3_omni", "GlmImagePipeline": "glm_image", } diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 1b80f4b1b77..a2b6bf722d5 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -33,6 +33,9 @@ class OmniTextPrompt(TextPrompt): """ negative_prompt: NotRequired[str] + # Using modalities field to differentiate between different tasks for the same pipeline + # for example Cosmos3OmniDiffusersPipeline handles t2i and t2v in the same pipeline. + modalities: NotRequired[list[str]] prompt_embeds: NotRequired[torch.Tensor] negative_prompt_embeds: NotRequired[torch.Tensor] additional_information: NotRequired[dict[str, Any]]