Skip to content

perf(helios): replace strided RoPE with stack+flatten for contiguous memory#2474

Merged
hsliuustc0106 merged 1 commit intovllm-project:mainfrom
willamhou:refactor/unify-rope-helios
Apr 19, 2026
Merged

perf(helios): replace strided RoPE with stack+flatten for contiguous memory#2474
hsliuustc0106 merged 1 commit intovllm-project:mainfrom
willamhou:refactor/unify-rope-helios

Conversation

@willamhou
Copy link
Copy Markdown
Contributor

Summary

Phase 1 of #2436 — applies the stack+flatten RoPE optimization to helios_transformer, matching what PR #2393 did for Wan2.2 (25% speedup on NPU).

Before (strided slice assignment — non-sequential writes):

out = torch.empty_like(hidden_states)
out[..., 0::2] = x_1 * cos - x_2 * sin
out[..., 1::2] = x_1 * sin + x_2 * cos

After (stack+flatten — contiguous memory):

rotated = torch.stack((x_1 * cos - x_2 * sin, x_1 * sin + x_2 * cos), dim=-1)
return rotated.flatten(-2, -1)

Math is identical. Output is bit-exact across float32/float16/bfloat16.

Change: 10 insertions, 4 deletions in helios_transformer.py.

Test plan

  • 11 unit tests verifying numerical equivalence:
    • Bit-identical across 3 dtypes (float32, float16, bfloat16)
    • 4 shape configs (minimal, typical, video-scale 8192 tokens, large head_dim)
    • Output contiguity, shape, and dtype preservation
    • Odd head_dim raises RuntimeError
  • CI GPU tests
  • Phase 2 (unified RoPE utility) to follow based on maintainer feedback

@willamhou willamhou requested a review from hsliuustc0106 as a code owner April 3, 2026 13:31
@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Apr 3, 2026
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

can you help us check whether other models can benefit from this change?

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

alex-jw-brooks commented Apr 3, 2026

@willamhou thanks for doing this! are you able to run performance benchmarks of some kind on this as well? I had looked into it a little yesterday, but didn't see as dramatic change in helios (using this example at least).

It would also be nice to get a feel for how this affects different model architectures, and Helios is a bit of a unique one with some of the pyramid / stage stuff

@willamhou
Copy link
Copy Markdown
Contributor Author

This PR's content was already merged to main via commit dbc0700. Closing as duplicate.

@willamhou willamhou closed this Apr 4, 2026
@willamhou
Copy link
Copy Markdown
Contributor Author

Reopening — the commit exists only on this branch, not on main. Previous close was a mistake.

@willamhou willamhou reopened this Apr 4, 2026
@willamhou
Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks @hsliuustc0106 Thanks for looking into this!

Reference data from PR #2393 (same optimization, applied to Wan2.2):

  • NPU: 44s → 33s (25% speedup)
  • A100: 55s → 44s (20% speedup)

Helios uses the exact same empty_like + strided slice pattern that was optimized in Wan2.2, so the theoretical benefit should be comparable.

Why the improvement may be less visible in some setups:
The main win comes from eliminating non-contiguous memory writes (out[..., 0::2] = ...). This causes InplaceCopy/ViewCopy overhead that's particularly costly on NPU, and less dramatic on high-end CUDA GPUs where strided access is better optimized in hardware. The benefit also depends on the proportion of total time spent in RoPE vs. other ops (attention, MLP, etc.).

To better isolate the RoPE difference, we could profile just the apply_rotary_emb function rather than end-to-end generation time, where RoPE is a small fraction. Something like:

import torch
from torch.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    # run a few forward passes
    ...
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

@alex-jw-brooks Could you share your benchmark config (model variant, resolution, steps, GPU type)? That would help narrow down whether the difference is masked by other ops or genuinely minimal on your hardware.

@hsliuustc0106 hsliuustc0106 added the nightly-test label to trigger buildkite nightly test CI label Apr 6, 2026
@willamhou
Copy link
Copy Markdown
Contributor Author

Gentle ping @hsliuustc0106 @alex-jw-brooks — any further thoughts on the benchmark discussion above? Happy to provide more details if needed.

@willamhou
Copy link
Copy Markdown
Contributor Author

@hsliuustc0106 @alex-jw-brooks Friendly follow-up — this has been open for ~12 days. If the benchmark concern is a blocker, I'm happy to add a micro-benchmark script that profiles just the RoPE function in isolation. Otherwise, if the scope is acceptable as-is (code + equivalence test), a review would be appreciated. Thanks!

lishunyang12
lishunyang12 previously approved these changes Apr 16, 2026
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: perf(helios) — stack+flatten RoPE

Verdict: Approve

Correctness

The mathematical equivalence is sound. The original code writes interleaved results via strided slice assignment:

