Skip to content

Add tensor parallel support to stable audio open 1.0#1406

Open
akshatvishu wants to merge 19 commits into
vllm-project:mainfrom
akshatvishu:feature/sao_tensor_parallel
Open

Add tensor parallel support to stable audio open 1.0#1406
akshatvishu wants to merge 19 commits into
vllm-project:mainfrom
akshatvishu:feature/sao_tensor_parallel

Conversation

@akshatvishu
Copy link
Copy Markdown
Contributor

@akshatvishu akshatvishu commented Feb 19, 2026

Part of #1217

Purpose

Add tensor parallel support to stable audio open

Test Plan

All experiments were conducted on Kaggle using 2× Tesla T4 GPUs.
Due to the hardware constraints, the maximum tensor parallelism (tp_size) tested was 2.

Environment

GPUs: 2 × Tesla T4

Precision: float16 & float32

Maximum tensor parallel size (TP): 2

Global Test Configuration:

STEPS = 100
AUDIO_LENGTH = 10.0  # seconds

Model Initialization:

parallel_config = DiffusionParallelConfig(
    tensor_parallel_size=tp_size
)

# Initialize Omni model
omni = Omni(
    model=MODEL_PATH,
    dtype="float16",
    parallel_config=parallel_config,
)

generator = torch.Generator(
    device=current_omni_platform.device_type
).manual_seed(SEED)

params = OmniDiffusionSamplingParams(
    num_inference_steps=STEPS,
    guidance_scale=7.0,
    generator=generator,
    extra_args={
        "audio_start_in_s": 0.0,
        "audio_end_in_s": AUDIO_LENGTH,
    },
omni.generate({"prompt": "A danceable electronic track in the genre of dance.", "negative_prompt": "Low quality, noisy"}, params)
)

Test Result

For float16 :

Name tp_size Time(seconds) Speedup File
hf_fp16 - 28.67 - dance_track.mp3
baseline_tp1_fp16 1 26.40 - TP_1_fp16.mp3
tp2_fp16 2 20.01 1.32x TP_2_fp16.mp3

For float32 :

Name tp_size Time(sec) Speedup File
hf_fp32 - 194.15 - dance_track_fp32.mp3
baseline_tp_1_fp32 1 145.00 - TP_1_fp32.mp3
tp_2_fp32 2 78.01 1.86x TP_2_fp32.mp3
  • files are in .mp3 as github don't allows .wav
  • hf_ prefix entries means that we run stable audio hugging face diffuser's pipeline with the same parameters.
  • baseline_ prefix entries are the ones from against which the speedup is calculated.

For extended testing and results, please refer to this kaggle-notebook.


Notes

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@akshatvishu akshatvishu force-pushed the feature/sao_tensor_parallel branch from c28f88a to 7e207dc Compare February 20, 2026 21:44
Add Tensor Parallelism (TP) support for Stable Audio Open.

- Implement fused QKV and GLU weight loading for TP shards
- Add cross-rank SDE synchronization for deterministic sampling
- Preserve compatibility with the existing inference flow

Tested with multi-GPU(T4x2) TP configurations.

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu akshatvishu force-pushed the feature/sao_tensor_parallel branch from 7e207dc to 3a5fa86 Compare February 20, 2026 21:51
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. The TP approach works, but broadcasting after every transformer block effectively defeats the purpose of tensor parallelism. I suspect there is a subtle numerical divergence elsewhere. A few high-level concerns and questions inline.

Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py Outdated
- Remove cross-rank hidden state broadcasts to restore true TP All-Reduce.
- Fix nn.Sequential tuple crash by using nn.Linear for cross-attention.
- Refactor Gaussian Fourier embeddings to avoid unsafe distributed init.
- Replace Python assert statements with explicit exceptions.
- Add architectural docstrings for MHA/GQA routing and SwiGLU fusion.
- Pass synchronized generator to SDE scheduler step to fix numerical drift.
- Sync unseeded generation via tp_group.broadcast instead of global RNG mutation.
- Reset scheduler.noise_sampler on forward pass.
- Remove sigma_min configuration override to restore native noise schedule.

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…latents

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…enable TP support.

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…without OOM

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu akshatvishu marked this pull request as ready for review February 21, 2026 20:14
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b5dbdf6ca6

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
Comment thread vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py
fix(transformer):  restore legacy checkpoint key mapping in load_weight

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu
Copy link
Copy Markdown
Contributor Author

akshatvishu commented Feb 21, 2026

Thanks for the review! Just a heads-up: I accidentally performed a force-with-lease on a follow-up after the initial squash . This shifted the hunk headers, so my previous responses to your line-specific comments might point to the wrong code blocks now but I’ve manually verified that all requested changes are addressed.

  checkpoint keys in `name_mapping` while retaining the legacy `.linear_x.`
  keys for backward compatibility. This prevents `timestep_proj` and
  `global_proj` from silently failing to load and using random initialization.

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu akshatvishu changed the title Add tensor parallel support to stable audio opem Add tensor parallel support to stable audio open 1.0 Feb 21, 2026
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Hi @akshatvishu 👋

This tensor parallel support PR for stable audio hasn't been updated for 16 days. Is this still being worked on?

Thanks!

@akshatvishu
Copy link
Copy Markdown
Contributor Author

akshatvishu commented Mar 12, 2026

Hey @hsliuustc0106 !
I think apart from a minor style fix in one of the comment ; this is ready to go from my side! Happy to provide any more test if needed!

Update: Done with the style changes!

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 17, 2026
linyueqian and others added 4 commits March 24, 2026 01:14
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…atvishu/vllm-omni into feature/sao_tensor_parallel

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@Gaohan123 Gaohan123 added this to the v0.20.0 milestone Apr 14, 2026
@lishunyang12
Copy link
Copy Markdown
Collaborator

Any progress?

@akshatvishu
Copy link
Copy Markdown
Contributor Author

akshatvishu commented Apr 20, 2026

Any progress?

Its ready for my side! Any test or benchmark you want me to test/run?

akshatvishu and others added 2 commits April 20, 2026 20:43
Keep both tensor-parallel-size arg (feature branch) and cache-backend/
tea-cache args (main) in text_to_audio.py example. Pass parallel_config
and cache params to Omni constructor.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <33392262+akshatvishu@users.noreply.github.com>
@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 24, 2026
@linyueqian linyueqian removed this from the v0.20.0 milestone Apr 24, 2026

# Memory Optimization: Decode latents in chunks to prevent VAE OOM spikes.
# Note: Safe default for 47s audio on T4.
chunk_size = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 [important] VAE chunk_size = 1 serializes decode for every caller, not just T4.

The earlier code did self.vae.decode(latents_for_vae).sample in a single call. Hardcoding chunk_size = 1 here means that once num_waveforms_per_prompt > 1 or batched inference lands, users on H100 / A100 pay a serial VAE decode that they do not need. The comment says safe default for 47s audio on T4 but the default is now applied to everyone.

Could we either:

  1. Expose vae_chunk_size as a sampling param / config option, defaulting to latents_for_vae.shape[0] (whole-batch decode, matches old behavior), so T4 users can opt into vae_chunk_size=1, or
  2. Pick the chunk size dynamically from available VRAM or batch size so the common single-waveform path stays a single decode call?

As-is, this is a silent perf regression for non-T4 users.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took approach 1: vae_chunk_size is now an explicit sampling param with default latents_for_vae.shape[0], so the common path stays a single VAE decode call. vae_chunk_size=1 is still available for low-VRAM use.

I added a unit test for the chunking behavior. Should I keep that, or remove it?

Comment thread vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py Outdated
@linyueqian
Copy link
Copy Markdown
Collaborator

@princepride ptal as well thanks

akshatvishu and others added 3 commits April 24, 2026 23:15
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@linyueqian linyueqian added this to the v0.20.0 milestone Apr 28, 2026
@linyueqian linyueqian enabled auto-merge (squash) April 28, 2026 17:36
@akshatvishu
Copy link
Copy Markdown
Contributor Author

akshatvishu commented Apr 29, 2026

failed tests:

=========================== short test summary info ============================
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_modulated_input_shape
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_run_transformer_blocks_callable
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_postprocess_callable
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_postprocess_output_shape
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_postprocess_return_tuple_when_return_dict_false
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_without_guidance

error:

E NotImplementedError: Could not run 'vllm::rocm_unquantized_gemm' with arguments from the 'CPU' backend

The CI CPU tests for FLUX extractors are crashing on ROCm builds probably because ReplicatedLinear dispatches to GPU kernels. I think we can fix this by using the same monkeypatch logic used in the adalayernorm tests to force the default CPU-compatible GEMM path during test execution.

Happy to raise a PR to fix this if this is indeed the preferred way forward!

ref:
https://github.com/vllm-project/vllm-omni/blob/main/tests/diffusion/layers/test_adalayernorm.py#L37

cc: @linyueqian

@Gaohan123 Gaohan123 modified the milestones: v0.20.0, v0.22.0 May 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants