diff --git a/docs/.nav.yml b/docs/.nav.yml index b7d08e77e91..bfa9365f6f6 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -60,6 +60,7 @@ nav: - FP8: user_guide/diffusion/quantization/fp8.md - Int8: user_guide/diffusion/quantization/int8.md - GGUF: user_guide/diffusion/quantization/gguf.md + - Step Execution: user_guide/diffusion/step_execution.md - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md - LoRA: user_guide/diffusion/lora.md @@ -91,6 +92,7 @@ nav: - design/feature/cache_dit.md - design/feature/teacache.md - design/feature/async_chunk_design.md + - design/feature/diffusion_step_execution.md - Module Design: - design/module/ar_module.md - design/module/dit_module.md diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md index 3d04bf94269..8629b41d524 100644 --- a/docs/contributing/model/adding_diffusion_model.md +++ b/docs/contributing/model/adding_diffusion_model.md @@ -739,6 +739,22 @@ See detailed guide: [How to add Sequence Parallel support](../../design/feature/ omni = Omni(model="your-model", ulysses_degree=2, ring_degree=2) ``` +### Step Execution + +See detailed design guide: [How to add step execution support](../../design/feature/diffusion_step_execution.md) + +Use this only when your pipeline can be split into stable request-scoped and +step-scoped phases. The reference implementation is +`QwenImagePipeline`, which maps its request-level `forward()` into: + +1. `prepare_encode()` for prompt encoding, latent init, timestep prep, and per-request scheduler setup. +2. `denoise_step()` for one transformer/noise prediction. +3. `step_scheduler()` for one scheduler update and `step_index` advance. +4. `post_decode()` for the final VAE decode. + +Do not enable `step_execution=True` until those four methods are implemented +and validated against the request-level path. + ### Cache Acceleration #### TeaCache diff --git a/docs/design/feature/diffusion_step_execution.md b/docs/design/feature/diffusion_step_execution.md new file mode 100644 index 00000000000..b8c81f04f69 --- /dev/null +++ b/docs/design/feature/diffusion_step_execution.md @@ -0,0 +1,121 @@ +# Diffusion Step Execution + +This guide documents vLLM-Omni's stepwise diffusion contract for model authors +and contributors implementing `step_execution=True` support for a diffusion +pipeline. + +For end-user enablement, supported models, and current limitations, see +[Step Execution](../../user_guide/diffusion/step_execution.md). + +## Current Support Scope + +`step_execution` is **not** a generic diffusion toggle. It only works for +pipelines that implement the segmented stateful contract in +[`vllm_omni/diffusion/models/interface.py`](gh-file:vllm_omni/diffusion/models/interface.py). + +Current in-tree support: + +| Pipeline | Example models | Step execution | +|----------|----------------|----------------| +| `QwenImagePipeline` | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Yes | +| All other diffusion pipelines | `QwenImageEditPipeline`, `QwenImageEditPlusPipeline`, `QwenImageLayeredPipeline`, GLM-Image, Wan, Flux, etc. | No | + +Current engine/runtime limitations: + +- `StepScheduler` only schedules `batch_size=1`. +- `cache_backend` is not supported in step mode. +- Request-mode extras such as KV transfer are not wired into step mode yet. +- Unsupported pipelines now fail early during model loading instead of failing on the first request. + +## Execution Contract + +Step mode is driven by four pipeline methods plus the shared mutable request +state object: + +- `prepare_encode(state)`: one-time request preparation. +- `denoise_step(state)`: compute the noise prediction for the current step. +- `step_scheduler(state, noise_pred)`: mutate latents and advance step state. +- `post_decode(state)`: decode the final output after denoising is complete. + +The state lives in +[`vllm_omni/diffusion/worker/utils.py`](gh-file:vllm_omni/diffusion/worker/utils.py) +as `DiffusionRequestState`. Store request-scoped tensors there, or use +`state.extra` for model-specific fields that do not justify extending the core +dataclass. + +The worker-side step loop lives in +[`vllm_omni/diffusion/worker/diffusion_model_runner.py`](gh-file:vllm_omni/diffusion/worker/diffusion_model_runner.py): + +1. `prepare_encode()` runs once for a new request. +2. `denoise_step()` runs every scheduler tick. +3. `step_scheduler()` mutates `state.latents` and advances `state.step_index`. +4. `post_decode()` runs exactly once after `state.denoise_completed` becomes true. + +## Recommended Split + +When converting an existing request-level `forward()` pipeline, keep the split +strict and mechanical: + +| Request-level phase | Stepwise method | What belongs there | +|---------------------|-----------------|--------------------| +| Input validation, prompt encoding, latent init, timestep prep, per-request scheduler creation | `prepare_encode()` | Anything that should happen once per request | +| Transformer forward / noise prediction | `denoise_step()` | Pure denoise computation for the current timestep | +| `scheduler.step(...)` and `step_index += 1` | `step_scheduler()` | Only latent/state mutation for one step | +| VAE decode / postprocess | `post_decode()` | Final decode only | + +Keep the stepwise path reusing the same helpers as the request-level path +whenever possible. Reimplementing the denoise loop from scratch is the easiest +way to introduce behavioral drift. + +## Qwen-Image Reference + +[`pipeline_qwen_image.py`](gh-file:vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py) +is the reference implementation and is split correctly for the current +contract: + +- `prepare_encode()` reuses `_prepare_generation_context()` so prompt encoding, + latent init, timestep creation, CFG setup, and shape bookkeeping stay aligned + with `forward()`. +- `prepare_encode()` deep-copies `self.scheduler` **after** + `prepare_timesteps()` so request-specific scheduler state is isolated. +- `denoise_step()` reuses `_build_denoise_kwargs()` plus + `predict_noise_maybe_with_cfg()`, so sequential CFG, CFG-parallel, and + non-CFG behavior stay identical to the request-level path. +- `step_scheduler()` only calls + `scheduler_step_maybe_with_cfg(..., per_request_scheduler=state.scheduler)` + and increments `state.step_index`. +- `post_decode()` reuses `_decode_latents()`, so the final image decode matches + the normal `forward()` path. + +That decomposition is the target pattern for future models. + +## Rules For New Pipelines + +- Do not keep request-scoped scheduler state on `self.scheduler`. Copy it into + `state.scheduler` during `prepare_encode()`. +- Do not mutate `state.step_index` inside `denoise_step()`. Only + `step_scheduler()` should advance the step. +- Do not decode partial outputs in `denoise_step()` or `step_scheduler()`. +- If the request-level pipeline has condition latents, masks, or edit-specific + tensors, store them in `state` or `state.extra`, not in global pipeline + attributes. +- Preserve CFG behavior by sharing the same helper path used by `forward()`. +- Keep `post_decode()` equivalent to the tail of `forward()`. + +## Validation Checklist + +Before marking a pipeline as `supports_step_execution = True`, verify: + +- Stepwise output matches request-level output for the same seed and sampling params. +- Per-request scheduler state is isolated across concurrent requests. +- Abort during denoise does not leak cached state. +- `step_index` reported by `RunnerOutput` matches the scheduler progress. +- CFG-parallel and non-CFG paths both work if the request-level pipeline supports them. + +## Related Files + +- Contract: [`vllm_omni/diffusion/models/interface.py`](gh-file:vllm_omni/diffusion/models/interface.py) +- State: [`vllm_omni/diffusion/worker/utils.py`](gh-file:vllm_omni/diffusion/worker/utils.py) +- Runner loop: [`vllm_omni/diffusion/worker/diffusion_model_runner.py`](gh-file:vllm_omni/diffusion/worker/diffusion_model_runner.py) +- Scheduler transport: [`vllm_omni/diffusion/sched/interface.py`](gh-file:vllm_omni/diffusion/sched/interface.py) +- Reference pipeline: [`vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py`](gh-file:vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py) diff --git a/docs/user_guide/diffusion/step_execution.md b/docs/user_guide/diffusion/step_execution.md new file mode 100644 index 00000000000..99c2878506e --- /dev/null +++ b/docs/user_guide/diffusion/step_execution.md @@ -0,0 +1,61 @@ +# Step Execution + +Step execution is an opt-in diffusion execution mode enabled with +`step_execution=True` when constructing `Omni`. + +It is not a generic diffusion toggle for every pipeline. Only pipelines that +implement the stepwise contract support it today. + +## Quick Start + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="Qwen/Qwen-Image", + step_execution=True, +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams( + num_inference_steps=50, + ), +) +``` + +## Supported Pipelines + +| Pipeline | Example models | Step execution | +|----------|----------------|----------------| +| `QwenImagePipeline` | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Yes | +| All other diffusion pipelines | `QwenImageEditPipeline`, `QwenImageEditPlusPipeline`, `QwenImageLayeredPipeline`, GLM-Image, Wan, Flux, etc. | No | + +## Current Limitations + +- `step_execution` currently supports `batch_size=1` only. +- `cache_backend` is not supported together with step execution. +- Unsupported pipelines fail early during model loading. +- Request-mode extras such as KV transfer are not wired into step mode yet. + +## When To Use It + +Use step execution only when you specifically need the pipeline to run through +its stepwise request state machine. For normal diffusion inference, leave it +disabled unless your workflow depends on this mode. + +If you are looking for general diffusion speedups, see +[Diffusion Acceleration Overview](../diffusion_acceleration.md). + +## Troubleshooting + +If model loading fails with a message mentioning `prepare_encode()`, +`denoise_step()`, `step_scheduler()`, and `post_decode()`, the selected +pipeline does not support step execution. + +## For Model Authors + +If you want to add step execution support to a new diffusion pipeline, see the +implementation guide: +[Diffusion Step Execution Design](../../design/feature/diffusion_step_execution.md). diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py new file mode 100644 index 00000000000..121fbc0d5fa --- /dev/null +++ b/tests/diffusion/test_diffusion_step_pipeline.py @@ -0,0 +1,507 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for step-level diffusion runner and worker execution.""" + +import os +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest +import torch + +import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module +from tests.utils import hardware_test +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.comm import RingComm, SeqAllToAll4D +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + get_sp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.diffusion.ipc import ( + pack_diffusion_output_shm, + unpack_diffusion_output_shm, +) +from vllm_omni.diffusion.sched.interface import ( + DiffusionRequestState as SchedulerRequestState, +) +from vllm_omni.diffusion.sched.interface import ( + DiffusionSchedulerOutput, +) +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker +from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.platforms import current_omni_platform + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + +# --------------------------------------------------------------------------- +# Helpers & fixtures +# --------------------------------------------------------------------------- + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + del args, kwargs + yield + + +def _update_environment_variables(envs_dict: dict[str, str]) -> None: + for key, value in envs_dict.items(): + os.environ[key] = value + + +class _StepPipeline: + """Minimal pipeline stub that supports step-wise execution.""" + + supports_step_execution = True + + def __init__(self): + self.prepare_calls = 0 + self.denoise_calls = 0 + self.scheduler_calls = 0 + self.decode_calls = 0 + + def prepare_encode(self, state, **kwargs): + del kwargs + self.prepare_calls += 1 + state.timesteps = [torch.tensor(10), torch.tensor(5)] + state.latents = torch.tensor([0.0]) + return state + + def denoise_step(self, state, **kwargs): + del state, kwargs + self.denoise_calls += 1 + return torch.tensor([1.0]) + + def step_scheduler(self, state, noise_pred, **kwargs): + del noise_pred, kwargs + self.scheduler_calls += 1 + state.step_index += 1 + + def post_decode(self, state, **kwargs): + del kwargs + self.decode_calls += 1 + return DiffusionOutput(output=torch.tensor([state.step_index], dtype=torch.float32)) + + +class _IdentityNoiseTransformer(torch.nn.Module): + def forward(self, x: torch.Tensor, **kwargs): + del kwargs + return (x,) + + +class _AdditiveScheduler: + def step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor, return_dict: bool = False): + del t, return_dict + return (latents + noise_pred,) + + +class _DistributedStepPipeline(CFGParallelMixin): + supports_step_execution = True + + def __init__(self, mode: str, device: torch.device): + self.mode = mode + self.device = device + self._interrupt = False + self.scheduler = _AdditiveScheduler() + self.transformer = _IdentityNoiseTransformer() + + @property + def interrupt(self): + return self._interrupt + + def prepare_encode(self, state, **kwargs): + del kwargs + state.timesteps = [torch.tensor(1.0, device=self.device)] + state.latents = torch.ones((1, 1), device=self.device) + state.step_index = 0 + state.scheduler = self.scheduler + state.do_true_cfg = self.mode == "cfg" + return state + + def denoise_step(self, state, **kwargs): + del kwargs + if self.mode == "ulysses": + sp_group = get_sp_group().ulysses_group + seq_world_size = torch.distributed.get_world_size(sp_group) + input_tensor = torch.randn(1, 2, 2 * seq_world_size, 2, device=self.device) + original = input_tensor.clone() + intermediate = SeqAllToAll4D.apply(sp_group, input_tensor, 2, 1, False) + output = SeqAllToAll4D.apply(sp_group, intermediate, 1, 2, False) + torch.testing.assert_close(output, original, rtol=1e-5, atol=1e-5) + return torch.ones_like(state.latents) + + if self.mode == "ring": + ring_group = get_sp_group().ring_group + rank = torch.distributed.get_rank(ring_group) + world_size = torch.distributed.get_world_size(ring_group) + comm = RingComm(ring_group) + input_tensor = torch.full((1, 2, 2), float(rank + 1), device=self.device) + recv_tensor = comm.send_recv(input_tensor) + comm.commit() + comm.wait() + expected = torch.full_like(recv_tensor, float(((rank - 1) % world_size) + 1)) + torch.testing.assert_close(recv_tensor, expected, rtol=1e-5, atol=1e-5) + return torch.ones_like(state.latents) + + positive_kwargs = {"x": state.latents + 1} + negative_kwargs = {"x": state.latents - 1} + return self.predict_noise_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=1.0, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + def step_scheduler(self, state, noise_pred, **kwargs): + del kwargs + if self.mode == "cfg": + state.latents = self.scheduler_step_maybe_with_cfg( + noise_pred, + state.current_timestep, + state.latents, + do_true_cfg=True, + per_request_scheduler=state.scheduler, + ) + else: + state.latents = state.latents + noise_pred + state.step_index += 1 + + def post_decode(self, state, **kwargs): + del kwargs + return DiffusionOutput(output=state.latents.detach().cpu()) + + +def _make_step_request(num_inference_steps: int = 2): + return SimpleNamespace( + prompts=["a prompt"], + request_ids=["req-1"], + sampling_params=SimpleNamespace( + generator=None, + seed=None, + generator_device=None, + num_inference_steps=num_inference_steps, + ), + ) + + +def _make_runner(): + runner = object.__new__(DiffusionModelRunner) + runner.vllm_config = object() + runner.od_config = SimpleNamespace( + cache_backend=None, + parallel_config=SimpleNamespace(use_hsdp=False), + ) + runner.device = torch.device("cpu") + runner.pipeline = _StepPipeline() + runner.cache_backend = None + runner.offload_backend = None + runner.state_cache = {} + runner.kv_transfer_manager = SimpleNamespace() + return runner + + +def _make_distributed_runner(mode: str, device: torch.device): + runner = object.__new__(DiffusionModelRunner) + runner.vllm_config = object() + runner.od_config = SimpleNamespace( + cache_backend=None, + parallel_config=SimpleNamespace(use_hsdp=False), + ) + runner.device = device + runner.pipeline = _DistributedStepPipeline(mode=mode, device=device) + runner.cache_backend = None + runner.offload_backend = None + runner.state_cache = {} + runner.kv_transfer_manager = SimpleNamespace() + return runner + + +def _make_scheduler_output(req, sched_req_id="req-1", step_id=0, finished_req_ids=None): + return DiffusionSchedulerOutput( + step_id=step_id, + req_states=[SchedulerRequestState(sched_req_id=sched_req_id, req=req)], + finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), + num_running_reqs=1, + num_waiting_reqs=0, + ) + + +def _expected_output_for_mode(mode: str) -> torch.Tensor: + if mode == "cfg": + return torch.tensor([[3.0]]) + return torch.tensor([[2.0]]) + + +def _distributed_step_worker(local_rank: int, world_size: int, mode: str, master_port: str): + device = torch.device(f"{current_omni_platform.device_type}:{local_rank}") + current_omni_platform.set_device(device) + _update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": master_port, + } + ) + model_runner_module.set_forward_context = _noop_forward_context + + try: + init_distributed_environment() + if mode == "ulysses": + initialize_model_parallel(ulysses_degree=world_size) + elif mode == "ring": + initialize_model_parallel(ring_degree=world_size) + elif mode == "cfg": + initialize_model_parallel(cfg_parallel_size=world_size) + else: + raise ValueError(f"Unsupported distributed test mode: {mode}") + + runner = _make_distributed_runner(mode, device) + output = DiffusionModelRunner.execute_stepwise( + runner, + _make_scheduler_output(_make_step_request(num_inference_steps=1), step_id=0), + ) + + assert output.finished is True + assert output.result is not None + torch.testing.assert_close(output.result.output, _expected_output_for_mode(mode), rtol=1e-5, atol=1e-5) + assert "req-1" not in runner.state_cache + finally: + destroy_distributed_env() + + +# --------------------------------------------------------------------------- +# Runner / Worker +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu +class TestRunner: + """DiffusionModelRunner.execute_stepwise""" + + def test_completes_request_and_clears_state(self, monkeypatch): + runner = _make_runner() + req = _make_step_request() + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + + first = DiffusionModelRunner.execute_stepwise(runner, _make_scheduler_output(req, step_id=0)) + assert first.req_id == "req-1" + assert first.step_index == 1 + assert first.finished is False + assert first.result is None + assert "req-1" in runner.state_cache + + second = DiffusionModelRunner.execute_stepwise(runner, _make_scheduler_output(req, step_id=1)) + assert second.req_id == "req-1" + assert second.step_index == 2 + assert second.finished is True + assert second.result is not None + assert second.result.error is None + assert torch.equal(second.result.output, torch.tensor([2.0])) + assert "req-1" not in runner.state_cache + + assert runner.pipeline.prepare_calls == 1 + assert runner.pipeline.denoise_calls == 2 + assert runner.pipeline.scheduler_calls == 2 + assert runner.pipeline.decode_calls == 1 + + def test_load_model_rejects_unsupported_step_execution(self, monkeypatch): + class _RequestOnlyPipeline: + pass + + class _FakeLoader: + def __init__(self, *args, **kwargs): + del args, kwargs + + def load_model(self, **kwargs): + del kwargs + return _RequestOnlyPipeline() + + class _FakeProfiler: + consumed_memory = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + del exc_type, exc, tb + return False + + runner = object.__new__(DiffusionModelRunner) + runner.vllm_config = object() + runner.od_config = SimpleNamespace( + enable_cpu_offload=False, + enable_layerwise_offload=False, + enforce_eager=True, + cache_backend=None, + cache_config=None, + step_execution=True, + model_class_name="RequestOnlyPipeline", + parallel_config=SimpleNamespace(use_hsdp=False), + ) + runner.device = torch.device("cpu") + runner.pipeline = None + runner.cache_backend = None + runner.offload_backend = None + runner.state_cache = {} + runner.kv_transfer_manager = SimpleNamespace() + + monkeypatch.setattr(model_runner_module, "DiffusersPipelineLoader", _FakeLoader) + monkeypatch.setattr(model_runner_module, "DeviceMemoryProfiler", _FakeProfiler) + monkeypatch.setattr(model_runner_module, "get_offload_backend", lambda *args, **kwargs: None) + monkeypatch.setattr(model_runner_module, "get_cache_backend", lambda *args, **kwargs: None) + + with pytest.raises(ValueError, match="RequestOnlyPipeline"): + DiffusionModelRunner.load_model(runner) + + +@pytest.mark.cpu +class TestWorker: + """DiffusionWorker.execute_stepwise""" + + def test_delegates_to_model_runner(self): + worker = object.__new__(DiffusionWorker) + expected = RunnerOutput(req_id="req-1", step_index=1, finished=False, result=None) + scheduler_output = SimpleNamespace( + req_states=[ + SimpleNamespace( + req=SimpleNamespace( + sampling_params=SimpleNamespace(lora_request=None), + ) + ) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace( + execute_stepwise=lambda arg: expected if arg is scheduler_output else None + ) + + output = DiffusionWorker.execute_stepwise(worker, scheduler_output) + + assert output is expected + + def test_clears_active_lora_before_stepwise_execution(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + req_states=[ + SimpleNamespace( + req=SimpleNamespace( + sampling_params=SimpleNamespace(lora_request=None), + ) + ) + ] + ) + calls: list[object | None] = [] + + class _FakeLoRAManager: + def set_active_adapter(self, adapter): + calls.append(adapter) + + worker.lora_manager = _FakeLoRAManager() + worker.model_runner = SimpleNamespace(execute_stepwise=lambda arg: RunnerOutput(req_id="req-1")) + + DiffusionWorker.execute_stepwise(worker, scheduler_output) + + assert calls == [None] + + def test_rejects_lora_requests_in_step_mode(self): + worker = object.__new__(DiffusionWorker) + scheduler_output = SimpleNamespace( + req_states=[ + SimpleNamespace( + req=SimpleNamespace( + sampling_params=SimpleNamespace(lora_request=object()), + ) + ) + ] + ) + worker.lora_manager = None + worker.model_runner = SimpleNamespace(execute_stepwise=lambda arg: RunnerOutput(req_id="req-1")) + + with pytest.raises(ValueError, match="does not support LoRA"): + DiffusionWorker.execute_stepwise(worker, scheduler_output) + + +@pytest.mark.cpu +class TestIPC: + def test_pack_unpack_runner_output_shm(self): + tensor = torch.zeros(300_000, dtype=torch.float32) + output = RunnerOutput(req_id="req-1", finished=True, result=DiffusionOutput(output=tensor)) + + packed = pack_diffusion_output_shm(output) + assert isinstance(packed.result.output, dict) + assert packed.result.output["__tensor_shm__"] is True + + unpacked = unpack_diffusion_output_shm(packed) + assert isinstance(unpacked.result.output, torch.Tensor) + torch.testing.assert_close(unpacked.result.output, tensor) + + +@pytest.mark.cpu +class TestSupportedPipelines: + """Step-execution protocol checks for supported pipelines.""" + + def test_qwen_image_supports_step_execution(self): + from vllm_omni.diffusion.models.interface import SupportsStepExecution, supports_step_execution + from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import QwenImagePipeline + + # Avoid loading model weights; protocol membership depends on the class contract. + pipeline = object.__new__(QwenImagePipeline) + + assert pipeline.supports_step_execution is True + assert supports_step_execution(pipeline) is True + assert isinstance(pipeline, SupportsStepExecution) is True + + +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_execute_stepwise_with_ulysses_parallel(): + world_size = 2 + if current_omni_platform.get_device_count() < world_size: + pytest.skip(f"Test requires {world_size} devices") + + torch.multiprocessing.spawn( + _distributed_step_worker, + args=(world_size, "ulysses", "29540"), + nprocs=world_size, + ) + + +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_execute_stepwise_with_ring_parallel(): + world_size = 2 + if current_omni_platform.get_device_count() < world_size: + pytest.skip(f"Test requires {world_size} devices") + + torch.multiprocessing.spawn( + _distributed_step_worker, + args=(world_size, "ring", "29541"), + nprocs=world_size, + ) + + +@hardware_test( + res={"cuda": "L4"}, + num_cards=2, +) +def test_execute_stepwise_with_cfg_parallel(): + world_size = 2 + if current_omni_platform.get_device_count() < world_size: + pytest.skip(f"Test requires {world_size} devices") + + torch.multiprocessing.spawn( + _distributed_step_worker, + args=(world_size, "cfg", "29542"), + nprocs=world_size, + ) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index c22b163f65d..f15aae6a026 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -464,6 +464,9 @@ class OmniDiffusionConfig: # Diffusion pipeline Profiling config enable_diffusion_pipeline_profiler: bool = False + # Step mode settings + step_execution: bool = False + @property def is_moe(self) -> bool: num_experts = self.tf_model_config.get("num_experts", None) diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 9f86bce228b..0743a00d4af 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -182,7 +182,13 @@ def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...): """ raise NotImplementedError("Subclasses must implement diffuse") - def scheduler_step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + def scheduler_step( + self, + noise_pred: torch.Tensor, + t: torch.Tensor, + latents: torch.Tensor, + per_request_scheduler: Any | None = None, + ) -> torch.Tensor: """ Step the scheduler. @@ -190,14 +196,30 @@ def scheduler_step(self, noise_pred: torch.Tensor, t: torch.Tensor, latents: tor noise_pred: Predicted noise t: Current timestep latents: Current latents + per_request_scheduler: Optional request-scoped scheduler that + overrides ``self.scheduler`` for this call. This is + primarily used by step-wise execution, where each request + may keep scheduler state in its own runner-managed state + object. Request-level execution should usually leave this + as ``None`` and continue using ``self.scheduler``. Returns: Updated latents after scheduler step """ - return self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + sched = per_request_scheduler if per_request_scheduler is not None else getattr(self, "scheduler", None) + if sched is None: + raise ValueError("No scheduler is available. Set self.scheduler or pass per_request_scheduler.") + if not callable(getattr(sched, "step", None)): + raise TypeError("per_request_scheduler must provide a callable step(...) method.") + return sched.step(noise_pred, t, latents, return_dict=False)[0] def scheduler_step_maybe_with_cfg( - self, noise_pred: torch.Tensor, t: torch.Tensor, latents: torch.Tensor, do_true_cfg: bool + self, + noise_pred: torch.Tensor, + t: torch.Tensor, + latents: torch.Tensor, + do_true_cfg: bool, + per_request_scheduler: Any | None = None, ) -> torch.Tensor: """ Step the scheduler with (maybe) automatic CFG parallel synchronization. @@ -210,6 +232,11 @@ def scheduler_step_maybe_with_cfg( t: Current timestep latents: Current latents do_true_cfg: Whether CFG is enabled + per_request_scheduler: Optional request-scoped scheduler that + overrides ``self.scheduler`` for this call. This is mainly + needed by step-wise execution, where scheduler state may be + stored per request. Request-level execution should normally + leave this as ``None``. Returns: Updated latents (synchronized across all CFG ranks) @@ -223,13 +250,23 @@ def scheduler_step_maybe_with_cfg( # Only rank 0 computes the scheduler step if cfg_rank == 0: - latents = self.scheduler_step(noise_pred, t, latents) + latents = self.scheduler_step( + noise_pred, + t, + latents, + per_request_scheduler=per_request_scheduler, + ) # Broadcast the updated latents to all ranks latents = latents.contiguous() cfg_group.broadcast(latents, src=0) else: # No CFG parallel: directly compute scheduler step - latents = self.scheduler_step(noise_pred, t, latents) + latents = self.scheduler_step( + noise_pred, + t, + latents, + per_request_scheduler=per_request_scheduler, + ) return latents diff --git a/vllm_omni/diffusion/ipc.py b/vllm_omni/diffusion/ipc.py index cac406ef4c0..d3d7b3aff35 100644 --- a/vllm_omni/diffusion/ipc.py +++ b/vllm_omni/diffusion/ipc.py @@ -65,12 +65,7 @@ def _tensor_from_shm(handle: dict[str, Any]) -> torch.Tensor: return tensor -def pack_diffusion_output_shm(output: DiffusionOutput) -> DiffusionOutput: - """Replace large tensors in *output* with shared-memory handles. - - The DiffusionOutput is modified **in-place** so that the (now lightweight) - object can be serialised cheaply through a MessageQueue. - """ +def _pack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput: if output.output is not None and isinstance(output.output, torch.Tensor): if output.output.nelement() * output.output.element_size() > _SHM_TENSOR_THRESHOLD: output.output = _tensor_to_shm(output.output) @@ -80,10 +75,35 @@ def pack_diffusion_output_shm(output: DiffusionOutput) -> DiffusionOutput: return output -def unpack_diffusion_output_shm(output: DiffusionOutput) -> DiffusionOutput: - """Reconstruct tensors from shared-memory handles produced by ``pack_diffusion_output_shm``.""" +def pack_diffusion_output_shm(output: object) -> object: + """Replace large tensors in diffusion worker outputs with SHM handles. + + Supports either a bare ``DiffusionOutput`` or a wrapper object carrying one + in ``.result`` (for example ``RunnerOutput``). + """ + if isinstance(output, DiffusionOutput): + return _pack_diffusion_fields(output) + + result = getattr(output, "result", None) + if isinstance(result, DiffusionOutput): + output.result = _pack_diffusion_fields(result) + return output + + +def _unpack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput: if isinstance(output.output, dict) and output.output.get("__tensor_shm__"): output.output = _tensor_from_shm(output.output) if isinstance(output.trajectory_latents, dict) and output.trajectory_latents.get("__tensor_shm__"): output.trajectory_latents = _tensor_from_shm(output.trajectory_latents) return output + + +def unpack_diffusion_output_shm(output: object) -> object: + """Reconstruct tensors from SHM handles in diffusion worker outputs.""" + if isinstance(output, DiffusionOutput): + return _unpack_diffusion_fields(output) + + result = getattr(output, "result", None) + if isinstance(result, DiffusionOutput): + output.result = _unpack_diffusion_fields(result) + return output diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index ef90e980414..ef906472bd0 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -1,11 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + from typing import ( + TYPE_CHECKING, + Any, ClassVar, Protocol, runtime_checkable, ) +if TYPE_CHECKING: + import torch + + from vllm_omni.diffusion.data import DiffusionOutput + from vllm_omni.diffusion.worker.utils import DiffusionRequestState + @runtime_checkable class SupportImageInput(Protocol): @@ -21,3 +31,34 @@ class SupportAudioInput(Protocol): @runtime_checkable class SupportAudioOutput(Protocol): support_audio_output: ClassVar[bool] = True + + +@runtime_checkable +class SupportsStepExecution(Protocol): + """State-driven step-level execution protocol for diffusion pipelines. + + Pipelines should split request-level ``forward()`` into: + ``prepare_encode()`` (one-time request setup), ``denoise_step()`` + (one denoise forward), ``step_scheduler()`` (one scheduler update), + and ``post_decode()`` (final decode). + """ + + supports_step_execution: ClassVar[bool] = True + + def prepare_encode(self, state: DiffusionRequestState, **kwargs: Any) -> DiffusionRequestState: + """Prepare request-level inputs and return initialized state.""" + + def denoise_step(self, state: DiffusionRequestState, **kwargs: Any) -> torch.Tensor | None: + """Run one denoise step.""" + + def step_scheduler(self, state: DiffusionRequestState, noise_pred: torch.Tensor, **kwargs: Any) -> None: + """Run one scheduler step.""" + + def post_decode(self, state: DiffusionRequestState, **kwargs: Any) -> DiffusionOutput: + """Decode output after denoise loop.""" + + +def supports_step_execution(pipeline: object) -> bool: + """Return whether `pipeline` implements :class:`SupportsStepExecution`.""" + + return isinstance(pipeline, SupportsStepExecution) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 8ec3d6b9fd5..49e6f780bc6 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import inspect import json import logging import math import os from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any, ClassVar import numpy as np import torch @@ -35,6 +36,10 @@ from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs + +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import DiffusionRequestState + from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, ) @@ -238,6 +243,8 @@ def apply_rotary_emb_qwen( class QwenImagePipeline(nn.Module, QwenImageCFGParallelMixin, DiffusionPipelineProfilerMixin): + supports_step_execution: ClassVar[bool] = True + def __init__( self, *, @@ -539,57 +546,44 @@ def current_timestep(self): def interrupt(self): return self._interrupt - def forward( - self, - req: OmniDiffusionRequest, - prompt: str | list[str] | None = None, - negative_prompt: str | list[str] | None = None, - true_cfg_scale: float = 4.0, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - sigmas: list[float] | None = None, - guidance_scale: float = 1.0, - num_images_per_prompt: int = 1, - generator: torch.Generator | list[torch.Generator] | None = None, - latents: torch.Tensor | None = None, - prompt_embeds: torch.Tensor | None = None, - prompt_embeds_mask: torch.Tensor | None = None, - negative_prompt_embeds: torch.Tensor | None = None, - negative_prompt_embeds_mask: torch.Tensor | None = None, - output_type: str | None = "pil", - attention_kwargs: dict[str, Any] | None = None, - callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, - ) -> DiffusionOutput: - # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") - # TODO: May be some data formatting operations on the API side. Hack for now. - prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt - if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + def _extract_prompts(self, prompts): + """Extract prompt and negative_prompt from OmniPromptType list.""" + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in prompts] or None + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in prompts): + negative_prompt = None + elif prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in prompts] + else: negative_prompt = None - elif req.prompts: - negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + return prompt, negative_prompt - height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor - width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps - sigmas = req.sampling_params.sigmas or sigmas - max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length - generator = req.sampling_params.generator or generator - true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale - if req.sampling_params.guidance_scale_provided: - guidance_scale = req.sampling_params.guidance_scale - num_images_per_prompt = ( - req.sampling_params.num_outputs_per_prompt - if req.sampling_params.num_outputs_per_prompt > 0 - else num_images_per_prompt - ) - # 1. check inputs - # 2. encode prompts - # 3. prepare latents and timesteps - # 4. diffusion process - # 5. decode latents - # 6. post-process outputs + def _prepare_generation_context( + self, + *, + prompt, + negative_prompt, + height, + width, + num_inference_steps, + sigmas, + guidance_scale, + num_images_per_prompt, + generator, + true_cfg_scale, + max_sequence_length, + prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + latents=None, + attention_kwargs=None, + callback_on_step_end_tensor_inputs=None, + ): + """Shared preparation logic for forward() and prepare_encode(). + + Validates inputs, encodes prompts, prepares latents, computes timesteps, + and returns all intermediate values as a dict. + """ self.check_inputs( prompt, height, @@ -604,7 +598,7 @@ def forward( ) self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs + self._attention_kwargs = attention_kwargs or {} self._current_timestep = None self._interrupt = False @@ -612,13 +606,14 @@ def forward( batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: + elif prompt_embeds is not None: batch_size = prompt_embeds.shape[0] + else: + batch_size = 1 has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) - do_true_cfg = true_cfg_scale > 1 and has_neg_prompt self.check_cfg_parallel_validity(true_cfg_scale, has_neg_prompt) @@ -637,6 +632,9 @@ def forward( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) + else: + negative_prompt_embeds = None + negative_prompt_embeds_mask = None num_channels_latents = self.transformer.in_channels // 4 latents = self.prepare_latents( @@ -649,48 +647,345 @@ def forward( generator, latents, ) - img_shapes = [ - [ - ( - 1, - height // self.vae_scale_factor // 2, - width // self.vae_scale_factor // 2, - ) - ] - ] * batch_size - - timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) - # num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + + timesteps, num_inference_steps = self.prepare_timesteps( + num_inference_steps, + sigmas, + latents.shape[1], + ) self._num_timesteps = len(timesteps) - # handle guidance if self.transformer.guidance_embeds: guidance = torch.full([1], guidance_scale, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None - if self.attention_kwargs is None: - self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) - # print inputp params + + return { + "prompt_embeds": prompt_embeds, + "prompt_embeds_mask": prompt_embeds_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_embeds_mask": negative_prompt_embeds_mask, + "latents": latents, + "img_shapes": img_shapes, + "timesteps": timesteps, + "do_true_cfg": do_true_cfg, + "guidance": guidance, + "txt_seq_lens": txt_seq_lens, + "negative_txt_seq_lens": negative_txt_seq_lens, + } + + def prepare_encode( + self, + state: "DiffusionRequestState", + **kwargs: Any, + ) -> "DiffusionRequestState": + """Populate *state* with encoded prompts, latents, timesteps, and CFG config.""" + sampling = state.sampling + prompt, negative_prompt = self._extract_prompts(state.prompts or []) + + ctx = self._prepare_generation_context( + prompt=prompt, + negative_prompt=negative_prompt, + height=sampling.height or self.default_sample_size * self.vae_scale_factor, + width=sampling.width or self.default_sample_size * self.vae_scale_factor, + num_inference_steps=sampling.num_inference_steps or 50, + sigmas=sampling.sigmas, + guidance_scale=sampling.guidance_scale if sampling.guidance_scale_provided else 1.0, + num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1, + generator=sampling.generator, + true_cfg_scale=sampling.true_cfg_scale or 4.0, + max_sequence_length=sampling.max_sequence_length or 512, + attention_kwargs=kwargs.get("attention_kwargs"), + ) + + # prepare_timesteps() has already materialized request-specific timestep + # state on self.scheduler, so deepcopy preserves dynamic-shifting state + # without replaying set_timesteps() on the per-request scheduler. + # Per-request scheduler (must not share state with self.scheduler) + req_scheduler = copy.deepcopy(self.scheduler) + req_scheduler.set_begin_index(0) + + # Populate state from generation context + state.prompt_embeds = ctx["prompt_embeds"] + state.prompt_embeds_mask = ctx["prompt_embeds_mask"] + state.negative_prompt_embeds = ctx["negative_prompt_embeds"] + state.negative_prompt_embeds_mask = ctx["negative_prompt_embeds_mask"] + state.latents = ctx["latents"] + state.timesteps = ctx["timesteps"] + state.step_index = 0 + state.scheduler = req_scheduler + state.do_true_cfg = ctx["do_true_cfg"] + state.guidance = ctx["guidance"] + state.img_shapes = ctx["img_shapes"] + state.txt_seq_lens = ctx["txt_seq_lens"] + state.negative_txt_seq_lens = ctx["negative_txt_seq_lens"] + # QwenImage always normalizes CFG output (matching forward()) + state.sampling.cfg_normalize = True + + return state + + def _build_denoise_kwargs( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + guidance: torch.Tensor | None, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + img_shapes: list, + txt_seq_lens: list[int] | None, + do_true_cfg: bool, + negative_prompt_embeds: torch.Tensor | None, + negative_prompt_embeds_mask: torch.Tensor | None, + negative_txt_seq_lens: list[int] | None, + image_latents: torch.Tensor | None = None, + extra_transformer_kwargs: dict[str, Any] | None = None, + ) -> tuple[dict[str, Any], dict[str, Any] | None, int | None]: + """Build positive/negative kwargs and output_slice for one denoise step. + + Returns: + (positive_kwargs, negative_kwargs, output_slice) + """ + extra_transformer_kwargs = extra_transformer_kwargs or {} + + # Broadcast timestep to match batch size + t_for_model = timestep.expand(latents.shape[0]).to( + device=latents.device, + dtype=latents.dtype, + ) + + # Concatenate image latents if available (editing pipelines) + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": t_for_model / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": prompt_embeds_mask, + "encoder_hidden_states": prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + **extra_transformer_kwargs, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": t_for_model / 1000, + "guidance": guidance, + "encoder_hidden_states_mask": negative_prompt_embeds_mask, + "encoder_hidden_states": negative_prompt_embeds, + "img_shapes": img_shapes, + "txt_seq_lens": negative_txt_seq_lens, + **extra_transformer_kwargs, + } + else: + negative_kwargs = None + + output_slice = latents.size(1) if image_latents is not None else None + return positive_kwargs, negative_kwargs, output_slice + + def _decode_latents( + self, + latents: torch.Tensor, + height: int, + width: int, + output_type: str = "pil", + ) -> DiffusionOutput: + """Unpack, normalize, and VAE-decode latents into a DiffusionOutput.""" + if output_type == "latent": + return DiffusionOutput( + output=latents, + stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, + ) + + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + return DiffusionOutput( + output=image, + stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, + ) + + def denoise_step( + self, + state: "DiffusionRequestState", + **kwargs: Any, + ) -> torch.Tensor | None: + """One denoise step: read from *state*, delegate to CFGParallelMixin. + + Reuses ``predict_noise_maybe_with_cfg`` so that CFG-parallel, + sequential-CFG, and no-CFG paths are handled identically to + ``diffuse()``. + """ + if self.interrupt: + return None + + t = state.current_timestep + self._current_timestep = t + self.transformer.do_true_cfg = state.do_true_cfg + + # Normalize timestep to [batch_size] tensor + if not torch.is_tensor(t): + t = torch.tensor(t, device=state.latents.device, dtype=state.latents.dtype) + + positive_kwargs, negative_kwargs, output_slice = self._build_denoise_kwargs( + latents=state.latents, + timestep=t, + guidance=state.guidance, + prompt_embeds=state.prompt_embeds, + prompt_embeds_mask=state.prompt_embeds_mask, + img_shapes=state.img_shapes, + txt_seq_lens=state.txt_seq_lens, + do_true_cfg=state.do_true_cfg, + negative_prompt_embeds=state.negative_prompt_embeds, + negative_prompt_embeds_mask=state.negative_prompt_embeds_mask, + negative_txt_seq_lens=state.negative_txt_seq_lens, + image_latents=state.sampling.image_latent, + extra_transformer_kwargs={ + "attention_kwargs": self.attention_kwargs, + "return_dict": False, + }, + ) + + true_cfg_scale = state.sampling.true_cfg_scale or 4.0 + cfg_normalize = state.sampling.cfg_normalize + + return self.predict_noise_maybe_with_cfg( + state.do_true_cfg, + true_cfg_scale, + positive_kwargs, + negative_kwargs, + cfg_normalize, + output_slice, + ) + + def step_scheduler( + self, + state: "DiffusionRequestState", + noise_pred: torch.Tensor, + **kwargs: Any, + ) -> None: + """One scheduler step: update ``state.latents`` and advance ``step_index``.""" + if self.interrupt: + return + + t = state.current_timestep + state.latents = self.scheduler_step_maybe_with_cfg( + noise_pred, + t, + state.latents, + state.do_true_cfg, + per_request_scheduler=state.scheduler, + ) + + state.step_index += 1 + + def post_decode( + self, + state: "DiffusionRequestState", + **kwargs: Any, + ) -> DiffusionOutput: + """Decode final latents from *state*.""" + self._current_timestep = None + + height = state.sampling.height or self.default_sample_size * self.vae_scale_factor + width = state.sampling.width or self.default_sample_size * self.vae_scale_factor + output_type = kwargs.get("output_type", "pil") + + return self._decode_latents(state.latents, height, width, output_type) + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ) -> DiffusionOutput: + extracted_prompt, negative_prompt = self._extract_prompts(req.prompts) + prompt = extracted_prompt or prompt + + height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor + width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + sigmas = req.sampling_params.sigmas or sigmas + max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length + generator = req.sampling_params.generator or generator + true_cfg_scale = req.sampling_params.true_cfg_scale or true_cfg_scale + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + + ctx = self._prepare_generation_context( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + true_cfg_scale=true_cfg_scale, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + latents=latents, + attention_kwargs=attention_kwargs, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) latents = self.diffuse( - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - latents, - img_shapes, - txt_seq_lens, - negative_txt_seq_lens, - timesteps, - do_true_cfg, - guidance, + ctx["prompt_embeds"], + ctx["prompt_embeds_mask"], + ctx["negative_prompt_embeds"], + ctx["negative_prompt_embeds_mask"], + ctx["latents"], + ctx["img_shapes"], + ctx["txt_seq_lens"], + ctx["negative_txt_seq_lens"], + ctx["timesteps"], + ctx["do_true_cfg"], + ctx["guidance"], true_cfg_scale, image_latents=None, cfg_normalize=True, @@ -701,26 +996,7 @@ def forward( ) self._current_timestep = None - if output_type == "latent": - image = latents - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents / latents_std + latents_mean - image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] - # processed_image = self.image_processor.postprocess(image, output_type=output_type) - - return DiffusionOutput( - output=image, stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None - ) + return self._decode_latents(latents, height, width, output_type) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py new file mode 100644 index 00000000000..d8b3a39c54b --- /dev/null +++ b/vllm_omni/diffusion/sched/interface.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import enum +from dataclasses import dataclass + +from vllm_omni.diffusion.request import OmniDiffusionRequest + + +class DiffusionRequestStatus(enum.IntEnum): + """Request status tracked by diffusion scheduler.""" + + WAITING = enum.auto() + RUNNING = enum.auto() + PREEMPTED = enum.auto() + + # if any status is after or equal to FINISHED_COMPLETED, it is considered finished + FINISHED_COMPLETED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_ERROR = enum.auto() + + @staticmethod + def is_finished(status: DiffusionRequestStatus) -> bool: + return status >= DiffusionRequestStatus.FINISHED_COMPLETED + + +@dataclass +class DiffusionRequestState: + """Scheduler-owned state for one queued OmniDiffusionRequest.""" + + sched_req_id: str + req: OmniDiffusionRequest + status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING + error: str | None = None + + def is_finished(self) -> bool: + return DiffusionRequestStatus.is_finished(self.status) + + +@dataclass +class DiffusionSchedulerOutput: + """Output of a single scheduling cycle. + + Kept intentionally small so step-execution components can share a stable + transport shape while scheduler policy continues to evolve. + """ + + step_id: int + req_states: list[DiffusionRequestState] + finished_req_ids: set[str] + num_running_reqs: int + num_waiting_reqs: int diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 972c95c292c..e3c27f94545 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -10,6 +10,7 @@ from __future__ import annotations +import copy import time from collections.abc import Iterable from contextlib import nullcontext @@ -26,9 +27,12 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import supports_step_execution from vllm_omni.diffusion.offloader import get_offload_backend from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput +from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.platforms import current_omni_platform @@ -65,6 +69,9 @@ def __init__( self.cache_backend = None self.offload_backend = None + # Cache for per-request stepwise state. + self.state_cache: dict[str, DiffusionRequestState] = {} + # Initialize KV cache manager for connector management self.kv_transfer_manager = OmniKVTransferManager.from_od_config(od_config) @@ -138,6 +145,13 @@ def get_memory_context(): ) logger.info("Model runner: Model loaded successfully.") + if getattr(self.od_config, "step_execution", False) and not self.supports_step_mode(): + raise ValueError( + "step_execution=True requires a pipeline implementing " + "prepare_encode(), denoise_step(), step_scheduler(), and post_decode(); " + f"{self.od_config.model_class_name} does not support that contract." + ) + # Apply CPU offloading self.offload_backend = get_offload_backend(self.od_config, device=self.device) if self.offload_backend is not None: @@ -239,3 +253,106 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: cache_summary(self.pipeline, details=True) return output + + # ------------------------------------------------------------------ + # Step-wise execution + # ------------------------------------------------------------------ + + def supports_step_mode(self) -> bool: + """Return whether current pipeline supports step execution.""" + return self.pipeline is not None and supports_step_execution(self.pipeline) + + def _update_states(self, scheduler_output: DiffusionSchedulerOutput) -> DiffusionRequestState: + """Step-before update: cleanup finished requests and get/create one running state.""" + for req_id in scheduler_output.finished_req_ids: + self.state_cache.pop(req_id, None) + + req_states = scheduler_output.req_states + if len(req_states) != 1: + raise ValueError(f"Step mode currently supports batch_size=1, but got {len(req_states)} req_states.") + + # TODO: remove req state from SchedulerOutput + # Stepwise mode currently trusts runner-owned cached state more than + # re-validating scheduler-provided request content on every step. + sched_req_state = req_states[0] + req_id = sched_req_state.sched_req_id + if req_id in self.state_cache: + return self.state_cache[req_id] + + req = sched_req_state.req + request_ids = req.request_ids or [req_id] + if len(request_ids) != len(req.prompts): + raise ValueError( + f"request_ids length ({len(request_ids)}) does not match prompts length ({len(req.prompts)})" + ) + + state = DiffusionRequestState( + req_id=req_id, + sampling=copy.deepcopy(req.sampling_params), + prompts=req.prompts, + ) + self.state_cache[req_id] = state + return state + + def _update_states_after(self, state: DiffusionRequestState, finished: bool) -> None: + """Step-after update: clear cached state for completed request.""" + if finished: + self.state_cache.pop(state.req_id, None) + + def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one step for one scheduled request and return runner output.""" + assert self.pipeline is not None, "Model not loaded. Call load_model() first." + if not self.supports_step_mode(): + raise ValueError("Current pipeline does not support step execution.") + # Stepwise mode only supports the basic state-driven denoise path for now. + # Request-mode extras such as cache backends, KV transfer, editing inputs, + # and similar features are not supported here yet. + if self.od_config.cache_backend not in (None, "none"): + raise ValueError("Step mode does not support cache_backend yet.") + + use_hsdp = self.od_config.parallel_config.use_hsdp + grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() + with grad_context: + state = self._update_states(scheduler_output) + + if state.new_request: + # TODO: support kv manager recv + # TODO: support cache backend + if state.sampling.generator is None and state.sampling.seed is not None: + if state.sampling.generator_device is not None: + gen_device = state.sampling.generator_device + elif self.device.type == "cpu": + gen_device = "cpu" + else: + gen_device = self.device + state.sampling.generator = torch.Generator(device=gen_device).manual_seed(state.sampling.seed) + + with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): + # step0/new request: encode + if state.new_request: + self.pipeline.prepare_encode(state) + + noise_pred = self.pipeline.denoise_step(state) + finished = False + + # In CFG parallel mode, only rank 0 gets the actual noise_pred; non-rank-0 workers receive None. + # A true interrupt (all ranks return None) is detected by checking self.pipeline.interrupt. + if noise_pred is None and getattr(self.pipeline, "interrupt", False): + finished = True + result = DiffusionOutput(error="stepwise denoise interrupted") + else: + self.pipeline.step_scheduler(state, noise_pred) + finished = state.denoise_completed + if finished: + result = self.pipeline.post_decode(state) + else: + result = None + + self._update_states_after(state, finished) + + return RunnerOutput( + req_id=state.req_id, + step_index=state.step_index, + finished=finished, + result=result, + ) diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 1b9dc56673e..3ca12d8e6a2 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -38,7 +38,9 @@ from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.profiler import CurrentProfiler from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.utils import RunnerOutput from vllm_omni.lora.request import LoRARequest from vllm_omni.platforms import current_omni_platform from vllm_omni.worker.gpu_memory_utils import get_process_gpu_memory @@ -194,6 +196,19 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi logger.warning("LoRA activation skipped: %s", exc) return self.model_runner.execute_model(req) + def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one diffusion step by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None: + # Step mode does not support LoRA yet. Clear any previously active + # adapter first so worker-local LoRA state cannot leak in. + self.lora_manager.set_active_adapter(None) + + if any(req_state.req.sampling_params.lora_request is not None for req_state in scheduler_output.req_states): + raise ValueError("Step mode does not support LoRA yet.") + + return self.model_runner.execute_stepwise(scheduler_output) + def load_weights(self, weights) -> set[str]: """Load weights by delegating to the model runner.""" assert self.model_runner is not None, "Model runner not initialized" @@ -376,7 +391,7 @@ def _create_worker( ) return wrapper - def return_result(self, output: DiffusionOutput): + def return_result(self, output: object): """Reply to client, only on rank 0.""" if self.result_mq is not None: try: @@ -623,6 +638,10 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi """ return self.worker.execute_model(reqs, od_config) + def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one diffusion step.""" + return self.worker.execute_stepwise(scheduler_output) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """ Load model weights. diff --git a/vllm_omni/diffusion/worker/utils.py b/vllm_omni/diffusion/worker/utils.py new file mode 100644 index 00000000000..af88913e7df --- /dev/null +++ b/vllm_omni/diffusion/worker/utils.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Per-request mutable state for step-wise diffusion execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch + +if TYPE_CHECKING: + from vllm_omni.diffusion.data import DiffusionOutput + from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType + + +@dataclass +class DiffusionRequestState: + """Per-request mutable state across all pipeline stages. + + Owned by Runner and passed through all step-execution stages: + ``prepare_encode()`` initializes/updates fields, ``denoise_step()`` and + ``step_scheduler()`` mutate per-step fields, and ``post_decode()`` + consumes final latents. This state object is also the cache unit for + future continuous batching. + + This dataclass keeps only the minimal cross-model state required by the + step-execution contract. Pipeline-specific state should be stored in + ``extra`` and promoted here only when it becomes shared across models. + + Examples: + - Wan-style pipelines may keep ``condition``, ``first_frame_mask``, or + ``image_embeds`` in ``extra``. + - Bagel-style pipelines may keep ``gen_context``, + ``cfg_text_context``, ``cfg_img_context``, or ``image_shape`` in + ``extra``. + """ + + # ── Identity / request-level inputs ── + req_id: str + sampling: OmniDiffusionSamplingParams + prompts: list[OmniPromptType] | None = None + + # ── Encoded prompts (set once by prepare_encode) ── + prompt_embeds: torch.Tensor | None = None + prompt_embeds_mask: torch.Tensor | None = None + negative_prompt_embeds: torch.Tensor | None = None + negative_prompt_embeds_mask: torch.Tensor | None = None + + # ── Latent state (mutated every step by step_scheduler) ── + latents: torch.Tensor | None = None + + # ── Timestep schedule (set once by prepare_encode) ── + timesteps: torch.Tensor | list[torch.Tensor] | None = None + step_index: int = 0 + + # ── Per-request scheduler instance (set once by prepare_encode) ── + scheduler: Any | None = None + + # ── CFG config (set once by prepare_encode) ── + do_true_cfg: bool = False + guidance: torch.Tensor | None = None + + # ── Spatial / sequence metadata (set once by prepare_encode) ── + img_shapes: list | None = None + txt_seq_lens: list[int] | None = None + negative_txt_seq_lens: list[int] | None = None + + # Pipeline-specific extras. Keep model-private fields here unless they + # become part of the shared step-execution contract. + # For example: Wan condition tensors / masks, or Bagel KV contexts. + extra: dict[str, Any] = field(default_factory=dict) + + # ── Properties ── + + @property + def current_timestep(self) -> torch.Tensor | None: + if self.timesteps is None: + return None + if self.step_index >= self.total_steps: + return None + if isinstance(self.timesteps, torch.Tensor): + if self.timesteps.ndim == 0: + return self.timesteps + return self.timesteps[self.step_index] + return self.timesteps[self.step_index] + + @property + def total_steps(self) -> int: + if self.timesteps is None: + return 0 + if isinstance(self.timesteps, torch.Tensor): + if self.timesteps.ndim == 0: + return 1 + return int(self.timesteps.shape[0]) + return len(self.timesteps) + + @property + def denoise_completed(self) -> bool: + total_steps = self.total_steps + if total_steps == 0: + return False + return self.step_index >= total_steps + + @property + def new_request(self) -> bool: + # TODO: this is only an approximation for current stepwise mode. + # A real "new request" signal should eventually come from scheduler/runner state transitions. + return self.step_index == 0 or self.timesteps is None + + +@dataclass +class RunnerOutput: + """Output of a single denoising step for a request. + + NOTE: `latents` may be None when returned through IPC to avoid + serialization overhead. The actual latents are kept in Worker's + _request_state_cache. + """ + + req_id: str + step_index: int | None = None + finished: bool = False + result: DiffusionOutput | None = None