out[..., 0::2] = even_vals
out[..., 1::2] = odd_vals

The replacement torch.stack((even_vals, odd_vals), dim=-1).flatten(-2, -1) produces the identical interleaving: for each position i in the head_dim/2 axis, the stack creates [even_i, odd_i], and flatten merges them into [even_0, odd_0, even_1, odd_1, ...] — which matches the 0::2 / 1::2 layout exactly.

The cos/sin sub-indexing (cos[..., 0::2], sin[..., 1::2]) is preserved unchanged from the original, so no risk of swapped frequencies.

Performance

Good change. The original pattern (torch.empty_like + two strided writes) produces non-contiguous write patterns that are unfriendly to GPU/NPU memory subsystems. The stack+flatten approach builds the result in a single contiguous allocation. The PR description cites a 25% speedup on NPU from the same pattern in PR #2393 for Wan2.2, which is plausible.

Tests

The 11 unit tests are well-structured:

  • Bit-exact (atol=0, rtol=0) across float32/float16/bfloat16 — good, this is a purely algebraic rearrangement so zero tolerance is correct.
  • Shape coverage includes a video-scale case (8192 tokens) which exercises realistic workloads.
  • Contiguity assertion directly validates the optimization goal.
  • Odd head_dim guard test is a nice edge-case check.

Minor observations (non-blocking)

  1. Test helper duplicates production code: _apply_rotary_emb_helios_optimized in the test file is a copy of the production function. If the production function later diverges (e.g., someone adds in-place ops), the test would still pass against the stale copy. Consider importing the production function directly instead, or add a comment noting this is intentional for isolation.

  2. type_as placement: .type_as(hidden_states) is now called on the flatten result, which is fine. Just noting that if hidden_states is already the same dtype (the common case), this is a no-op — no concern here.

LGTM. Clean, well-tested optimization.

@lishunyang12 lishunyang12 dismissed their stale review April 16, 2026 14:56

re-review

…memory (vllm-project#2436)

Replace strided slice assignment (`out[..., 0::2] = ...`) with
`torch.stack(..., dim=-1).flatten(-2, -1)` in Helios RoPE, matching the
optimization applied to Wan2.2 in PR vllm-project#2393 (25% speedup on NPU).

The stack+flatten pattern produces contiguous memory layout, avoiding
non-sequential write patterns that hurt GPU/NPU cache performance.
Math is identical — verified bit-exact across float32/float16/bfloat16.

This is Phase 1 of vllm-project#2436 (Helios only). Phase 2 (unified utility) to follow.

Signed-off-by: willamhou <willamhou@ceresman.com>
@willamhou willamhou force-pushed the refactor/unify-rope-helios branch from dbc0700 to ccd0c5d Compare April 19, 2026 10:34
@willamhou
Copy link
Copy Markdown
Contributor Author

Rebased onto latest main. @hsliuustc0106 @alex-jw-brooks ready for review.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit 1568451 into vllm-project:main Apr 19, 2026
7 of 8 checks passed
@Gaohan123
Copy link
Copy Markdown
Collaborator

@willamhou Please check the L3 CI failure https://buildkite.com/vllm/vllm-omni/builds/7241/steps/canvas and resolve it. Thanks

@willamhou
Copy link
Copy Markdown
Contributor Author

@Gaohan123 Thanks for flagging! The failing test is Qwen3-TTS Base E2E Test — this PR only modifies helios_transformer.py (RoPE pattern in Helios model), which is unrelated to Qwen3-TTS.

This looks like a pre-existing flaky test or an issue introduced by another commit merged around the same time. The Helios-specific tests all pass.

@willamhou
Copy link
Copy Markdown
Contributor Author

Correction — there are actually two failing tests, not one:

  1. Qwen3-TTS Base E2E Testtest_qwen3_tts_base.py
  2. Omni Model Test with H100test_qwen3_omni.py + test_mimo_audio.py

I've checked all five failing test files: none import helios code, share conftest fixtures with helios, or have any indirect dependency path to helios_transformer.py. They use HuggingFace-compatible RoPE (_RotaryEmbedding), not Helios 3D RoPE. Our test file (test_rotary_emb_equivalence.py) has no module-level side effects.

This is a merge CI run on main (full test matrix). The failures are unrelated flaky tests or resource contention.

lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
…memory (vllm-project#2474)

Signed-off-by: willamhou <willamhou@ceresman.com>
Co-authored-by: willamhou <willamhou@ceresman.com>
qinganrice pushed a commit to qinganrice/vllm-omni that referenced this pull request Apr 23, 2026
…memory (vllm-project#2474)

Signed-off-by: willamhou <willamhou@ceresman.com>
Co-authored-by: willamhou <willamhou@ceresman.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nightly-test label to trigger buildkite nightly test CI ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants