Skip to content

[Feat]add cpu-offload/layerwise-offload for stable-audio-open & fix output inconsistency with same seed#2909

Merged
linyueqian merged 19 commits into
vllm-project:mainfrom
sphinxkkkbc:feature/add-cpu-offloading
May 5, 2026
Merged

[Feat]add cpu-offload/layerwise-offload for stable-audio-open & fix output inconsistency with same seed#2909
linyueqian merged 19 commits into
vllm-project:mainfrom
sphinxkkkbc:feature/add-cpu-offloading

Conversation

@sphinxkkkbc
Copy link
Copy Markdown
Contributor

@sphinxkkkbc sphinxkkkbc commented Apr 19, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

1.Add cpu-offloading(layerwise-offload) for stable-audio-open
2.fix output inconsistency with same seed

Test Plan

python vllm-omni/examples/offline_inference/text_to_audio/text_to_audio.py \
  --model stabilityai/stable-audio-open-1.0 \
  --prompt "The sound of a dog barking" \
  --enable-cpu-offload \
  --audio-length 10.0 \
  --num-inference-steps 100 \
  --guidance-scale 7.0 \
  --seed 42 \
  --output dog_barking_cpu_offload.wav

python vllm-omni/examples/offline_inference/text_to_audio/text_to_audio.py \
  --model stabilityai/stable-audio-open-1.0 \
  --prompt "The sound of a dog barking" \
  --enable-layerwise-offload \
  --audio-length 10.0 \
  --num-inference-steps 100 \
  --guidance-scale 7.0 \
  --seed 42 \
  --output dog_barking_layerwise_offload.wav

python vllm-omni/examples/offline_inference/text_to_audio/text_to_audio.py \
  --model stabilityai/stable-audio-open-1.0 \
  --prompt "The sound of a dog barking" \
  --audio-length 10.0 \
  --num-inference-steps 100 \
  --guidance-scale 7.0 \
  --seed 42 \
  --output dog_barking.wav

Test Result

Offload Strategy Peak Memory Generation Time Output Wav
LayerWise Offload 11.00 GB reserved, 5.70 GB allocated 18.59s dog_barking_layerwise_offload.wav
CPU Offload(ModelWise) 11.70 GB reserved, 7.39 GB allocated 12.20s dog_barking_cpu_offload.wav
No Offload 12.81 GB reserved, 7.60 GB allocated 9.30s dog_barking.wav

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@sphinxkkkbc sphinxkkkbc changed the title [Fear]add cpu-offload/layerwise-offload for stable-audio-open [Feat]add cpu-offload/layerwise-offload for stable-audio-open Apr 19, 2026
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

@hsliuustc0106 PTAL, any other test to do?

@sphinxkkkbc sphinxkkkbc force-pushed the feature/add-cpu-offloading branch from 9b09453 to 831e96e Compare April 21, 2026 03:17
@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

sphinxkkkbc commented Apr 21, 2026

previous implementation of stable-audio-open doesn't generate same output with same seed, i added eval() and generator in denoising loop, but it still output different and regularly presents two outputs with subtle differences like this. @linyueqian does this happened before?These are the results of my three experiments

CleanShot 2026-04-21 at 11 27 47@2x CleanShot 2026-04-21 at 14 15 02@2x CleanShot 2026-04-21 at 14 15 47@2x

@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

@hsliuustc0106 @linyueqian I've implemented CPU offloading for stable-audio-open. During testing, I noticed that even without offloading, the output with the same seed can be inconsistent across runs (see screenshots above). I tried adding

self.transformer.eval() 

and a generator

 latents = self.scheduler.step(noise_pred, t, latents, generator).prev_sample 

in the denoising loop, but the issue persists. Any advice would be appreciated. Thanks!

@sphinxkkkbc sphinxkkkbc force-pushed the feature/add-cpu-offloading branch 3 times, most recently from 831e96e to 791aedf Compare April 23, 2026 02:07
@linyueqian
Copy link
Copy Markdown
Collaborator

can you check with HF original implementation? a side by side comparation of embeddings in each step may help

@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

can you check with HF original implementation? a side by side comparation of embeddings in each step may help

thanks, I'll check it

@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

sphinxkkkbc commented Apr 25, 2026

can you check with HF original implementation? a side by side comparation of embeddings in each step may help

thanks, I'll check it

@linyueqian I've checked the hf implementation, the difference is that the generator is not included in scheduler.step.
before:

latents = self.scheduler.step(noise_pred, t, latents).prev_sample

after:

latents = self.scheduler.step(noise_pred, t, latents, generator).prev_sample

