diff --git a/docs/design/feature/cfg_parallel.md b/docs/design/feature/cfg_parallel.md index e31e64eddd..64decbe956 100644 --- a/docs/design/feature/cfg_parallel.md +++ b/docs/design/feature/cfg_parallel.md @@ -34,7 +34,7 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic. | Method | Purpose | Automatic Behavior | |--------|---------|-------------------| | [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with CFG | Detects parallel mode, distributes computation, gathers results | -| [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler with sync | Rank 0 steps, broadcasts latents to all ranks | +| [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler | All ranks step locally (no broadcast needed) | | [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine positive/negative | Applies CFG formula with optional normalization | | [`predict_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Forward pass wrapper | Override for custom transformer calls | | [`cfg_normalize_function()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Normalize CFG output | Override for custom normalization | @@ -47,7 +47,7 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic. - Rank 0 computes positive prompt prediction - Rank 1 computes negative prompt prediction - Results are gathered via `all_gather()` - - Combined on rank 0 using CFG formula + - All ranks compute CFG combine locally (deterministic, identical results) - **Sequential mode** (when `cfg_world_size == 1`): - Single rank computes both positive and negative predictions @@ -55,14 +55,7 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic. `scheduler_step_maybe_with_cfg()` ensures consistent latent states across all ranks: -- **CFG-Parallel mode**: - - Only rank 0 performs the scheduler step (applies noise prediction to update latents) - - Updated latents are broadcast to all other ranks via `broadcast()` - - All ranks maintain synchronized latent states for the next iteration - -- **Sequential mode**: - - Single rank directly performs the scheduler step - - No synchronization needed +- All ranks compute the scheduler step locally — no broadcast needed because `predict_noise_maybe_with_cfg` already ensures all ranks have identical noise predictions after `all_gather` + local combine. --- @@ -177,6 +170,70 @@ class LongCatImagePipeline(nn.Module, CFGParallelMixin): # return noise_pred ``` + +### Override `combine_cfg_noise()` for Multi-Output Models + +When `predict_noise()` returns a tuple (e.g., video + audio), the default `combine_cfg_noise()` applies CFG to every element. Override it to apply different logic per element — for example, CFG on video but positive-only on audio: + +```python +class MyVideoAudioPipeline(nn.Module, CFGParallelMixin): + def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize): + (video_pos, audio_pos) = positive_noise_pred + (video_neg, audio_neg) = negative_noise_pred + video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize) + return (video_combined, audio_pos) # audio: positive only, no CFG +``` + +This also requires `predict_noise()` to return a tuple (see [Override predict_noise](#override-predict_noise-for-custom-transformer-calls) above). + +### Implement a Composite Scheduler for Multi-Output Models + +When each output has its own denoising schedule, implement a composite scheduler that dispatches to per-output schedulers. Assign it to `self.scheduler` so the default `scheduler_step()` works without override. + +**Complete example (video + audio with separate schedulers and diffuse loop):** + +```python +class VideoAudioScheduler: + """Composite scheduler dispatching to video and audio schedulers.""" + def __init__(self, video_scheduler, audio_scheduler): + self.video_scheduler = video_scheduler + self.audio_scheduler = audio_scheduler + + def step(self, noise_pred, t, latents, return_dict=False, generator=None): + video_out = self.video_scheduler.step(noise_pred[0], t[0], latents[0], return_dict=False, generator=generator)[0] + audio_out = self.audio_scheduler.step(noise_pred[1], t[1], latents[1], return_dict=False, generator=generator)[0] + return ((video_out, audio_out),) + +class MyVideoAudioPipeline(nn.Module, CFGParallelMixin): + def __init__(self, ...): + self.scheduler = VideoAudioScheduler(video_sched, audio_sched) + + def predict_noise(self, **kwargs): + video_pred, audio_pred = self.transformer(**kwargs) + return (video_pred, audio_pred) + + def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize): + # ... (as above) + + def diffuse(self, video_latents, audio_latents, timesteps_video, timesteps_audio, ...): + for t_v, t_a in zip(timesteps_video, timesteps_audio): + positive_kwargs = {...} + negative_kwargs = {...} if do_true_cfg else None + + video_pred, audio_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, true_cfg_scale=self.guidance_scale, + positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs, + ) + video_latents, audio_latents = self.scheduler_step_maybe_with_cfg( + (video_pred, audio_pred), (t_v, t_a), + (video_latents, audio_latents), do_true_cfg=do_true_cfg, + generator=generator, + ) + return video_latents, audio_latents +``` + +> **Note:** If you use a non-deterministic scheduler, e.g., DDPM, please set `self.scheduler_step_maybe_with_cfg(..., generator=torch.Generator(device).manual_seed(seed))` explicitly to control the randomness of scheduler step among ranks. + --- ## Testing diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py index 24e4559de3..79dbe9e6dd 100644 --- a/tests/diffusion/distributed/test_cfg_parallel.py +++ b/tests/diffusion/distributed/test_cfg_parallel.py @@ -209,12 +209,9 @@ def _test_cfg_parallel_worker( cfg_normalize=test_config["cfg_normalize"], ) - # Only rank 0 has valid output in CFG parallel mode - if cfg_rank == 0: - assert noise_pred is not None - result_queue.put(noise_pred.cpu()) - else: - assert noise_pred is None + # CFG parallel returns the combined prediction on every rank. + assert noise_pred is not None + result_queue.put((cfg_rank, noise_pred.cpu())) destroy_distributed_env() @@ -348,7 +345,18 @@ def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype # Get results from queues baseline_output = baseline_queue.get() - cfg_parallel_output = cfg_parallel_queue.get() + cfg_parallel_outputs = [cfg_parallel_queue.get() for _ in range(cfg_parallel_size)] + cfg_parallel_outputs.sort(key=lambda item: item[0]) + cfg_parallel_output = cfg_parallel_outputs[0][1] + + for cfg_rank, rank_output in cfg_parallel_outputs[1:]: + torch.testing.assert_close( + rank_output, + cfg_parallel_output, + rtol=0, + atol=0, + msg=f"CFG parallel ranks produced different outputs (rank 0 vs rank {cfg_rank})", + ) # Verify shapes match assert baseline_output.shape == cfg_parallel_output.shape, ( diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index 0743a00d4a..a8b0012f66 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -17,12 +17,38 @@ ) +def _wrap(pred: torch.Tensor | tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: + """Normalize prediction to tuple form.""" + return pred if isinstance(pred, tuple) else (pred,) + + +def _unwrap(pred: tuple[torch.Tensor, ...]) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Unwrap single-element tuple to plain tensor; keep multi-element as tuple.""" + return pred[0] if len(pred) == 1 else pred + + +def _slice_pred(pred: tuple[torch.Tensor, ...], output_slice: int) -> tuple[torch.Tensor, ...]: + """Slice each element along dim 1.""" + return tuple(p[:, :output_slice] for p in pred) + + class CFGParallelMixin(metaclass=ABCMeta): """ Base Mixin class for Diffusion pipelines providing shared CFG methods. All pipelines should inherit from this class to reuse classifier-free guidance logic. + + CFG Parallel Architecture: + When cfg_world_size > 1, each rank computes one branch (positive or + negative), then all_gather exchanges results. All ranks then compute + the CFG combine and scheduler step locally — no broadcast needed + because the operations are deterministic. + + Multi-output models: + Models that return tuple from predict_noise() (e.g., video + audio) + should override combine_cfg_noise() to define per-element combine logic, + and set self.scheduler to a composite scheduler that handles tuples. """ def predict_noise_maybe_with_cfg( @@ -33,7 +59,7 @@ def predict_noise_maybe_with_cfg( negative_kwargs: dict[str, Any] | None, cfg_normalize: bool = True, output_slice: int | None = None, - ) -> torch.Tensor | None: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Predict noise with optional classifier-free guidance. @@ -43,57 +69,64 @@ def predict_noise_maybe_with_cfg( positive_kwargs: Kwargs for positive/conditional prediction negative_kwargs: Kwargs for negative/unconditional prediction cfg_normalize: Whether to normalize CFG output (default: True) - output_slice: If set, slice output to [:, :output_slice] for image editing + output_slice: If set, slice each output to [:, :output_slice] for image editing Returns: - Predicted noise tensor (only valid on rank 0 in CFG parallel mode) + Predicted noise tensor or tuple of tensors. + In CFG parallel mode, result is valid on ALL ranks (not just rank 0). + + Note: + For multi-output models (e.g., video + audio where predict_noise + returns a tuple), override combine_cfg_noise() for per-element CFG + logic and set self.scheduler to a composite scheduler. """ if do_true_cfg: # Automatically detect CFG parallel configuration cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1 if cfg_parallel_ready: - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. cfg_group = get_cfg_group() cfg_rank = get_classifier_free_guidance_rank() - if cfg_rank == 0: - local_pred = self.predict_noise(**positive_kwargs) - else: - local_pred = self.predict_noise(**negative_kwargs) + # Each rank computes one branch + kwargs = positive_kwargs if cfg_rank == 0 else negative_kwargs + local_pred = _wrap(self.predict_noise(**kwargs)) - # Slice output for image editing pipelines (remove condition latents) if output_slice is not None: - local_pred = local_pred[:, :output_slice] - - gathered = cfg_group.all_gather(local_pred, separate_tensors=True) - - if cfg_rank == 0: - noise_pred = gathered[0] - neg_noise_pred = gathered[1] - noise_pred = self.combine_cfg_noise(noise_pred, neg_noise_pred, true_cfg_scale, cfg_normalize) - return noise_pred - else: - return None + local_pred = _slice_pred(local_pred, output_slice) + + # All-gather each element, reconstruct positive/negative tuples + gathered = [cfg_group.all_gather(p, separate_tensors=True) for p in local_pred] + positive_noise_pred = tuple(g[0] for g in gathered) + negative_noise_pred = tuple(g[1] for g in gathered) + + # All ranks compute combine (deterministic, same result) + return self.combine_cfg_noise( + positive_noise_pred, + negative_noise_pred, + true_cfg_scale, + cfg_normalize, + ) else: # Sequential CFG: compute both positive and negative - positive_noise_pred = self.predict_noise(**positive_kwargs) - negative_noise_pred = self.predict_noise(**negative_kwargs) + positive_noise_pred = _wrap(self.predict_noise(**positive_kwargs)) + negative_noise_pred = _wrap(self.predict_noise(**negative_kwargs)) - # Slice output for image editing pipelines if output_slice is not None: - positive_noise_pred = positive_noise_pred[:, :output_slice] - negative_noise_pred = negative_noise_pred[:, :output_slice] - - noise_pred = self.combine_cfg_noise( - positive_noise_pred, negative_noise_pred, true_cfg_scale, cfg_normalize + positive_noise_pred = _slice_pred(positive_noise_pred, output_slice) + negative_noise_pred = _slice_pred(negative_noise_pred, output_slice) + + return self.combine_cfg_noise( + positive_noise_pred, + negative_noise_pred, + true_cfg_scale, + cfg_normalize, ) - return noise_pred else: # No CFG: only compute positive/conditional prediction pred = self.predict_noise(**positive_kwargs) if output_slice is not None: - pred = pred[:, :output_slice] + pred = _unwrap(_slice_pred(_wrap(pred), output_slice)) return pred def cfg_normalize_function(self, noise_pred: torch.Tensor, comb_pred: torch.Tensor) -> torch.Tensor: @@ -113,35 +146,61 @@ def cfg_normalize_function(self, noise_pred: torch.Tensor, comb_pred: torch.Tens return noise_pred def combine_cfg_noise( - self, noise_pred: torch.Tensor, neg_noise_pred: torch.Tensor, true_cfg_scale: float, cfg_normalize: bool = False - ) -> torch.Tensor: + self, + positive_noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + negative_noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + true_cfg_scale: float, + cfg_normalize: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Combine conditional and unconditional noise predictions with CFG. + Accepts both plain tensors (backward-compatible, used by LTX2 etc.) + and tuples (multi-output models). Default implementation applies the + standard CFG formula to every element. + + Multi-output models can override this to apply different logic per element. + + Example override for a model returning (video_pred, audio_pred):: + + def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize): + (video_pos, audio_pos) = positive_noise_pred + (video_neg, audio_neg) = negative_noise_pred + video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize) + return (video_combined, audio_pos) # audio: positive only, no CFG + Args: - noise_pred: Conditional noise prediction - neg_noise_pred: Unconditional noise prediction + positive_noise_pred: Positive/conditional prediction(s) — Tensor or tuple + negative_noise_pred: Negative/unconditional prediction(s) — Tensor or tuple true_cfg_scale: CFG scale factor cfg_normalize: Whether to normalize the combined prediction (default: False) Returns: - Combined noise prediction tensor + Combined noise prediction(s) — same type as inputs """ - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - if cfg_normalize: - noise_pred = self.cfg_normalize_function(noise_pred, comb_pred) - else: - noise_pred = comb_pred - - return noise_pred - - def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor: + pos_t = _wrap(positive_noise_pred) + neg_t = _wrap(negative_noise_pred) + + results = [] + for p, n in zip(pos_t, neg_t): + comb = n + true_cfg_scale * (p - n) + if cfg_normalize: + comb = self.cfg_normalize_function(p, comb) + results.append(comb) + return _unwrap(tuple(results)) + + def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Forward pass through transformer to predict noise. Subclasses should override this if they need custom behavior, but the default implementation calls self.transformer. + + Returns: + Single Tensor for standard models, or tuple of Tensors for + multi-output models (e.g., video + audio). Multi-output models + must also override combine_cfg_noise() and set self.scheduler + to a composite scheduler that handles tuples. """ return self.transformer(*args, **kwargs)[0] @@ -156,52 +215,89 @@ def diffuse( Subclasses MUST implement this method to define the complete diffusion/denoising loop for their specific model. - Typical implementation pattern: + Typical implementation pattern (single output): ```python def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, ...): for t in timesteps: - # Prepare kwargs for positive and negative predictions positive_kwargs = {...} negative_kwargs = {...} - # Predict noise with automatic CFG handling noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=True, + do_true_cfg=do_true_cfg, true_cfg_scale=self.guidance_scale, positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs, ) - # Step scheduler with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg( - noise_pred, t, latents, do_true_cfg=True + noise_pred, t, latents, do_true_cfg=do_true_cfg ) return latents ``` + + Multi-output models (e.g., video + audio) should: + 1. Override ``predict_noise()`` to return a tuple + 2. Override ``combine_cfg_noise()`` for per-element CFG logic + 3. Set ``self.scheduler`` to a composite scheduler that handles tuples + + ```python + def diffuse(self, video_latents, audio_latents, timesteps_video, timesteps_audio, ...): + for t_v, t_a in zip(timesteps_video, timesteps_audio): + positive_kwargs = {...} + negative_kwargs = {...} + + # Returns tuple: (video_pred, audio_pred) + video_pred, audio_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=self.guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + ) + + # self.scheduler = VideoAudioScheduler(video_sched, audio_sched) + # which accepts and returns tuples + video_latents, audio_latents = self.scheduler_step_maybe_with_cfg( + (video_pred, audio_pred), + (t_v, t_a), + (video_latents, audio_latents), + do_true_cfg=do_true_cfg, + ) + + return video_latents, audio_latents + ``` """ raise NotImplementedError("Subclasses must implement diffuse") def scheduler_step( self, - noise_pred: torch.Tensor, - t: torch.Tensor, - latents: torch.Tensor, + noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + t: torch.Tensor | tuple[torch.Tensor, ...], + latents: torch.Tensor | tuple[torch.Tensor, ...], per_request_scheduler: Any | None = None, - ) -> torch.Tensor: + generator: torch.Generator | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Step the scheduler. + Default implementation passes inputs directly to ``self.scheduler.step()``. + For multi-output models, set ``self.scheduler`` to a composite scheduler + that handles tuples (e.g., ``VideoAudioScheduler``). + Args: - noise_pred: Predicted noise - t: Current timestep - latents: Current latents + noise_pred: Predicted noise (Tensor or tuple for multi-output) + t: Current timestep (Tensor or tuple when schedulers differ per output) + latents: Current latents (Tensor or tuple for multi-output) per_request_scheduler: Optional request-scoped scheduler that overrides ``self.scheduler`` for this call. This is primarily used by step-wise execution, where each request may keep scheduler state in its own runner-managed state object. Request-level execution should usually leave this as ``None`` and continue using ``self.scheduler``. + generator: Optional torch Generator for reproducible sampling. + When using CFG parallel, both ranks should receive generators + initialized with the same seed so that non-deterministic + schedulers (e.g., DDPM) produce identical results. Returns: Updated latents after scheduler step @@ -211,62 +307,49 @@ def scheduler_step( raise ValueError("No scheduler is available. Set self.scheduler or pass per_request_scheduler.") if not callable(getattr(sched, "step", None)): raise TypeError("per_request_scheduler must provide a callable step(...) method.") - return sched.step(noise_pred, t, latents, return_dict=False)[0] + step_kwargs = dict(return_dict=False) + if generator is not None: + step_kwargs["generator"] = generator + return sched.step(noise_pred, t, latents, **step_kwargs)[0] def scheduler_step_maybe_with_cfg( self, - noise_pred: torch.Tensor, - t: torch.Tensor, - latents: torch.Tensor, + noise_pred: torch.Tensor | tuple[torch.Tensor, ...], + t: torch.Tensor | tuple[torch.Tensor, ...], + latents: torch.Tensor | tuple[torch.Tensor, ...], do_true_cfg: bool, per_request_scheduler: Any | None = None, - ) -> torch.Tensor: + generator: torch.Generator | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: """ - Step the scheduler with (maybe) automatic CFG parallel synchronization. + Step the scheduler with automatic CFG parallel handling. - In CFG parallel mode, only rank 0 computes the scheduler step, - then broadcasts the result to other ranks. + All ranks compute the scheduler step locally — no broadcast needed + because predict_noise_maybe_with_cfg already ensures all ranks have + identical noise_pred after all_gather + local combine. Args: - noise_pred: Predicted noise (only valid on rank 0 in CFG parallel) - t: Current timestep - latents: Current latents + noise_pred: Predicted noise (Tensor or tuple, valid on all ranks) + t: Current timestep (Tensor or tuple when schedulers differ per output) + latents: Current latents (Tensor or tuple) do_true_cfg: Whether CFG is enabled per_request_scheduler: Optional request-scoped scheduler that overrides ``self.scheduler`` for this call. This is mainly needed by step-wise execution, where scheduler state may be stored per request. Request-level execution should normally leave this as ``None``. + generator: Optional torch Generator for reproducible sampling. + When using CFG parallel, both ranks should receive generators + initialized with the same seed so that non-deterministic + schedulers (e.g., DDPM) produce identical results. Returns: - Updated latents (synchronized across all CFG ranks) + Updated latents (identical across all CFG ranks) """ - # Automatically detect CFG parallel configuration - cfg_parallel_ready = do_true_cfg and get_classifier_free_guidance_world_size() > 1 - - if cfg_parallel_ready: - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - # Only rank 0 computes the scheduler step - if cfg_rank == 0: - latents = self.scheduler_step( - noise_pred, - t, - latents, - per_request_scheduler=per_request_scheduler, - ) - - # Broadcast the updated latents to all ranks - latents = latents.contiguous() - cfg_group.broadcast(latents, src=0) - else: - # No CFG parallel: directly compute scheduler step - latents = self.scheduler_step( - noise_pred, - t, - latents, - per_request_scheduler=per_request_scheduler, - ) - - return latents + return self.scheduler_step( + noise_pred, + t, + latents, + per_request_scheduler=per_request_scheduler, + generator=generator, + )