Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 48 additions & 120 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Spyre-specific RMSNorm implementation using out-of-tree (OOT) registration.

This module provides a custom RMSNorm layer optimized for IBM's Spyre device,
This module provides a custom RMSNorm layer for IBM's Spyre device,
replacing the upstream vLLM implementation (vllm/model_executor/layers/layernorm.py)
when instantiated.

Architecture Overview:
1. OOT Registration: @RMSNorm.register_oot() replaces upstream class at instantiation
2. Custom Op Pattern: Uses torch.ops.vllm.spyre_rmsnorm to bypass torch.compile
3. Static Forward Context: Registers in compilation_config.static_forward_context
4. No-Compile Execution: Retrieved via forward_context.no_compile_layers during forward

Key Components:
- SpyreRMSNorm: Main layer class with Spyre-specific optimizations
- spyre_rmsnorm: Custom op implementation (executes outside torch.compile)
- spyre_rmsnorm_fake: Fake implementation for shape inference
- register(): Registers the custom op with vLLM
Architecture:
- OOT Registration: @RMSNorm.register_oot() replaces upstream at instantiation
- Custom Op Boundary: torch.ops.vllm.spyre_rmsnorm is opaque to torch.compile,
so forward_native runs eagerly outside the compiled graph
- Separate Compilation: forward_static is compiled independently via maybe_compile

Spyre Device Constraints:
- Minimum batch size: 64 (due to spyre constraint, automatically padded)
- Device dtype: float16 (via prepare_inputs_on_spyre)
- Device dtype: float16 (converted for CPU)
- Output dtype: bfloat16 (converted on CPU)
- Algorithm: Transpose-based computation with torch.ops.spyre.full()

Expand All @@ -39,15 +33,16 @@

from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from functools import lru_cache

from .utils import convert_for_spyre, convert_from_spyre
from .utils import convert, register_layer, get_layer, _fake_impl

logger = init_logger(__name__)

# Minimum batch size required by Spyre hardware.
_SPYRE_MIN_BATCH_SIZE = 64


@RMSNorm.register_oot(name="RMSNorm")
class SpyreRMSNorm(RMSNorm):
Expand All @@ -57,6 +52,8 @@ class SpyreRMSNorm(RMSNorm):
when instantiated, providing Spyre-specific optimizations and device handling.
"""

_dynamic_arg_dims = {"x": [], "residual": []}

def __init__(self, *args, **kwargs):
"""Initialize SpyreRMSNorm layer.

Expand All @@ -67,25 +64,17 @@ def __init__(self, *args, **kwargs):

logger.debug("Building custom RMS norm")

self._fwd_spyre = torch.compile(self.forward_spyre, dynamic=False)
self._target_device = torch.device("spyre")
self._target_dtype = torch.float16
self._fwd = self.maybe_compile(self.forward_spyre)

self._layer_name = register_layer(self, "spyre_rmsnorm")

logger.warning(
"SpyreRMSNorm: no dtype promotion is performed, \
expect numerical differences to upstream vLLM."
"SpyreRMSNorm: no dtype promotion is performed, "
"expect numerical differences to upstream vLLM."
)

# Register in static_forward_context for custom op access
# Pattern: Each instance gets unique name via counter to avoid collisions
compilation_config = get_current_vllm_config().compilation_config
if not hasattr(SpyreRMSNorm, "_instance_counter"):
SpyreRMSNorm._instance_counter = 0
self.prefix = f"spyre_rmsnorm_{SpyreRMSNorm._instance_counter}"
SpyreRMSNorm._instance_counter += 1

if self.prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {self.prefix}")
compilation_config.static_forward_context[self.prefix] = self

