[Model Runner V2] Add Support for XD-RoPE#36817
Conversation
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for XD-RoPE. The changes include a new XDRopeState class with a Triton kernel to manage XD-RoPE positions, and integration into the DefaultModelState. The core logic seems correct, but I've identified an opportunity to improve the clarity and maintainability of the new Triton kernel by refactoring its indexing logic to be more conventional. The rest of the changes for integrating XD-RoPE support are well-structured.
| def prepare_xdrope_positions( | ||
| self, | ||
| idx_mapping: torch.Tensor, | ||
| query_start_loc: torch.Tensor, | ||
| prefill_lens: torch.Tensor, | ||
| num_computed_tokens: torch.Tensor, | ||
| ) -> None: | ||
| num_reqs = idx_mapping.shape[0] | ||
| _prepare_xdrope_positions_kernel[(num_reqs,)]( | ||
| self.xdrope_positions, | ||
| self.xdrope_positions.stride(0), | ||
| self.prefill_xdrope_positions.gpu, | ||
| self.uses_xdrope_dim * self.max_model_len, | ||
| self.max_model_len, | ||
| idx_mapping, | ||
| query_start_loc, | ||
| prefill_lens, | ||
| num_computed_tokens, | ||
| BLOCK_SIZE=1024, | ||
| USES_XDROPE_DIM=self.uses_xdrope_dim, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _prepare_xdrope_positions_kernel( | ||
| xdrope_positions_ptr, | ||
| xdrope_positions_stride, | ||
| prefill_xdrope_positions_ptr, | ||
| prefill_xdrope_positions_stride0, | ||
| prefill_xdrope_positions_stride1, | ||
| idx_mapping_ptr, | ||
| query_start_loc_ptr, | ||
| prefill_lens_ptr, | ||
| num_computed_tokens_ptr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| USES_XDROPE_DIM: tl.constexpr, | ||
| ): | ||
| batch_idx = tl.program_id(0) | ||
| req_state_idx = tl.load(idx_mapping_ptr + batch_idx) | ||
|
|
||
| prefill_len = tl.load(prefill_lens_ptr + req_state_idx) | ||
| num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) | ||
| is_prefill = num_computed < prefill_len | ||
|
|
||
| query_start = tl.load(query_start_loc_ptr + batch_idx) | ||
| query_end = tl.load(query_start_loc_ptr + batch_idx + 1) | ||
| query_len = query_end - query_start | ||
|
|
||
| for i in range(0, query_len, BLOCK_SIZE): | ||
| block = i + tl.arange(0, BLOCK_SIZE) | ||
| mask = block < query_len | ||
| orig_pos = num_computed + block | ||
|
|
||
| for j in tl.static_range(USES_XDROPE_DIM): | ||
| if is_prefill: | ||
| # Read from pre-computed XD-RoPE positions. | ||
| pos = tl.load( | ||
| prefill_xdrope_positions_ptr | ||
| + req_state_idx * prefill_xdrope_positions_stride0 | ||
| + j * prefill_xdrope_positions_stride1 | ||
| + orig_pos, | ||
| mask=mask, | ||
| ) | ||
| else: | ||
| pos = orig_pos | ||
| tl.store( | ||
| xdrope_positions_ptr | ||
| + j * xdrope_positions_stride | ||
| + query_start | ||
| + block, | ||
| pos, | ||
| mask=mask, | ||
| ) |
There was a problem hiding this comment.
The current implementation of _prepare_xdrope_positions_kernel and its call site in prepare_xdrope_positions is functionally correct but hard to understand due to unconventional argument passing for tensor indexing. The arguments prefill_xdrope_positions_stride0 and prefill_xdrope_positions_stride1 are not standard tensor strides, making the address calculation logic within the kernel confusing.
For better readability and maintainability, I suggest refactoring to use standard stride-based indexing. This involves passing the actual stride of the prefill_xdrope_positions tensor and adjusting the kernel to calculate the row index explicitly. This change does not alter the logic but makes the code much clearer and aligned with common practices for writing Triton kernels.
def prepare_xdrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
prefill_positions = self.prefill_xdrope_positions.gpu
_prepare_xdrope_positions_kernel[(num_reqs,)](
self.xdrope_positions,
self.xdrope_positions.stride(0),
prefill_positions,
prefill_positions.stride(0),
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
USES_XDROPE_DIM=self.uses_xdrope_dim,
)
@triton.jit
def _prepare_xdrope_positions_kernel(
xdrope_positions_ptr,
xdrope_positions_stride,
prefill_xdrope_positions_ptr,
prefill_xdrope_positions_stride0,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
USES_XDROPE_DIM: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
is_prefill = num_computed < prefill_len
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
orig_pos = num_computed + block
for j in tl.static_range(USES_XDROPE_DIM):
if is_prefill:
# Read from pre-computed XD-RoPE positions.
row_idx = req_state_idx * USES_XDROPE_DIM + j
pos = tl.load(
prefill_xdrope_positions_ptr
+ row_idx * prefill_xdrope_positions_stride0
+ orig_pos,
mask=mask,
)
else:
pos = orig_pos
tl.store(
xdrope_positions_ptr
+ j * xdrope_positions_stride
+ query_start
+ block,
pos,
mask=mask,
)Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
| :, : input_batch.num_tokens_after_padding | ||
| ] | ||
| return {"positions": mrope_positions} | ||
| if self.uses_mrope: |
There was a problem hiding this comment.
May be good to keep common case at the top, i.e.
if not self.uses_mrope and not self.uses_xdrope:
# Common case (1D positions).
return {}
if self.uses_mrope:
# ...
return ...
# xdrope logic
return ...There was a problem hiding this comment.
+1. I intentionally put the early exit for the common case at the top. It'd be nice to keep it!
There was a problem hiding this comment.
Early exit makes sense. updated.
| max_model_len=self.max_model_len, | ||
| device=self.device, | ||
| ) | ||
| self.uses_xdrope_dim = self.model_config.uses_xdrope_dim |
There was a problem hiding this comment.
This can be a bool
| self.uses_xdrope_dim = self.model_config.uses_xdrope_dim | |
| self.uses_xdrope = self.model_config.uses_xdrope_dim > 0 |
I actually like using self.xdrope_state = XDRopeState | None for this rather than separate bool, but I don't think @WoosukKwon likes that :)
Or, I feel it might be better/clearer to instead have something like a pe_type enum (MROPE, XDROPE or None)
There was a problem hiding this comment.
+1. I'm also ok with if self.xdrope_state is not None.
There was a problem hiding this comment.
I will change to using if self.xdrope_state is not None. From what I can tell, these two PE variants are the only ones getting special treatment in model runner. Should we wait to use enums once we have more PE's to support? It already feels like a good amount of duplication in mrope/xdrope that can probably get consolidated once we have more of them to worry about?
There was a problem hiding this comment.
sure, I just thought it's maybe clearer to have a single enum than two bools which can't both be true
| @@ -960,7 +961,6 @@ def execute_model( | |||
| # Update for non-first PP ranks. | |||
| model_inputs["input_ids"] = None | |||
| model_inputs["inputs_embeds"] = None | |||
There was a problem hiding this comment.
we could consider adding this for clarity
| model_inputs["inputs_embeds"] = None | |
| model_inputs["inputs_embeds"] = None | |
| assert intermediate_tensors is not None |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
|
Hi @santiramos27, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor | ||
|
|
||
|
|
||
| class XDRopeState: |
There was a problem hiding this comment.
Looking at the code now ... it would simplify things to have a shared RopeState interface with init_prefill_positions, prepare_positions and get_positions(num_tokens) methods
There was a problem hiding this comment.
Done. I followed the interface pattern under model_states/ and did a slight renaming for clarity. Let me know what you think.
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
njhill
left a comment
There was a problem hiding this comment.
Thanks @santiramos27 it looks good to me now.
I can see that there's a natural generalization of both of these, where the mrope uses dim 3 and xdrope uses bias 0. We could have a single kernel and other state logic used by both which would eliminate more of the code. But I have a feeling that @WoosukKwon won't like that idea :-)
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
|
Thanks for the feedback @njhill! I incorporated your edits. @WoosukKwon wdyt about consolidating the |
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
|
@njhill Updated with the unified position prep triton kernel -- this feels cleaner to me than the redundant util files + interface. |
njhill
left a comment
There was a problem hiding this comment.
Thanks @santiramos27 this looks great. Just have one more suggestion.
njhill
left a comment
There was a problem hiding this comment.
Thanks @santiramos27 just couple of last minor things!
| prefill_token_ids: list[int], | ||
| mm_features: list, | ||
| ) -> None: | ||
| if self.has_delta: |
There was a problem hiding this comment.
I guess this should be renamed to e.g. is_mrope now?
There was a problem hiding this comment.
Since get_rope_state() already handles M-RoPE vs XD-RoPE dispatch, has_delta feels more natural here — it describes the behavioral difference rather than leaking model-type identity into the unified RopeState (ie. mrope and xdrope are instances of RopeState). Wdyt?
vllm/v1/worker/gpu/mm/rope.py
Outdated
|
|
||
| def apply_staged_writes(self) -> None: | ||
| self.prefill_positions.apply_write() | ||
| self.prefill_delta.copy_to_uva() |
There was a problem hiding this comment.
probably still worth has_delta (or is_mrope) check here?
There was a problem hiding this comment.
You're right -- for xdrope the UVA tensor is zero-initialized and we never write to it, so this copy is unnecessary
vllm/v1/worker/gpu/mm/rope.py
Outdated
| max_model_len=max_model_len, | ||
| device=device, | ||
| ) | ||
| elif model_config.uses_xdrope_dim > 0: |
There was a problem hiding this comment.
nit
| elif model_config.uses_xdrope_dim > 0: | |
| if model_config.uses_xdrope_dim > 0: |
There was a problem hiding this comment.
mutually exclusive but i forgot i added early return here :')
There was a problem hiding this comment.
yeah it's not incorrect maybe more a code style preference :)
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
njhill
left a comment
There was a problem hiding this comment.
Thanks @santiramos27 for the nice work!
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Summary
HunyuanOCRthat use 3 or 4-dim rotary embeddingsNUM_DIMS: tl.constexpr), replacing the previous M-RoPE-only MRopeStateChanges
vllm/v1/worker/gpu/mm/rope.py(new): Unified RopeState class and get_rope_state() factory. The factory validates the model interface (SupportsMRoPE / SupportsXDRoPE) and dispatches based onModelConfig
vllm/v1/worker/gpu/mm/mrope_utils.py(deleted): Replaced by rope.pyvllm/v1/worker/gpu/model_states/default.py: Simplified to use get_rope_state() — all M-RoPE/XD-RoPE branching now lives in the factoryvllm/v1/worker/gpu/model_runner.py: Passintermediate_tensorsinmodel_inputsdict directly; non-first PP rank asserts it's not Nonevllm/v1/worker/gpu/cudagraph_utils.py: Addintermediate_tensors: Nonefor CUDA graph capture with TODO for PP support ([Model Runner V2] Enable piecewise & full CUDA graphs for pipeline parallelism #35162)Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.