Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import silu_and_mul
from . import vocab_parallel_embedding
from . import linear
from . import rotary_embedding
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -15,3 +16,4 @@ def register_all():
silu_and_mul.register()
vocab_parallel_embedding.register()
linear.register()
rotary_embedding.register()
325 changes: 325 additions & 0 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Spyre-specific RotaryEmbedding implementation using out-of-tree (OOT) registration.

This module provides a custom RoPE (Rotary Position Embedding) layer for IBM's Spyre device,
replacing the upstream vLLM implementation (vllm/model_executor/layers/rotary_embedding.py)
when instantiated.

Architecture:
- OOT Registration: @RotaryEmbedding.register_oot() replaces upstream at instantiation
- Custom Op Boundary: torch.ops.vllm.spyre_rotary_embedding is opaque to torch.compile,
so _forward_spyre_impl runs eagerly outside the compiled graph
- Separate Compilation: forward_spyre is compiled independently via maybe_compile

Spyre Device Constraints:
- Device dtype: float16 (converted for Spyre)
- Output dtype: matches input dtype (converted on CPU)
Comment thread
bohnstingl marked this conversation as resolved.
- Cache management: cos/sin caches stored on Spyre device

Limitations:
- No dtype promotion (torch-spyre limitation)
- rope_scaling not yet implemented
- sin/cos calculation use float32 dtype for accuracy

References:
- Upstream RotaryEmbedding: vllm/model_executor/layers/rotary_embedding.py
"""

import torch

from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from functools import lru_cache
import torch.utils._pytree as pytree

from .utils import convert, register_layer, get_layer, _fake_impl

logger = init_logger(__name__)


@RotaryEmbedding.register_oot(name="RotaryEmbedding")
class SpyreRotaryEmbedding(RotaryEmbedding):
"""Out-of-tree (OOT) RotaryEmbedding implementation for IBM's Spyre device.

This replaces the upstream vLLM RotaryEmbedding when instantiated,
providing Spyre-specific optimizations and device handling.

Implements RoPE (Rotary Position Embedding) which applies position-dependent
rotations to query and key tensors in attention mechanisms.
"""

_dynamic_arg_dims = {"positions": [], "query": [], "key": []}

def __init__(self, *args, **kwargs):
"""Initialize SpyreRotaryEmbedding layer.
Compiles the Spyre kernel and registers this instance in static_forward_context.
Builds cos/sin cache for position embeddings.
"""
super().__init__(*args, **kwargs)
# Validate supported configurations
scaling_type = getattr(self, "scaling_type", "default")
rope_parameters = getattr(self, "rope_parameters", {}) or {}

is_supported = (
scaling_type == "default"
and "mrope_section" not in rope_parameters
and ("use_fope" not in rope_parameters or not rope_parameters["use_fope"])
)

if not is_supported:
raise NotImplementedError(
f"SpyreRotaryEmbedding only supports default scaling without mrope_section or fope."
f"Got scaling_type={scaling_type}, rope_parameters={rope_parameters}"
)

logger.debug("Building custom RotaryEmbedding")
self._target_device = torch.device("spyre")
self._target_dtype = torch.float16
# Build cos/sin cache on initialization
self._build_cos_sin_cache()
# Compile the forward kernel
self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre)
self._layer_name = register_layer(self, "spyre_rotary_embedding")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recently introduced an additional logging. Thus, please include something like

logger.debug_once(
    "SpyreRotaryEmbedding: Dispatch: enabled=%s, Forward method=%s, Compiled=%s",
    self.enabled(),
    self._forward_method.__name__,
    self.maybe_compiled_forward_spyre is not self.forward_spyre,
)

def _build_cos_sin_cache(self):
"""Build cos/sin cache for position embeddings.

Creates frequency-based position encodings:
- frequencies = base^(-2i/rotary_dim) for i in [0, rotary_dim/2)
- cos_cache[pos] = cos(pos * frequencies)
- sin_cache[pos] = sin(pos * frequencies)

Note: sin/cos calculation use float32 dtype for accuracy.
Generated text differ from baseline when float16 used.
"""
logger.warning_once(
"Sin/Cos computation use float32 for accuracy.All computations are run on CPU."
)
compute_dtype = torch.float32

""" Compute inverse frequencies: base^(-2i/rotary_dim)
Using negative exponent for numerical stability"""

i = torch.arange(0, self.rotary_dim, 2, dtype=compute_dtype)
ratio = i / self.rotary_dim

freq = torch.pow(self.base, ratio)

inv_freq = 1.0 / freq

# Create position indices [0, 1, 2, ..., max_position_embeddings-1]
pos_id = torch.arange(self.max_position_embeddings, dtype=compute_dtype)

# Compute frequencies for each position: pos * inv_freq
# Shape: [max_position_embeddings, rotary_dim // 2]

freqs = pos_id.unsqueeze(1) * inv_freq.unsqueeze(0)

# Duplicate frequencies for interleaved pattern
# Shape: [max_position_embeddings, rotary_dim]
emb = torch.cat([freqs, freqs], dim=-1)

# Compute cos and sin directly in float16, then move to device
self.cos_cache = convert(emb.cos(), None, compute_dtype)
self.sin_cache = convert(emb.sin(), None, compute_dtype)

def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass using custom op to bypass torch.compile.

Delegates to torch.ops.vllm.spyre_rotary_embedding which retrieves this layer
from forward_context.no_compile_layers and calls forward_impl outside
the compilation graph.

Args:
positions: Position indices [batch_size, seq_len] or [total_tokens]
query: Query tensor [batch_size, seq_len, num_heads, head_dim]
key: Key tensor [batch_size, seq_len, num_kv_heads, head_dim]

Returns:
Tuple of (rotated_query, rotated_key) with same shapes as inputs
"""
rotated_query = torch.empty_like(query)
rotated_key = torch.empty_like(key)