def forward(
self,
x: torch.Tensor,
Expand All @@ -108,38 +97,12 @@ def forward(
output = torch.empty_like(x)

# Custom op call - executes outside torch.compile graph
torch.ops.vllm.spyre_rmsnorm(x, output, self.prefix, residual)
torch.ops.vllm.spyre_rmsnorm(x, output, self._layer_name, residual)

if residual is not None:
return output, residual
return output

def forward_impl(
self,
x: torch.Tensor,
output: torch.Tensor,
residual: torch.Tensor | None = None,
) -> None:
"""Implementation called by custom op, executes outside torch.compile.

Called by spyre_rmsnorm custom op via forward_context.no_compile_layers.
Delegates to forward_native for actual computation, then copies results
to pre-allocated output tensors.

Args:
x: Input tensor
output: Pre-allocated output tensor (modified in-place)
residual: Optional residual tensor (modified in-place if provided)
"""
result = self.forward_native(x, residual)

if residual is not None:
output_data, residual_data = result
output.copy_(output_data)
residual.copy_(residual_data)
else:
output.copy_(result)

@staticmethod
def forward_spyre(
x: torch.Tensor,
Expand Down Expand Up @@ -169,8 +132,8 @@ def forward_spyre(

x = x.transpose(-1, -2).contiguous()

variance_epsilon = torch.ops.spyre.full(
x.shape, variance_epsilon, dtype=torch.float16, device="spyre"
variance_epsilon = torch.full(
x.shape, variance_epsilon, dtype=torch.float16, device=x.device
)

if variance_size_override is None:
Expand Down Expand Up @@ -206,8 +169,8 @@ def forward_native(

Handles Spyre-specific constraints:
1. Minimum batch size: Pads to 64 if needed
2. Device transfer: CPU -> Spyre (float16) via prepare_inputs_on_spyre
3. Kernel execution: Calls compiled _fwd_spyre
2. Device transfer: CPU -> Spyre convert to float16
3. Kernel execution: Calls compiled _fwd
4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16

Limitations:
Expand All @@ -230,91 +193,56 @@ def forward_native(

# Pad to minimum batch size of 64 (Spyre constraint)
# Pad at END so original data stays at indices [0:orig_batch_size]
if x.shape[0] < 64:
pad_amount = 64 - x.shape[0]
if x.shape[0] < _SPYRE_MIN_BATCH_SIZE:
pad_amount = _SPYRE_MIN_BATCH_SIZE - x.shape[0]
x = torch.nn.functional.pad(x, (0, 0, 0, pad_amount))
if residual is not None:
residual = torch.nn.functional.pad(residual, (0, 0, 0, pad_amount))

# Execute compiled kernel on Spyre device
# convert_for_spyre: CPU tensor -> Spyre device (float16)
outs = self._fwd_spyre(
convert_for_spyre(x, dtype=torch.float16),
outs = self._fwd(
convert(x, self._target_device, self._target_dtype),
self.variance_epsilon,
self.hidden_size,
convert_for_spyre(self.weight.data, dtype=torch.float16) if self.has_weight else None,
convert_for_spyre(residual, dtype=torch.float16),
convert(self.weight.data, self._target_device, self._target_dtype)
if self.has_weight
else None,
convert(residual, self._target_device, self._target_dtype),
self.variance_size_override,
)

# Transfer back to CPU and restore original shape
return pytree.tree_map(
lambda el: el[:orig_batch_size, :],
convert_from_spyre(outs, dtype=x_dtype, device=x_device),
lambda el: convert(el, dtype=x_dtype, device=x_device)[:orig_batch_size, :],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was a small bug here where convert wasn't in the transform so in the case where there is a residual, a tuple was being passed to convert which started failing with the new unit tests

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you @joerunde

outs,
)

def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""OOT forward method - delegates to forward_native."""
return self.forward_native(x, residual)


# Custom op implementation (executed outside torch.compile)
def spyre_rmsnorm(
x: torch.Tensor,
output: torch.Tensor,
layer_name: str,
residual: torch.Tensor | None = None,
) -> None:
"""Custom op implementation - retrieves layer and executes outside compilation.

Called by SpyreRMSNorm.forward() via torch.ops.vllm.spyre_rmsnorm.
Retrieves the layer instance from forward_context.no_compile_layers using
layer_name, then calls forward_impl to execute the actual computation.

This pattern prevents torch.compile from inlining Spyre-specific operations.
Similar to mamba_mixer2 (vllm/model_executor/layers/mamba/mamba_mixer2.py).

Args:
x: Input tensor
output: Pre-allocated output tensor (modified in-place)
layer_name: Unique layer identifier in static_forward_context
residual: Optional residual tensor
"""
forward_context = get_forward_context()
layer = forward_context.no_compile_layers[layer_name]
layer.forward_impl(x, output, residual)


def spyre_rmsnorm_fake(
def _op_func(
x: torch.Tensor,
output: torch.Tensor,
layer_name: str,
residual: torch.Tensor | None = None,
) -> None:
"""Fake implementation for shape/dtype inference during torch.compile.
"""Custom op implementation — runs outside torch.compile graph."""
layer = get_layer(layer_name)
result = layer.forward_native(x, residual)

Provides metadata to torch.compile without executing actual computation.
"""
return
if residual is not None:
output_data, residual_data = result
output.copy_(output_data)
residual.copy_(residual_data)
else:
output.copy_(result)


@lru_cache(maxsize=1)
def register():
"""Register the spyre_rmsnorm custom op with vLLM.

Registers torch.ops.vllm.spyre_rmsnorm with:
- op_func: Actual implementation (spyre_rmsnorm)
- fake_impl: Shape inference implementation (spyre_rmsnorm_fake)
- mutates_args: Indicates 'output' is modified in-place
"""
"""Register the spyre_rmsnorm custom op with vLLM."""
direct_register_custom_op(
op_name="spyre_rmsnorm",
op_func=spyre_rmsnorm,
op_func=_op_func,
mutates_args=["output"],
fake_impl=spyre_rmsnorm_fake,
fake_impl=_fake_impl,
)
logger.info("Registered custom op: SpyreRMSNorm")
Loading
Loading