From 85f2b8b9a665228a01640a1261e383692682e9a0 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 18:57:40 +0000 Subject: [PATCH 01/12] wip Signed-off-by: Santino Ramos --- vllm/model_executor/models/hunyuan_vision.py | 4 +- vllm/v1/worker/gpu/mm/xdrope_utils.py | 128 +++++++++++++++++++ vllm/v1/worker/gpu/model_states/default.py | 62 ++++++--- 3 files changed, 177 insertions(+), 17 deletions(-) create mode 100644 vllm/v1/worker/gpu/mm/xdrope_utils.py diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index b6fda25ddfbb..861ff517682d 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -999,8 +999,8 @@ def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None, - inputs_embeds: torch.Tensor | None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: diff --git a/vllm/v1/worker/gpu/mm/xdrope_utils.py b/vllm/v1/worker/gpu/mm/xdrope_utils.py new file mode 100644 index 000000000000..f753064ff5f3 --- /dev/null +++ b/vllm/v1/worker/gpu/mm/xdrope_utils.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.model_executor.models.interfaces import SupportsXDRoPE +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor + + +class XDRopeState: + def __init__( + self, + uses_xdrope_dim: int, + max_num_reqs: int, + max_num_tokens: int, + max_model_len: int, + device: torch.device, + ): + self.uses_xdrope_dim = uses_xdrope_dim + self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens + self.max_model_len = max_model_len + self.device = device + + # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) + # wasting a lot of CPU memory. + self.prefill_xdrope_positions = StagedWriteTensor( + (max_num_reqs * uses_xdrope_dim, max_model_len), + dtype=torch.int32, + device=device, + uva_instead_of_gpu=True, + ) + self.xdrope_positions = torch.zeros( + (uses_xdrope_dim, max_num_tokens + 1), dtype=torch.int64, device=device + ) + + def init_prefill_xdrope_positions( + self, + req_idx: int, + xdrope_model: SupportsXDRoPE, + prefill_token_ids: list[int], + mm_features: list, + ) -> None: + prefill_xdrope_positions = xdrope_model.get_xdrope_input_positions( + prefill_token_ids, mm_features + ) + for i in range(self.uses_xdrope_dim): + pos = prefill_xdrope_positions[i].tolist() + self.prefill_xdrope_positions.stage_write( + self.uses_xdrope_dim * req_idx + i, 0, pos + ) + + def apply_staged_writes(self) -> None: + self.prefill_xdrope_positions.apply_write() + + def prepare_xdrope_positions( + self, + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + prefill_lens: torch.Tensor, + num_computed_tokens: torch.Tensor, + ) -> None: + num_reqs = idx_mapping.shape[0] + _prepare_xdrope_positions_kernel[(num_reqs,)]( + self.xdrope_positions, + self.xdrope_positions.stride(0), + self.prefill_xdrope_positions.gpu, + self.uses_xdrope_dim * self.max_model_len, + self.max_model_len, + idx_mapping, + query_start_loc, + prefill_lens, + num_computed_tokens, + BLOCK_SIZE=1024, + USES_XDROPE_DIM=self.uses_xdrope_dim, + ) + + +@triton.jit +def _prepare_xdrope_positions_kernel( + xdrope_positions_ptr, + xdrope_positions_stride, + prefill_xdrope_positions_ptr, + prefill_xdrope_positions_stride0, + prefill_xdrope_positions_stride1, + idx_mapping_ptr, + query_start_loc_ptr, + prefill_lens_ptr, + num_computed_tokens_ptr, + BLOCK_SIZE: tl.constexpr, + USES_XDROPE_DIM: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + prefill_len = tl.load(prefill_lens_ptr + req_state_idx) + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + is_prefill = num_computed < prefill_len + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + orig_pos = num_computed + block + + for j in tl.static_range(USES_XDROPE_DIM): + if is_prefill: + # Read from pre-computed XD-RoPE positions. + pos = tl.load( + prefill_xdrope_positions_ptr + + req_state_idx * prefill_xdrope_positions_stride0 + + j * prefill_xdrope_positions_stride1 + + orig_pos, + mask=mask, + ) + else: + pos = orig_pos + tl.store( + xdrope_positions_ptr + + j * xdrope_positions_stride + + query_start + + block, + pos, + mask=mask, + ) diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index e27916b40663..cbc789747249 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -13,6 +13,7 @@ from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState +from vllm.v1.worker.gpu.mm.xdrope_utils import XDRopeState from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup @@ -59,6 +60,15 @@ def __init__( max_model_len=self.max_model_len, device=self.device, ) + self.uses_xdrope_dim = self.model_config.uses_xdrope_dim + if self.uses_xdrope_dim > 0: + self.xdrope_state = XDRopeState( + uses_xdrope_dim=self.uses_xdrope_dim, + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + ) def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: if self.uses_mrope: @@ -70,10 +80,19 @@ def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: new_req_data.prefill_token_ids, mm_features=new_req_data.mm_features, ) + elif self.uses_xdrope_dim > 0: + self.xdrope_state.init_prefill_xdrope_positions( + req_index, + self.model, + new_req_data.prefill_token_ids, + mm_features=new_req_data.mm_features, + ) def apply_staged_writes(self) -> None: if self.uses_mrope: self.mrope_state.apply_staged_writes() + elif self.uses_xdrope_dim > 0: + self.xdrope_state.apply_staged_writes() def get_mm_embeddings( self, @@ -106,21 +125,31 @@ def get_mm_embeddings( def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState ) -> dict[str, torch.Tensor | None]: - if not self.uses_mrope: - # Common case (1D positions). - return {} - - # Prepare M-RoPE positions. - self.mrope_state.prepare_mrope_positions( - input_batch.idx_mapping, - input_batch.query_start_loc, - req_states.prefill_len.gpu, - req_states.num_computed_tokens.gpu, - ) - mrope_positions = self.mrope_state.mrope_positions[ - :, : input_batch.num_tokens_after_padding - ] - return {"positions": mrope_positions} + if self.uses_mrope: + # Prepare M-RoPE positions. + self.mrope_state.prepare_mrope_positions( + input_batch.idx_mapping, + input_batch.query_start_loc, + req_states.prefill_len.gpu, + req_states.num_computed_tokens.gpu, + ) + mrope_positions = self.mrope_state.mrope_positions[ + :, : input_batch.num_tokens_after_padding + ] + return {"positions": mrope_positions} + elif self.uses_xdrope_dim > 0: + # Prepare XD-RoPE positions. + self.xdrope_state.prepare_xdrope_positions( + input_batch.idx_mapping, + input_batch.query_start_loc, + req_states.prefill_len.gpu, + req_states.num_computed_tokens.gpu, + ) + xdrope_positions = self.xdrope_state.xdrope_positions[ + :, : input_batch.num_tokens_after_padding + ] + return {"positions": xdrope_positions} + return {} # Common case (1D positions). def prepare_dummy_inputs( self, num_reqs: int, num_tokens: int @@ -132,6 +161,9 @@ def prepare_dummy_inputs( if self.uses_mrope: mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] model_inputs["positions"] = mrope_positions + elif self.uses_xdrope_dim > 0: + xdrope_positions = self.xdrope_state.xdrope_positions[:, :num_tokens] + model_inputs["positions"] = xdrope_positions return model_inputs def prepare_attn( From 68e5caa5b8cdca5c3a1299aa4b34a8f54f6d775c Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 20:06:51 +0000 Subject: [PATCH 02/12] add type ignore and comment Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/model_states/default.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 50717f3c130e..9bd75fba30d2 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -82,9 +82,11 @@ def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: mm_features=new_req_data.mm_features, ) elif self.uses_xdrope_dim > 0: + # Pre-compute XD-RoPE positions for prefill. + assert new_req_data.prefill_token_ids is not None self.xdrope_state.init_prefill_xdrope_positions( req_index, - self.model, + self.model, # type: ignore new_req_data.prefill_token_ids, mm_features=new_req_data.mm_features, ) From 9db9d804a910e87cae2d8e52f5aae5ce13283085 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 20:23:27 +0000 Subject: [PATCH 03/12] undo changes to hunyuan model code Signed-off-by: Santino Ramos --- vllm/model_executor/models/hunyuan_vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index 861ff517682d..b6fda25ddfbb 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -999,8 +999,8 @@ def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: From f4b6f6a92c4221ebb90aafde59485628455da8d7 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 20:45:38 +0000 Subject: [PATCH 04/12] move intermediate_tensors passing to mrv2 code Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/cudagraph_utils.py | 2 ++ vllm/v1/worker/gpu/model_runner.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 202470c7bac0..dc057f19d52e 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -320,6 +320,8 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None: model_inputs = { "input_ids": input_buffers.input_ids[:num_tokens], "positions": input_buffers.positions[:num_tokens], + # NOTE(santinor): PP currently uses eager mode only. + "intermediate_tensors": None, **model_state.prepare_dummy_inputs(num_reqs, num_tokens), } model_output = model(**model_inputs) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index c4fe833ff30e..e710e74b3205 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -952,6 +952,7 @@ def execute_model( "input_ids": input_batch.input_ids, "positions": input_batch.positions, "inputs_embeds": inputs_embeds, + "intermediate_tensors": intermediate_tensors, # NOTE: Values returned by `prepare_inputs` will override the default # values above. **self.model_state.prepare_inputs(input_batch, self.req_states), @@ -960,7 +961,6 @@ def execute_model( # Update for non-first PP ranks. model_inputs["input_ids"] = None model_inputs["inputs_embeds"] = None - model_inputs["intermediate_tensors"] = intermediate_tensors # Run model. if batch_desc.cg_mode == CUDAGraphMode.FULL: From 49c75170e9ab1fcebb7b7334082cdb8886f70803 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 21:29:12 +0000 Subject: [PATCH 05/12] early exit common case Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/model_states/default.py | 28 ++++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 9bd75fba30d2..8dd7a16c13ec 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -131,6 +131,9 @@ def get_mm_embeddings( def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState ) -> dict[str, torch.Tensor | None]: + if not self.uses_mrope and not self.uses_xdrope_dim > 0: + return {} # Common case (1D positions). + if self.uses_mrope: # Prepare M-RoPE positions. self.mrope_state.prepare_mrope_positions( @@ -143,19 +146,18 @@ def prepare_inputs( :, : input_batch.num_tokens_after_padding ] return {"positions": mrope_positions} - elif self.uses_xdrope_dim > 0: - # Prepare XD-RoPE positions. - self.xdrope_state.prepare_xdrope_positions( - input_batch.idx_mapping, - input_batch.query_start_loc, - req_states.prefill_len.gpu, - req_states.num_computed_tokens.gpu, - ) - xdrope_positions = self.xdrope_state.xdrope_positions[ - :, : input_batch.num_tokens_after_padding - ] - return {"positions": xdrope_positions} - return {} # Common case (1D positions). + + # Prepare XD-RoPE positions. + self.xdrope_state.prepare_xdrope_positions( + input_batch.idx_mapping, + input_batch.query_start_loc, + req_states.prefill_len.gpu, + req_states.num_computed_tokens.gpu, + ) + xdrope_positions = self.xdrope_state.xdrope_positions[ + :, : input_batch.num_tokens_after_padding + ] + return {"positions": xdrope_positions} def prepare_dummy_inputs( self, num_reqs: int, num_tokens: int From 9d7fe65bc2822fb6f1ba844fa7fcdfdf6e980945 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 21:50:54 +0000 Subject: [PATCH 06/12] remove uses_xdrope_dim redundant bool Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/model_runner.py | 1 + vllm/v1/worker/gpu/model_states/default.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index e710e74b3205..b95b8111f13e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -961,6 +961,7 @@ def execute_model( # Update for non-first PP ranks. model_inputs["input_ids"] = None model_inputs["inputs_embeds"] = None + assert intermediate_tensors is not None # Run model. if batch_desc.cg_mode == CUDAGraphMode.FULL: diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 8dd7a16c13ec..785c2c36c73e 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -61,15 +61,16 @@ def __init__( max_model_len=self.max_model_len, device=self.device, ) - self.uses_xdrope_dim = self.model_config.uses_xdrope_dim - if self.uses_xdrope_dim > 0: + if self.model_config.uses_xdrope_dim > 0: self.xdrope_state = XDRopeState( - uses_xdrope_dim=self.uses_xdrope_dim, + uses_xdrope_dim=self.model_config.uses_xdrope_dim, max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, ) + else: + self.xdrope_state = None def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: if self.uses_mrope: @@ -81,7 +82,7 @@ def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: new_req_data.prefill_token_ids, mm_features=new_req_data.mm_features, ) - elif self.uses_xdrope_dim > 0: + elif self.xdrope_state is not None: # Pre-compute XD-RoPE positions for prefill. assert new_req_data.prefill_token_ids is not None self.xdrope_state.init_prefill_xdrope_positions( @@ -94,7 +95,7 @@ def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: def apply_staged_writes(self) -> None: if self.uses_mrope: self.mrope_state.apply_staged_writes() - elif self.uses_xdrope_dim > 0: + elif self.xdrope_state is not None: self.xdrope_state.apply_staged_writes() def get_mm_embeddings( @@ -131,7 +132,7 @@ def get_mm_embeddings( def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState ) -> dict[str, torch.Tensor | None]: - if not self.uses_mrope and not self.uses_xdrope_dim > 0: + if not self.uses_mrope and self.xdrope_state is None: return {} # Common case (1D positions). if self.uses_mrope: @@ -169,7 +170,7 @@ def prepare_dummy_inputs( if self.uses_mrope: mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] model_inputs["positions"] = mrope_positions - elif self.uses_xdrope_dim > 0: + elif self.xdrope_state is not None: xdrope_positions = self.xdrope_state.xdrope_positions[:, :num_tokens] model_inputs["positions"] = xdrope_positions return model_inputs From 7d790f91c9f740829ae88e473cbcb3aac065f559 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Wed, 11 Mar 2026 22:27:14 +0000 Subject: [PATCH 07/12] fix mypy Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/model_states/default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 8a43e4dc367c..3eca30c751bb 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -61,6 +61,7 @@ def __init__( max_model_len=self.max_model_len, device=self.device, ) + self.xdrope_state: XDRopeState | None = None if self.model_config.uses_xdrope_dim > 0: self.xdrope_state = XDRopeState( uses_xdrope_dim=self.model_config.uses_xdrope_dim, @@ -69,8 +70,6 @@ def __init__( max_model_len=self.max_model_len, device=self.device, ) - else: - self.xdrope_state = None def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: if self.uses_mrope: @@ -149,6 +148,7 @@ def prepare_inputs( return {"positions": mrope_positions} # Prepare XD-RoPE positions. + assert self.xdrope_state is not None self.xdrope_state.prepare_xdrope_positions( input_batch.idx_mapping, input_batch.query_start_loc, From 19afa5faf1f40e39b3d9fbd6c60f336d969cb8df Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Thu, 12 Mar 2026 18:30:20 +0000 Subject: [PATCH 08/12] add rope interface and rope folder under mm Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/mm/rope/__init__.py | 2 + vllm/v1/worker/gpu/mm/rope/interface.py | 38 ++++++++++ .../gpu/mm/{mrope_utils.py => rope/mrope.py} | 16 +++- .../mm/{xdrope_utils.py => rope/xdrope.py} | 16 +++- vllm/v1/worker/gpu/model_states/default.py | 74 ++++++------------- 5 files changed, 86 insertions(+), 60 deletions(-) create mode 100644 vllm/v1/worker/gpu/mm/rope/__init__.py create mode 100644 vllm/v1/worker/gpu/mm/rope/interface.py rename vllm/v1/worker/gpu/mm/{mrope_utils.py => rope/mrope.py} (92%) rename vllm/v1/worker/gpu/mm/{xdrope_utils.py => rope/xdrope.py} (91%) diff --git a/vllm/v1/worker/gpu/mm/rope/__init__.py b/vllm/v1/worker/gpu/mm/rope/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/vllm/v1/worker/gpu/mm/rope/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/v1/worker/gpu/mm/rope/interface.py b/vllm/v1/worker/gpu/mm/rope/interface.py new file mode 100644 index 000000000000..cd01fdf9e122 --- /dev/null +++ b/vllm/v1/worker/gpu/mm/rope/interface.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + + +class RopeState(ABC): + """Shared interface for multi-dimensional RoPE variants (M-RoPE, XD-RoPE). + + Implementations pre-compute positions during prefill and prepare + per-step position tensors for the model forward pass. + """ + + @abstractmethod + def init_prefill_positions( + self, + req_idx: int, + model: nn.Module, + prefill_token_ids: list[int], + mm_features: list, + ) -> None: ... + + @abstractmethod + def apply_staged_writes(self) -> None: ... + + @abstractmethod + def prepare_positions( + self, + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + prefill_lens: torch.Tensor, + num_computed_tokens: torch.Tensor, + ) -> None: ... + + @abstractmethod + def get_positions(self, num_tokens: int) -> torch.Tensor: ... diff --git a/vllm/v1/worker/gpu/mm/mrope_utils.py b/vllm/v1/worker/gpu/mm/rope/mrope.py similarity index 92% rename from vllm/v1/worker/gpu/mm/mrope_utils.py rename to vllm/v1/worker/gpu/mm/rope/mrope.py index 7e27f28bab93..303947b109f9 100644 --- a/vllm/v1/worker/gpu/mm/mrope_utils.py +++ b/vllm/v1/worker/gpu/mm/rope/mrope.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + import torch +import torch.nn as nn from vllm.model_executor.models.interfaces import SupportsMRoPE from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor +from vllm.v1.worker.gpu.mm.rope.interface import RopeState -class MRopeState: +class MRopeState(RopeState): def __init__( self, max_num_reqs: int, @@ -43,13 +47,14 @@ def __init__( (3, max_num_tokens + 1), dtype=torch.int64, device=device ) - def init_prefill_mrope_positions( + def init_prefill_positions( self, req_idx: int, - mrope_model: SupportsMRoPE, + model: nn.Module, prefill_token_ids: list[int], mm_features: list, ) -> None: + mrope_model = cast(SupportsMRoPE, model) prefill_mrope_positions, prefill_mrope_delta = ( mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features) ) @@ -62,7 +67,10 @@ def apply_staged_writes(self) -> None: self.prefill_mrope_positions.apply_write() self.prefill_mrope_delta.copy_to_uva() - def prepare_mrope_positions( + def get_positions(self, num_tokens: int) -> torch.Tensor: + return self.mrope_positions[:, :num_tokens] + + def prepare_positions( self, idx_mapping: torch.Tensor, query_start_loc: torch.Tensor, diff --git a/vllm/v1/worker/gpu/mm/xdrope_utils.py b/vllm/v1/worker/gpu/mm/rope/xdrope.py similarity index 91% rename from vllm/v1/worker/gpu/mm/xdrope_utils.py rename to vllm/v1/worker/gpu/mm/rope/xdrope.py index f753064ff5f3..2d38a6846d3f 100644 --- a/vllm/v1/worker/gpu/mm/xdrope_utils.py +++ b/vllm/v1/worker/gpu/mm/rope/xdrope.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + import torch +import torch.nn as nn from vllm.model_executor.models.interfaces import SupportsXDRoPE from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor +from vllm.v1.worker.gpu.mm.rope.interface import RopeState -class XDRopeState: +class XDRopeState(RopeState): def __init__( self, uses_xdrope_dim: int, @@ -34,13 +38,14 @@ def __init__( (uses_xdrope_dim, max_num_tokens + 1), dtype=torch.int64, device=device ) - def init_prefill_xdrope_positions( + def init_prefill_positions( self, req_idx: int, - xdrope_model: SupportsXDRoPE, + model: nn.Module, prefill_token_ids: list[int], mm_features: list, ) -> None: + xdrope_model = cast(SupportsXDRoPE, model) prefill_xdrope_positions = xdrope_model.get_xdrope_input_positions( prefill_token_ids, mm_features ) @@ -53,7 +58,10 @@ def init_prefill_xdrope_positions( def apply_staged_writes(self) -> None: self.prefill_xdrope_positions.apply_write() - def prepare_xdrope_positions( + def get_positions(self, num_tokens: int) -> torch.Tensor: + return self.xdrope_positions[:, :num_tokens] + + def prepare_positions( self, idx_mapping: torch.Tensor, query_start_loc: torch.Tensor, diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 3eca30c751bb..eda54cba82e0 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -13,8 +13,9 @@ from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner -from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState -from vllm.v1.worker.gpu.mm.xdrope_utils import XDRopeState +from vllm.v1.worker.gpu.mm.rope.interface import RopeState +from vllm.v1.worker.gpu.mm.rope.mrope import MRopeState +from vllm.v1.worker.gpu.mm.rope.xdrope import XDRopeState from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup @@ -53,17 +54,16 @@ def __init__( device=self.device, ) - self.uses_mrope = self.model_config.uses_mrope - if self.uses_mrope: - self.mrope_state = MRopeState( + self.rope_state: RopeState | None = None + if self.model_config.uses_mrope: + self.rope_state = MRopeState( max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, ) - self.xdrope_state: XDRopeState | None = None - if self.model_config.uses_xdrope_dim > 0: - self.xdrope_state = XDRopeState( + elif self.model_config.uses_xdrope_dim > 0: + self.rope_state = XDRopeState( uses_xdrope_dim=self.model_config.uses_xdrope_dim, max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, @@ -72,30 +72,18 @@ def __init__( ) def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: - if self.uses_mrope: - # Pre-compute M-RoPE positions for prefill. + if self.rope_state is not None: assert new_req_data.prefill_token_ids is not None - self.mrope_state.init_prefill_mrope_positions( + self.rope_state.init_prefill_positions( req_index, - self.model, # type: ignore - new_req_data.prefill_token_ids, - mm_features=new_req_data.mm_features, - ) - elif self.xdrope_state is not None: - # Pre-compute XD-RoPE positions for prefill. - assert new_req_data.prefill_token_ids is not None - self.xdrope_state.init_prefill_xdrope_positions( - req_index, - self.model, # type: ignore + self.model, new_req_data.prefill_token_ids, mm_features=new_req_data.mm_features, ) def apply_staged_writes(self) -> None: - if self.uses_mrope: - self.mrope_state.apply_staged_writes() - elif self.xdrope_state is not None: - self.xdrope_state.apply_staged_writes() + if self.rope_state is not None: + self.rope_state.apply_staged_writes() def get_mm_embeddings( self, @@ -131,46 +119,28 @@ def get_mm_embeddings( def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState ) -> dict[str, torch.Tensor | None]: - if not self.uses_mrope and self.xdrope_state is None: + if self.rope_state is None: return {} # Common case (1D positions). - if self.uses_mrope: - # Prepare M-RoPE positions. - self.mrope_state.prepare_mrope_positions( - input_batch.idx_mapping, - input_batch.query_start_loc, - req_states.prefill_len.gpu, - req_states.num_computed_tokens.gpu, - ) - mrope_positions = self.mrope_state.mrope_positions[ - :, : input_batch.num_tokens_after_padding - ] - return {"positions": mrope_positions} - - # Prepare XD-RoPE positions. - assert self.xdrope_state is not None - self.xdrope_state.prepare_xdrope_positions( + self.rope_state.prepare_positions( input_batch.idx_mapping, input_batch.query_start_loc, req_states.prefill_len.gpu, req_states.num_computed_tokens.gpu, ) - xdrope_positions = self.xdrope_state.xdrope_positions[ - :, : input_batch.num_tokens_after_padding - ] - return {"positions": xdrope_positions} + return { + "positions": self.rope_state.get_positions( + input_batch.num_tokens_after_padding + ) + } def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: model_inputs = {} if self.supports_mm_inputs: inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens] model_inputs["inputs_embeds"] = inputs_embeds - if self.uses_mrope: - mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] - model_inputs["positions"] = mrope_positions - elif self.xdrope_state is not None: - xdrope_positions = self.xdrope_state.xdrope_positions[:, :num_tokens] - model_inputs["positions"] = xdrope_positions + if self.rope_state is not None: + model_inputs["positions"] = self.rope_state.get_positions(num_tokens) return model_inputs def prepare_attn( From 4f3b6fbeedcfbe30cee611ba49e5e9156356d45a Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Thu, 12 Mar 2026 20:59:28 +0000 Subject: [PATCH 09/12] incorporate pr feedback Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/cudagraph_utils.py | 3 ++- vllm/v1/worker/gpu/mm/rope/interface.py | 12 ++++++++++++ vllm/v1/worker/gpu/mm/rope/mrope.py | 5 +---- vllm/v1/worker/gpu/mm/rope/xdrope.py | 5 +---- vllm/v1/worker/gpu/model_states/default.py | 7 ++----- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index bbcf5c247bb5..2b94362a808f 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -320,7 +320,8 @@ def forward_fn(cg_mode: CUDAGraphMode) -> None: model_inputs = { "input_ids": input_buffers.input_ids[:num_tokens], "positions": input_buffers.positions[:num_tokens], - # NOTE(santinor): PP currently uses eager mode only. + # TODO: Pass intermediate_tensors for PP CUDA graph + # support (https://github.com/vllm-project/vllm/pull/35162). "intermediate_tensors": None, **model_state.prepare_dummy_inputs(num_reqs, num_tokens), } diff --git a/vllm/v1/worker/gpu/mm/rope/interface.py b/vllm/v1/worker/gpu/mm/rope/interface.py index cd01fdf9e122..4fd12dce0c80 100644 --- a/vllm/v1/worker/gpu/mm/rope/interface.py +++ b/vllm/v1/worker/gpu/mm/rope/interface.py @@ -13,6 +13,18 @@ class RopeState(ABC): per-step position tensors for the model forward pass. """ + def __init__( + self, + max_num_reqs: int, + max_num_tokens: int, + max_model_len: int, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens + self.max_model_len = max_model_len + self.device = device + @abstractmethod def init_prefill_positions( self, diff --git a/vllm/v1/worker/gpu/mm/rope/mrope.py b/vllm/v1/worker/gpu/mm/rope/mrope.py index 303947b109f9..1b7f9f8cfc92 100644 --- a/vllm/v1/worker/gpu/mm/rope/mrope.py +++ b/vllm/v1/worker/gpu/mm/rope/mrope.py @@ -19,10 +19,7 @@ def __init__( max_model_len: int, device: torch.device, ): - self.max_num_reqs = max_num_reqs - self.max_num_tokens = max_num_tokens - self.max_model_len = max_model_len - self.device = device + super().__init__(max_num_reqs, max_num_tokens, max_model_len, device) # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # wasting a lot of CPU memory. diff --git a/vllm/v1/worker/gpu/mm/rope/xdrope.py b/vllm/v1/worker/gpu/mm/rope/xdrope.py index 2d38a6846d3f..a9d0b340741e 100644 --- a/vllm/v1/worker/gpu/mm/rope/xdrope.py +++ b/vllm/v1/worker/gpu/mm/rope/xdrope.py @@ -20,11 +20,8 @@ def __init__( max_model_len: int, device: torch.device, ): + super().__init__(max_num_reqs, max_num_tokens, max_model_len, device) self.uses_xdrope_dim = uses_xdrope_dim - self.max_num_reqs = max_num_reqs - self.max_num_tokens = max_num_tokens - self.max_model_len = max_model_len - self.device = device # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # wasting a lot of CPU memory. diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index eda54cba82e0..3693a091eea1 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -128,11 +128,8 @@ def prepare_inputs( req_states.prefill_len.gpu, req_states.num_computed_tokens.gpu, ) - return { - "positions": self.rope_state.get_positions( - input_batch.num_tokens_after_padding - ) - } + positions = self.rope_state.get_positions(input_batch.num_tokens_after_padding) + return {"positions": positions} def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: model_inputs = {} From 728f40c6d9f410c3cf550b9645fb688d90772fac Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Fri, 13 Mar 2026 19:29:30 +0000 Subject: [PATCH 10/12] unified triton kernel impl Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/mm/rope.py | 167 +++++++++++++++++++++ vllm/v1/worker/gpu/mm/rope/__init__.py | 2 - vllm/v1/worker/gpu/mm/rope/interface.py | 50 ------ vllm/v1/worker/gpu/mm/rope/mrope.py | 141 ----------------- vllm/v1/worker/gpu/mm/rope/xdrope.py | 133 ---------------- vllm/v1/worker/gpu/model_states/default.py | 13 +- 6 files changed, 174 insertions(+), 332 deletions(-) create mode 100644 vllm/v1/worker/gpu/mm/rope.py delete mode 100644 vllm/v1/worker/gpu/mm/rope/__init__.py delete mode 100644 vllm/v1/worker/gpu/mm/rope/interface.py delete mode 100644 vllm/v1/worker/gpu/mm/rope/mrope.py delete mode 100644 vllm/v1/worker/gpu/mm/rope/xdrope.py diff --git a/vllm/v1/worker/gpu/mm/rope.py b/vllm/v1/worker/gpu/mm/rope.py new file mode 100644 index 000000000000..5e87d10f65bf --- /dev/null +++ b/vllm/v1/worker/gpu/mm/rope.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + +import torch +import torch.nn as nn + +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsXDRoPE +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor + + +class RopeState: + """Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE). + + M-RoPE: 3 dims, uses position delta for decode. + XD-RoPE: 3 or 4 dims, no delta (decode uses orig_pos for all dims). + + NOTE: `positions` is implemented with one additional dummy position on + purpose to make it non-contiguous so that it can work with torch compile. + See detailed explanation in + https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + NOTE: When M-RoPE is enabled, position ids are 3D regardless of the + modality of inputs. For text-only inputs, each dimension has identical + position IDs, making M-RoPE functionally equivalent to 1D-RoPE. + See page 5 of https://arxiv.org/abs/2409.12191 + """ + + def __init__( + self, + num_dims: int, + has_delta: bool, + max_num_reqs: int, + max_num_tokens: int, + max_model_len: int, + device: torch.device, + ): + self.num_dims = num_dims + self.has_delta = has_delta + self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens + self.max_model_len = max_model_len + self.device = device + + # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) + # wasting a lot of CPU memory. + self.prefill_positions = StagedWriteTensor( + (max_num_reqs * num_dims, max_model_len), + dtype=torch.int32, + device=device, + uva_instead_of_gpu=True, + ) + self.positions = torch.zeros( + (num_dims, max_num_tokens + 1), dtype=torch.int64, device=device + ) + + if has_delta: + self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + + def init_prefill_positions( + self, + req_idx: int, + model: nn.Module, + prefill_token_ids: list[int], + mm_features: list, + ) -> None: + if self.has_delta: + mrope_model = cast(SupportsMRoPE, model) + prefill_positions, delta = mrope_model.get_mrope_input_positions( + prefill_token_ids, mm_features + ) + self.prefill_delta.np[req_idx] = delta + else: + xdrope_model = cast(SupportsXDRoPE, model) + prefill_positions = xdrope_model.get_xdrope_input_positions( + prefill_token_ids, mm_features + ) + + for i in range(self.num_dims): + pos = prefill_positions[i].tolist() + self.prefill_positions.stage_write(self.num_dims * req_idx + i, 0, pos) + + def apply_staged_writes(self) -> None: + self.prefill_positions.apply_write() + if self.has_delta: + self.prefill_delta.copy_to_uva() + + def get_positions(self, num_tokens: int) -> torch.Tensor: + return self.positions[:, :num_tokens] + + def prepare_positions( + self, + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + prefill_lens: torch.Tensor, + num_computed_tokens: torch.Tensor, + ) -> None: + num_reqs = idx_mapping.shape[0] + _prepare_rope_positions_kernel[(num_reqs,)]( + self.positions, + self.positions.stride(0), + self.prefill_positions.gpu, + self.num_dims * self.max_model_len, + self.max_model_len, + self.prefill_delta.gpu if self.has_delta else idx_mapping, + idx_mapping, + query_start_loc, + prefill_lens, + num_computed_tokens, + BLOCK_SIZE=1024, + NUM_DIMS=self.num_dims, + HAS_DELTA=self.has_delta, + ) + + +@triton.jit +def _prepare_rope_positions_kernel( + positions_ptr, + positions_stride, + prefill_positions_ptr, + prefill_positions_stride0, + prefill_positions_stride1, + prefill_delta_ptr, + idx_mapping_ptr, + query_start_loc_ptr, + prefill_lens_ptr, + num_computed_tokens_ptr, + BLOCK_SIZE: tl.constexpr, + NUM_DIMS: tl.constexpr, + HAS_DELTA: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + prefill_len = tl.load(prefill_lens_ptr + req_state_idx) + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + is_prefill = num_computed < prefill_len + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + if HAS_DELTA: + delta = tl.load(prefill_delta_ptr + req_state_idx) + + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + orig_pos = num_computed + block + + for j in tl.static_range(NUM_DIMS): + if is_prefill: + pos = tl.load( + prefill_positions_ptr + + req_state_idx * prefill_positions_stride0 + + j * prefill_positions_stride1 + + orig_pos, + mask=mask, + ) + else: + pos = orig_pos + delta if HAS_DELTA else orig_pos + tl.store( + positions_ptr + j * positions_stride + query_start + block, + pos, + mask=mask, + ) diff --git a/vllm/v1/worker/gpu/mm/rope/__init__.py b/vllm/v1/worker/gpu/mm/rope/__init__.py deleted file mode 100644 index 208f01a7cb5e..000000000000 --- a/vllm/v1/worker/gpu/mm/rope/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/v1/worker/gpu/mm/rope/interface.py b/vllm/v1/worker/gpu/mm/rope/interface.py deleted file mode 100644 index 4fd12dce0c80..000000000000 --- a/vllm/v1/worker/gpu/mm/rope/interface.py +++ /dev/null @@ -1,50 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod - -import torch -import torch.nn as nn - - -class RopeState(ABC): - """Shared interface for multi-dimensional RoPE variants (M-RoPE, XD-RoPE). - - Implementations pre-compute positions during prefill and prepare - per-step position tensors for the model forward pass. - """ - - def __init__( - self, - max_num_reqs: int, - max_num_tokens: int, - max_model_len: int, - device: torch.device, - ): - self.max_num_reqs = max_num_reqs - self.max_num_tokens = max_num_tokens - self.max_model_len = max_model_len - self.device = device - - @abstractmethod - def init_prefill_positions( - self, - req_idx: int, - model: nn.Module, - prefill_token_ids: list[int], - mm_features: list, - ) -> None: ... - - @abstractmethod - def apply_staged_writes(self) -> None: ... - - @abstractmethod - def prepare_positions( - self, - idx_mapping: torch.Tensor, - query_start_loc: torch.Tensor, - prefill_lens: torch.Tensor, - num_computed_tokens: torch.Tensor, - ) -> None: ... - - @abstractmethod - def get_positions(self, num_tokens: int) -> torch.Tensor: ... diff --git a/vllm/v1/worker/gpu/mm/rope/mrope.py b/vllm/v1/worker/gpu/mm/rope/mrope.py deleted file mode 100644 index 1b7f9f8cfc92..000000000000 --- a/vllm/v1/worker/gpu/mm/rope/mrope.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import cast - -import torch -import torch.nn as nn - -from vllm.model_executor.models.interfaces import SupportsMRoPE -from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor -from vllm.v1.worker.gpu.mm.rope.interface import RopeState - - -class MRopeState(RopeState): - def __init__( - self, - max_num_reqs: int, - max_num_tokens: int, - max_model_len: int, - device: torch.device, - ): - super().__init__(max_num_reqs, max_num_tokens, max_model_len, device) - - # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) - # wasting a lot of CPU memory. - self.prefill_mrope_positions = StagedWriteTensor( - (max_num_reqs * 3, max_model_len), - dtype=torch.int32, - device=device, - uva_instead_of_gpu=True, - ) - self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) - - # NOTE: `mrope_positions` is implemented with one additional dummy - # position on purpose to make it non-contiguous so that it can work - # with torch compile. - # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 - # NOTE: When M-RoPE is enabled, position ids are 3D regardless of - # the modality of inputs. For text-only inputs, each dimension has - # identical position IDs, making M-RoPE functionally equivalent to - # 1D-RoPE. - # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros( - (3, max_num_tokens + 1), dtype=torch.int64, device=device - ) - - def init_prefill_positions( - self, - req_idx: int, - model: nn.Module, - prefill_token_ids: list[int], - mm_features: list, - ) -> None: - mrope_model = cast(SupportsMRoPE, model) - prefill_mrope_positions, prefill_mrope_delta = ( - mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features) - ) - for i in range(3): - pos = prefill_mrope_positions[i].tolist() - self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos) - self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta - - def apply_staged_writes(self) -> None: - self.prefill_mrope_positions.apply_write() - self.prefill_mrope_delta.copy_to_uva() - - def get_positions(self, num_tokens: int) -> torch.Tensor: - return self.mrope_positions[:, :num_tokens] - - def prepare_positions( - self, - idx_mapping: torch.Tensor, - query_start_loc: torch.Tensor, - prefill_lens: torch.Tensor, - num_computed_tokens: torch.Tensor, - ) -> None: - num_reqs = idx_mapping.shape[0] - _prepare_mrope_positions_kernel[(num_reqs,)]( - self.mrope_positions, - self.mrope_positions.stride(0), - self.prefill_mrope_positions.gpu, - 3 * self.max_model_len, - self.max_model_len, - self.prefill_mrope_delta.gpu, - idx_mapping, - query_start_loc, - prefill_lens, - num_computed_tokens, - BLOCK_SIZE=1024, - ) - - -@triton.jit -def _prepare_mrope_positions_kernel( - mrope_positions_ptr, - mrope_positions_stride, - prefill_mrope_positions_ptr, - prefill_mrope_positions_stride0, - prefill_mrope_positions_stride1, - prefill_mrope_delta_ptr, - idx_mapping_ptr, - query_start_loc_ptr, - prefill_lens_ptr, - num_computed_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - batch_idx = tl.program_id(0) - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) - - prefill_len = tl.load(prefill_lens_ptr + req_state_idx) - num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) - is_prefill = num_computed < prefill_len - - query_start = tl.load(query_start_loc_ptr + batch_idx) - query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - query_len = query_end - query_start - - mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx) - for i in range(0, query_len, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - mask = block < query_len - orig_pos = num_computed + block - - for j in tl.static_range(3): - if is_prefill: - # Read from pre-computed M-RoPE positions. - pos = tl.load( - prefill_mrope_positions_ptr - + req_state_idx * prefill_mrope_positions_stride0 - + j * prefill_mrope_positions_stride1 - + orig_pos, - mask=mask, - ) - else: - # Apply M-RoPE delta. - pos = orig_pos + mrope_delta - tl.store( - mrope_positions_ptr + j * mrope_positions_stride + query_start + block, - pos, - mask=mask, - ) diff --git a/vllm/v1/worker/gpu/mm/rope/xdrope.py b/vllm/v1/worker/gpu/mm/rope/xdrope.py deleted file mode 100644 index a9d0b340741e..000000000000 --- a/vllm/v1/worker/gpu/mm/rope/xdrope.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import cast - -import torch -import torch.nn as nn - -from vllm.model_executor.models.interfaces import SupportsXDRoPE -from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor -from vllm.v1.worker.gpu.mm.rope.interface import RopeState - - -class XDRopeState(RopeState): - def __init__( - self, - uses_xdrope_dim: int, - max_num_reqs: int, - max_num_tokens: int, - max_model_len: int, - device: torch.device, - ): - super().__init__(max_num_reqs, max_num_tokens, max_model_len, device) - self.uses_xdrope_dim = uses_xdrope_dim - - # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) - # wasting a lot of CPU memory. - self.prefill_xdrope_positions = StagedWriteTensor( - (max_num_reqs * uses_xdrope_dim, max_model_len), - dtype=torch.int32, - device=device, - uva_instead_of_gpu=True, - ) - self.xdrope_positions = torch.zeros( - (uses_xdrope_dim, max_num_tokens + 1), dtype=torch.int64, device=device - ) - - def init_prefill_positions( - self, - req_idx: int, - model: nn.Module, - prefill_token_ids: list[int], - mm_features: list, - ) -> None: - xdrope_model = cast(SupportsXDRoPE, model) - prefill_xdrope_positions = xdrope_model.get_xdrope_input_positions( - prefill_token_ids, mm_features - ) - for i in range(self.uses_xdrope_dim): - pos = prefill_xdrope_positions[i].tolist() - self.prefill_xdrope_positions.stage_write( - self.uses_xdrope_dim * req_idx + i, 0, pos - ) - - def apply_staged_writes(self) -> None: - self.prefill_xdrope_positions.apply_write() - - def get_positions(self, num_tokens: int) -> torch.Tensor: - return self.xdrope_positions[:, :num_tokens] - - def prepare_positions( - self, - idx_mapping: torch.Tensor, - query_start_loc: torch.Tensor, - prefill_lens: torch.Tensor, - num_computed_tokens: torch.Tensor, - ) -> None: - num_reqs = idx_mapping.shape[0] - _prepare_xdrope_positions_kernel[(num_reqs,)]( - self.xdrope_positions, - self.xdrope_positions.stride(0), - self.prefill_xdrope_positions.gpu, - self.uses_xdrope_dim * self.max_model_len, - self.max_model_len, - idx_mapping, - query_start_loc, - prefill_lens, - num_computed_tokens, - BLOCK_SIZE=1024, - USES_XDROPE_DIM=self.uses_xdrope_dim, - ) - - -@triton.jit -def _prepare_xdrope_positions_kernel( - xdrope_positions_ptr, - xdrope_positions_stride, - prefill_xdrope_positions_ptr, - prefill_xdrope_positions_stride0, - prefill_xdrope_positions_stride1, - idx_mapping_ptr, - query_start_loc_ptr, - prefill_lens_ptr, - num_computed_tokens_ptr, - BLOCK_SIZE: tl.constexpr, - USES_XDROPE_DIM: tl.constexpr, -): - batch_idx = tl.program_id(0) - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) - - prefill_len = tl.load(prefill_lens_ptr + req_state_idx) - num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) - is_prefill = num_computed < prefill_len - - query_start = tl.load(query_start_loc_ptr + batch_idx) - query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - query_len = query_end - query_start - - for i in range(0, query_len, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - mask = block < query_len - orig_pos = num_computed + block - - for j in tl.static_range(USES_XDROPE_DIM): - if is_prefill: - # Read from pre-computed XD-RoPE positions. - pos = tl.load( - prefill_xdrope_positions_ptr - + req_state_idx * prefill_xdrope_positions_stride0 - + j * prefill_xdrope_positions_stride1 - + orig_pos, - mask=mask, - ) - else: - pos = orig_pos - tl.store( - xdrope_positions_ptr - + j * xdrope_positions_stride - + query_start - + block, - pos, - mask=mask, - ) diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 3693a091eea1..27cff056b82f 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -13,9 +13,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner -from vllm.v1.worker.gpu.mm.rope.interface import RopeState -from vllm.v1.worker.gpu.mm.rope.mrope import MRopeState -from vllm.v1.worker.gpu.mm.rope.xdrope import XDRopeState +from vllm.v1.worker.gpu.mm.rope import RopeState from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup @@ -56,15 +54,18 @@ def __init__( self.rope_state: RopeState | None = None if self.model_config.uses_mrope: - self.rope_state = MRopeState( + self.rope_state = RopeState( + num_dims=3, + has_delta=True, max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, ) elif self.model_config.uses_xdrope_dim > 0: - self.rope_state = XDRopeState( - uses_xdrope_dim=self.model_config.uses_xdrope_dim, + self.rope_state = RopeState( + num_dims=self.model_config.uses_xdrope_dim, + has_delta=False, max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, max_model_len=self.max_model_len, From 1ba54046eaeececda5c1902c2830a3d9fcbaa9d1 Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Fri, 13 Mar 2026 21:49:53 +0000 Subject: [PATCH 11/12] cleanups Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/mm/rope.py | 51 +++++++++++++++++----- vllm/v1/worker/gpu/model_states/default.py | 29 ++++-------- 2 files changed, 49 insertions(+), 31 deletions(-) diff --git a/vllm/v1/worker/gpu/mm/rope.py b/vllm/v1/worker/gpu/mm/rope.py index 5e87d10f65bf..a598f306f1de 100644 --- a/vllm/v1/worker/gpu/mm/rope.py +++ b/vllm/v1/worker/gpu/mm/rope.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +from vllm.config import ModelConfig from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsXDRoPE from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor @@ -14,7 +15,7 @@ class RopeState: """Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE). M-RoPE: 3 dims, uses position delta for decode. - XD-RoPE: 3 or 4 dims, no delta (decode uses orig_pos for all dims). + XD-RoPE: 3 or 4 dims, delta is 0 (decode uses orig_pos for all dims). NOTE: `positions` is implemented with one additional dummy position on purpose to make it non-contiguous so that it can work with torch compile. @@ -55,8 +56,8 @@ def __init__( (num_dims, max_num_tokens + 1), dtype=torch.int64, device=device ) - if has_delta: - self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + # Delta is non-zero for M-RoPE, always 0 for XD-RoPE. + self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) def init_prefill_positions( self, @@ -83,8 +84,7 @@ def init_prefill_positions( def apply_staged_writes(self) -> None: self.prefill_positions.apply_write() - if self.has_delta: - self.prefill_delta.copy_to_uva() + self.prefill_delta.copy_to_uva() def get_positions(self, num_tokens: int) -> torch.Tensor: return self.positions[:, :num_tokens] @@ -103,17 +103,48 @@ def prepare_positions( self.prefill_positions.gpu, self.num_dims * self.max_model_len, self.max_model_len, - self.prefill_delta.gpu if self.has_delta else idx_mapping, + self.prefill_delta.gpu, idx_mapping, query_start_loc, prefill_lens, num_computed_tokens, BLOCK_SIZE=1024, NUM_DIMS=self.num_dims, - HAS_DELTA=self.has_delta, ) +def get_rope_state( + model_config: ModelConfig, + model: nn.Module, + max_num_reqs: int, + max_num_tokens: int, + max_model_len: int, + device: torch.device, +) -> RopeState | None: + """Create a RopeState if the model uses multi-dimensional RoPE.""" + if model_config.uses_mrope: + assert isinstance(model, SupportsMRoPE) + return RopeState( + num_dims=3, + has_delta=True, + max_num_reqs=max_num_reqs, + max_num_tokens=max_num_tokens, + max_model_len=max_model_len, + device=device, + ) + elif model_config.uses_xdrope_dim > 0: + assert isinstance(model, SupportsXDRoPE) + return RopeState( + num_dims=model_config.uses_xdrope_dim, + has_delta=False, + max_num_reqs=max_num_reqs, + max_num_tokens=max_num_tokens, + max_model_len=max_model_len, + device=device, + ) + return None + + @triton.jit def _prepare_rope_positions_kernel( positions_ptr, @@ -128,7 +159,6 @@ def _prepare_rope_positions_kernel( num_computed_tokens_ptr, BLOCK_SIZE: tl.constexpr, NUM_DIMS: tl.constexpr, - HAS_DELTA: tl.constexpr, ): batch_idx = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + batch_idx) @@ -141,8 +171,7 @@ def _prepare_rope_positions_kernel( query_end = tl.load(query_start_loc_ptr + batch_idx + 1) query_len = query_end - query_start - if HAS_DELTA: - delta = tl.load(prefill_delta_ptr + req_state_idx) + delta = tl.load(prefill_delta_ptr + req_state_idx) for i in range(0, query_len, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) @@ -159,7 +188,7 @@ def _prepare_rope_positions_kernel( mask=mask, ) else: - pos = orig_pos + delta if HAS_DELTA else orig_pos + pos = orig_pos + delta tl.store( positions_ptr + j * positions_stride + query_start + block, pos, diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 27cff056b82f..104e4c1948b5 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -13,7 +13,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner -from vllm.v1.worker.gpu.mm.rope import RopeState +from vllm.v1.worker.gpu.mm.rope import get_rope_state from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.utils import AttentionGroup @@ -52,25 +52,14 @@ def __init__( device=self.device, ) - self.rope_state: RopeState | None = None - if self.model_config.uses_mrope: - self.rope_state = RopeState( - num_dims=3, - has_delta=True, - max_num_reqs=self.max_num_reqs, - max_num_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, - device=self.device, - ) - elif self.model_config.uses_xdrope_dim > 0: - self.rope_state = RopeState( - num_dims=self.model_config.uses_xdrope_dim, - has_delta=False, - max_num_reqs=self.max_num_reqs, - max_num_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, - device=self.device, - ) + self.rope_state = get_rope_state( + self.model_config, + model, + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + ) def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: if self.rope_state is not None: From 1209251007c61897025dcd6331c3dac287e1eaea Mon Sep 17 00:00:00 2001 From: Santino Ramos Date: Fri, 13 Mar 2026 23:09:27 +0000 Subject: [PATCH 12/12] dont apply staged writes to xdrope on delta tensor Signed-off-by: Santino Ramos --- vllm/v1/worker/gpu/mm/rope.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu/mm/rope.py b/vllm/v1/worker/gpu/mm/rope.py index a598f306f1de..712f58af578f 100644 --- a/vllm/v1/worker/gpu/mm/rope.py +++ b/vllm/v1/worker/gpu/mm/rope.py @@ -84,7 +84,8 @@ def init_prefill_positions( def apply_staged_writes(self) -> None: self.prefill_positions.apply_write() - self.prefill_delta.copy_to_uva() + if self.has_delta: + self.prefill_delta.copy_to_uva() def get_positions(self, num_tokens: int) -> torch.Tensor: return self.positions[:, :num_tokens] @@ -132,7 +133,7 @@ def get_rope_state( max_model_len=max_model_len, device=device, ) - elif model_config.uses_xdrope_dim > 0: + if model_config.uses_xdrope_dim > 0: assert isinstance(model, SupportsXDRoPE) return RopeState( num_dims=model_config.uses_xdrope_dim,