Skip to content
Merged
41 changes: 19 additions & 22 deletions vllm_spyre_next/tests/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
import sys


def reference_rms_norm(
Expand Down Expand Up @@ -38,8 +39,8 @@ def test_spyre_rmsnorm_matches_reference(
"""SpyreRMSNorm output matches golden reference.

Tests both paths:
- forward(): custom op dispatch (no-compile path via torch.ops.vllm.spyre_rmsnorm)
- forward_native(): direct Spyre device execution
- forward_oot(): OOT dispatch via custom op (torch.ops.vllm.spyre_rmsnorm)
- reference_rms_norm(): golden reference, similar to vLLM upstream pure PyTorch (ground truth)
"""
from vllm_spyre_next.custom_ops.rms_norm import SpyreRMSNorm

Expand All @@ -51,7 +52,9 @@ def test_spyre_rmsnorm_matches_reference(
residual = torch.randn(batch_size, hidden_size, dtype=torch.float32) if use_residual else None

expected = reference_rms_norm(x, layer.weight.data, eps, residual)
actual = layer.forward_native(x, residual)

# Test forward_oot (Spyre device execution via custom op)
actual = layer.forward_oot(x, residual)

if use_residual:
expected_norm, expected_resid = expected
Expand All @@ -63,30 +66,18 @@ def test_spyre_rmsnorm_matches_reference(
else:
torch.testing.assert_close(actual.float(), expected.float(), atol=1e-2, rtol=1e-2)

actual_forward = layer.forward(x, residual)
if use_residual:
actual_fwd_norm, actual_fwd_resid = actual_forward
torch.testing.assert_close(
actual_fwd_norm.float(), expected_norm.float(), atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(
actual_fwd_resid.float(), expected_resid.float(), atol=1e-2, rtol=1e-2
)
else:
torch.testing.assert_close(actual_forward.float(), expected.float(), atol=1e-2, rtol=1e-2)


@pytest.fixture
def dummy_tensor():
return torch.randn(4, 128, dtype=torch.float32)


def mock_forward_native_no_residual(x, residual=None):
def mock_forward_oot(x, residual=None):
"""Mock: return x + 1 (no residual path)."""
return x + 1


def mock_forward_native_with_residual(x, residual=None):
def mock_forward_oot_with_residual(x, residual=None):
"""Mock: return (2 * x, 2 * residual) (residual path)."""
return 2 * x, 2 * residual

Expand All @@ -109,15 +100,21 @@ def test_rmsnorm_oot_dispatch(default_vllm_config, monkeypatch, dummy_tensor, us

residual = torch.randn(4, 128, dtype=torch.float32) if use_residual else None

# Mock forward_native (called by forward_oot) with a known transform
# Mock _forward_spyre_impl (called by the custom op) with a known transform
if residual is not None:
monkeypatch.setattr(layer, "forward_native", mock_forward_native_with_residual)
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
monkeypatch.setattr(layer, "_forward_spyre_impl", mock_forward_oot_with_residual)
out_x, out_residual = layer.forward(dummy_tensor, residual)

assert torch.allclose(out_x, 2 * dummy_tensor)

# The residual is modified in-place
assert torch.allclose(out_residual, 2 * residual)
else:
monkeypatch.setattr(layer, "forward_native", mock_forward_native_no_residual)
out_x = layer.forward_oot(dummy_tensor, residual)
monkeypatch.setattr(layer, "_forward_spyre_impl", mock_forward_oot)
out_x = layer.forward(dummy_tensor, residual)

assert torch.allclose(out_x, dummy_tensor + 1)


if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-k", "test_rmsnorm_oot_dispatch", "-v"]))
36 changes: 20 additions & 16 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

Architecture:
- OOT Registration: @RMSNorm.register_oot() replaces upstream at instantiation
- forward_oot(): Entry point for OOT dispatch, calls custom op for
torch.compile opacity
- Custom Op Boundary: torch.ops.vllm.spyre_rmsnorm is opaque to torch.compile,
so forward_native runs eagerly outside the compiled graph
- Separate Compilation: forward_static is compiled independently via maybe_compile
so _forward_spyre_impl runs eagerly outside the compiled graph
- Separate Compilation: forward_spyre is compiled independently via maybe_compile

Spyre Device Constraints:
- Minimum batch size: 64 (due to spyre constraint, automatically padded)
Expand All @@ -19,8 +21,8 @@
- Algorithm: Transpose-based computation with torch.ops.spyre.full()

Limitations:
Currently the implementation in `_forward_vLLM_native` is similar to the
upstream implementation in `forward_static` from llm/model_executor/layers/layernorm.py,
Currently the implementation in `forward_spyre` is similar to the
upstream implementation in `forward_static` from vllm/model_executor/layers/layernorm.py,
but it DOES NOT use the promotion of the data types, as this is not
yet supported in torch-spyre.

Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(self, *args, **kwargs):

self._target_device = torch.device("spyre")
self._target_dtype = torch.float16
self._fwd = self.maybe_compile(self.forward_spyre)
self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre)

self._layer_name = register_layer(self, "spyre_rmsnorm")

Expand All @@ -75,15 +77,15 @@ def __init__(self, *args, **kwargs):
"expect numerical differences to upstream vLLM."
)

def forward(
def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Forward pass using custom op to bypass torch.compile.
"""OOT forward pass using custom op to bypass torch.compile.

Delegates to torch.ops.vllm.spyre_rmsnorm which retrieves this layer
from forward_context.no_compile_layers and calls forward_impl outside
from the layer registry and calls _forward_spyre_impl outside
the compilation graph. This prevents torch.compile from inlining the
Spyre-specific operations.

Expand All @@ -95,12 +97,13 @@ def forward(
Normalized output, or (output, residual) tuple if residual provided
"""
output = torch.empty_like(x)
residual_out = torch.empty_like(residual) if residual is not None else None

# Custom op call - executes outside torch.compile graph
torch.ops.vllm.spyre_rmsnorm(x, output, self._layer_name, residual)
torch.ops.vllm.spyre_rmsnorm(x, output, self._layer_name, residual, residual_out)

if residual is not None:
return output, residual
return output, residual_out
return output

@staticmethod
Expand Down Expand Up @@ -145,7 +148,7 @@ def forward_spyre(
else:
return x, residual

def forward_native(
def _forward_spyre_impl(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
Expand All @@ -155,7 +158,7 @@ def forward_native(
Handles Spyre-specific constraints:
1. Minimum batch size: Pads to 64 if needed
2. Device transfer: CPU -> Spyre convert to float16
3. Kernel execution: Calls compiled _fwd
3. Kernel execution: Calls compiled maybe_compiled_forward_spyre
4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16

Limitations:
Expand Down Expand Up @@ -185,7 +188,7 @@ def forward_native(
residual = torch.nn.functional.pad(residual, (0, 0, 0, pad_amount))

# Execute compiled kernel on Spyre device
outs = self._fwd(
outs = self.maybe_compiled_forward_spyre(
convert(x, self._target_device, self._target_dtype),
self.variance_epsilon,
self.hidden_size,
Expand All @@ -207,15 +210,16 @@ def _op_func(
output: torch.Tensor,
layer_name: str,
residual: torch.Tensor | None = None,
residual_out: torch.Tensor | None = None,
) -> None:
"""Custom op implementation — runs outside torch.compile graph."""
layer = get_layer(layer_name)
result = layer.forward_native(x, residual)
result = layer._forward_spyre_impl(x, residual)

if residual is not None:
output_data, residual_data = result
output.copy_(output_data)
residual.copy_(residual_data)
residual_out.copy_(residual_data)
else:
output.copy_(result)

Expand All @@ -226,7 +230,7 @@ def register():
direct_register_custom_op(
op_name="spyre_rmsnorm",
op_func=_op_func,
mutates_args=["output"],
mutates_args=["output", "residual_out"],
fake_impl=_fake_impl,
)
logger.info("Registered custom op: SpyreRMSNorm")
26 changes: 14 additions & 12 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

Architecture:
- OOT Registration: @SiluAndMul.register_oot() replaces upstream at instantiation
- forward_oot(): Entry point for OOT dispatch, calls custom op for
torch.compile opacity
- Custom Op Boundary: torch.ops.vllm.spyre_siluandmul is opaque to torch.compile,
so forward_native runs eagerly outside the compiled graph
- Separate Compilation: forward_static is compiled independently via maybe_compile
so _forward_spyre_impl runs eagerly outside the compiled graph
- Separate Compilation: forward_spyre is compiled independently via maybe_compile

Spyre Device Constraints:
- Device dtype: float16 (via convert_for_spyre)
Expand Down Expand Up @@ -60,15 +62,15 @@ def __init__(self, *args, **kwargs):

self._target_device = torch.device("spyre")
self._target_dtype = torch.float16
self._fwd = self.maybe_compile(self.forward_static)
self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre)

self._layer_name = register_layer(self, "spyre_siluandmul")

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass using custom op to bypass torch.compile.
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
"""OOT forward pass using custom op to bypass torch.compile.

Delegates to torch.ops.vllm.spyre_siluandmul which retrieves this layer
from forward_context.no_compile_layers and calls forward_impl outside
from the layer registry and calls _forward_spyre_impl outside
the compilation graph.

Args:
Expand All @@ -86,14 +88,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return output

@staticmethod
def forward_static(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
def forward_spyre(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""Spyre-optimized silu+multiply kernel compiled via torch.compile.

Computes silu(x1) * x2 on the Spyre device, relying on torch-spyre's
registered aten::silu.out kernel. The two halves are passed in as
separate tensors because the Spyre device does not yet support tensor
slicing (strided views); the split is therefore performed on CPU before
this method is called (see forward_native).
this method is called (see _forward_spyre_impl).

Args:
x1: First half of the gated input, shape [..., d], on Spyre device
Expand All @@ -106,7 +108,7 @@ def forward_static(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
return F.silu(x1) * x2

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor:
"""Spyre device execution: CPU slicing workaround, device transfer, kernel call.

The Spyre device does not currently support strided tensor views (slicing),
Expand All @@ -118,7 +120,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
1. Slice on CPU: split x into x1 = x[..., :d] and x2 = x[..., d:]
2. Device transfer: convert x1 and x2 independently to Spyre (float16)
via convert_for_spyre
3. Kernel execution: call compiled _fwd_spyre(x1_spyre, x2_spyre)
3. Kernel execution: call compiled maybe_compiled_forward_spyre(x1_spyre, x2_spyre)
4. Result transfer: Spyre -> original device, restore original dtype

Args:
Expand All @@ -135,7 +137,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
out = self._fwd(
out = self.maybe_compiled_forward_spyre(
convert(x1, self._target_device, self._target_dtype),
convert(x2, self._target_device, self._target_dtype),
)
Expand All @@ -151,7 +153,7 @@ def _op_func(
) -> None:
"""Custom op implementation — runs outside torch.compile graph."""
layer = get_layer(layer_name)
result = layer.forward_native(x)
result = layer._forward_spyre_impl(x)
output.copy_(result)


Expand Down
Loading