-
Notifications
You must be signed in to change notification settings - Fork 54
[Spyre-Next] [Feature] Wrap RoPE layer on Spyre #881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
11b7807
3311554
12ed4c9
494ee99
fe46602
8cbb942
f720b63
b29a6f7
8e585df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,325 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Spyre-specific RotaryEmbedding implementation using out-of-tree (OOT) registration. | ||
|
|
||
| This module provides a custom RoPE (Rotary Position Embedding) layer for IBM's Spyre device, | ||
| replacing the upstream vLLM implementation (vllm/model_executor/layers/rotary_embedding.py) | ||
| when instantiated. | ||
|
|
||
| Architecture: | ||
| - OOT Registration: @RotaryEmbedding.register_oot() replaces upstream at instantiation | ||
| - Custom Op Boundary: torch.ops.vllm.spyre_rotary_embedding is opaque to torch.compile, | ||
| so _forward_spyre_impl runs eagerly outside the compiled graph | ||
| - Separate Compilation: forward_spyre is compiled independently via maybe_compile | ||
|
|
||
| Spyre Device Constraints: | ||
| - Device dtype: float16 (converted for Spyre) | ||
| - Output dtype: matches input dtype (converted on CPU) | ||
| - Cache management: cos/sin caches stored on Spyre device | ||
|
|
||
| Limitations: | ||
| - No dtype promotion (torch-spyre limitation) | ||
| - rope_scaling not yet implemented | ||
| - sin/cos calculation use float32 dtype for accuracy | ||
|
|
||
| References: | ||
| - Upstream RotaryEmbedding: vllm/model_executor/layers/rotary_embedding.py | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
| from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding | ||
| from functools import lru_cache | ||
| import torch.utils._pytree as pytree | ||
|
|
||
| from .utils import convert, register_layer, get_layer, _fake_impl | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| @RotaryEmbedding.register_oot(name="RotaryEmbedding") | ||
| class SpyreRotaryEmbedding(RotaryEmbedding): | ||
| """Out-of-tree (OOT) RotaryEmbedding implementation for IBM's Spyre device. | ||
|
|
||
| This replaces the upstream vLLM RotaryEmbedding when instantiated, | ||
| providing Spyre-specific optimizations and device handling. | ||
|
|
||
| Implements RoPE (Rotary Position Embedding) which applies position-dependent | ||
| rotations to query and key tensors in attention mechanisms. | ||
| """ | ||
|
|
||
| _dynamic_arg_dims = {"positions": [], "query": [], "key": []} | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| """Initialize SpyreRotaryEmbedding layer. | ||
| Compiles the Spyre kernel and registers this instance in static_forward_context. | ||
| Builds cos/sin cache for position embeddings. | ||
| """ | ||
| super().__init__(*args, **kwargs) | ||
| # Validate supported configurations | ||
| scaling_type = getattr(self, "scaling_type", "default") | ||
| rope_parameters = getattr(self, "rope_parameters", {}) or {} | ||
|
|
||
| is_supported = ( | ||
| scaling_type == "default" | ||
| and "mrope_section" not in rope_parameters | ||
| and ("use_fope" not in rope_parameters or not rope_parameters["use_fope"]) | ||
| ) | ||
|
|
||
| if not is_supported: | ||
| raise NotImplementedError( | ||
| f"SpyreRotaryEmbedding only supports default scaling without mrope_section or fope." | ||
| f"Got scaling_type={scaling_type}, rope_parameters={rope_parameters}" | ||
| ) | ||
|
|
||
| logger.debug("Building custom RotaryEmbedding") | ||
| self._target_device = torch.device("spyre") | ||
| self._target_dtype = torch.float16 | ||
| # Build cos/sin cache on initialization | ||
| self._build_cos_sin_cache() | ||
| # Compile the forward kernel | ||
| self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) | ||
| self._layer_name = register_layer(self, "spyre_rotary_embedding") | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We recently introduced an additional logging. Thus, please include something like |
||
| def _build_cos_sin_cache(self): | ||
| """Build cos/sin cache for position embeddings. | ||
|
|
||
| Creates frequency-based position encodings: | ||
| - frequencies = base^(-2i/rotary_dim) for i in [0, rotary_dim/2) | ||
| - cos_cache[pos] = cos(pos * frequencies) | ||
| - sin_cache[pos] = sin(pos * frequencies) | ||
|
|
||
| Note: sin/cos calculation use float32 dtype for accuracy. | ||
| Generated text differ from baseline when float16 used. | ||
| """ | ||
| logger.warning_once( | ||
| "Sin/Cos computation use float32 for accuracy.All computations are run on CPU." | ||
| ) | ||
| compute_dtype = torch.float32 | ||
|
|
||
| """ Compute inverse frequencies: base^(-2i/rotary_dim) | ||
| Using negative exponent for numerical stability""" | ||
|
|
||
| i = torch.arange(0, self.rotary_dim, 2, dtype=compute_dtype) | ||
| ratio = i / self.rotary_dim | ||
|
|
||
| freq = torch.pow(self.base, ratio) | ||
|
|
||
| inv_freq = 1.0 / freq | ||
|
|
||
| # Create position indices [0, 1, 2, ..., max_position_embeddings-1] | ||
| pos_id = torch.arange(self.max_position_embeddings, dtype=compute_dtype) | ||
|
|
||
| # Compute frequencies for each position: pos * inv_freq | ||
| # Shape: [max_position_embeddings, rotary_dim // 2] | ||
|
|
||
| freqs = pos_id.unsqueeze(1) * inv_freq.unsqueeze(0) | ||
|
|
||
| # Duplicate frequencies for interleaved pattern | ||
| # Shape: [max_position_embeddings, rotary_dim] | ||
| emb = torch.cat([freqs, freqs], dim=-1) | ||
|
|
||
| # Compute cos and sin directly in float16, then move to device | ||
| self.cos_cache = convert(emb.cos(), None, compute_dtype) | ||
| self.sin_cache = convert(emb.sin(), None, compute_dtype) | ||
|
|
||
| def forward_oot( | ||
| self, | ||
| positions: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Forward pass using custom op to bypass torch.compile. | ||
|
|
||
| Delegates to torch.ops.vllm.spyre_rotary_embedding which retrieves this layer | ||
| from forward_context.no_compile_layers and calls forward_impl outside | ||
| the compilation graph. | ||
|
|
||
| Args: | ||
| positions: Position indices [batch_size, seq_len] or [total_tokens] | ||
| query: Query tensor [batch_size, seq_len, num_heads, head_dim] | ||
| key: Key tensor [batch_size, seq_len, num_kv_heads, head_dim] | ||
|
|
||
| Returns: | ||
| Tuple of (rotated_query, rotated_key) with same shapes as inputs | ||
| """ | ||
| rotated_query = torch.empty_like(query) | ||
| rotated_key = torch.empty_like(key) | ||
|
|
||
| # Custom op call - executes outside torch.compile graph | ||
| torch.ops.vllm.spyre_rotary_embedding( | ||
| positions, query, key, rotated_query, rotated_key, self._layer_name | ||
| ) | ||
|
|
||
| return rotated_query, rotated_key | ||
|
|
||
| @staticmethod | ||
| def forward_spyre( | ||
| query: torch.Tensor, | ||
| query_half: torch.Tensor, | ||
| key: torch.Tensor, | ||
| key_half: torch.Tensor, | ||
| cos_q: torch.Tensor, | ||
| sin_q: torch.Tensor, | ||
| cos_k: torch.Tensor, | ||
| sin_k: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Spyre-optimized RoPE computation (active implementation). | ||
|
|
||
| Applies rotary position embeddings to query and key tensors: | ||
| output = x * cos + rotate_half(x) * sin | ||
|
|
||
| Args: | ||
| positions: Position indices [batch_size, seq_len] or [total_tokens] | ||
| query: Query tensor [batch_size, seq_len, num_heads, head_dim] | ||
| key: Key tensor [batch_size, seq_len, num_kv_heads, head_dim] | ||
| cos_cache: Cosine cache [max_position_embeddings, rotary_dim] | ||
| sin_cache: Sine cache [max_position_embeddings, rotary_dim] | ||
| rotary_dim: Dimension to apply rotation | ||
|
|
||
| Returns: | ||
| Tuple of (rotated_query, rotated_key) | ||
| """ | ||
|
|
||
| query_out = query * cos_q + query_half * sin_q | ||
| key_out = key * cos_k + key_half * sin_k | ||
|
|
||
| return query_out, key_out | ||
|
|
||
| def _forward_spyre_impl( | ||
| self, | ||
| positions: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Spyre device execution with device transfer and dtype conversion. | ||
|
|
||
| Handles Spyre-specific constraints: | ||
| 1. Device transfer: CPU -> Spyre, convert to float16 | ||
| 2. Kernel execution: Calls compiled _fwd | ||
| 3. Result transfer: Spyre -> CPU, restore original dtype | ||
|
|
||
| Args: | ||
| positions: Position indices on CPU | ||
| query: Query tensor on CPU | ||
| key: Key tensor on CPU | ||
|
|
||
| Returns: | ||
| Tuple of (rotated_query, rotated_key) on CPU with original dtype | ||
| """ | ||
| query_dtype = query.dtype | ||
| query_device = query.device | ||
| key_dtype = key.dtype | ||
| key_device = key.device | ||
|
|
||
| # Execute compiled kernel on Spyre device | ||
|
|
||
| positions = convert(positions, None, torch.int64) | ||
|
|
||
| Tq, q_hidden = query.shape | ||
| Tk, k_hidden = key.shape | ||
|
|
||
| assert Tq == Tk, f"Query/Key sequence mismatch: {Tq} != {Tk}" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have this constraint? Is this also part of the upstream implementation? In upstream I see https://github.com/vllm-project/vllm/blob/a5b17fba8ff3fcc076d73ba749a0819e0ec25f06/vllm/model_executor/layers/rotary_embedding/base.py#L140-L180 |
||
| T = Tq | ||
|
|
||
| q_heads = q_hidden // self.head_size | ||
| k_heads = k_hidden // self.head_size | ||
|
|
||
| query = query.reshape(T, q_heads, self.head_size) | ||
| key = key.reshape(T, k_heads, self.head_size) | ||
|
|
||
| # get cos/sin | ||
| cos = self.cos_cache[positions] # [T, D] | ||
| sin = self.sin_cache[positions] | ||
|
|
||
| def expand_cos_sin(cos, sin, num_heads): | ||
| cos = cos.unsqueeze(1).expand(-1, num_heads, -1) | ||
| sin = sin.unsqueeze(1).expand(-1, num_heads, -1) | ||
| return cos.contiguous(), sin.contiguous() | ||
|
|
||
| # expand to heads | ||
| cos_q, sin_q = expand_cos_sin(cos, sin, q_heads) | ||
| cos_k, sin_k = expand_cos_sin(cos, sin, k_heads) | ||
|
|
||
| query_rot = query[..., : self.rotary_dim] | ||
| query_pass = query[..., self.rotary_dim :] | ||
|
|
||
| key_rot = key[..., : self.rotary_dim] | ||
| key_pass = key[..., self.rotary_dim :] | ||
|
|
||
| d = self.rotary_dim // 2 | ||
|
|
||
| q1 = query_rot[..., :d] | ||
| q2 = query_rot[..., d:] | ||
| query_half = torch.cat([-q2, q1], dim=-1) | ||
|
|
||
| k1 = key_rot[..., :d] | ||
| k2 = key_rot[..., d:] | ||
| key_half = torch.cat([-k2, k1], dim=-1) | ||
|
|
||
| assert cos_q.shape == query_rot.shape, f"{cos_q.shape} != {query.shape}" | ||
| assert sin_q.shape == query_rot.shape | ||
|
Comment on lines
+262
to
+263
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have more descriptive error messages here? |
||
|
|
||
| query_spyre = convert(query_rot, self._target_device, self._target_dtype) | ||
| query_half_spyre = convert(query_half, self._target_device, self._target_dtype) | ||
|
|
||
| key_spyre = convert(key_rot, self._target_device, self._target_dtype) | ||
| key_half_spyre = convert(key_half, self._target_device, self._target_dtype) | ||
|
|
||
| cos_q_spyre = convert(cos_q, self._target_device, self._target_dtype) | ||
| sin_q_spyre = convert(sin_q, self._target_device, self._target_dtype) | ||
|
|
||
| cos_k_spyre = convert(cos_k, self._target_device, self._target_dtype) | ||
| sin_k_spyre = convert(sin_k, self._target_device, self._target_dtype) | ||
|
|
||
| rotated_query, rotated_key = self.maybe_compiled_forward_spyre( | ||
| query_spyre, | ||
| query_half_spyre, | ||
| key_spyre, | ||
| key_half_spyre, | ||
| cos_q_spyre, | ||
| sin_q_spyre, | ||
| cos_k_spyre, | ||
| sin_k_spyre, | ||
| ) | ||
|
|
||
| # Transfer back to CPU and restore original dtype | ||
| rotated_query = convert(rotated_query, query_device, query_dtype) | ||
| rotated_key = convert(rotated_key, key_device, key_dtype) | ||
|
|
||
| rotated_query = torch.cat([rotated_query, query_pass], dim=-1) | ||
| rotated_key = torch.cat([rotated_key, key_pass], dim=-1) | ||
|
|
||
| rotated_query = rotated_query.reshape(T, -1) | ||
| rotated_key = rotated_key.reshape(T, -1) | ||
|
|
||
| return rotated_query, rotated_key | ||
|
|
||
|
|
||
| def _op_func( | ||
| positions: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| rotated_query: torch.Tensor, | ||
| rotated_key: torch.Tensor, | ||
| layer_name: str, | ||
| ) -> None: | ||
| """Custom op implementation — runs outside torch.compile graph.""" | ||
| layer = get_layer(layer_name) | ||
| result = list(layer._forward_spyre_impl(positions, query, key)) | ||
| outputs = [rotated_query, rotated_key] | ||
| pytree.tree_map(lambda out, res: out.copy_(res), outputs, result) | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
| def register(): | ||
| """Register the spyre_rotary_embedding custom op with vLLM.""" | ||
| direct_register_custom_op( | ||
| op_name="spyre_rotary_embedding", | ||
| op_func=_op_func, | ||
| mutates_args=["rotated_query", "rotated_key"], | ||
| fake_impl=_fake_impl, | ||
| ) | ||
| logger.info("Registered custom op: SpyreRotaryEmbedding") | ||
Uh oh!
There was an error while loading. Please reload this page.