Should I include it in this PR? Also, the CPU offloading code is ready, could you help review it? Thanks

@linyueqian
Copy link
Copy Markdown
Collaborator

yes please include it in this pr you can revise the desc and title a bit. thanks!

@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 25, 2026
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@sphinxkkkbc sphinxkkkbc changed the title [Feat]add cpu-offload/layerwise-offload for stable-audio-open [Feat]add cpu-offload/layerwise-offload for stable-audio-open & fix output mismatch with same seed Apr 25, 2026
@sphinxkkkbc sphinxkkkbc changed the title [Feat]add cpu-offload/layerwise-offload for stable-audio-open & fix output mismatch with same seed [Feat]add cpu-offload/layerwise-offload for stable-audio-open & fix output inconsistency with same seed Apr 25, 2026
@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

sphinxkkkbc commented Apr 25, 2026

yes please include it in this pr you can revise the desc and title a bit. thanks!

done, CI is passed

@hsliuustc0106 hsliuustc0106 removed the ready label to trigger buildkite CI label Apr 29, 2026
@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

can this PR move forward? if any remaining issues, please let me know, thanks!

@linyueqian linyueqian added the ready label to trigger buildkite CI label May 3, 2026

# Scheduler step
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, generator).prev_sample
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[suggestion] Worth adding a small regression test that pins this fix. Run the pipeline twice with the same torch.Generator(...).manual_seed(42) and assert the audio tensors are bitwise equal (or torch.allclose with tight tolerance). Without it, a future contributor could drop generator again and we'd silently regress to non-deterministic outputs.

The existing tests/e2e/offline_inference/test_diffusion_layerwise_offload.py and test_diffusion_cpu_offload.py are good neighbors for this; they only parametrize riverclouds/qwen_image_random today. Adding stable-audio-open there with a determinism assertion would cover both this fix and the new offload paths in one shot.

@linyueqian
Copy link
Copy Markdown
Collaborator

please add a test as suggested thanks

Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@Gaohan123 Gaohan123 removed this from the v0.20.0 milestone May 4, 2026
@linyueqian
Copy link
Copy Markdown
Collaborator

please fix ci and dco.

sphinxkkkbc added 2 commits May 4, 2026 23:25
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@sphinxkkkbc sphinxkkkbc force-pushed the feature/add-cpu-offloading branch 3 times, most recently from 3235e53 to 9d2c592 Compare May 4, 2026 15:49
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@sphinxkkkbc sphinxkkkbc force-pushed the feature/add-cpu-offloading branch from 1010b82 to 64e0aa1 Compare May 4, 2026 15:55
sphinxkkkbc and others added 2 commits May 4, 2026 23:57
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

New CI failed even at weight size assertion in weight loading stage – looks like a recently introduced bug. I'll fix it later

sphinxkkkbc and others added 10 commits May 5, 2026 11:52
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
…erence

Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
@sphinxkkkbc
Copy link
Copy Markdown
Contributor Author

@linyueqian The latest version works as expected. The AMD Ci failure seems unrelated to this pr. Changes are listed below, if there's better way to fix these, please let me know, thanks.

  1. [Diffusion] [Model] Support AudioX #2077 added a GaussianFourierProjection module to the Stable Audio transformer, but its parameter shape did not match the checkpoint weight shape. I added a narrow preprocessing helper to restore the trailing singleton dimension when needed.

  2. The official Stable Audio scheduler uses final_sigmas_type="zero", so the final CosineDPM step asks torchsde for Brownian noise over the sigma_min -> 0 interval. On CUDA, this out-of-range interval only emits a warning and produces zero noise, while on ROCm it can raise a RecursionError. I added a scheduler wrapper that keeps the official schedule unchanged and intercepts only this final sigma_min -> 0 step, substituting zero noise to match CUDA behavior.

One remaining question: In AMD CI, reserved memory appears to be an outlier, I temporarily set its threshold to None, while allocated memory matches the expected CPU-offload savings. Should we use allocated memory instead of reserved memory for the CPU-offload memory assertion? Or may need to re-examine the memory activity during model-wise cpu offloading.

@linyueqian linyueqian merged commit a0918ce into vllm-project:main May 5, 2026
7 of 8 checks passed
@linyueqian
Copy link
Copy Markdown
Collaborator

thanks! i have merged it.

@linyueqian
Copy link
Copy Markdown
Collaborator

clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…utput inconsistency with same seed (vllm-project#2909)

Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
Co-authored-by: sphinxkkkbc <binchengkang8@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants