Add tensor parallel support to stable audio open 1.0#1406
Conversation
c28f88a to
7e207dc
Compare
Add Tensor Parallelism (TP) support for Stable Audio Open. - Implement fused QKV and GLU weight loading for TP shards - Add cross-rank SDE synchronization for deterministic sampling - Preserve compatibility with the existing inference flow Tested with multi-GPU(T4x2) TP configurations. Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
7e207dc to
3a5fa86
Compare
There was a problem hiding this comment.
Thanks for the contribution. The TP approach works, but broadcasting after every transformer block effectively defeats the purpose of tensor parallelism. I suspect there is a subtle numerical divergence elsewhere. A few high-level concerns and questions inline.
- Remove cross-rank hidden state broadcasts to restore true TP All-Reduce. - Fix nn.Sequential tuple crash by using nn.Linear for cross-attention. - Refactor Gaussian Fourier embeddings to avoid unsafe distributed init. - Replace Python assert statements with explicit exceptions. - Add architectural docstrings for MHA/GQA routing and SwiGLU fusion. - Pass synchronized generator to SDE scheduler step to fix numerical drift. - Sync unseeded generation via tp_group.broadcast instead of global RNG mutation. - Reset scheduler.noise_sampler on forward pass. - Remove sigma_min configuration override to restore native noise schedule. Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…latents Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…enable TP support. Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…without OOM Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b5dbdf6ca6
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
fix(transformer): restore legacy checkpoint key mapping in load_weight Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
|
Thanks for the review! Just a heads-up: I accidentally performed a |
checkpoint keys in `name_mapping` while retaining the legacy `.linear_x.` keys for backward compatibility. This prevents `timestep_proj` and `global_proj` from silently failing to load and using random initialization. Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
|
@vllm-omni-reviewer |
|
Hi @akshatvishu 👋 This tensor parallel support PR for stable audio hasn't been updated for 16 days. Is this still being worked on? Thanks! |
|
Hey @hsliuustc0106 ! Update: Done with the style changes! |
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…atvishu/vllm-omni into feature/sao_tensor_parallel Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
|
Any progress? |
Its ready for my side! Any test or benchmark you want me to test/run? |
Keep both tensor-parallel-size arg (feature branch) and cache-backend/ tea-cache args (main) in text_to_audio.py example. Pass parallel_config and cache params to Omni constructor. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <33392262+akshatvishu@users.noreply.github.com>
|
|
||
| # Memory Optimization: Decode latents in chunks to prevent VAE OOM spikes. | ||
| # Note: Safe default for 47s audio on T4. | ||
| chunk_size = 1 |
There was a problem hiding this comment.
🟡 [important] VAE chunk_size = 1 serializes decode for every caller, not just T4.
The earlier code did self.vae.decode(latents_for_vae).sample in a single call. Hardcoding chunk_size = 1 here means that once num_waveforms_per_prompt > 1 or batched inference lands, users on H100 / A100 pay a serial VAE decode that they do not need. The comment says safe default for 47s audio on T4 but the default is now applied to everyone.
Could we either:
- Expose
vae_chunk_sizeas a sampling param / config option, defaulting tolatents_for_vae.shape[0](whole-batch decode, matches old behavior), so T4 users can opt intovae_chunk_size=1, or - Pick the chunk size dynamically from available VRAM or batch size so the common single-waveform path stays a single decode call?
As-is, this is a silent perf regression for non-T4 users.
There was a problem hiding this comment.
Took approach 1: vae_chunk_size is now an explicit sampling param with default latents_for_vae.shape[0], so the common path stays a single VAE decode call. vae_chunk_size=1 is still available for low-VRAM use.
I added a unit test for the chunking behavior. Should I keep that, or remove it?
|
@princepride ptal as well thanks |
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
|
failed tests: =========================== short test summary info ============================
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_modulated_input_shape
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_run_transformer_blocks_callable
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_postprocess_callable
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_postprocess_output_shape
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_postprocess_return_tuple_when_return_dict_false
FAILED tests/diffusion/cache/test_teacache_extractors.py::TestFluxExtractor::test_without_guidanceerror: E NotImplementedError: Could not run 'vllm::rocm_unquantized_gemm' with arguments from the 'CPU' backendThe CI CPU tests for FLUX extractors are crashing on ROCm builds probably because Happy to raise a PR to fix this if this is indeed the preferred way forward! cc: @linyueqian |
Part of #1217
Purpose
Add tensor parallel support to stable audio open
Test Plan
All experiments were conducted on Kaggle using 2× Tesla T4 GPUs.
Due to the hardware constraints, the maximum tensor parallelism (tp_size) tested was 2.
Environment
GPUs: 2 × Tesla T4
Precision:
float16&float32Maximum tensor parallel size (TP): 2
Global Test Configuration:
Model Initialization:
Test Result
For float16 :
For float32 :
.mp3as github don't allows.wavhf_prefix entries means that we run stable audio hugging face diffuser's pipeline with the same parameters.baseline_prefix entries are the ones from against which the speedup is calculated.For extended testing and results, please refer to this kaggle-notebook.
Notes
final_sigmas_type="zero"in Diffusers produces a harmlesstorchsdeboundary warnings when usingCosineDPMSolverMultistepScheduler. This is expected upstream behavior and not a vLLM Omni issue.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)