diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 8a39b4ba4b7..a0588a5c336 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -33,7 +33,7 @@ th { | `ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanPipeline` | Wan2.1-T2V, Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | -| `Cosmos3OmniDiffusersPipeline` | Cosmos3 T2I, T2V, I2V, T2V with sound | `nvidia/Cosmos3-Nano` | ✅︎ | | | | +| `Cosmos3OmniDiffusersPipeline` | Cosmos3 T2I, T2V, I2V, T2V with sound, action policy | `nvidia/Cosmos3-Nano` | ✅︎ | | | | | `WanSpeechToVideoPipeline` | Wan2.2-S2V | `Wan-AI/Wan2.2-S2V-14B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Wan22VACEPipeline` | Wan2.1-VACE | `Wan-AI/Wan2.1-VACE-1.3B-diffusers`, `Wan-AI/Wan2.1-VACE-14B-diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | | diff --git a/recipes/README.md b/recipes/README.md index 161bcdd5edc..4648dbd6e5d 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -36,8 +36,8 @@ recipes/ | [`LTX/LTX-2.md`](./LTX/LTX-2.md) | Text-to-video and image-to-video serving | 1x H200 141GB | | [`LTX/LTX-2.3.md`](./LTX/LTX-2.3.md) | Text-to-video with audio generation (22B) | 1x GPU (96GB VRAM) | | [`mistralai/Voxtral-TTS.md`](./mistralai/Voxtral-TTS.md) | Online serving for TTS | 1x RTX 4090 24GB | -| [`nvidia/Cosmos3-Nano.md`](./nvidia/Cosmos3-Nano.md) | Text-to-image, text-to-video, image-to-video generation, text to video with sound | 1x H200 141GB / B300 | -| [`nvidia/Cosmos3-Super.md`](./nvidia/Cosmos3-Super.md) | 64B T2I / T2V / I2V generation (+ optional audio) | 8x H200/H100/A100 / 2x H200 / B300 | +| [`cosmos3/Cosmos3-Nano.md`](./cosmos3/Cosmos3-Nano.md) | Text-to-image, text-to-video, image-to-video generation, text to video with sound, action policy | 1x H200 141GB / B300 | +| [`cosmos3/Cosmos3-Super.md`](./cosmos3/Cosmos3-Super.md) | 64B T2I / T2V / I2V generation (+ optional audio) / Action policy | 8x H200/H100/A100 / 2x H200 / B300 | | [`OpenBMB/MiniCPM-o-4_5.md`](./OpenBMB/MiniCPM-o-4_5.md) | Online serving for omni multimodal chat (text / image / audio / video → text + 24 kHz speech) | 2x A100/H100 80GB / 3x mid-tier GPU / 8x RTX 4090 24GB | | [`OpenBMB/VoxCPM2.md`](./OpenBMB/VoxCPM2.md) | Online + offline TTS with native AR pipeline (48 kHz, 30+ languages) | 1x RTX 4090 24GB | | [`Qwen/Qwen-Image.md`](./Qwen/Qwen-Image.md) | Text-to-image serving with step-wise continuous batching replay and ModelOpt mixed FP8/NVFP4 | 1x A100 80GB / 2x B200 | diff --git a/recipes/nvidia/Cosmos3-Nano.md b/recipes/cosmos3/Cosmos3-Nano.md similarity index 75% rename from recipes/nvidia/Cosmos3-Nano.md rename to recipes/cosmos3/Cosmos3-Nano.md index 57f6b983cda..57eb9048c51 100644 --- a/recipes/nvidia/Cosmos3-Nano.md +++ b/recipes/cosmos3/Cosmos3-Nano.md @@ -6,7 +6,7 @@ - Vendor: NVIDIA - Model: `nvidia/Cosmos3-Nano` -- Task: Text-to-image (T2I), text-to-video (T2V), and image-to-video (I2V) generation, with optional synchronized audio (video + sound) +- Task: Text-to-image (T2I), text-to-video (T2V), and image-to-video (I2V) generation, with optional synchronized audio (video + sound), action policy - Mode: Online serving with the OpenAI-compatible image/video APIs, plus offline generation via the `Omni` API - Maintainer: Community @@ -23,6 +23,23 @@ the mode is selected per request: - **T2VS / I2VS** — add `generate_sound=true` (and optional `sound_duration`) to a T2V/I2V `/v1/videos/sync` request to also generate synchronized audio, muxed into the mp4 as AAC 48 kHz stereo. See the official model card's "Video + Audio" examples. +- **Action** — pass `extra_params={"action_mode": ...}` to drive Physical-AI tasks: + - `forward_dynamics` — given a first frame **and** an action trajectory, roll out + the resulting video. Synchronous: `POST /v1/videos/sync`. + - `policy` — given a first frame and a language instruction, **predict** the action + trajectory (and a rollout video). Use the async `POST /v1/videos` endpoint and + read the predicted action from the top-level `action` field + (`{data, shape, dtype, raw_action_dim, domain_id}`). + + Action requests also take `domain_name` (e.g. `av`, `bridge_orig_lerobot`, + `droid_lerobot`, `agibotworld`, …; or a numeric `domain_id`), `raw_action_dim`, + and `action_chunk_size` (must equal `num_frames` or `num_frames - 1`). For + `forward_dynamics` also pass the `action` array. The dedicated policy checkpoint + **`nvidia/Cosmos3-Nano-Policy-DROID`** is served the same way + (`domain_name=droid_lerobot`). + + `inverse_dynamics` (recover the action from a given video) is supported by the + pipeline; **online inference of inverse dynamics will be added in a follow-up MR.** ## References @@ -144,6 +161,35 @@ curl -sS -X POST http://localhost:8000/v1/videos/sync \ -F "sound_duration=7.875" \ -F 'extra_params={"use_resolution_template":false,"use_duration_template":false,"guardrails":true}' \ -o cosmos3_t2v_with_sound.mp4 + +# Action — forward dynamics (first frame + action trajectory -> rollout video). +# Synchronous; `action` is a JSON array shaped [action_chunk_size, raw_action_dim]. +curl -sS -X POST http://localhost:8000/v1/videos/sync \ + -H "Accept: video/mp4" \ + --form-string "model=nvidia/Cosmos3-Nano" \ + --form-string "prompt=You are an autonomous vehicle. This video is captured from a first-person perspective." \ + -F "input_reference=@first_frame.jpg;type=image/jpeg" \ + -F "size=640x480" -F "num_frames=61" -F "fps=10" \ + -F "num_inference_steps=30" -F "guidance_scale=1.0" -F "flow_shift=5.0" \ + --form-string "extra_params={\"action_mode\":\"forward_dynamics\",\"domain_name\":\"av\",\"raw_action_dim\":9,\"action_chunk_size\":60,\"action\":$(cat action.json)}" \ + -F "seed=0" \ + -o cosmos3_forward_dynamics.mp4 + +# Action — policy (first frame + instruction -> predicted action trajectory + video). +# Asynchronous: POST returns a job id; poll, then read the predicted action from +# the top-level `action` field ({data, shape, dtype, raw_action_dim, domain_id}). +VIDEO_ID=$(curl -sS -X POST http://localhost:8000/v1/videos \ + -H "Accept: application/json" \ + --form-string "model=nvidia/Cosmos3-Nano" \ + --form-string "prompt=Pick up the banana and place it in the bowl." \ + -F "input_reference=@first_frame.jpg;type=image/jpeg" \ + -F "size=640x480" -F "num_frames=17" -F "fps=5" \ + -F "num_inference_steps=30" -F "guidance_scale=1.0" -F "flow_shift=5.0" \ + --form-string 'extra_params={"action_mode":"policy","domain_name":"bridge_orig_lerobot","raw_action_dim":10,"action_chunk_size":16}' \ + -F "seed=0" | jq -r '.id') +# poll until status == completed, then: +curl -sS "http://localhost:8000/v1/videos/$VIDEO_ID" | jq '.action | {shape, dtype, raw_action_dim, domain_id}' +curl -sS -L "http://localhost:8000/v1/videos/$VIDEO_ID/content" -o cosmos3_policy.mp4 ``` #### Notes @@ -152,6 +198,7 @@ curl -sS -X POST http://localhost:8000/v1/videos/sync \ - T2I 1024² — 10 / 25 / 50 steps → ~0.4 / 0.7 / **1.3 s** - T2V 1280×720 @ 35 steps — 25 / 49 / 93 / **189** frames → ~7 / 15 / 33 / **~93 s** - I2V 1280×720, 189 frames @ 35 steps → ~**99 s** + - Action 640×480 @ 30 steps — forward-dynamics 61f ~**4 s**, policy 17f ~**1–3 s**. - Guardrails-on overhead: ~8% on T2I, negligible on video. - **Memory:** transformer ~17 GiB (bf16); peak ~46 GiB for 720p video on 1 GPU; full repo (transformer + Wan VAE + Qwen3-VL vision encoder + audio tokenizer) @@ -173,8 +220,10 @@ curl -sS -X POST http://localhost:8000/v1/videos/sync \ the server fails at pipeline build with a gated-repo / safety-checker error. - A guardrail-blocked prompt currently returns HTTP 500 (`"Guardrail blocked prompt"`). - - Action (policy / forward- / inverse-dynamics) modalities are not part of - this integration yet. + - Action `forward_dynamics` (sync `/v1/videos/sync`) and `policy` (async + `/v1/videos`, returns the predicted action under the top-level `action` + field) are supported online. **Online inference of inverse dynamics will be + added in a follow-up MR.** ### 1x GPU (Offline generation) diff --git a/recipes/nvidia/Cosmos3-Super.md b/recipes/cosmos3/Cosmos3-Super.md similarity index 90% rename from recipes/nvidia/Cosmos3-Super.md rename to recipes/cosmos3/Cosmos3-Super.md index 528b7a77393..1dc88d0c3de 100644 --- a/recipes/nvidia/Cosmos3-Super.md +++ b/recipes/cosmos3/Cosmos3-Super.md @@ -104,5 +104,9 @@ curl -sS -X POST http://localhost:8000/v1/videos/sync -H "Accept: video/mp4" \ - (NVIDIA's reference: 8×H200 @ 50 steps ≈ 55 s/video; 2×H200 @ 35 steps ≈ 3 min/video.) - **Memory:** ~61.5 GiB per GPU when sharded across 2 GPUs (HSDP shard 2); repo ~135 GB on disk. - Same generation defaults, supported sizes, and `generate_sound`/`sound_duration` - semantics as Nano. Action (policy / forward- / inverse-dynamics) modalities are - not part of this integration yet. + semantics as Nano, including the **action** modality: `forward_dynamics` + (sync `/v1/videos/sync`) and `policy` (async `/v1/videos`, predicted action under + the top-level `action` field) — see the Cosmos3-Nano recipe for the request shape. + Online inference of inverse dynamics will be added in a follow-up MR. Verified on + the 64B Super under `--cfg-parallel-size 2`: async `policy` returns the predicted + action (`[16, 10]`) and the rollout video reliably. diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py b/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py index 0e441766a97..087d1a67344 100644 --- a/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py +++ b/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py @@ -99,12 +99,16 @@ def __init__( sound_gen: bool = False, sound_dim: int = 3, sound_latent_fps: float = 25.0, + action_gen: bool = False, + action_dim: int = 4, ) -> None: super().__init__() self.latent_channel_size = latent_channel_size self.sound_gen = sound_gen self.sound_dim = sound_dim self.sound_latent_fps = sound_latent_fps + self.action_gen = action_gen + self.action_dim = action_dim self.cached_kv: Any | None = None self.cached_freqs_gen: Any | None = None self.calls: list[dict[str, Any]] = [] @@ -139,7 +143,10 @@ def forward( marker = torch.tensor([token], dtype=torch.float32) self.cached_kv = [(marker, marker + 100)] self.cached_freqs_gen = (marker + 200, marker + 300) + action_latents = kwargs.get("action_latents") outputs: list[torch.Tensor] = [torch.full_like(hidden_states, float(token))] + if action_latents is not None: + outputs.append(torch.full_like(action_latents, float(token + 20))) if sound_latents is not None: outputs.append(torch.full_like(sound_latents, float(token + 10))) return outputs[0] if len(outputs) == 1 else tuple(outputs) @@ -344,7 +351,7 @@ def test_pipeline_init_passes_tokenizer_attrs_into_transformer( assert pipeline.transformer.audio_proj_out.out_features == 5 -def test_preprocess_i2v_image_input() -> None: +def test_preprocess_i2v_image_and_action_video_inputs() -> None: from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_pre_process_func preprocess = get_cosmos3_pre_process_func(SimpleNamespace()) @@ -357,6 +364,16 @@ def test_preprocess_i2v_image_input() -> None: assert (result.sampling_params.height, result.sampling_params.width) == (672, 1344) assert tuple(result.prompts[0]["additional_information"]["preprocessed_image"].shape[-2:]) == (672, 1344) + frames = [Image.new("RGB", (8, 4), color) for color in ("red", "green", "blue")] + action = SimpleNamespace( + prompts=[{"prompt": "Move.", "multi_modal_data": {"video": frames}}], + sampling_params=SimpleNamespace(height=16, width=32, extra_args={"action_mode": "forward_dynamics"}), + ) + + additional = preprocess(action).prompts[0]["additional_information"] + assert tuple(additional["preprocessed_image"].shape) == (1, 3, 16, 32) + assert tuple(additional["preprocessed_video"].shape) == (1, 3, 3, 16, 32) + def test_postprocess_handles_image_video_audio_and_validation() -> None: from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_post_process_func @@ -434,7 +451,7 @@ def test_prompt_formatting_and_checkpoint_key_remap(make_cosmos3_pipeline) -> No assert {key: Cosmos3OmniDiffusersPipeline._remap_ckpt_key(key) for key in remaps} == remaps -def test_prepare_latents_for_video_image_and_sound(make_cosmos3_pipeline) -> None: +def test_prepare_latents_for_video_image_sound_and_action(make_cosmos3_pipeline) -> None: pipeline = make_cosmos3_pipeline() latents = pipeline._prepare_latents(16, 24, 5, torch.Generator(device="cpu").manual_seed(0)) assert latents.shape == (1, 2, 2, 2, 3) @@ -463,8 +480,20 @@ def test_prepare_latents_for_video_image_and_sound(make_cosmos3_pipeline) -> Non assert (sound_latents.shape, latent_frames) == (torch.Size([1, 3, 6]), 6) assert pipeline._decode_sound_latents(torch.zeros(1, 3, 6), target_audio_samples=21).shape == (1, 2, 21) + pipeline.transformer = pipeline.transformer.__class__(action_gen=True, action_dim=4) + action, action_mask, clean, raw_dim = pipeline._prepare_action_latents( + mode="forward_dynamics", + action_chunk_size=2, + raw_action_dim=None, + generator=torch.Generator(device="cpu").manual_seed(0), + sp=SimpleNamespace(extra_args={"action": [[1.0, 2.0], [3.0, 4.0]]}), + ) + assert raw_dim == 2 + assert action_mask.tolist() == [[[0.0], [0.0]]] + torch.testing.assert_close(action, clean) + -def test_diffuse_covers_cfg_i2v_and_sound_steps(make_cosmos3_pipeline) -> None: +def test_diffuse_covers_cfg_i2v_and_multimodal_steps(make_cosmos3_pipeline) -> None: pipeline = make_cosmos3_pipeline() latents = torch.zeros(1, 2, 1, 1, 1) @@ -496,20 +525,22 @@ def test_diffuse_covers_cfg_i2v_and_sound_steps(make_cosmos3_pipeline) -> None: ) torch.testing.assert_close(i2v[:, :, 0:1], torch.full((1, 2, 1, 1, 1), 7.0)) - pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) - video_result, sound_result = pipeline.diffuse( + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + video_result, action_result = pipeline.diffuse( latents=latents, - sound_latents=torch.zeros(1, 3, 4), + action_latents=torch.zeros(1, 3, 4), + action_velocity_mask=torch.ones(1, 3, 1), + action_condition_latents=torch.zeros(1, 3, 4), timesteps=torch.tensor([7, 3]), cond_ids=_ids(2), cond_mask=_mask(), uncond_ids=_ids(1), uncond_mask=_mask(), guidance_scale=1.0, - shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0}, + shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0, "action_domain_ids": torch.tensor([0])}, ) torch.testing.assert_close(video_result, torch.full_like(latents, 4.0)) - torch.testing.assert_close(sound_result, torch.full((), 24.0).expand_as(sound_result)) + torch.testing.assert_close(action_result, torch.full((), 44.0).expand_as(action_result)) def test_diffuse_keeps_paired_cfg_when_cache_dit_active(make_cosmos3_pipeline) -> None: @@ -568,6 +599,8 @@ def fake_prepare(height, width, num_frames, generator): def fake_diffuse(**kwargs): captured["diffuse_calls"].append(kwargs) outputs = [kwargs["latents"] + len(captured["diffuse_calls"])] + if kwargs.get("action_latents") is not None: + outputs.append(kwargs["action_latents"] + 3.0) if kwargs.get("sound_latents") is not None: outputs.append(kwargs["sound_latents"] + 2.0) return outputs[0] if len(outputs) == 1 else tuple(outputs) @@ -612,7 +645,7 @@ def test_forward_defaults_and_mode_selection( assert captured["flow_shifts"] == expected["flow"] assert [call[0] for call in pipeline.scheduler.set_timesteps_calls] == expected["steps"] - def test_forward_i2v_and_sound_routes(self, make_cosmos3_pipeline) -> None: + def test_forward_i2v_sound_and_action_routes(self, make_cosmos3_pipeline) -> None: pipeline = make_cosmos3_pipeline() captured = self._install_forward_stubs(pipeline) image_tensor = torch.zeros(1, 3, 16, 16) @@ -651,6 +684,31 @@ def test_forward_i2v_and_sound_routes(self, make_cosmos3_pipeline) -> None: assert captured["diffuse_calls"][-1]["sound_latents"] is sound_latents assert output.output["audio_sample_rate"] == 10 + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + output = pipeline.forward( + SimpleNamespace( + prompts=[ + { + "prompt": "Pick the block.", + "modalities": ["video"], + "additional_information": {"preprocessed_image": image_tensor}, + } + ], + sampling_params=make_sampling_params( + height=16, + width=16, + extra_args={ + "action_mode": "policy", + "action_chunk_size": 2, + "raw_action_dim": 2, + "domain_name": "bridge_orig_lerobot", + }, + ), + ) + ) + assert captured["diffuse_calls"][-1]["shared_kwargs"]["action_domain_ids"].tolist() == [7] + assert output.custom_output["action"].shape == (1, 2, 2) + @pytest.mark.parametrize( ("prompt", "sampling_params", "message"), [ diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py b/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py index 062cd8abf98..50a8d22adf5 100644 --- a/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py +++ b/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py @@ -30,8 +30,9 @@ def _tiny_cosmos3_config(**overrides): return config -def test_mrope_position_ids_cover_text_video_and_sound() -> None: +def test_mrope_position_ids_cover_text_video_sound_and_action() -> None: from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_action, compute_mrope_position_ids_sound, compute_mrope_position_ids_text, compute_mrope_position_ids_vision, @@ -62,6 +63,10 @@ def test_mrope_position_ids_cover_text_video_and_sound() -> None: torch.testing.assert_close(sound_ids[0], torch.tensor([10.0, 10.96, 11.92])) assert sound_offset == 12 + action_ids, action_offset = compute_mrope_position_ids_action(3, temporal_offset=10, action_fps=None) + assert action_ids.tolist() == [[11, 12, 13], [0, 0, 0], [0, 0, 0]] + assert action_offset == 14 + @pytest.mark.parametrize( ("key", "value"), @@ -126,7 +131,7 @@ def test_forward_returns_video_prediction(monkeypatch: pytest.MonkeyPatch) -> No assert tuple(output.shape) == (1, 2, 1, 2, 2) -def test_sound_modules_follow_injected_sound_dim() -> None: +def test_sound_and_action_modules_follow_config() -> None: from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer tiny = _tiny_cosmos3_config() @@ -137,12 +142,22 @@ def test_sound_modules_follow_injected_sound_dim() -> None: sound_dim=5, sound_latent_fps=40.0, ) + with_action = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "action_gen": True, "max_action_dim": 6, "num_embodiment_domains": 9}, + dtype=torch.float32, + ) + ) assert no_modal.sound_gen is False + assert no_modal.action_gen is False assert not hasattr(no_modal, "audio_proj_in") + assert not hasattr(no_modal, "action_proj_in") assert with_sound.sound_dim == 5 assert with_sound.sound_latent_fps == 40.0 assert with_sound.audio_proj_in.in_features == 5 + assert with_action.action_dim == 6 + assert with_action.action_proj_in.num_domains == 9 @pytest.mark.parametrize( @@ -163,33 +178,56 @@ def test_transformer_requires_sound_dim_and_fps_when_sound_gen_true(kwargs: dict ) -def test_sound_pack_unpack_validate_shapes() -> None: +def test_sound_and_action_pack_unpack_validate_shapes() -> None: from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer model = object.__new__(Cosmos3VFMTransformer) nn.Module.__init__(model) model.sound_dim = 3 + model.action_dim = 3 sound = torch.arange(2 * 3 * 4, dtype=torch.float32).reshape(2, 3, 4) + action = torch.arange(2 * 4 * 3, dtype=torch.float32).reshape(2, 4, 3) torch.testing.assert_close(model.unpack_sound(model.pack_sound(sound)), sound) + torch.testing.assert_close(model.unpack_action(model.pack_action(action)), action) with pytest.raises(ValueError, match="channel mismatch"): model.pack_sound(torch.zeros(1, 4, 2)) + with pytest.raises(ValueError, match="dimension mismatch"): + model.pack_action(torch.zeros(1, 2, 4)) -def test_forward_returns_video_and_sound_predictions(monkeypatch: pytest.MonkeyPatch) -> None: +@pytest.mark.parametrize( + ("config", "transformer_kwargs", "extra_kwargs", "expected_shapes"), + [ + ( + _tiny_cosmos3_config(), + {"sound_gen": True, "sound_dim": 3, "sound_latent_fps": 24.0}, + {"sound_latents": torch.zeros(1, 3, 4)}, + [(1, 2, 1, 2, 2), (1, 3, 4)], + ), + ( + _tiny_cosmos3_config(action_gen=True, max_action_dim=3, num_embodiment_domains=4), + {}, + {"action_latents": torch.zeros(1, 5, 3), "action_domain_ids": torch.tensor([2])}, + [(1, 2, 1, 2, 2), (1, 5, 3)], + ), + ], +) +def test_forward_returns_video_plus_optional_modality_predictions( + monkeypatch: pytest.MonkeyPatch, + config, + transformer_kwargs, + extra_kwargs, + expected_shapes, +) -> None: from vllm_omni.diffusion.models.cosmos3 import transformer_cosmos3 monkeypatch.setattr(transformer_cosmos3, "_get_ulysses_state", lambda: (1, 0, None)) output = transformer_cosmos3.Cosmos3VFMTransformer( - SimpleNamespace( - tf_model_config=_tiny_cosmos3_config(), - dtype=torch.float32, - ), - sound_gen=True, - sound_dim=3, - sound_latent_fps=40.0, + SimpleNamespace(tf_model_config=config, dtype=torch.float32), + **transformer_kwargs, )( hidden_states=torch.zeros(1, 2, 1, 2, 2), timestep=torch.tensor([1.0]), @@ -197,11 +235,12 @@ def test_forward_returns_video_and_sound_predictions(monkeypatch: pytest.MonkeyP text_mask=torch.ones(1, 2, dtype=torch.long), video_shape=(1, 2, 2), fps=24.0, - sound_latents=torch.zeros(1, 3, 4), + action_noisy_mask=torch.ones(1, 5, 1), + **extra_kwargs, ) assert isinstance(output, tuple) - assert [tuple(tensor.shape) for tensor in output] == [(1, 2, 1, 2, 2), (1, 3, 4)] + assert [tuple(tensor.shape) for tensor in output] == expected_shapes def test_forward_with_sound_ulysses_error_mentions_combined_sequence(monkeypatch: pytest.MonkeyPatch) -> None: @@ -227,7 +266,7 @@ def test_forward_with_sound_ulysses_error_mentions_combined_sequence(monkeypatch ) -def test_compute_rope_freqs_places_text_video_and_sound_positions() -> None: +def test_compute_rope_freqs_places_text_video_action_and_sound_positions() -> None: from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer class FakeRotary: @@ -275,9 +314,11 @@ def __call__(self, x, position_ids): fps=24.0, device=torch.device("cpu"), dtype=torch.float32, + t_action=2, + action_start_frame_offset=1, t_sound=1, ) _, gen_pos = rotary.position_ids - assert gen_pos.shape == (3, 1, 3) - assert gen_pos[0, 0].tolist() == [102, 103, 102] + assert gen_pos.shape == (3, 1, 5) + assert gen_pos[0, 0].tolist() == [102, 103, 103, 104, 102] diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index de1f14c7455..072fd25f191 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -627,6 +627,115 @@ async def _generate(prompt, request_id, sampling_params_list): assert completed["stage_durations"] == {"diffuse": 2.5, "vae.decode": 0.3} assert completed["peak_memory_mb"] == 4096.5 + assert completed["action"] is None + + +def test_video_generation_response_exposes_action_payload(mocker: MockerFixture): + engine = FakeAsyncOmni() + handler = OmniOpenAIServingVideo.for_diffusion( + diffusion_engine=engine, + model_name="Cosmos3-8B-UVA", + ) + + async def _generate(prompt, request_id, sampling_params_list): + del prompt, request_id, sampling_params_list + import numpy as np + + yield MockVideoResult( + [object()], + custom_output={ + "action": np.array([[[1.5, 2.5], [3.5, 4.5]]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + engine.generate = _generate + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video.encode_video_base64", + return_value="encoded-video", + ) + + response = asyncio.run( + handler.generate_videos( + VideoGenerationRequest(prompt="predict actions"), + "action-json", + ) + ) + + action = response.data[0].action + assert action is not None + assert action.data == [[1.5, 2.5], [3.5, 4.5]] + assert action.shape == [2, 2] + assert action.dtype == "float32" + assert action.raw_action_dim == 2 + assert action.action_mode == "policy" + assert action.domain_id == 7 + assert response.model_dump(mode="json")["data"][0]["action"]["data"] == [[1.5, 2.5], [3.5, 4.5]] + + +def test_video_job_persists_action_metadata(test_client, mocker: MockerFixture): + engine = test_client.app.state.openai_serving_video._engine_client + + async def _generate(prompt, request_id, sampling_params_list): + import numpy as np + + engine.captured_prompt = prompt + engine.captured_sampling_params_list = sampling_params_list + yield MockVideoResult( + [object()], + custom_output={ + "action": np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + engine.generate = _generate + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes", + return_value=b"fake-video", + ) + + response = test_client.post("/v1/videos", data={"prompt": "profile me"}) + assert response.status_code == 200 + video_id = response.json()["id"] + completed = _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + + expected_action = { + "data": [[1.0, 2.0], [3.0, 4.0]], + "shape": [2, 2], + "dtype": "float32", + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + } + assert completed["action"] == expected_action + + listed = test_client.get("/v1/videos").json() + assert listed["data"][0]["action"] == expected_action + + +def test_action_extraction_accepts_unbatched_action(): + import numpy as np + + result = MockVideoResult( + [object()], + custom_output={ + "action": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + actions = OmniOpenAIServingVideo._extract_action_outputs(result, expected_count=1) + + assert actions[0] is not None + assert actions[0].data == [[1.0, 2.0], [3.0, 4.0]] + assert actions[0].shape == [2, 2] def test_missing_handler_returns_503(): @@ -1038,6 +1147,31 @@ def test_sample_solver_forwarded_via_extra_params(test_client, mocker: MockerFix assert captured.extra_args["sample_solver"] == "euler" +def test_extra_params_allows_inline_action(test_client, mocker: MockerFixture): + """Inline ``action`` data is accepted and forwarded verbatim to + ``extra_args`` (the supported way to pass forward-dynamics actions).""" + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes", + return_value=b"fake-video", + ) + action = [[0.1, 0.2], [0.3, 0.4]] + response = test_client.post( + "/v1/videos", + data={ + "prompt": "forward dynamics inline", + "extra_params": json.dumps({"action_mode": "forward_dynamics", "action": action}), + }, + ) + + assert response.status_code == 200 + video_id = response.json()["id"] + _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + engine = test_client.app.state.openai_serving_video._engine_client + captured = engine.captured_sampling_params_list[0] + assert captured.extra_args["action"] == action + assert captured.extra_args["action_mode"] == "forward_dynamics" + + # --------------------------------------------------------------------------- # Sync endpoint tests (POST /v1/videos/sync) # --------------------------------------------------------------------------- diff --git a/vllm_omni/diffusion/models/cosmos3/action.py b/vllm_omni/diffusion/models/cosmos3/action.py new file mode 100644 index 00000000000..80a1bc7a9b8 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/action.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Action-token helpers for Cosmos3 UVA/action generation.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +ACTION_MODE_POLICY = "policy" +ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" +ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" +ACTION_MODES = { + ACTION_MODE_POLICY, + ACTION_MODE_FORWARD_DYNAMICS, + ACTION_MODE_INVERSE_DYNAMICS, +} + + +EMBODIMENT_TO_DOMAIN_ID: dict[str, int] = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + + +VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { + "256": { + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (320, 192), + "9,16": (192, 320), + }, + "480": { + "1,1": (640, 640), + "4,3": (736, 544), + "3,4": (544, 736), + "16,9": (832, 480), + "9,16": (480, 832), + }, + "704": { + "1,1": (960, 960), + "4,3": (1088, 832), + "3,4": (832, 1088), + "16,9": (1280, 704), + "9,16": (704, 1280), + }, + "720": { + "1,1": (960, 960), + "4,3": (1104, 832), + "3,4": (832, 1104), + "16,9": (1280, 720), + "9,16": (720, 1280), + }, +} + + +def normalize_action_mode(mode: Any) -> str | None: + if mode is None: + return None + normalized = str(mode).strip().lower() + if not normalized: + return None + if normalized not in ACTION_MODES: + raise ValueError(f"Unsupported Cosmos3 action_mode={mode!r}; expected one of {sorted(ACTION_MODES)}.") + return normalized + + +def resolve_domain_id( + *, + domain_id: Any = None, + domain_name: Any = None, + require_explicit: bool = False, +) -> int: + if domain_id is not None: + resolved = int(domain_id) + if resolved < 0: + raise ValueError(f"Cosmos3 domain_id must be non-negative, got {resolved}.") + return resolved + + if domain_name is None or str(domain_name).strip() == "": + if require_explicit: + raise ValueError( + "Cosmos3 action generation requires extra_args['domain_id'] or non-empty extra_args['domain_name']." + ) + return 0 + + key = str(domain_name).strip().lower() + if key not in EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(EMBODIMENT_TO_DOMAIN_ID)} or pass domain_id directly." + ) + return EMBODIMENT_TO_DOMAIN_ID[key] + + +def action_condition_indexes(mode: str, action_length: int) -> list[int]: + mode = normalize_action_mode(mode) + if mode == ACTION_MODE_FORWARD_DYNAMICS: + return list(range(action_length)) + if mode in {ACTION_MODE_POLICY, ACTION_MODE_INVERSE_DYNAMICS}: + return [] + raise AssertionError(f"Unexpected action mode: {mode!r}") + + +def vision_condition_indexes(mode: str, video_length: int, temporal_compression_factor: int) -> list[int]: + mode = normalize_action_mode(mode) + latent_frames = (video_length - 1) // temporal_compression_factor + 1 + if mode in {ACTION_MODE_POLICY, ACTION_MODE_FORWARD_DYNAMICS}: + return [0] + if mode == ACTION_MODE_INVERSE_DYNAMICS: + return list(range(latent_frames)) + raise AssertionError(f"Unexpected action mode: {mode!r}") + + +def action_start_frame_offset(mode: str, action_length: int, video_length: int) -> int: + del mode + if action_length == video_length - 1: + return 1 + if action_length == video_length: + return 0 + raise ValueError( + "Cosmos3 action_chunk_size must equal num_frames - 1 or num_frames; " + f"got action_chunk_size={action_length}, num_frames={video_length}." + ) + + +def build_action_condition_mask( + mode: str, + action_length: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.zeros(1, action_length, 1, device=device, dtype=dtype) + for idx in action_condition_indexes(mode, action_length): + mask[:, idx, :] = 1.0 + return mask + + +def build_vision_condition_mask( + mode: str, + video_length: int, + temporal_compression_factor: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + latent_frames = (video_length - 1) // temporal_compression_factor + 1 + mask = torch.zeros(1, 1, latent_frames, 1, 1, device=device, dtype=dtype) + for idx in vision_condition_indexes(mode, video_length, temporal_compression_factor): + mask[:, :, idx, :, :] = 1.0 + return mask + + +def pad_action_to_dim(action: torch.Tensor, action_dim: int) -> torch.Tensor: + if action.shape[-1] > action_dim: + raise ValueError(f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}.") + if action.shape[-1] == action_dim: + return action + padding = torch.zeros(*action.shape[:-1], action_dim - action.shape[-1], dtype=action.dtype, device=action.device) + return torch.cat([action, padding], dim=-1) + + +def load_action_tensor(action: Any = None) -> torch.Tensor: + if action is None: + raise ValueError( + "Cosmos3 forward_dynamics action mode requires extra_args['action'] " + "(a JSON array / nested list / tensor of shape [T, D])." + ) + if isinstance(action, torch.Tensor): + tensor = action.detach().to(dtype=torch.float32) + else: + tensor = torch.as_tensor(np.asarray(action), dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2: + raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.") + return tensor + + +def find_closest_target_size(h: int, w: int, resolution: str | int) -> tuple[int, int]: + key = str(resolution) + if key not in VIDEO_RES_SIZE_INFO: + raise ValueError( + f"Unknown Cosmos3 action resolution={resolution!r}; expected one of {sorted(VIDEO_RES_SIZE_INFO)}." + ) + input_ratio = h / w + best_size = None + best_diff = float("inf") + for cand_w, cand_h in VIDEO_RES_SIZE_INFO[key].values(): + diff = abs(input_ratio - cand_h / cand_w) + if diff < best_diff: + best_diff = diff + best_size = (cand_w, cand_h) + assert best_size is not None + return best_size diff --git a/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py b/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py index 5290e21204e..501f8ffab8d 100644 --- a/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py +++ b/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py @@ -46,6 +46,19 @@ from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from .action import ( + ACTION_MODE_FORWARD_DYNAMICS, + ACTION_MODE_INVERSE_DYNAMICS, + ACTION_MODE_POLICY, + action_start_frame_offset, + build_action_condition_mask, + build_vision_condition_mask, + find_closest_target_size, + load_action_tensor, + normalize_action_mode, + pad_action_to_dim, + resolve_domain_id, +) from .transformer_cosmos3 import Cosmos3VFMTransformer, resolve_sound_gen logger = init_logger(__name__) @@ -98,14 +111,67 @@ def get_cosmos3_pre_process_func(od_config: OmniDiffusionConfig): if is_guardrails_enabled(od_config): ensure_initialized(od_config) + def _extra_args(request: OmniDiffusionRequest) -> dict[str, Any]: + extra = getattr(getattr(request, "sampling_params", None), "extra_args", None) + return extra if isinstance(extra, dict) else {} + + def _request_action_mode(request: OmniDiffusionRequest) -> str | None: + return normalize_action_mode(_extra_args(request).get("action_mode")) + + def _set_action_size_from_image(request: OmniDiffusionRequest, image: PIL.Image.Image) -> tuple[int, int]: + sp = request.sampling_params + if sp.height is not None and sp.width is not None: + return int(sp.height), int(sp.width) + + extra = _extra_args(request) + resolution = extra.get("resolution", extra.get("image_size", 480)) + target_w, target_h = find_closest_target_size(image.height, image.width, resolution) + if sp.height is None: + sp.height = target_h + if sp.width is None: + sp.width = target_w + return int(sp.height), int(sp.width) + def _pil_to_rgb(value: Any) -> PIL.Image.Image: if isinstance(value, str): return PIL.Image.open(value).convert("RGB") if isinstance(value, PIL.Image.Image): return value.convert("RGB") - raise TypeError(f"Cosmos3 preprocessing expected PIL image or image path, got {type(value)!r}.") + raise TypeError(f"Cosmos3 action preprocessing expected PIL image or image path, got {type(value)!r}.") + + def _resize_and_pad_action_image(image: PIL.Image.Image, target_h: int, target_w: int) -> PIL.Image.Image: + scale = min(target_w / image.width, target_h / image.height, 1.0) + resize_w = max(1, int(scale * image.width + 0.5)) + resize_h = max(1, int(scale * image.height + 0.5)) + if (resize_w, resize_h) != image.size: + image = image.resize((resize_w, resize_h), PIL.Image.Resampling.BICUBIC) + + array = np.asarray(image) + pad_h = target_h - resize_h + pad_w = target_w - resize_w + if pad_h < 0 or pad_w < 0: + raise ValueError( + f"Cosmos3 action image resize exceeded target size: resized={(resize_h, resize_w)}, " + f"target={(target_h, target_w)}." + ) + if pad_h == 0 and pad_w == 0: + return image + pad_mode = "reflect" if pad_h < resize_h and pad_w < resize_w else "edge" + padded = np.pad(array, ((0, pad_h), (0, pad_w), (0, 0)), mode=pad_mode) + return PIL.Image.fromarray(padded) + + def _preprocess_action_image(image: PIL.Image.Image, target_h: int, target_w: int) -> torch.Tensor: + image = _resize_and_pad_action_image(image, target_h, target_w) + return video_processor.preprocess(image, height=target_h, width=target_w) + + def _preprocess_action_video(frames: list[Any], target_h: int, target_w: int) -> torch.Tensor: + if not frames: + raise ValueError("Cosmos3 action video input must contain at least one frame.") + processed = [_preprocess_action_image(_pil_to_rgb(frame), target_h, target_w).squeeze(0) for frame in frames] + return torch.stack(processed, dim=1).unsqueeze(0).contiguous() def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + action_mode = _request_action_mode(request) if is_guardrails_enabled(od_config, request.sampling_params): for prompt in request.prompts: text = prompt if isinstance(prompt, str) else prompt.get("prompt", "") @@ -116,39 +182,63 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: continue multi_modal_data = prompt.get("multi_modal_data", {}) or {} raw_image = multi_modal_data.get("image") - if raw_image is None: + raw_video = multi_modal_data.get("video") + if raw_image is None and not (action_mode is not None and raw_video is not None): continue if "additional_information" not in prompt: prompt["additional_information"] = {} - image = _pil_to_rgb(raw_image) + if raw_image is None: + if not isinstance(raw_video, list) or not raw_video: + raise TypeError("Cosmos3 action video input must be a non-empty list of PIL images or image paths.") + image = _pil_to_rgb(raw_video[0]) + else: + image = _pil_to_rgb(raw_image) # Auto-calculate H/W from aspect ratio (720p max area) if request.sampling_params.height is None or request.sampling_params.width is None: - max_area = 720 * 1280 - aspect_ratio = image.height / image.width - mod_value = 16 - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - if request.sampling_params.height is None: - request.sampling_params.height = height - if request.sampling_params.width is None: - request.sampling_params.width = width + if action_mode is not None: + _set_action_size_from_image(request, image) + else: + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + mod_value = 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width target_w = request.sampling_params.width target_h = request.sampling_params.height - scale = max(target_w / image.width, target_h / image.height) - resize_w = int(np.ceil(scale * image.width)) - resize_h = int(np.ceil(scale * image.height)) - image = image.resize((resize_w, resize_h), PIL.Image.Resampling.LANCZOS) - left = (resize_w - target_w) // 2 - top = (resize_h - target_h) // 2 - image = image.crop((left, top, left + target_w, top + target_h)) - - prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( - image, height=target_h, width=target_w - ) + if action_mode is not None: + prompt["additional_information"]["preprocessed_image"] = _preprocess_action_image( + image, + int(target_h), + int(target_w), + ) + else: + scale = max(target_w / image.width, target_h / image.height) + resize_w = int(np.ceil(scale * image.width)) + resize_h = int(np.ceil(scale * image.height)) + image = image.resize((resize_w, resize_h), PIL.Image.Resampling.LANCZOS) + left = (resize_w - target_w) // 2 + top = (resize_h - target_h) // 2 + image = image.crop((left, top, left + target_w, top + target_h)) + + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=target_h, width=target_w + ) + if action_mode is not None and raw_video is not None: + if not isinstance(raw_video, list): + raise TypeError("Cosmos3 action video input must be a list of PIL images or image paths.") + prompt["additional_information"]["preprocessed_video"] = _preprocess_action_video( + raw_video, + int(target_h), + int(target_w), + ) request.prompts[i] = prompt return request @@ -421,11 +511,17 @@ def _remap_ckpt_key(key: str) -> str | None: "time_embedder.", "audio_proj_in.", "audio_proj_out.", + "action_proj_in.", + "action_proj_out.", ) ): return f"transformer.{k}" if k in ("audio_modality_embed", "audio_modality_embed.weight"): return "transformer.audio_modality_embed" + if k in ("action_modality_embed", "action_modality_embed.weight"): + return "transformer.action_modality_embed" + if k.startswith("action_pos_embed."): + return None # Skip lm_head if k.startswith("lm_head."): @@ -528,13 +624,22 @@ def _remapped_weights() -> Iterable[tuple[str, torch.Tensor]]: f"the checkpoint is missing sound weights for {missing}. " "Use a sound-capable transformer checkpoint." ) + if getattr(self.transformer, "action_gen", False): + action_markers = ("action_proj_in.", "action_proj_out.", "action_modality_embed") + missing = [marker.rstrip(".") for marker in action_markers if not any(marker in name for name in loaded)] + if missing: + raise ValueError( + "Cosmos3 transformer config enables action generation, but " + f"the checkpoint is missing action weights for {missing}. " + "Use an action-capable transformer checkpoint." + ) return loaded def predict_noise(self, **kwargs) -> torch.Tensor | tuple[torch.Tensor, ...]: """Override CFGParallelMixin.predict_noise for Cosmos3. The transformer returns the raw prediction: video-only as a tensor, - or a tuple in video, sound order for sound generation. + or a tuple in video, action, sound order for multimodal generation. """ return self.transformer(**kwargs) @@ -611,6 +716,12 @@ def _is_sound_request(cls, prompt_data, sp) -> bool: return True return False + @classmethod + def _get_action_mode(cls, prompt_data, sp) -> str | None: + return normalize_action_mode( + cls._get_sp_param(sp, "action_mode", cls._get_prompt_param(prompt_data, "action_mode", None)) + ) + def _get_sound_tokenizer(self): if self._sound_tokenizer is None: from .sound_tokenizer import Cosmos3SoundTokenizer @@ -1017,6 +1128,30 @@ def _encode_conditioning_video( return latent.to(self.dtype) + def _encode_video_tensor(self, video_tensor: torch.Tensor) -> torch.Tensor: + """VAE-encode a preprocessed pixel video [1, 3, T, H, W].""" + if video_tensor.ndim == 4: + video_tensor = video_tensor.unsqueeze(0) + if video_tensor.ndim != 5: + raise ValueError(f"Cosmos3 video tensor must have shape [1, 3, T, H, W], got {tuple(video_tensor.shape)}.") + if video_tensor.shape[0] != 1 or video_tensor.shape[1] != 3: + raise ValueError(f"Cosmos3 video tensor must have shape [1, 3, T, H, W], got {tuple(video_tensor.shape)}.") + + video = video_tensor.to(device=self.device, dtype=self.vae.dtype) + latent = self.vae.encode(video).latent_dist.mode() + + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(latent.device, latent.dtype) + latent = (latent - latents_mean) / latents_std + else: + scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) + latent = latent * scaling_factor + + return latent.to(self.dtype) + def _prepare_latents_i2v( self, image_tensor: torch.Tensor, @@ -1053,6 +1188,95 @@ def _prepare_latents_i2v( velocity_mask = 1.0 - condition_mask return latents, velocity_mask, image_latent + def _prepare_latents_action_video( + self, + video_tensor: torch.Tensor, + mode: str, + height: int, + width: int, + num_frames: int, + generator: torch.Generator, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepare video latents for action modes with mode-specific conditioning.""" + del height, width + C = self.transformer.latent_channel_size + T_lat = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + H_lat = video_tensor.shape[-2] // self.vae_scale_factor_spatial + W_lat = video_tensor.shape[-1] // self.vae_scale_factor_spatial + + noise = randn_tensor( + (1, C, T_lat, H_lat, W_lat), + generator=generator, + device=self.device, + dtype=self.dtype, + ) + cond_latent = self._encode_video_tensor(video_tensor) + if cond_latent.shape[2:] != noise.shape[2:]: + raise ValueError( + "Cosmos3 action video latent shape mismatch: " + f"encoded={tuple(cond_latent.shape)}, expected={tuple(noise.shape)}." + ) + condition_mask = build_vision_condition_mask( + mode, + num_frames, + self.vae_scale_factor_temporal, + device=self.device, + dtype=self.dtype, + ) + latents = condition_mask * cond_latent + (1.0 - condition_mask) * noise + velocity_mask = 1.0 - condition_mask + return latents, velocity_mask, cond_latent + + def _prepare_action_latents( + self, + *, + mode: str, + action_chunk_size: int, + raw_action_dim: int | None, + generator: torch.Generator, + sp, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + action_dim = int(getattr(self.transformer, "action_dim", 64)) + if mode == ACTION_MODE_FORWARD_DYNAMICS: + action = load_action_tensor(self._get_sp_param(sp, "action", None)) + if action.shape[0] < action_chunk_size: + pad = action[-1:].repeat(action_chunk_size - action.shape[0], 1) + action = torch.cat([action, pad], dim=0) + elif action.shape[0] > action_chunk_size: + action = action[:action_chunk_size] + if raw_action_dim is None: + raw_action_dim = int(action.shape[-1]) + clean_action = pad_action_to_dim(action, action_dim) + else: + if raw_action_dim is None: + raise ValueError( + "Cosmos3 action_mode='policy' and 'inverse_dynamics' require extra_args['raw_action_dim']." + ) + clean_action = torch.zeros(action_chunk_size, action_dim, dtype=torch.float32) + + raw_action_dim = int(raw_action_dim) + if raw_action_dim <= 0 or raw_action_dim > action_dim: + raise ValueError(f"Cosmos3 raw_action_dim must be in [1, {action_dim}], got {raw_action_dim}.") + + clean_action = clean_action.to(device=self.device, dtype=self.dtype).unsqueeze(0) + condition_mask = build_action_condition_mask( + mode, + action_chunk_size, + device=self.device, + dtype=self.dtype, + ) + noise = randn_tensor( + (1, action_chunk_size, action_dim), + generator=generator, + device=self.device, + dtype=self.dtype, + ) + noise[:, :, raw_action_dim:] = 0 + clean_action[:, :, raw_action_dim:] = 0 + action_latents = condition_mask * clean_action + (1.0 - condition_mask) * noise + action_velocity_mask = 1.0 - condition_mask + return action_latents, action_velocity_mask, clean_action, raw_action_dim + # -- Denoising loop (shared by T2V and I2V) ----------------------------- def diffuse( @@ -1066,11 +1290,15 @@ def diffuse( guidance_scale: float, shared_kwargs: dict, *, + action_latents: torch.Tensor | None = None, + action_velocity_mask: torch.Tensor | None = None, + action_condition_latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, velocity_mask: torch.Tensor | None = None, image_latent: torch.Tensor | None = None, condition_latents: torch.Tensor | None = None, guidance_interval: tuple[float, float] | None = None, + raw_action_dim: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Denoising loop with 3-mode CFG support (parallel, sequential, none). @@ -1115,10 +1343,13 @@ def _cfg_active_at(t: torch.Tensor) -> bool: # scheduler with cross-element dependencies (e.g. per-modality timestep). def _pack_joint( video_tensor: torch.Tensor, + action_tensor: torch.Tensor | None = None, sound_tensor: torch.Tensor | None = None, ): batch = video_tensor.shape[0] tensors = [video_tensor] + if action_tensor is not None: + tensors.append(action_tensor) if sound_tensor is not None: tensors.append(sound_tensor) flats = [tensor.reshape(batch, -1) for tensor in tensors] @@ -1138,57 +1369,86 @@ def _unpack_joint( def _split_noise_pred( noise_pred: torch.Tensor | tuple[torch.Tensor, ...], - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + has_action = action_latents is not None has_sound = sound_latents is not None - if not has_sound: + if not has_action and not has_sound: if isinstance(noise_pred, tuple): raise ValueError("Cosmos3 video-only diffusion received tuple predictions.") - return noise_pred, None + return noise_pred, None, None if not isinstance(noise_pred, tuple): raise ValueError("Cosmos3 multimodal diffusion expects transformer predictions as a tuple.") - if len(noise_pred) != 2: - raise ValueError(f"Cosmos3 sound diffusion expected 2 predictions, got {len(noise_pred)}.") - return noise_pred[0], noise_pred[1] + expected = 1 + int(has_action) + int(has_sound) + if len(noise_pred) != expected: + raise ValueError( + f"Cosmos3 multimodal diffusion expected {expected} predictions, got {len(noise_pred)}." + ) + video_pred = noise_pred[0] + idx = 1 + action_pred = noise_pred[idx] if has_action else None + if has_action: + idx += 1 + sound_pred = noise_pred[idx] if has_sound else None + return video_pred, action_pred, sound_pred def _step( noise_pred: torch.Tensor | tuple[torch.Tensor, ...], t: torch.Tensor, latents: torch.Tensor, + action_latents: torch.Tensor | None, sound_latents: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: - video_pred, sound_pred = _split_noise_pred(noise_pred) + video_pred, action_pred, sound_pred = _split_noise_pred(noise_pred) if velocity_mask is not None: video_pred = video_pred * velocity_mask - if sound_latents is None: + if action_pred is not None and action_velocity_mask is not None: + action_pred = action_pred * action_velocity_mask + if raw_action_dim is not None and 0 < raw_action_dim < action_pred.shape[-1]: + action_pred[..., raw_action_dim:] = 0 + if action_latents is None and sound_latents is None: latents = self.scheduler.step(video_pred, t, latents, return_dict=False)[0] else: - packed_noise, shapes, numels = _pack_joint(video_pred, sound_pred) - packed_latents, _, _ = _pack_joint(latents, sound_latents) + packed_noise, shapes, numels = _pack_joint(video_pred, action_pred, sound_pred) + packed_latents, _, _ = _pack_joint(latents, action_latents, sound_latents) packed_next = self.scheduler.step(packed_noise, t, packed_latents, return_dict=False)[0] unpacked = _unpack_joint(packed_next, shapes, numels) latents = unpacked[0] + idx = 1 + if action_latents is not None: + action_latents = unpacked[idx] + idx += 1 if sound_latents is not None: - sound_latents = unpacked[1] + sound_latents = unpacked[idx] if condition_latents is not None and velocity_mask is not None: latents = velocity_mask * latents + (1.0 - velocity_mask) * condition_latents elif image_latent is not None: latents[:, :, 0:1, :, :] = image_latent + if action_latents is not None and action_condition_latents is not None and action_velocity_mask is not None: + action_latents = ( + action_velocity_mask * action_latents + (1.0 - action_velocity_mask) * action_condition_latents + ) outputs = [latents] + if action_latents is not None: + outputs.append(action_latents) if sound_latents is not None: outputs.append(sound_latents) return outputs[0] if len(outputs) == 1 else tuple(outputs) def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: - nonlocal latents, sound_latents - if sound_latents is None: + nonlocal latents, action_latents, sound_latents + if action_latents is None and sound_latents is None: assert isinstance(step_out, torch.Tensor) latents = step_out return if not isinstance(step_out, tuple): raise ValueError("Cosmos3 multimodal diffusion step returned a non-tuple result.") latents = step_out[0] + idx = 1 + if action_latents is not None: + action_latents = step_out[idx] + idx += 1 if sound_latents is not None: - sound_latents = step_out[1] + sound_latents = step_out[idx] if cfg_parallel: for t in self.progress_bar(timesteps): @@ -1206,6 +1466,7 @@ def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: timestep=timestep, text_ids=cond_ids, text_mask=cond_mask, + action_latents=action_latents, sound_latents=sound_latents, **shared_kwargs, ), @@ -1214,12 +1475,13 @@ def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: timestep=timestep, text_ids=uncond_ids, text_mask=uncond_mask, + action_latents=action_latents, sound_latents=sound_latents, **shared_kwargs, ), cfg_normalize=False, ) - _assign_step_out(_step(noise_pred, t, latents, sound_latents)) + _assign_step_out(_step(noise_pred, t, latents, action_latents, sound_latents)) elif do_cfg: cond_cache: tuple = (None, None) @@ -1237,6 +1499,7 @@ def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: timestep=timestep, text_ids=cond_ids, text_mask=cond_mask, + action_latents=action_latents, sound_latents=sound_latents, **shared_kwargs, ) @@ -1250,6 +1513,7 @@ def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: timestep=timestep, text_ids=uncond_ids, text_mask=uncond_mask, + action_latents=action_latents, sound_latents=sound_latents, **shared_kwargs, ) @@ -1263,7 +1527,7 @@ def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: else: noise_pred = noise_cond - _assign_step_out(_step(noise_pred, t, latents, sound_latents)) + _assign_step_out(_step(noise_pred, t, latents, action_latents, sound_latents)) else: for t in self.progress_bar(timesteps): @@ -1273,12 +1537,15 @@ def _assign_step_out(step_out: torch.Tensor | tuple[torch.Tensor, ...]) -> None: timestep=timestep, text_ids=cond_ids, text_mask=cond_mask, + action_latents=action_latents, sound_latents=sound_latents, **shared_kwargs, ) - _assign_step_out(_step(noise_pred, t, latents, sound_latents)) + _assign_step_out(_step(noise_pred, t, latents, action_latents, sound_latents)) outputs = [latents] + if action_latents is not None: + outputs.append(action_latents) if sound_latents is not None: outputs.append(sound_latents) return outputs[0] if len(outputs) == 1 else tuple(outputs) @@ -1300,15 +1567,29 @@ def forward( prompt = prompt_data negative_prompt = None image_tensor = None + action_video_tensor = None else: prompt = prompt_data.get("prompt", "") negative_prompt = prompt_data.get("negative_prompt") additional_info = prompt_data.get("additional_information", {}) or {} image_tensor = additional_info.get("preprocessed_image") + action_video_tensor = additional_info.get("preprocessed_video") sp = req.sampling_params is_t2i = self._is_t2i_request(req) sound_enabled = self._is_sound_request(prompt_data, sp) + action_mode = self._get_action_mode(prompt_data, sp) + action_enabled = action_mode is not None + if action_enabled and is_t2i: + raise ValueError("Cosmos3 action generation is supported only for video outputs.") + if action_enabled and sound_enabled: + raise ValueError("Cosmos3 action+sound joint generation is not supported in this phase.") + if action_enabled and not getattr(self.transformer, "action_gen", False): + raise ValueError( + "Cosmos3 action generation was requested, but the transformer was " + "initialized without action modules. Check that the checkpoint config " + "enables action_gen and includes action weights." + ) if sound_enabled and is_t2i: raise ValueError( "Cosmos3 sound generation is supported only for video outputs in " @@ -1349,6 +1630,36 @@ def forward( default_guidance_interval = None batch_size = 1 # Existing video pipeline assumes B=1. + if action_enabled: + action_chunk_param = self._get_sp_param(sp, "action_chunk_size", None) + if action_chunk_param is not None: + action_chunk_size = int(action_chunk_param) + if sp.num_frames is None: + num_frames = action_chunk_size + 1 + elif sp.num_frames is None: + action_chunk_size = 16 + num_frames = action_chunk_size + 1 + else: + action_chunk_size = int(num_frames) - 1 + if action_chunk_size <= 0: + raise ValueError(f"Cosmos3 action_chunk_size must be positive, got {action_chunk_size}.") + if num_frames not in (action_chunk_size, action_chunk_size + 1): + raise ValueError( + "Cosmos3 action requests require num_frames to equal action_chunk_size " + f"or action_chunk_size + 1; got num_frames={num_frames}, action_chunk_size={action_chunk_size}." + ) + num_inference_steps = sp.num_inference_steps or 30 + guidance_scale = sp.guidance_scale if sp.guidance_scale is not None else 1.0 + default_flow_shift = 5.0 + + domain_id = None + if action_enabled: + domain_id = resolve_domain_id( + domain_id=self._get_sp_param(sp, "domain_id", None), + domain_name=self._get_sp_param(sp, "domain_name", None), + require_explicit=True, + ) + # Runtime controls: prefer ``extra_args`` (OpenAI endpoints write # there) over direct attrs. flow_shift_target = float(self._get_sp_param(sp, "flow_shift", default_flow_shift)) @@ -1361,6 +1672,23 @@ def forward( ) use_system_prompt = bool(self._get_sp_param(sp, "use_system_prompt", False)) + if action_enabled and action_video_tensor is None: + extra_action_video = self._get_sp_param(sp, "action_video", None) + if isinstance(extra_action_video, torch.Tensor): + action_video_tensor = extra_action_video + if action_enabled and isinstance(action_video_tensor, torch.Tensor): + if action_video_tensor.ndim == 4: + action_video_tensor = action_video_tensor.unsqueeze(0) + if action_video_tensor.ndim != 5: + raise ValueError( + "Cosmos3 extra_args['action_video'] must have shape [1, 3, T, H, W] " + f"or [3, T, H, W], got {tuple(action_video_tensor.shape)}." + ) + if sp.height is None: + height = int(action_video_tensor.shape[-2]) + if sp.width is None: + width = int(action_video_tensor.shape[-1]) + self._guidance_scale = guidance_scale self._num_timesteps = num_inference_steps @@ -1396,7 +1724,58 @@ def forward( # batching B=N together would require expanding text K/V (UND # pathway is text-only and cached) and is left as a future # optimization. - if image_tensor is not None and not is_t2i: + action_latents = None + action_velocity_mask = None + action_condition_latents = None + raw_action_dim = None + action_offset = 1 + if action_enabled: + if action_video_tensor is not None and action_video_tensor.ndim == 4: + action_video_tensor = action_video_tensor.unsqueeze(0) + if action_video_tensor is not None and action_video_tensor.ndim != 5: + raise ValueError( + "Cosmos3 action video tensor must have shape [1, 3, T, H, W] " + f"or [3, T, H, W], got {tuple(action_video_tensor.shape)}." + ) + if action_video_tensor is not None and action_video_tensor.shape[2] < num_frames: + pad = action_video_tensor[:, :, -1:].repeat(1, 1, num_frames - action_video_tensor.shape[2], 1, 1) + action_video_tensor = torch.cat([action_video_tensor, pad], dim=2) + elif action_video_tensor is not None and action_video_tensor.shape[2] > num_frames: + action_video_tensor = action_video_tensor[:, :, :num_frames] + + if action_mode == ACTION_MODE_INVERSE_DYNAMICS and action_video_tensor is None: + raise ValueError("Cosmos3 inverse_dynamics action mode requires multi_modal_data['video'].") + if action_mode in {ACTION_MODE_POLICY, ACTION_MODE_FORWARD_DYNAMICS} and image_tensor is None: + if action_video_tensor is None: + raise ValueError( + f"Cosmos3 action_mode={action_mode!r} requires multi_modal_data['image'] " + "or multi_modal_data['video']." + ) + image_tensor = action_video_tensor[:, :, 0] + + raw_action_dim_param = self._get_sp_param(sp, "raw_action_dim", None) + raw_action_dim = int(raw_action_dim_param) if raw_action_dim_param is not None else None + action_prepared = self._prepare_action_latents( + mode=action_mode, + action_chunk_size=action_chunk_size, + raw_action_dim=raw_action_dim, + generator=generator, + sp=sp, + ) + action_latents, action_velocity_mask, action_condition_latents, raw_action_dim = action_prepared + action_offset = action_start_frame_offset(action_mode, action_chunk_size, num_frames) + + if action_enabled and action_video_tensor is not None: + latents, velocity_mask, condition_latents = self._prepare_latents_action_video( + action_video_tensor, + action_mode, + height, + width, + num_frames, + generator, + ) + image_latent = condition_latents[:, :, 0:1] + elif image_tensor is not None and not is_t2i: latents, velocity_mask, image_latent = self._prepare_latents_i2v( image_tensor, height, @@ -1427,6 +1806,13 @@ def forward( shared_kwargs = dict(video_shape=video_shape, fps=frame_rate) if velocity_mask is not None: shared_kwargs["noisy_frame_mask"] = velocity_mask + if action_enabled: + shared_kwargs.update( + action_domain_ids=torch.tensor([domain_id], dtype=torch.long, device=self.device), + action_noisy_mask=action_velocity_mask, + action_start_frame_offset=action_offset, + action_fps=float(self._get_sp_param(sp, "action_fps", frame_rate) or frame_rate), + ) def _run_diffusion(start_latents): self.scheduler.set_timesteps(num_inference_steps, device=self.device) @@ -1439,11 +1825,15 @@ def _run_diffusion(start_latents): uncond_mask=uncond_mask, guidance_scale=guidance_scale, shared_kwargs=shared_kwargs, + action_latents=action_latents, + action_velocity_mask=action_velocity_mask, + action_condition_latents=action_condition_latents, sound_latents=sound_latents, velocity_mask=velocity_mask, image_latent=image_latent, condition_latents=condition_latents, guidance_interval=guidance_interval, + raw_action_dim=raw_action_dim, ) if is_t2i and batch_size > 1: @@ -1461,7 +1851,11 @@ def _run_diffusion(start_latents): latents = torch.cat(samples, dim=0) else: diffusion_output = _run_diffusion(latents) - if sound_enabled: + if action_enabled and sound_enabled: + latents, action_latents, sound_latents = diffusion_output + elif action_enabled: + latents, action_latents = diffusion_output + elif sound_enabled: latents, sound_latents = diffusion_output else: latents = diffusion_output @@ -1483,4 +1877,18 @@ def _run_diffusion(start_latents): audio = self._decode_sound_latents(sound_latents, target_audio_samples) return DiffusionOutput(output={"video": video, "audio": audio, "audio_sample_rate": sound_sample_rate}) + if action_enabled: + if action_latents is None or raw_action_dim is None or domain_id is None: + raise ValueError("Cosmos3 action generation finished without action latents.") + action = action_latents[:, :, :raw_action_dim].detach().cpu() + return DiffusionOutput( + output={"video": video}, + custom_output={ + "action": action, + "raw_action_dim": raw_action_dim, + "action_mode": action_mode, + "domain_id": domain_id, + }, + ) + return DiffusionOutput(output={"image": video} if is_t2i else {"video": video}) diff --git a/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py b/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py index 5ff2683fdda..c777776947f 100644 --- a/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py +++ b/vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py @@ -136,6 +136,47 @@ def resolve_sound_gen(od_config: Any) -> bool: return False +class DomainAwareLinear(nn.Module): + """Linear projection with one weight/bias pair per action embodiment domain.""" + + def __init__( + self, + input_size: int, + output_size: int, + num_domains: int, + *, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self.input_size = int(input_size) + self.output_size = int(output_size) + self.num_domains = int(num_domains) + self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size, dtype=dtype) + self.bias = nn.Embedding(self.num_domains, self.output_size, dtype=dtype) + nn.init.xavier_uniform_(self.fc.weight) + nn.init.zeros_(self.bias.weight) + + def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: + if domain_id.ndim == 0: + domain_id = domain_id.unsqueeze(0) + domain_id = domain_id.to(device=x.device, dtype=torch.long).reshape(-1) + if x.shape[0] != domain_id.shape[0]: + raise ValueError( + "Cosmos3 action domain_id batch size must match action tokens: " + f"tokens={x.shape[0]}, domain_id={domain_id.shape[0]}." + ) + if torch.any((domain_id < 0) | (domain_id >= self.num_domains)): + raise ValueError(f"Cosmos3 action domain_id must be in [0, {self.num_domains}), got {domain_id.tolist()}.") + + weight = self.fc(domain_id).view(domain_id.shape[0], self.input_size, self.output_size) + bias = self.bias(domain_id).view(domain_id.shape[0], self.output_size) + if x.ndim == 2: + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias + if x.ndim == 3: + return torch.bmm(x, weight) + bias.unsqueeze(1) + raise ValueError(f"Cosmos3 DomainAwareLinear expected rank-2 or rank-3 input, got {tuple(x.shape)}.") + + # --------------------------------------------------------------------------- # Rotary Position Embeddings (mRoPE) # --------------------------------------------------------------------------- @@ -163,6 +204,7 @@ def compute_mrope_position_ids_vision( temporal_compression_factor: int = 4, base_temporal_compression_factor: int | None = None, enable_fps_modulation: bool = True, + start_frame_offset: int = 0, ) -> tuple[torch.Tensor, int | float]: """Generate 3D mRoPE position IDs for vision tokens. @@ -180,10 +222,17 @@ def compute_mrope_position_ids_vision( ) base_tps = base_fps / effective_base_tcf frame_indices = torch.arange(grid_t, dtype=torch.float32) - t_index = (frame_indices / tps * base_tps + temporal_offset).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + t_index = ( + ((frame_indices + start_frame_offset) / tps * base_tps + temporal_offset) + .view(-1, 1) + .expand(-1, grid_h * grid_w) + .flatten() + ) else: - t_index = torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + int( - temporal_offset + t_index = ( + torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + + int(temporal_offset) + + start_frame_offset ) h_index = torch.arange(grid_h, dtype=torch.long).view(1, -1, 1).expand(grid_t, -1, grid_w).flatten() @@ -220,6 +269,30 @@ def compute_mrope_position_ids_sound( ) +def compute_mrope_position_ids_action( + grid_t: int, + temporal_offset: int | float, + action_fps: float | None, + base_fps: float = 24.0, + base_temporal_compression_factor: int = 4, + enable_fps_modulation: bool = True, + start_frame_offset: int = 1, +) -> tuple[torch.Tensor, int | float]: + """Generate mRoPE IDs for action tokens as a frame-rate (T, 1, 1) grid.""" + return compute_mrope_position_ids_vision( + grid_t=grid_t, + grid_h=1, + grid_w=1, + temporal_offset=temporal_offset, + fps=action_fps, + base_fps=base_fps, + temporal_compression_factor=1, + base_temporal_compression_factor=base_temporal_compression_factor, + enable_fps_modulation=enable_fps_modulation, + start_frame_offset=start_frame_offset, + ) + + class Qwen3VLTextRotaryEmbedding(nn.Module): """Multi-dimensional rotary position embedding for Qwen3-VL.""" @@ -947,19 +1020,28 @@ def __init__( self.sound_gen = sound_gen self.sound_dim = sound_dim self.sound_latent_fps = sound_latent_fps - if self.sound_gen and (sound_dim is None or sound_latent_fps is None): raise ValueError( "Cosmos3VFMTransformer requires an explicit sound_dim and sound_latent_fps when sound_gen is True; " "the pipeline must pass Cosmos3SoundTokenizer.latent_ch so the audio projection " "layers are sized from the authoritative AVAE latent width." ) + action_gen_value = _od_config_get(od_config, "action_gen", None) + action_dim_value = _od_config_get(od_config, "action_dim", None) + if action_dim_value is None: + action_dim_value = _od_config_get(od_config, "max_action_dim", None) + self.action_gen = _as_bool(action_gen_value) if action_gen_value is not None else False + self.action_dim = int(action_dim_value if action_dim_value is not None else 64) + self.num_embodiment_domains = int(_od_config_get(od_config, "num_embodiment_domains", 32)) if temporal_compression_factor is None: temporal_compression_factor = _tf_config_get(model_config, "temporal_compression_factor", 4) self.temporal_compression_factor = int(temporal_compression_factor) self.temporal_compression_factor_sound = int( _tf_config_get(model_config, "temporal_compression_factor_sound", 1) ) + self.temporal_compression_factor_sound = int( + _tf_config_get(model_config, "temporal_compression_factor_sound", 1) + ) self.enable_fps_modulation = bool(_tf_config_get(model_config, "enable_fps_modulation", True)) self.temporal_modality_margin = int( _tf_config_get( @@ -992,6 +1074,20 @@ def __init__( self.proj_in = nn.Linear(self.patch_latent_dim, self.hidden_size) self.proj_out = nn.Linear(self.hidden_size, self.patch_latent_dim) self.time_embedder = TimestepEmbedder(self.hidden_size, target_dtype=dtype) + if self.action_gen: + self.action_proj_in = DomainAwareLinear( + self.action_dim, + self.hidden_size, + self.num_embodiment_domains, + dtype=dtype, + ) + self.action_proj_out = DomainAwareLinear( + self.hidden_size, + self.action_dim, + self.num_embodiment_domains, + dtype=dtype, + ) + self.action_modality_embed = nn.Parameter(torch.zeros(self.hidden_size, dtype=dtype)) if self.sound_gen: self.audio_proj_in = nn.Linear(self.sound_dim, self.hidden_size) self.audio_proj_out = nn.Linear(self.hidden_size, self.sound_dim) @@ -1079,6 +1175,21 @@ def unpack_sound(tokens: torch.Tensor) -> torch.Tensor: """[B, T_sound, C_sound] -> [B, C_sound, T_sound].""" return tokens.permute(0, 2, 1).contiguous() + def pack_action(self, action_latents: torch.Tensor) -> torch.Tensor: + """Validate and return action latents as [B, T_action, D_action] tokens.""" + if action_latents.ndim != 3: + raise ValueError(f"Cosmos3 action latents must have shape [B, T, D], got {tuple(action_latents.shape)}.") + if action_latents.shape[-1] != self.action_dim: + raise ValueError( + f"Cosmos3 action latent dimension mismatch: expected {self.action_dim}, got {action_latents.shape[-1]}." + ) + return action_latents.contiguous() + + @staticmethod + def unpack_action(tokens: torch.Tensor) -> torch.Tensor: + """Return [B, T_action, D_action] action predictions.""" + return tokens.contiguous() + # -- RoPE computation ---------------------------------------------------- def _compute_rope_freqs( @@ -1090,6 +1201,9 @@ def _compute_rope_freqs( fps: float | None, device: torch.device, dtype: torch.dtype, + t_action: int | None = None, + action_start_frame_offset: int = 1, + action_fps: float | None = None, t_sound: int | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: """Compute mRoPE cos/sin for UND text and GEN media pathways.""" @@ -1097,6 +1211,7 @@ def _compute_rope_freqs( S_text = text_mask.shape[1] text_lengths = text_mask.sum(dim=1).long() effective_fps = fps if fps is not None and t > 1 else None + action_frames = int(t_action or 0) sound_frames = int(t_sound or 0) text_pos_list = [] @@ -1116,6 +1231,17 @@ def _compute_rope_freqs( enable_fps_modulation=self.enable_fps_modulation, ) gen_positions = [v_pos] + if action_frames > 0: + a_pos, _ = compute_mrope_position_ids_action( + action_frames, + temporal_offset=media_temporal_offset, + action_fps=action_fps if action_fps is not None else fps, + base_fps=self.base_fps, + base_temporal_compression_factor=self.temporal_compression_factor, + enable_fps_modulation=self.enable_fps_modulation, + start_frame_offset=action_start_frame_offset, + ) + gen_positions.append(a_pos) if sound_frames > 0: s_pos, _ = compute_mrope_position_ids_sound( sound_frames, @@ -1161,7 +1287,9 @@ def _validate_gen_sequence_parallel( *, s_gen: int, s_video: int, + s_action: int, s_sound: int, + has_action: bool, has_sound: bool, ulysses_size: int, ) -> None: @@ -1169,14 +1297,16 @@ def _validate_gen_sequence_parallel( return detail_parts = [f"video tokens {s_video}"] + if has_action: + detail_parts.append(f"action tokens {s_action}") if has_sound: detail_parts.append(f"sound tokens {s_sound}") detail = " = " + " + ".join(detail_parts) if len(detail_parts) > 1 else "" adjust_detail = ( - "Adjust the spatial resolution, frame count, sound duration, " - "or sound latent FPS so the combined media sequence is a " + "Adjust the spatial resolution, frame count, action chunk size, " + "sound duration, or sound latent FPS so the combined media sequence is a " "multiple of ulysses_degree." - if has_sound + if has_action or has_sound else ( "Adjust the spatial resolution so that " "t * ceil(h/patch) * ceil(w/patch) is a multiple " @@ -1198,6 +1328,11 @@ def forward( text_mask: torch.Tensor, video_shape: tuple[int, int, int], fps: float | None = None, + action_latents: torch.Tensor | None = None, + action_domain_ids: torch.Tensor | None = None, + action_noisy_mask: torch.Tensor | None = None, + action_start_frame_offset: int = 1, + action_fps: float | None = None, sound_latents: torch.Tensor | None = None, noisy_frame_mask: torch.Tensor | None = None, **kwargs, @@ -1210,6 +1345,10 @@ def forward( text_mask: [B, S_text] attention mask (1=real, 0=pad) video_shape: (t, h, w) in latent space fps: video frame rate for temporal mRoPE modulation + action_latents: Optional [B, T_action, D_action] noisy action latents. + action_domain_ids: Optional [B] embodiment domain IDs for action projections. + action_noisy_mask: Optional [B, T_action, 1] mask where 1=noisy + action token and 0=clean conditioned token. sound_latents: Optional [B, C_sound, T_sound] noisy sound latents. noisy_frame_mask: Optional [B, 1, t, 1, 1] mask where 1=noisy (add timestep embedding, predict velocity) and 0=conditioned (clean @@ -1218,7 +1357,7 @@ def forward( Returns: [B, C, t, h, w] velocity prediction, or - tuple outputs in video, sound order when sound latents are provided. + tuple outputs in video, action, sound order when extra modalities are provided. """ t, h, w = video_shape hp, wp, _, _ = self._pad_to_patch_size(h, w) @@ -1230,7 +1369,14 @@ def forward( f"Cosmos3 requires identical real text lengths within a batch " f"(got min={min_real_len}, max={max_real_len})." ) + has_action = action_latents is not None has_sound = sound_latents is not None + if has_action and not self.action_gen: + raise ValueError( + "Cosmos3 action generation was requested, but this transformer " + "was initialized without action modules. Check that the " + "transformer config enables action_gen." + ) if has_sound and not self.sound_gen: raise ValueError( "Cosmos3 sound generation was requested, but this transformer " @@ -1244,8 +1390,21 @@ def forward( # Patchify latents and project to hidden space hidden_video = self.proj_in(self.patchify(hidden_states, t, h, w)) s_video = hidden_video.shape[1] + s_action = 0 + hidden_action = None s_sound = 0 hidden_sound = None + if action_latents is not None: + if action_latents.shape[0] != hidden_states.shape[0]: + raise ValueError( + "Cosmos3 action and video batch sizes must match: " + f"video={hidden_states.shape[0]}, action={action_latents.shape[0]}." + ) + if action_domain_ids is None: + action_domain_ids = torch.zeros(action_latents.shape[0], dtype=torch.long, device=action_latents.device) + hidden_action = self.action_proj_in(self.pack_action(action_latents), action_domain_ids) + hidden_action = hidden_action + self.action_modality_embed.to(hidden_action.dtype) + s_action = hidden_action.shape[1] if sound_latents is not None: if sound_latents.shape[0] != hidden_states.shape[0]: raise ValueError( @@ -1277,9 +1436,22 @@ def forward( else: hidden_video = hidden_video + time_embed.unsqueeze(1) + if hidden_action is not None: + if action_noisy_mask is None: + hidden_action = hidden_action + time_embed.unsqueeze(1) + else: + if action_noisy_mask.shape != (hidden_action.shape[0], hidden_action.shape[1], 1): + raise ValueError( + "Cosmos3 action_noisy_mask must have shape [B, T_action, 1], " + f"got {tuple(action_noisy_mask.shape)}." + ) + hidden_action = hidden_action + time_embed.unsqueeze(1) * action_noisy_mask.to(hidden_action.dtype) + if hidden_sound is not None: hidden_sound = hidden_sound + time_embed.unsqueeze(1) hidden_parts = [hidden_video] + if hidden_action is not None: + hidden_parts.append(hidden_action) if hidden_sound is not None: hidden_parts.append(hidden_sound) hidden_gen = torch.cat(hidden_parts, dim=1) @@ -1294,6 +1466,9 @@ def forward( fps, hidden_states.device, hidden_states.dtype, + t_action=s_action, + action_start_frame_offset=action_start_frame_offset, + action_fps=action_fps, t_sound=s_sound, ) cached_kv_full = self.language_model(text_ids, freqs_und) @@ -1311,7 +1486,9 @@ def forward( self._validate_gen_sequence_parallel( s_gen=hidden_gen.shape[1], s_video=s_video, + s_action=s_action, s_sound=s_sound, + has_action=has_action, has_sound=has_sound, ulysses_size=ulysses_size, ) @@ -1346,10 +1523,12 @@ def forward( # Final norm and project back to latent space hidden_gen = self.norm_moe_gen(hidden_gen) - if not has_sound: + if not has_action and not has_sound: return self.unpatchify(self.proj_out(hidden_gen), t, h, w) split_sizes = [s_video] + if has_action: + split_sizes.append(s_action) if has_sound: split_sizes.append(s_sound) split_hidden = hidden_gen.split(split_sizes, dim=1) @@ -1357,6 +1536,11 @@ def forward( video_pred = self.unpatchify(self.proj_out(hidden_video), t, h, w) outputs: list[torch.Tensor] = [video_pred] split_idx = 1 + if has_action: + hidden_action = split_hidden[split_idx] + split_idx += 1 + assert action_domain_ids is not None + outputs.append(self.unpack_action(self.action_proj_out(hidden_action, action_domain_ids))) if has_sound: hidden_sound = split_hidden[split_idx] outputs.append(self.unpack_sound(self.audio_proj_out(hidden_sound))) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index c7f9bb5f3cf..723e1f2360b 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -100,6 +100,9 @@ def cmd(args: TrackingNamespace) -> None: model_config = dict(existing) if isinstance(existing, dict) else {} model_config["guardrails"] = False args.model_config = model_config + explicit_keys = getattr(args, "explicit_keys", None) + if explicit_keys is not None: + args.explicit_keys = explicit_keys | {"model_config"} if args.headless: run_headless(args) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 28acc7379c9..b1936062473 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -2545,7 +2545,7 @@ async def _run_video_generation_job( started_at = time.perf_counter() output_path = None try: - video_bytes, stage_durations, peak_memory_mb = await handler.generate_video_bytes( + video_bytes, stage_durations, peak_memory_mb, action = await handler.generate_video_bytes( request, video_id, reference_image=reference_image ) @@ -2563,6 +2563,7 @@ async def _run_video_generation_job( "inference_time_s": time.perf_counter() - started_at, "stage_durations": stage_durations, "peak_memory_mb": peak_memory_mb, + "action": action, }, ) except (EngineGenerateError, EngineDeadError) as exc: @@ -2774,7 +2775,7 @@ async def create_video_sync( raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id) started_at = time.perf_counter() try: - video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for( + video_bytes, stage_durations, peak_memory_mb, _action = await asyncio.wait_for( handler.generate_video_bytes(request, request_id, reference_image=reference_image), timeout=VIDEO_SYNC_TIMEOUT_S, ) diff --git a/vllm_omni/entrypoints/openai/protocol/__init__.py b/vllm_omni/entrypoints/openai/protocol/__init__.py index c68f6f59879..0d8ddd82d90 100644 --- a/vllm_omni/entrypoints/openai/protocol/__init__.py +++ b/vllm_omni/entrypoints/openai/protocol/__init__.py @@ -13,6 +13,7 @@ ResponseFormat, ) from vllm_omni.entrypoints.openai.protocol.videos import ( + VideoAction, VideoData, VideoGenerationRequest, VideoGenerationResponse, @@ -27,6 +28,7 @@ "ImageGenerationRequest", "ImageGenerationResponse", "ResponseFormat", + "VideoAction", "VideoData", "VideoGenerationRequest", "VideoGenerationResponse", diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index 887e3ce67ea..ec5ab14e8d8 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -220,12 +220,24 @@ def resolve_video_params(self) -> VideoParams: return vp +class VideoAction(BaseModel): + """Generated action sequence returned by action-capable video models.""" + + data: list[Any] = Field(..., description="JSON-serializable nested action values") + shape: list[int] = Field(..., description="Shape of the returned action data") + dtype: str | None = Field(default=None, description="Source action dtype, if available") + raw_action_dim: int | None = Field(default=None, description="Raw action dimension requested by the model") + action_mode: str | None = Field(default=None, description="Action generation mode") + domain_id: int | None = Field(default=None, description="Action embodiment domain id") + + class VideoData(BaseModel): """Single generated video data.""" b64_json: str | None = Field(default=None, description="Base64-encoded MP4 video") url: str | None = Field(default=None, description="Video URL (not implemented)") revised_prompt: str | None = Field(default=None, description="Revised prompt (OpenAI compatibility, always null)") + action: VideoAction | None = Field(default=None, description="Generated action sequence metadata, if any") class VideoGenerationResponse(BaseModel): @@ -298,6 +310,7 @@ class VideoResponse(BaseModel): default=0.0, description="Peak device memory usage in MB reported by the diffusion pipeline.", ) + action: VideoAction | None = Field(default=None, description="Generated action sequence metadata, if any") @property def file_extension(self) -> str: diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index 57a76594a0f..8673e192386 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -16,6 +16,7 @@ from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.protocol.videos import ( + VideoAction, VideoData, VideoGenerationRequest, VideoGenerationResponse, @@ -44,6 +45,7 @@ class VideoGenerationArtifacts: videos: list[Any] audios: list[Any | None] + actions: list[VideoAction | None] audio_sample_rate: int output_fps: int stage_durations: dict[str, float] @@ -162,7 +164,20 @@ async def _run_and_extract( ) # Merge extra_params into extra_args gen_params.extra_args.update(request.extra_params) - logger.info("Applied extra_params: %s", request.extra_params) + + # Redact the inline ``action`` array (hundreds of floats) when + # logging so it doesn't flood the logs; everything else is logged + # verbatim. + loggable = request.extra_params + action_val = loggable.get("action") + if action_val is not None: + summary = ( + f"<{type(action_val).__name__} len={len(action_val)}>" + if hasattr(action_val, "__len__") + else f"<{type(action_val).__name__}>" + ) + loggable = {**loggable, "action": summary} + logger.info("Applied extra_params: %s", loggable) self._apply_lora(request.lora, gen_params) @@ -177,11 +192,13 @@ async def _run_and_extract( result = await self._run_generation(prompt, gen_params, reference_id) videos = self._extract_video_outputs(result) audios = self._extract_audio_outputs(result, expected_count=len(videos)) + actions = self._extract_action_outputs(result, expected_count=len(videos)) audio_sample_rate = self._resolve_audio_sample_rate(result) output_fps = (vp.fps or self._resolve_fps(result) or 24) * self._resolve_video_fps_multiplier(result) return VideoGenerationArtifacts( videos=videos, audios=audios, + actions=actions, audio_sample_rate=audio_sample_rate, output_fps=output_fps, stage_durations=self._extract_stage_durations(result), @@ -215,7 +232,8 @@ async def generate_videos( audio_sample_rate=artifacts.audio_sample_rate, video_codec_options=video_codec_options, ) - ) + ), + action=artifacts.actions[idx], ) for idx, video in enumerate(artifacts.videos) ] @@ -234,7 +252,7 @@ async def generate_video_bytes( reference_id: str, *, reference_image: ReferenceImage | None = None, - ) -> tuple[bytes, dict[str, float], float]: + ) -> tuple[bytes, dict[str, float], float, VideoAction | None]: """Generate a video and return raw MP4 bytes, bypassing base64 encoding.""" artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image) if len(artifacts.videos) > 1: @@ -259,7 +277,7 @@ async def generate_video_bytes( ) _t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000 logger.info("Video response encoding (MP4 bytes): %.2f ms", _t_encode_ms) - return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb + return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb, artifacts.actions[0] @staticmethod def _resolve_video_fps_multiplier(result: Any) -> int: @@ -440,6 +458,18 @@ def _resolve_audio_sample_rate(self, result: Any) -> int: return 24000 + @classmethod + def _extract_action_outputs(cls, result: Any, expected_count: int) -> list[VideoAction | None]: + custom_output = cls._extract_custom_output(result) + if not custom_output or "action" not in custom_output: + return [None] * expected_count + + action_items = cls._split_action_payload(custom_output["action"], expected_count) + return [ + cls._make_video_action(action_item, custom_output) if action_item is not None else None + for action_item in action_items + ] + @staticmethod def _extract_custom_output(result: Any) -> dict[str, Any]: custom_output = getattr(result, "custom_output", None) @@ -458,6 +488,89 @@ def _extract_custom_output(result: Any) -> dict[str, Any]: return custom_output if isinstance(custom_output, dict) else {} + @classmethod + def _split_action_payload(cls, action: Any, expected_count: int) -> list[Any | None]: + if expected_count <= 0: + return [] + + shape = cls._shape_of(action) + if len(shape) >= 3: + count = min(shape[0], expected_count) + actions = [cls._index_action(action, i) for i in range(count)] + actions.extend([None] * (expected_count - count)) + return actions + + return [action] + [None] * (expected_count - 1) + + @classmethod + def _make_video_action(cls, action: Any, custom_output: dict[str, Any]) -> VideoAction: + data = cls._to_jsonable(action) + if not isinstance(data, list): + data = [data] + + action_mode = custom_output.get("action_mode") + return VideoAction( + data=data, + shape=cls._shape_of(action), + dtype=cls._dtype_of(action), + raw_action_dim=cls._coerce_optional_int(custom_output.get("raw_action_dim")), + action_mode=str(action_mode) if action_mode is not None else None, + domain_id=cls._coerce_optional_int(custom_output.get("domain_id")), + ) + + @staticmethod + def _index_action(action: Any, index: int) -> Any: + try: + return action[index] + except (IndexError, KeyError, TypeError): + return None + + @classmethod + def _to_jsonable(cls, value: Any) -> Any: + if hasattr(value, "detach"): + value = value.detach() + if hasattr(value, "cpu"): + value = value.cpu() + if hasattr(value, "tolist"): + return cls._to_jsonable(value.tolist()) + if isinstance(value, (list, tuple)): + return [cls._to_jsonable(item) for item in value] + if hasattr(value, "item"): + try: + return value.item() + except (TypeError, ValueError): + pass + return value + + @classmethod + def _shape_of(cls, value: Any) -> list[int]: + shape = getattr(value, "shape", None) + if shape is not None: + try: + return [int(dim) for dim in shape] + except (TypeError, ValueError): + pass + if isinstance(value, (list, tuple)): + if not value: + return [0] + return [len(value)] + cls._shape_of(value[0]) + return [] + + @staticmethod + def _dtype_of(value: Any) -> str | None: + dtype = getattr(value, "dtype", None) + return str(dtype) if dtype is not None else None + + @staticmethod + def _coerce_optional_int(value: Any) -> int | None: + if value is None: + return None + try: + value = value.item() if hasattr(value, "item") else value + return int(value) + except (TypeError, ValueError): + return None + @staticmethod def _resolve_fps(result: Any) -> int | None: """Extract fps from multimodal_output if the model reported it."""