diff --git a/vllm_spyre_next/tests/test_rms_norm.py b/vllm_spyre_next/tests/test_rms_norm.py index 60560b7ab..3964f2b96 100644 --- a/vllm_spyre_next/tests/test_rms_norm.py +++ b/vllm_spyre_next/tests/test_rms_norm.py @@ -4,6 +4,7 @@ import pytest import torch +import sys def reference_rms_norm( @@ -38,8 +39,8 @@ def test_spyre_rmsnorm_matches_reference( """SpyreRMSNorm output matches golden reference. Tests both paths: - - forward(): custom op dispatch (no-compile path via torch.ops.vllm.spyre_rmsnorm) - - forward_native(): direct Spyre device execution + - forward_oot(): OOT dispatch via custom op (torch.ops.vllm.spyre_rmsnorm) + - reference_rms_norm(): golden reference, similar to vLLM upstream pure PyTorch (ground truth) """ from vllm_spyre_next.custom_ops.rms_norm import SpyreRMSNorm @@ -51,7 +52,9 @@ def test_spyre_rmsnorm_matches_reference( residual = torch.randn(batch_size, hidden_size, dtype=torch.float32) if use_residual else None expected = reference_rms_norm(x, layer.weight.data, eps, residual) - actual = layer.forward_native(x, residual) + + # Test forward_oot (Spyre device execution via custom op) + actual = layer.forward_oot(x, residual) if use_residual: expected_norm, expected_resid = expected @@ -63,30 +66,18 @@ def test_spyre_rmsnorm_matches_reference( else: torch.testing.assert_close(actual.float(), expected.float(), atol=1e-2, rtol=1e-2) - actual_forward = layer.forward(x, residual) - if use_residual: - actual_fwd_norm, actual_fwd_resid = actual_forward - torch.testing.assert_close( - actual_fwd_norm.float(), expected_norm.float(), atol=1e-2, rtol=1e-2 - ) - torch.testing.assert_close( - actual_fwd_resid.float(), expected_resid.float(), atol=1e-2, rtol=1e-2 - ) - else: - torch.testing.assert_close(actual_forward.float(), expected.float(), atol=1e-2, rtol=1e-2) - @pytest.fixture def dummy_tensor(): return torch.randn(4, 128, dtype=torch.float32) -def mock_forward_native_no_residual(x, residual=None): +def mock_forward_oot(x, residual=None): """Mock: return x + 1 (no residual path).""" return x + 1 -def mock_forward_native_with_residual(x, residual=None): +def mock_forward_oot_with_residual(x, residual=None): """Mock: return (2 * x, 2 * residual) (residual path).""" return 2 * x, 2 * residual @@ -109,15 +100,21 @@ def test_rmsnorm_oot_dispatch(default_vllm_config, monkeypatch, dummy_tensor, us residual = torch.randn(4, 128, dtype=torch.float32) if use_residual else None - # Mock forward_native (called by forward_oot) with a known transform + # Mock _forward_spyre_impl (called by the custom op) with a known transform if residual is not None: - monkeypatch.setattr(layer, "forward_native", mock_forward_native_with_residual) - out_x, out_residual = layer.forward_oot(dummy_tensor, residual) + monkeypatch.setattr(layer, "_forward_spyre_impl", mock_forward_oot_with_residual) + out_x, out_residual = layer.forward(dummy_tensor, residual) assert torch.allclose(out_x, 2 * dummy_tensor) + + # The residual is modified in-place assert torch.allclose(out_residual, 2 * residual) else: - monkeypatch.setattr(layer, "forward_native", mock_forward_native_no_residual) - out_x = layer.forward_oot(dummy_tensor, residual) + monkeypatch.setattr(layer, "_forward_spyre_impl", mock_forward_oot) + out_x = layer.forward(dummy_tensor, residual) assert torch.allclose(out_x, dummy_tensor + 1) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-k", "test_rmsnorm_oot_dispatch", "-v"])) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index b17f866d7..4ba164b2d 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -8,9 +8,11 @@ Architecture: - OOT Registration: @RMSNorm.register_oot() replaces upstream at instantiation + - forward_oot(): Entry point for OOT dispatch, calls custom op for + torch.compile opacity - Custom Op Boundary: torch.ops.vllm.spyre_rmsnorm is opaque to torch.compile, - so forward_native runs eagerly outside the compiled graph - - Separate Compilation: forward_static is compiled independently via maybe_compile + so _forward_spyre_impl runs eagerly outside the compiled graph + - Separate Compilation: forward_spyre is compiled independently via maybe_compile Spyre Device Constraints: - Minimum batch size: 64 (due to spyre constraint, automatically padded) @@ -19,8 +21,8 @@ - Algorithm: Transpose-based computation with torch.ops.spyre.full() Limitations: - Currently the implementation in `_forward_vLLM_native` is similar to the - upstream implementation in `forward_static` from llm/model_executor/layers/layernorm.py, + Currently the implementation in `forward_spyre` is similar to the + upstream implementation in `forward_static` from vllm/model_executor/layers/layernorm.py, but it DOES NOT use the promotion of the data types, as this is not yet supported in torch-spyre. @@ -66,7 +68,7 @@ def __init__(self, *args, **kwargs): self._target_device = torch.device("spyre") self._target_dtype = torch.float16 - self._fwd = self.maybe_compile(self.forward_spyre) + self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) self._layer_name = register_layer(self, "spyre_rmsnorm") @@ -75,15 +77,15 @@ def __init__(self, *args, **kwargs): "expect numerical differences to upstream vLLM." ) - def forward( + def forward_oot( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Forward pass using custom op to bypass torch.compile. + """OOT forward pass using custom op to bypass torch.compile. Delegates to torch.ops.vllm.spyre_rmsnorm which retrieves this layer - from forward_context.no_compile_layers and calls forward_impl outside + from the layer registry and calls _forward_spyre_impl outside the compilation graph. This prevents torch.compile from inlining the Spyre-specific operations. @@ -95,12 +97,13 @@ def forward( Normalized output, or (output, residual) tuple if residual provided """ output = torch.empty_like(x) + residual_out = torch.empty_like(residual) if residual is not None else None # Custom op call - executes outside torch.compile graph - torch.ops.vllm.spyre_rmsnorm(x, output, self._layer_name, residual) + torch.ops.vllm.spyre_rmsnorm(x, output, self._layer_name, residual, residual_out) if residual is not None: - return output, residual + return output, residual_out return output @staticmethod @@ -145,7 +148,7 @@ def forward_spyre( else: return x, residual - def forward_native( + def _forward_spyre_impl( self, x: torch.Tensor, residual: torch.Tensor | None = None, @@ -155,7 +158,7 @@ def forward_native( Handles Spyre-specific constraints: 1. Minimum batch size: Pads to 64 if needed 2. Device transfer: CPU -> Spyre convert to float16 - 3. Kernel execution: Calls compiled _fwd + 3. Kernel execution: Calls compiled maybe_compiled_forward_spyre 4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16 Limitations: @@ -185,7 +188,7 @@ def forward_native( residual = torch.nn.functional.pad(residual, (0, 0, 0, pad_amount)) # Execute compiled kernel on Spyre device - outs = self._fwd( + outs = self.maybe_compiled_forward_spyre( convert(x, self._target_device, self._target_dtype), self.variance_epsilon, self.hidden_size, @@ -207,15 +210,16 @@ def _op_func( output: torch.Tensor, layer_name: str, residual: torch.Tensor | None = None, + residual_out: torch.Tensor | None = None, ) -> None: """Custom op implementation — runs outside torch.compile graph.""" layer = get_layer(layer_name) - result = layer.forward_native(x, residual) + result = layer._forward_spyre_impl(x, residual) if residual is not None: output_data, residual_data = result output.copy_(output_data) - residual.copy_(residual_data) + residual_out.copy_(residual_data) else: output.copy_(result) @@ -226,7 +230,7 @@ def register(): direct_register_custom_op( op_name="spyre_rmsnorm", op_func=_op_func, - mutates_args=["output"], + mutates_args=["output", "residual_out"], fake_impl=_fake_impl, ) logger.info("Registered custom op: SpyreRMSNorm") diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index 481fc33cb..2774e8dfa 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -8,9 +8,11 @@ Architecture: - OOT Registration: @SiluAndMul.register_oot() replaces upstream at instantiation + - forward_oot(): Entry point for OOT dispatch, calls custom op for + torch.compile opacity - Custom Op Boundary: torch.ops.vllm.spyre_siluandmul is opaque to torch.compile, - so forward_native runs eagerly outside the compiled graph - - Separate Compilation: forward_static is compiled independently via maybe_compile + so _forward_spyre_impl runs eagerly outside the compiled graph + - Separate Compilation: forward_spyre is compiled independently via maybe_compile Spyre Device Constraints: - Device dtype: float16 (via convert_for_spyre) @@ -60,15 +62,15 @@ def __init__(self, *args, **kwargs): self._target_device = torch.device("spyre") self._target_dtype = torch.float16 - self._fwd = self.maybe_compile(self.forward_static) + self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) self._layer_name = register_layer(self, "spyre_siluandmul") - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass using custom op to bypass torch.compile. + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: + """OOT forward pass using custom op to bypass torch.compile. Delegates to torch.ops.vllm.spyre_siluandmul which retrieves this layer - from forward_context.no_compile_layers and calls forward_impl outside + from the layer registry and calls _forward_spyre_impl outside the compilation graph. Args: @@ -86,14 +88,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output @staticmethod - def forward_static(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + def forward_spyre(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """Spyre-optimized silu+multiply kernel compiled via torch.compile. Computes silu(x1) * x2 on the Spyre device, relying on torch-spyre's registered aten::silu.out kernel. The two halves are passed in as separate tensors because the Spyre device does not yet support tensor slicing (strided views); the split is therefore performed on CPU before - this method is called (see forward_native). + this method is called (see _forward_spyre_impl). Args: x1: First half of the gated input, shape [..., d], on Spyre device @@ -106,7 +108,7 @@ def forward_static(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """ return F.silu(x1) * x2 - def forward_native(self, x: torch.Tensor) -> torch.Tensor: + def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor: """Spyre device execution: CPU slicing workaround, device transfer, kernel call. The Spyre device does not currently support strided tensor views (slicing), @@ -118,7 +120,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: 1. Slice on CPU: split x into x1 = x[..., :d] and x2 = x[..., d:] 2. Device transfer: convert x1 and x2 independently to Spyre (float16) via convert_for_spyre - 3. Kernel execution: call compiled _fwd_spyre(x1_spyre, x2_spyre) + 3. Kernel execution: call compiled maybe_compiled_forward_spyre(x1_spyre, x2_spyre) 4. Result transfer: Spyre -> original device, restore original dtype Args: @@ -135,7 +137,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 x1 = x[..., :d] x2 = x[..., d:] - out = self._fwd( + out = self.maybe_compiled_forward_spyre( convert(x1, self._target_device, self._target_dtype), convert(x2, self._target_device, self._target_dtype), ) @@ -151,7 +153,7 @@ def _op_func( ) -> None: """Custom op implementation — runs outside torch.compile graph.""" layer = get_layer(layer_name) - result = layer.forward_native(x) + result = layer._forward_spyre_impl(x) output.copy_(result)