Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions docs/design/feature/cfg_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic.
| Method | Purpose | Automatic Behavior |
|--------|---------|-------------------|
| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with CFG | Detects parallel mode, distributes computation, gathers results |
| [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler with sync | Rank 0 steps, broadcasts latents to all ranks |
| [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler | All ranks step locally (no broadcast needed) |
| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine positive/negative | Applies CFG formula with optional normalization |
| [`predict_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Forward pass wrapper | Override for custom transformer calls |
| [`cfg_normalize_function()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Normalize CFG output | Override for custom normalization |
Expand All @@ -47,22 +47,15 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic.
- Rank 0 computes positive prompt prediction
- Rank 1 computes negative prompt prediction
- Results are gathered via `all_gather()`
- Combined on rank 0 using CFG formula
- All ranks compute CFG combine locally (deterministic, identical results)

- **Sequential mode** (when `cfg_world_size == 1`):
- Single rank computes both positive and negative predictions
- Directly combines them with CFG formula

`scheduler_step_maybe_with_cfg()` ensures consistent latent states across all ranks:

- **CFG-Parallel mode**:
- Only rank 0 performs the scheduler step (applies noise prediction to update latents)
- Updated latents are broadcast to all other ranks via `broadcast()`
- All ranks maintain synchronized latent states for the next iteration

- **Sequential mode**:
- Single rank directly performs the scheduler step
- No synchronization needed
- All ranks compute the scheduler step locally — no broadcast needed because `predict_noise_maybe_with_cfg` already ensures all ranks have identical noise predictions after `all_gather` + local combine.

---

Expand Down Expand Up @@ -177,6 +170,70 @@ class LongCatImagePipeline(nn.Module, CFGParallelMixin):
# return noise_pred
```


### Override `combine_cfg_noise()` for Multi-Output Models

When `predict_noise()` returns a tuple (e.g., video + audio), the default `combine_cfg_noise()` applies CFG to every element. Override it to apply different logic per element — for example, CFG on video but positive-only on audio:

```python
class MyVideoAudioPipeline(nn.Module, CFGParallelMixin):
def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize):
(video_pos, audio_pos) = positive_noise_pred
(video_neg, audio_neg) = negative_noise_pred
video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize)
return (video_combined, audio_pos) # audio: positive only, no CFG
```

This also requires `predict_noise()` to return a tuple (see [Override predict_noise](#override-predict_noise-for-custom-transformer-calls) above).

### Implement a Composite Scheduler for Multi-Output Models

When each output has its own denoising schedule, implement a composite scheduler that dispatches to per-output schedulers. Assign it to `self.scheduler` so the default `scheduler_step()` works without override.

**Complete example (video + audio with separate schedulers and diffuse loop):**

```python
class VideoAudioScheduler:
"""Composite scheduler dispatching to video and audio schedulers."""
def __init__(self, video_scheduler, audio_scheduler):
self.video_scheduler = video_scheduler
self.audio_scheduler = audio_scheduler

def step(self, noise_pred, t, latents, return_dict=False, generator=None):
video_out = self.video_scheduler.step(noise_pred[0], t[0], latents[0], return_dict=False, generator=generator)[0]
audio_out = self.audio_scheduler.step(noise_pred[1], t[1], latents[1], return_dict=False, generator=generator)[0]
return ((video_out, audio_out),)

class MyVideoAudioPipeline(nn.Module, CFGParallelMixin):
def __init__(self, ...):
self.scheduler = VideoAudioScheduler(video_sched, audio_sched)

def predict_noise(self, **kwargs):
video_pred, audio_pred = self.transformer(**kwargs)
return (video_pred, audio_pred)

def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize):
# ... (as above)

def diffuse(self, video_latents, audio_latents, timesteps_video, timesteps_audio, ...):
for t_v, t_a in zip(timesteps_video, timesteps_audio):
positive_kwargs = {...}
negative_kwargs = {...} if do_true_cfg else None

video_pred, audio_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=do_true_cfg, true_cfg_scale=self.guidance_scale,
positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs,
)
video_latents, audio_latents = self.scheduler_step_maybe_with_cfg(
(video_pred, audio_pred), (t_v, t_a),
(video_latents, audio_latents), do_true_cfg=do_true_cfg,
generator=generator,
)
return video_latents, audio_latents
```

Comment thread
TKONIY marked this conversation as resolved.
> **Note:** If you use a non-deterministic scheduler, e.g., DDPM, please set `self.scheduler_step_maybe_with_cfg(..., generator=torch.Generator(device).manual_seed(seed))` explicitly to control the randomness of scheduler step among ranks.

---

## Testing
Expand Down
22 changes: 15 additions & 7 deletions tests/diffusion/distributed/test_cfg_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,9 @@ def _test_cfg_parallel_worker(
cfg_normalize=test_config["cfg_normalize"],
)

# Only rank 0 has valid output in CFG parallel mode
if cfg_rank == 0:
assert noise_pred is not None
result_queue.put(noise_pred.cpu())
else:
assert noise_pred is None
# CFG parallel returns the combined prediction on every rank.
assert noise_pred is not None
result_queue.put((cfg_rank, noise_pred.cpu()))

destroy_distributed_env()

Expand Down Expand Up @@ -348,7 +345,18 @@ def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype

# Get results from queues
baseline_output = baseline_queue.get()
cfg_parallel_output = cfg_parallel_queue.get()
cfg_parallel_outputs = [cfg_parallel_queue.get() for _ in range(cfg_parallel_size)]
cfg_parallel_outputs.sort(key=lambda item: item[0])
cfg_parallel_output = cfg_parallel_outputs[0][1]

for cfg_rank, rank_output in cfg_parallel_outputs[1:]:
torch.testing.assert_close(
rank_output,
cfg_parallel_output,
rtol=0,
atol=0,
msg=f"CFG parallel ranks produced different outputs (rank 0 vs rank {cfg_rank})",
)

# Verify shapes match
assert baseline_output.shape == cfg_parallel_output.shape, (
Expand Down
Loading
Loading