# Custom op call - executes outside torch.compile graph
torch.ops.vllm.spyre_rotary_embedding(
positions, query, key, rotated_query, rotated_key, self._layer_name
)

return rotated_query, rotated_key

@staticmethod
def forward_spyre(
query: torch.Tensor,
query_half: torch.Tensor,
key: torch.Tensor,
key_half: torch.Tensor,
cos_q: torch.Tensor,
sin_q: torch.Tensor,
cos_k: torch.Tensor,
sin_k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Spyre-optimized RoPE computation (active implementation).

Applies rotary position embeddings to query and key tensors:
output = x * cos + rotate_half(x) * sin

Args:
positions: Position indices [batch_size, seq_len] or [total_tokens]
query: Query tensor [batch_size, seq_len, num_heads, head_dim]
key: Key tensor [batch_size, seq_len, num_kv_heads, head_dim]
cos_cache: Cosine cache [max_position_embeddings, rotary_dim]
sin_cache: Sine cache [max_position_embeddings, rotary_dim]
rotary_dim: Dimension to apply rotation

Returns:
Tuple of (rotated_query, rotated_key)
"""

query_out = query * cos_q + query_half * sin_q
key_out = key * cos_k + key_half * sin_k

return query_out, key_out

def _forward_spyre_impl(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Spyre device execution with device transfer and dtype conversion.

Handles Spyre-specific constraints:
1. Device transfer: CPU -> Spyre, convert to float16
2. Kernel execution: Calls compiled _fwd
3. Result transfer: Spyre -> CPU, restore original dtype

Args:
positions: Position indices on CPU
query: Query tensor on CPU
key: Key tensor on CPU

Returns:
Tuple of (rotated_query, rotated_key) on CPU with original dtype
"""
query_dtype = query.dtype
query_device = query.device
key_dtype = key.dtype
key_device = key.device

# Execute compiled kernel on Spyre device

positions = convert(positions, None, torch.int64)

Tq, q_hidden = query.shape
Tk, k_hidden = key.shape

assert Tq == Tk, f"Query/Key sequence mismatch: {Tq} != {Tk}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T = Tq

q_heads = q_hidden // self.head_size
k_heads = k_hidden // self.head_size

query = query.reshape(T, q_heads, self.head_size)
key = key.reshape(T, k_heads, self.head_size)

# get cos/sin
cos = self.cos_cache[positions] # [T, D]
sin = self.sin_cache[positions]

def expand_cos_sin(cos, sin, num_heads):
cos = cos.unsqueeze(1).expand(-1, num_heads, -1)
sin = sin.unsqueeze(1).expand(-1, num_heads, -1)
return cos.contiguous(), sin.contiguous()

# expand to heads
cos_q, sin_q = expand_cos_sin(cos, sin, q_heads)
cos_k, sin_k = expand_cos_sin(cos, sin, k_heads)

query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]

key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]

d = self.rotary_dim // 2

q1 = query_rot[..., :d]
q2 = query_rot[..., d:]
query_half = torch.cat([-q2, q1], dim=-1)

k1 = key_rot[..., :d]
k2 = key_rot[..., d:]
key_half = torch.cat([-k2, k1], dim=-1)

assert cos_q.shape == query_rot.shape, f"{cos_q.shape} != {query.shape}"
assert sin_q.shape == query_rot.shape
Comment on lines +262 to +263
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have more descriptive error messages here?


query_spyre = convert(query_rot, self._target_device, self._target_dtype)
query_half_spyre = convert(query_half, self._target_device, self._target_dtype)

key_spyre = convert(key_rot, self._target_device, self._target_dtype)
key_half_spyre = convert(key_half, self._target_device, self._target_dtype)

cos_q_spyre = convert(cos_q, self._target_device, self._target_dtype)
sin_q_spyre = convert(sin_q, self._target_device, self._target_dtype)

cos_k_spyre = convert(cos_k, self._target_device, self._target_dtype)
sin_k_spyre = convert(sin_k, self._target_device, self._target_dtype)

rotated_query, rotated_key = self.maybe_compiled_forward_spyre(
query_spyre,
query_half_spyre,
key_spyre,
key_half_spyre,
cos_q_spyre,
sin_q_spyre,
cos_k_spyre,
sin_k_spyre,
)

# Transfer back to CPU and restore original dtype
rotated_query = convert(rotated_query, query_device, query_dtype)
rotated_key = convert(rotated_key, key_device, key_dtype)

rotated_query = torch.cat([rotated_query, query_pass], dim=-1)
rotated_key = torch.cat([rotated_key, key_pass], dim=-1)

rotated_query = rotated_query.reshape(T, -1)
rotated_key = rotated_key.reshape(T, -1)

return rotated_query, rotated_key


def _op_func(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotated_query: torch.Tensor,
rotated_key: torch.Tensor,
layer_name: str,
) -> None:
"""Custom op implementation — runs outside torch.compile graph."""
layer = get_layer(layer_name)
result = list(layer._forward_spyre_impl(positions, query, key))
outputs = [rotated_query, rotated_key]
pytree.tree_map(lambda out, res: out.copy_(res), outputs, result)


@lru_cache(maxsize=1)
def register():
"""Register the spyre_rotary_embedding custom op with vLLM."""
direct_register_custom_op(
op_name="spyre_rotary_embedding",
op_func=_op_func,
mutates_args=["rotated_query", "rotated_key"],
fake_impl=_fake_impl,
)
logger.info("Registered custom op: SpyreRotaryEmbedding")
Loading