Skip to content

[Helion] Add rotary positional embedding (RoPE) Helion kernels (neox + gptj)#38880

Open
meinie0826 wants to merge 1 commit intovllm-project:mainfrom
meinie0826:helion-rope-kernel
Open

[Helion] Add rotary positional embedding (RoPE) Helion kernels (neox + gptj)#38880
meinie0826 wants to merge 1 commit intovllm-project:mainfrom
meinie0826:helion-rope-kernel

Conversation

@meinie0826
Copy link
Copy Markdown

@meinie0826 meinie0826 commented Apr 3, 2026

Purpose

Add Helion/Triton-backed kernels for rotary positional embedding (RoPE), supporting both Neox-style and GPT-J-style rotation. The kernels register as implementations of the vLLM IR ops introduced in #33825 and are dispatched automatically on CUDA-capable hardware when Helion is installed.

Test Plan

# Autotune (run once on target GPU)
python scripts/autotune_helion_kernels.py --kernels _helion_rotary_embedding_neox _helion_rotary_embedding_gptj

# Unit tests
python -m pytest tests/kernels/helion/test_rotary_embedding.py -v

Test Result

cd /home/meiziyuan/vllm && python -m pytest tests/kernels/helion/test_rotary_embedding.py -v 2>&1
======================================================= test session starts =======================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/home/meiziyuan/vllm/.hypothesis/examples'))
rootdir: /home/meiziyuan/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, hypothesis-6.130.8, flakefinder-1.1.0, rerunfailures-15.1, shard-0.1.2, xdist-3.6.1, xdoctest-1.0.2
collected 47 items                                                                                                                
Running 47 items in this shard: tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_exact_match, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_ceiling_selection, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_fallback_to_largest, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_closest_rotary_dim, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_fallback_to_default, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_empty_config_keys, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_malformed_key_raises, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_default_skipped_when_valid_keys_exist, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_pytorch, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_partial_rope, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-1], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-8], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-32], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-128], tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_pytorch, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_both_kernels_registered, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_kernel_wrappers_have_config_picker, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_kernel_wrappers_have_input_generator, tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_output_is_not_inplace

tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_exact_match PASSED                     [  2%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_ceiling_selection PASSED               [  4%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_fallback_to_largest PASSED             [  6%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_closest_rotary_dim PASSED              [  8%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_fallback_to_default PASSED             [ 10%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_empty_config_keys PASSED               [ 12%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_malformed_key_raises PASSED            [ 14%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingConfigPicker::test_default_skipped_when_valid_keys_exist PASSED [ 17%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-1] PASSED [ 19%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-8] PASSED [ 21%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-32] PASSED [ 23%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-64-128] PASSED [ 25%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-1] PASSED [ 27%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-8] PASSED [ 29%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-32] PASSED [ 31%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype0-128-128] PASSED [ 34%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-1] PASSED [ 36%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-8] PASSED [ 38%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-32] PASSED [ 40%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-64-128] PASSED [ 42%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-1] PASSED [ 44%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-8] PASSED [ 46%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-32] PASSED [ 48%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_baseline[dtype1-128-128] PASSED [ 51%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_vs_pytorch PASSED              [ 53%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingNeoxCorrectness::test_neox_partial_rope PASSED            [ 55%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-1] PASSED [ 57%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-8] PASSED [ 59%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-32] PASSED [ 61%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-64-128] PASSED [ 63%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-1] PASSED [ 65%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-8] PASSED [ 68%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-32] PASSED [ 70%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype0-128-128] PASSED [ 72%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-1] PASSED [ 74%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-8] PASSED [ 76%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-32] PASSED [ 78%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-64-128] PASSED [ 80%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-1] PASSED [ 82%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-8] PASSED [ 85%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-32] PASSED [ 87%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_baseline[dtype1-128-128] PASSED [ 89%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingGptjCorrectness::test_gptj_vs_pytorch PASSED              [ 91%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_both_kernels_registered PASSED         [ 93%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_kernel_wrappers_have_config_picker PASSED [ 95%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_kernel_wrappers_have_input_generator PASSED [ 97%]
tests/kernels/helion/test_rotary_embedding.py::TestRotaryEmbeddingRegistration::test_output_is_not_inplace SKIPPED (Cur...) [100%]

======================================================== warnings summary =========================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

tests/kernels/helion/test_rotary_embedding.py: 14 warnings
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================== 46 passed, 1 skipped, 16 warnings in 33.45s ===========================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copilot AI review requested due to automatic review settings April 3, 2026 05:34
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Helion-optimized kernels for rotary positional embeddings (RoPE), supporting both Neox and GPT-J styles. It includes the IR operation definitions, the Triton-based implementations, autotuning configurations for NVIDIA H800, and a comprehensive test suite. Feedback focuses on the GPT-J implementation, specifically addressing a crash and incorrect output when rotary_dim is less than head_size (partial RoPE), and recommends adding a corresponding test case to ensure correctness for this scenario.

Comment on lines +212 to +215
q_rot4 = query.view(num_tokens, num_q_heads, rot_half, 2)
k_rot4 = key.view(num_tokens, num_kv_heads, rot_half, 2)
q_out4 = torch.empty_like(q_rot4)
k_out4 = torch.empty_like(k_rot4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The GPT-J implementation will crash with a RuntimeError if rotary_dim < head_size. The view operation expects the total number of elements to match exactly, but when rotary_dim < head_size, the query and key tensors contain additional unrotated dimensions that are not accounted for in this view.

Furthermore, the implementation is missing the logic to copy the unrotated tail of the embeddings, which is required for partial RoPE support. The comment on line 211 stating that the IR wrapper handles the split is incorrect; the IR op passes the full tensors to the implementation.

To fix this, you should handle the partial RoPE case by copying the tail and applying the 4D view trick only to the rotary portion. Note that if rotary_dim < head_size, the rotary portion is not contiguous across heads, so you may need to ensure contiguity before applying the 4D view.

        q_3d = query.view(num_tokens, num_q_heads, head_size)
        k_3d = key.view(num_tokens, num_kv_heads, head_size)
        q_out_3d = torch.empty_like(q_3d)
        k_out_3d = torch.empty_like(k_3d)

        if rotary_dim < head_size:
            q_out_3d[:, :, rotary_dim:] = q_3d[:, :, rotary_dim:]
            k_out_3d[:, :, rotary_dim:] = k_3d[:, :, rotary_dim:]

        q_rot4 = q_3d[:, :, :rotary_dim].view(num_tokens, num_q_heads, rot_half, 2)
        k_rot4 = k_3d[:, :, :rotary_dim].view(num_tokens, num_kv_heads, rot_half, 2)
        q_out4 = q_out_3d[:, :, :rotary_dim].view(num_tokens, num_q_heads, rot_half, 2)
        k_out4 = k_out_3d[:, :, :rotary_dim].view(num_tokens, num_kv_heads, rot_half, 2)

k_out4[tile_tok, :, tile_pair, 0] = k1 * cos_h2 - k2 * sin_h2
k_out4[tile_tok, :, tile_pair, 1] = k2 * cos_h2 + k1 * sin_h2

return q_out4.view_as(query), k_out4.view_as(key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If rotary_dim < head_size, q_out4 will not have the same number of elements as query, causing view_as(query) to fail. You should return the full-sized output tensors.

Suggested change
return q_out4.view_as(query), k_out4.view_as(key)
return q_out_3d.view_as(query), k_out_3d.view_as(key)

)
torch.testing.assert_close(q_helion.float(), q_ref.float(), rtol=2e-2, atol=2e-2)
torch.testing.assert_close(k_helion.float(), k_ref.float(), rtol=2e-2, atol=2e-2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test suite is missing a correctness test for GPT-J style RoPE with rotary_dim < head_size (partial RoPE). Given that the current implementation has a bug in this case, adding a test case similar to test_neox_partial_rope is highly recommended to ensure correctness and prevent regressions.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Helion/Triton-backed rotary positional embedding (RoPE) kernels and wires them into the vLLM IR op dispatch system, enabling automatic selection on CUDA-capable hardware when Helion is installed.

Changes:

  • Introduces new vLLM IR ops: rotary_embedding_neox and rotary_embedding_gptj.
  • Adds Helion kernel implementations for both Neox-style and GPT-J-style RoPE, including autotune config picking/input generation.
  • Adds H800 pre-tuned Helion configs and a dedicated Helion kernel test suite for RoPE.

Reviewed changes

Copilot reviewed 5 out of 6 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
vllm/kernels/helion/ops/rotary_embedding.py Implements Helion RoPE kernels (Neox + GPT-J) and registers them as IR op implementations.
vllm/kernels/helion/configs/_helion_rotary_embedding_neox/nvidia_h800.json Adds pre-tuned configs for the Neox Helion RoPE kernel on H800.
vllm/kernels/helion/configs/_helion_rotary_embedding_gptj/nvidia_h800.json Adds pre-tuned configs for the GPT-J Helion RoPE kernel on H800.
vllm/ir/ops/rotary_embedding.py Defines the IR op semantic contract and native PyTorch implementations for RoPE (Neox + GPT-J).
vllm/ir/ops/__init__.py Exposes the new IR ops via the vllm.ir.ops package exports.
tests/kernels/helion/test_rotary_embedding.py Adds config-picker, correctness, and registration tests for the new Helion RoPE kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +206 to +214
# Reshape to 4D so even/odd pairs are the last dim (index 0/1).
# Helion does not support tile-index arithmetic (tile*2) or strided
# slices, so this 4D view is the only way to express interleaved pairs.
# Requires head_size == rotary_dim; when rotary_dim < head_size the
# caller is expected to pre-split query/key to the rotary portion and
# handle the tail separately (the IR wrapper does this).
q_rot4 = query.view(num_tokens, num_q_heads, rot_half, 2)
k_rot4 = key.view(num_tokens, num_kv_heads, rot_half, 2)
q_out4 = torch.empty_like(q_rot4)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPT-J Helion kernel reshapes query/key to (num_tokens, n_heads, rot_half, 2), which only works when head_size == rotary_dim. However, the IR op contract (and existing models like vllm/model_executor/models/gpt_j.py) allows partial RoPE with rotary_dim < head_size. With the current registration (supports_args=None), the Helion impl can be selected for partial RoPE and will error or produce incorrect results. Fix by either implementing the partial-rotation tail handling (rotate first rotary_dim, copy the remaining head_size-rotary_dim) or registering this impl with a supports_args predicate that returns false unless rotary_dim == head_size (and any other required constraints).

Copilot uses AI. Check for mistakes.

_HELION_AVAILABLE = has_helion()
_CUDA_ALIKE = current_platform.is_cuda_alike()

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger is defined but never used, which will trigger Ruff/Pyflakes unused-variable linting (F841) under the repo’s configured checks. Either remove the logger initialization or use it (e.g., for the Helion/CUDA gating paths).

Suggested change
if not _HELION_AVAILABLE:
logger.debug(
"Helion is not available; Helion rotary embedding kernels will not "
"be registered."
)
elif not _CUDA_ALIKE:
logger.debug(
"Current platform is not CUDA-alike; Helion rotary embedding kernels "
"may not be usable."
)

Copilot uses AI. Check for mistakes.
Comment on lines +68 to +82
def _skip_if_platform_unsupported(kernel_name: str) -> None:
"""Skip the test if no pre-tuned configs are present for this GPU."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")

try:
from vllm.kernels.helion.utils import get_canonical_gpu_name

platform = get_canonical_gpu_name()
try:
config_manager = ConfigManager.get_instance()
except RuntimeError:
config_manager = ConfigManager()

configs = config_manager.get_platform_configs(kernel_name, platform)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_skip_if_platform_unsupported references ConfigManager, but ConfigManager is only imported when _HELION_AVAILABLE is true. On a CUDA machine without Helion installed, this will raise NameError (not caught) before the test can skip. Add an early guard like if not _HELION_AVAILABLE: pytest.skip(...), or import ConfigManager inside the function within the try/except.

Copilot uses AI. Check for mistakes.
Comment on lines +248 to +257
class TestRotaryEmbeddingNeoxCorrectness:
"""Compare rotary_embedding_neox against RotaryEmbedding.forward_static."""

@pytest.mark.parametrize("num_tokens", [1, 8, 32, 128])
@pytest.mark.parametrize("rotary_dim", [64, 128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_neox_vs_baseline(self, num_tokens, rotary_dim, dtype):
_skip_if_platform_unsupported("_helion_rotary_embedding_neox")

head_size = rotary_dim
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness tests call rotary_embedding_neox/gptj even when Helion is not installed (the imported symbols are stubs that raise). These tests should be skipped when _HELION_AVAILABLE is false (e.g., via pytest.mark.skipif on the classes or an early if not _HELION_AVAILABLE: pytest.skip(...) in each test) to avoid failing non-Helion test environments.

Copilot uses AI. Check for mistakes.

def test_output_is_not_inplace(self):
"""The Helion kernel must return new tensors (out-of-place)."""
_skip_if_platform_unsupported("rotary_embedding_neox")
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes kernel_name="rotary_embedding_neox" to _skip_if_platform_unsupported, but the pre-tuned configs in this PR are under the registered kernel name _helion_rotary_embedding_neox. As a result, the test will always skip even on supported GPUs. Use _NEOX_KERNEL_NAME (or the same string used by @register_kernel / config directory) here so the test actually exercises the out-of-place guarantee.

Suggested change
_skip_if_platform_unsupported("rotary_embedding_neox")
_skip_if_platform_unsupported(_NEOX_KERNEL_NAME)

Copilot uses AI. Check for mistakes.
Comment on lines +430 to +444
def test_kernel_wrappers_have_input_generator(self):
if not _HELION_AVAILABLE:
pytest.skip("Helion not installed")
from vllm.kernels.helion.register import get_registered_kernels

registered = get_registered_kernels()
for name in [_NEOX_KERNEL_NAME, _GPTJ_KERNEL_NAME]:
wrapper = registered[name]
assert wrapper._input_generator is not None, (
f"Kernel '{name}' has no input generator"
)
# Smoke-test: calling get_inputs() on CPU should not crash.
inputs = wrapper.get_inputs()
assert len(inputs) > 0

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_kernel_wrappers_have_input_generator unconditionally calls wrapper.get_inputs() and claims it should work on CPU, but _generate_rope_inputs() creates CUDA tensors. This will fail on hosts where Helion is installed but CUDA is unavailable, and it will also unnecessarily allocate GPU memory during a pure registration test. Either (a) skip this test when not torch.cuda.is_available(), or (b) change the input generator to use CPU tensors when CUDA isn’t available (and/or avoid calling get_inputs() here).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants