-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
[Model Runner V2] Add Support for XD-RoPE #36817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
85f2b8b
9d2fe4a
68e5caa
9db9d80
f4b6f6a
49c7517
9d7fe65
32155bb
7d790f9
19afa5f
4f3b6fb
728f40c
3dbfe9f
1ba5404
1209251
8215e71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,196 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
| from typing import cast | ||||||
|
|
||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
|
|
||||||
| from vllm.config import ModelConfig | ||||||
| from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsXDRoPE | ||||||
| from vllm.triton_utils import tl, triton | ||||||
| from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor | ||||||
|
|
||||||
|
|
||||||
| class RopeState: | ||||||
| """Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE). | ||||||
|
|
||||||
| M-RoPE: 3 dims, uses position delta for decode. | ||||||
| XD-RoPE: 3 or 4 dims, delta is 0 (decode uses orig_pos for all dims). | ||||||
|
|
||||||
| NOTE: `positions` is implemented with one additional dummy position on | ||||||
| purpose to make it non-contiguous so that it can work with torch compile. | ||||||
| See detailed explanation in | ||||||
| https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 | ||||||
|
|
||||||
| NOTE: When M-RoPE is enabled, position ids are 3D regardless of the | ||||||
| modality of inputs. For text-only inputs, each dimension has identical | ||||||
| position IDs, making M-RoPE functionally equivalent to 1D-RoPE. | ||||||
| See page 5 of https://arxiv.org/abs/2409.12191 | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| num_dims: int, | ||||||
| has_delta: bool, | ||||||
| max_num_reqs: int, | ||||||
| max_num_tokens: int, | ||||||
| max_model_len: int, | ||||||
| device: torch.device, | ||||||
| ): | ||||||
| self.num_dims = num_dims | ||||||
| self.has_delta = has_delta | ||||||
| self.max_num_reqs = max_num_reqs | ||||||
| self.max_num_tokens = max_num_tokens | ||||||
| self.max_model_len = max_model_len | ||||||
| self.device = device | ||||||
|
|
||||||
| # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) | ||||||
| # wasting a lot of CPU memory. | ||||||
| self.prefill_positions = StagedWriteTensor( | ||||||
| (max_num_reqs * num_dims, max_model_len), | ||||||
| dtype=torch.int32, | ||||||
| device=device, | ||||||
| uva_instead_of_gpu=True, | ||||||
| ) | ||||||
| self.positions = torch.zeros( | ||||||
| (num_dims, max_num_tokens + 1), dtype=torch.int64, device=device | ||||||
| ) | ||||||
|
|
||||||
| # Delta is non-zero for M-RoPE, always 0 for XD-RoPE. | ||||||
| self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) | ||||||
|
|
||||||
| def init_prefill_positions( | ||||||
| self, | ||||||
| req_idx: int, | ||||||
| model: nn.Module, | ||||||
| prefill_token_ids: list[int], | ||||||
| mm_features: list, | ||||||
| ) -> None: | ||||||
| if self.has_delta: | ||||||
| mrope_model = cast(SupportsMRoPE, model) | ||||||
| prefill_positions, delta = mrope_model.get_mrope_input_positions( | ||||||
| prefill_token_ids, mm_features | ||||||
| ) | ||||||
| self.prefill_delta.np[req_idx] = delta | ||||||
| else: | ||||||
| xdrope_model = cast(SupportsXDRoPE, model) | ||||||
| prefill_positions = xdrope_model.get_xdrope_input_positions( | ||||||
| prefill_token_ids, mm_features | ||||||
| ) | ||||||
|
|
||||||
| for i in range(self.num_dims): | ||||||
| pos = prefill_positions[i].tolist() | ||||||
| self.prefill_positions.stage_write(self.num_dims * req_idx + i, 0, pos) | ||||||
|
|
||||||
| def apply_staged_writes(self) -> None: | ||||||
| self.prefill_positions.apply_write() | ||||||
| self.prefill_delta.copy_to_uva() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably still worth has_delta (or is_mrope) check here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right -- for xdrope the UVA tensor is zero-initialized and we never write to it, so this copy is unnecessary |
||||||
|
|
||||||
| def get_positions(self, num_tokens: int) -> torch.Tensor: | ||||||
| return self.positions[:, :num_tokens] | ||||||
|
|
||||||
| def prepare_positions( | ||||||
| self, | ||||||
| idx_mapping: torch.Tensor, | ||||||
| query_start_loc: torch.Tensor, | ||||||
| prefill_lens: torch.Tensor, | ||||||
| num_computed_tokens: torch.Tensor, | ||||||
| ) -> None: | ||||||
| num_reqs = idx_mapping.shape[0] | ||||||
| _prepare_rope_positions_kernel[(num_reqs,)]( | ||||||
| self.positions, | ||||||
| self.positions.stride(0), | ||||||
| self.prefill_positions.gpu, | ||||||
| self.num_dims * self.max_model_len, | ||||||
| self.max_model_len, | ||||||
| self.prefill_delta.gpu, | ||||||
| idx_mapping, | ||||||
| query_start_loc, | ||||||
| prefill_lens, | ||||||
| num_computed_tokens, | ||||||
| BLOCK_SIZE=1024, | ||||||
| NUM_DIMS=self.num_dims, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def get_rope_state( | ||||||
| model_config: ModelConfig, | ||||||
| model: nn.Module, | ||||||
| max_num_reqs: int, | ||||||
| max_num_tokens: int, | ||||||
| max_model_len: int, | ||||||
| device: torch.device, | ||||||
| ) -> RopeState | None: | ||||||
| """Create a RopeState if the model uses multi-dimensional RoPE.""" | ||||||
| if model_config.uses_mrope: | ||||||
| assert isinstance(model, SupportsMRoPE) | ||||||
| return RopeState( | ||||||
| num_dims=3, | ||||||
| has_delta=True, | ||||||
| max_num_reqs=max_num_reqs, | ||||||
| max_num_tokens=max_num_tokens, | ||||||
| max_model_len=max_model_len, | ||||||
| device=device, | ||||||
| ) | ||||||
| elif model_config.uses_xdrope_dim > 0: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mutually exclusive but i forgot i added early return here :')
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's not incorrect maybe more a code style preference :) |
||||||
| assert isinstance(model, SupportsXDRoPE) | ||||||
| return RopeState( | ||||||
| num_dims=model_config.uses_xdrope_dim, | ||||||
| has_delta=False, | ||||||
| max_num_reqs=max_num_reqs, | ||||||
| max_num_tokens=max_num_tokens, | ||||||
| max_model_len=max_model_len, | ||||||
| device=device, | ||||||
| ) | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| @triton.jit | ||||||
| def _prepare_rope_positions_kernel( | ||||||
| positions_ptr, | ||||||
| positions_stride, | ||||||
| prefill_positions_ptr, | ||||||
| prefill_positions_stride0, | ||||||
| prefill_positions_stride1, | ||||||
| prefill_delta_ptr, | ||||||
| idx_mapping_ptr, | ||||||
| query_start_loc_ptr, | ||||||
| prefill_lens_ptr, | ||||||
| num_computed_tokens_ptr, | ||||||
| BLOCK_SIZE: tl.constexpr, | ||||||
| NUM_DIMS: tl.constexpr, | ||||||
| ): | ||||||
| batch_idx = tl.program_id(0) | ||||||
| req_state_idx = tl.load(idx_mapping_ptr + batch_idx) | ||||||
|
|
||||||
| prefill_len = tl.load(prefill_lens_ptr + req_state_idx) | ||||||
| num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) | ||||||
| is_prefill = num_computed < prefill_len | ||||||
|
|
||||||
| query_start = tl.load(query_start_loc_ptr + batch_idx) | ||||||
| query_end = tl.load(query_start_loc_ptr + batch_idx + 1) | ||||||
| query_len = query_end - query_start | ||||||
|
|
||||||
| delta = tl.load(prefill_delta_ptr + req_state_idx) | ||||||
|
|
||||||
| for i in range(0, query_len, BLOCK_SIZE): | ||||||
| block = i + tl.arange(0, BLOCK_SIZE) | ||||||
| mask = block < query_len | ||||||
| orig_pos = num_computed + block | ||||||
|
|
||||||
| for j in tl.static_range(NUM_DIMS): | ||||||
| if is_prefill: | ||||||
| pos = tl.load( | ||||||
| prefill_positions_ptr | ||||||
| + req_state_idx * prefill_positions_stride0 | ||||||
| + j * prefill_positions_stride1 | ||||||
| + orig_pos, | ||||||
| mask=mask, | ||||||
| ) | ||||||
| else: | ||||||
| pos = orig_pos + delta | ||||||
| tl.store( | ||||||
| positions_ptr + j * positions_stride + query_start + block, | ||||||
| pos, | ||||||
| mask=mask, | ||||||
| ) | ||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -992,6 +992,7 @@ def execute_model( | |||||||
| "input_ids": input_batch.input_ids, | ||||||||
| "positions": input_batch.positions, | ||||||||
| "inputs_embeds": inputs_embeds, | ||||||||
| "intermediate_tensors": intermediate_tensors, | ||||||||
| # NOTE: Values returned by `prepare_inputs` will override the default | ||||||||
| # values above. | ||||||||
| **self.model_state.prepare_inputs(input_batch, self.req_states), | ||||||||
|
|
@@ -1000,7 +1001,7 @@ def execute_model( | |||||||
| # Update for non-first PP ranks. | ||||||||
| model_inputs["input_ids"] = None | ||||||||
| model_inputs["inputs_embeds"] = None | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could consider adding this for clarity
Suggested change
|
||||||||
| model_inputs["intermediate_tensors"] = intermediate_tensors | ||||||||
| assert intermediate_tensors is not None | ||||||||
|
|
||||||||
| # Run model. | ||||||||
| if batch_desc.cg_mode == CUDAGraphMode.FULL: | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this should be renamed to e.g.
is_mropenow?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
get_rope_state()already handles M-RoPE vs XD-RoPE dispatch, has_delta feels more natural here — it describes the behavioral difference rather than leaking model-type identity into the unified RopeState (ie. mrope and xdrope are instances of RopeState). Wdyt?