Skip to content

[Model Runner V2] Add Support for XD-RoPE#36817

Merged
njhill merged 16 commits intovllm-project:mainfrom
santiramos27:santinor/mrv2-xdrope
Mar 14, 2026
Merged

[Model Runner V2] Add Support for XD-RoPE#36817
njhill merged 16 commits intovllm-project:mainfrom
santiramos27:santinor/mrv2-xdrope

Conversation

@santiramos27
Copy link
Copy Markdown
Contributor

@santiramos27 santiramos27 commented Mar 11, 2026

Summary

  • Adds XD-RoPE (variable-dimension RoPE) support to Model Runner V2, enabling models like HunyuanOCR that use 3 or 4-dim rotary embeddings
  • Unifies M-RoPE and XD-RoPE into a single RopeState class with one parameterized Triton kernel (NUM_DIMS: tl.constexpr), replacing the previous M-RoPE-only MRopeState
  • The kernel always loads and applies delta (non-zero for M-RoPE, zero for XD-RoPE), avoiding a separate code path per variant

Changes

  • vllm/v1/worker/gpu/mm/rope.py (new): Unified RopeState class and get_rope_state() factory. The factory validates the model interface (SupportsMRoPE / SupportsXDRoPE) and dispatches based on
    ModelConfig
  • vllm/v1/worker/gpu/mm/mrope_utils.py (deleted): Replaced by rope.py
  • vllm/v1/worker/gpu/model_states/default.py: Simplified to use get_rope_state() — all M-RoPE/XD-RoPE branching now lives in the factory
  • vllm/v1/worker/gpu/model_runner.py: Pass intermediate_tensors in model_inputs dict directly; non-first PP rank asserts it's not None
  • vllm/v1/worker/gpu/cudagraph_utils.py: Add intermediate_tensors: None for CUDA graph capture with TODO for PP support ([Model Runner V2] Enable piecewise & full CUDA graphs for pipeline parallelism #35162)

Test Result

  • Verified numerical equivalence between V1 and V2 model runners for both XD-RoPE (HunyuanOCR with image input) and M-RoPE (Qwen2-VL-2B-Instruct with image input)
  • Identical OCRBench result on V1 vs V2 model runner for HunyuanOCR and Qwen2-VL-2B-Instruct

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@mergify mergify bot added the v1 label Mar 11, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for XD-RoPE. The changes include a new XDRopeState class with a Triton kernel to manage XD-RoPE positions, and integration into the DefaultModelState. The core logic seems correct, but I've identified an opportunity to improve the clarity and maintainability of the new Triton kernel by refactoring its indexing logic to be more conventional. The rest of the changes for integrating XD-RoPE support are well-structured.

Comment on lines +56 to +128
def prepare_xdrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_xdrope_positions_kernel[(num_reqs,)](
self.xdrope_positions,
self.xdrope_positions.stride(0),
self.prefill_xdrope_positions.gpu,
self.uses_xdrope_dim * self.max_model_len,
self.max_model_len,
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
USES_XDROPE_DIM=self.uses_xdrope_dim,
)


@triton.jit
def _prepare_xdrope_positions_kernel(
xdrope_positions_ptr,
xdrope_positions_stride,
prefill_xdrope_positions_ptr,
prefill_xdrope_positions_stride0,
prefill_xdrope_positions_stride1,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
USES_XDROPE_DIM: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
is_prefill = num_computed < prefill_len

query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start

for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
orig_pos = num_computed + block

for j in tl.static_range(USES_XDROPE_DIM):
if is_prefill:
# Read from pre-computed XD-RoPE positions.
pos = tl.load(
prefill_xdrope_positions_ptr
+ req_state_idx * prefill_xdrope_positions_stride0
+ j * prefill_xdrope_positions_stride1
+ orig_pos,
mask=mask,
)
else:
pos = orig_pos
tl.store(
xdrope_positions_ptr
+ j * xdrope_positions_stride
+ query_start
+ block,
pos,
mask=mask,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _prepare_xdrope_positions_kernel and its call site in prepare_xdrope_positions is functionally correct but hard to understand due to unconventional argument passing for tensor indexing. The arguments prefill_xdrope_positions_stride0 and prefill_xdrope_positions_stride1 are not standard tensor strides, making the address calculation logic within the kernel confusing.

For better readability and maintainability, I suggest refactoring to use standard stride-based indexing. This involves passing the actual stride of the prefill_xdrope_positions tensor and adjusting the kernel to calculate the row index explicitly. This change does not alter the logic but makes the code much clearer and aligned with common practices for writing Triton kernels.

    def prepare_xdrope_positions(
        self,
        idx_mapping: torch.Tensor,
        query_start_loc: torch.Tensor,
        prefill_lens: torch.Tensor,
        num_computed_tokens: torch.Tensor,
    ) -> None:
        num_reqs = idx_mapping.shape[0]
        prefill_positions = self.prefill_xdrope_positions.gpu
        _prepare_xdrope_positions_kernel[(num_reqs,)](
            self.xdrope_positions,
            self.xdrope_positions.stride(0),
            prefill_positions,
            prefill_positions.stride(0),
            idx_mapping,
            query_start_loc,
            prefill_lens,
            num_computed_tokens,
            BLOCK_SIZE=1024,
            USES_XDROPE_DIM=self.uses_xdrope_dim,
        )


@triton.jit
def _prepare_xdrope_positions_kernel(
    xdrope_positions_ptr,
    xdrope_positions_stride,
    prefill_xdrope_positions_ptr,
    prefill_xdrope_positions_stride0,
    idx_mapping_ptr,
    query_start_loc_ptr,
    prefill_lens_ptr,
    num_computed_tokens_ptr,
    BLOCK_SIZE: tl.constexpr,
    USES_XDROPE_DIM: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

    prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
    num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
    is_prefill = num_computed < prefill_len

    query_start = tl.load(query_start_loc_ptr + batch_idx)
    query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
    query_len = query_end - query_start

    for i in range(0, query_len, BLOCK_SIZE):
        block = i + tl.arange(0, BLOCK_SIZE)
        mask = block < query_len
        orig_pos = num_computed + block

        for j in tl.static_range(USES_XDROPE_DIM):
            if is_prefill:
                # Read from pre-computed XD-RoPE positions.
                row_idx = req_state_idx * USES_XDROPE_DIM + j
                pos = tl.load(
                    prefill_xdrope_positions_ptr
                    + row_idx * prefill_xdrope_positions_stride0
                    + orig_pos,
                    mask=mask,
                )
            else:
                pos = orig_pos
            tl.store(
                xdrope_positions_ptr
                + j * xdrope_positions_stride
                + query_start
                + block,
                pos,
                mask=mask,
            )

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@mergify mergify bot added the nvidia label Mar 11, 2026
@santiramos27 santiramos27 marked this pull request as ready for review March 11, 2026 20:56
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @santiramos27!

:, : input_batch.num_tokens_after_padding
]
return {"positions": mrope_positions}
if self.uses_mrope:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be good to keep common case at the top, i.e.

if not self.uses_mrope and not self.uses_xdrope:
    # Common case (1D positions).
    return {}

if self.uses_mrope:
    # ...
    return ...

# xdrope logic
return ...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I intentionally put the early exit for the common case at the top. It'd be nice to keep it!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early exit makes sense. updated.

max_model_len=self.max_model_len,
device=self.device,
)
self.uses_xdrope_dim = self.model_config.uses_xdrope_dim
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a bool

Suggested change
self.uses_xdrope_dim = self.model_config.uses_xdrope_dim
self.uses_xdrope = self.model_config.uses_xdrope_dim > 0

I actually like using self.xdrope_state = XDRopeState | None for this rather than separate bool, but I don't think @WoosukKwon likes that :)

Or, I feel it might be better/clearer to instead have something like a pe_type enum (MROPE, XDROPE or None)

Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I'm also ok with if self.xdrope_state is not None.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change to using if self.xdrope_state is not None. From what I can tell, these two PE variants are the only ones getting special treatment in model runner. Should we wait to use enums once we have more PE's to support? It already feels like a good amount of duplication in mrope/xdrope that can probably get consolidated once we have more of them to worry about?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I just thought it's maybe clearer to have a single enum than two bools which can't both be true

@@ -960,7 +961,6 @@ def execute_model(
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could consider adding this for clarity

Suggested change
model_inputs["inputs_embeds"] = None
model_inputs["inputs_embeds"] = None
assert intermediate_tensors is not None

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @santiramos27.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 11, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@mergify mergify bot removed the needs-rebase label Mar 11, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

Hi @santiramos27, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@santiramos27 santiramos27 requested a review from njhill March 11, 2026 22:37
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor


class XDRopeState:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code now ... it would simplify things to have a shared RopeState interface with init_prefill_positions, prepare_positions and get_positions(num_tokens) methods

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I followed the interface pattern under model_states/ and did a slight renaming for clarity. Let me know what you think.

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @santiramos27 it looks good to me now.

I can see that there's a natural generalization of both of these, where the mrope uses dim 3 and xdrope uses bias 0. We could have a single kernel and other state logic used by both which would eliminate more of the code. But I have a feeling that @WoosukKwon won't like that idea :-)

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@santiramos27
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback @njhill! I incorporated your edits.

@WoosukKwon wdyt about consolidating the /mm/rope triton kernels?

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
@santiramos27
Copy link
Copy Markdown
Contributor Author

@njhill Updated with the unified position prep triton kernel -- this feels cleaner to me than the redundant util files + interface.

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @santiramos27 this looks great. Just have one more suggestion.

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @santiramos27 just couple of last minor things!

prefill_token_ids: list[int],
mm_features: list,
) -> None:
if self.has_delta:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be renamed to e.g. is_mrope now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since get_rope_state() already handles M-RoPE vs XD-RoPE dispatch, has_delta feels more natural here — it describes the behavioral difference rather than leaking model-type identity into the unified RopeState (ie. mrope and xdrope are instances of RopeState). Wdyt?


def apply_staged_writes(self) -> None:
self.prefill_positions.apply_write()
self.prefill_delta.copy_to_uva()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably still worth has_delta (or is_mrope) check here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right -- for xdrope the UVA tensor is zero-initialized and we never write to it, so this copy is unnecessary

max_model_len=max_model_len,
device=device,
)
elif model_config.uses_xdrope_dim > 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
elif model_config.uses_xdrope_dim > 0:
if model_config.uses_xdrope_dim > 0:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutually exclusive but i forgot i added early return here :')

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's not incorrect maybe more a code style preference :)

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @santiramos27 for the nice work!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 14, 2026
@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 14, 2026
@njhill njhill merged commit 3ed46f3 into vllm-project:main Mar 14, 2026
50 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 14, 2026
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants