diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index ffaa8a63b..e44804fc7 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -8,19 +8,17 @@ Architecture: - OOT Registration: @SiluAndMul.register_oot() replaces upstream at instantiation - - forward_oot(): Entry point for OOT dispatch, calls custom op for - torch.compile opacity - - Custom Op Boundary: torch.ops.vllm.spyre_siluandmul is opaque to torch.compile, - so _forward_spyre_impl runs eagerly outside the compiled graph - - Separate Compilation: forward_spyre is compiled independently via maybe_compile + - forward_oot(): Entry point for OOT dispatch, fully transparent to the outer + torch.compile graph (no opaque custom-op boundary). + - Halves use .contiguous() to ensure zero storage offset before device transfer. + - convert() utility handles device/dtype transfers efficiently. Spyre Device Constraints: - - Computations performed in torch.float16: - Input (dtype defined by model / user) converted to torch.float16 for - operations on spyre and then converted back to original dtype for cpu. + - Splitting (aten.slice.Tensor) inside compiled Spyre graphs is unsupported — + non-zero storage offsets are rejected by the Flex backend. + - The .contiguous() call ensures zero storage offset before transfer to Spyre. Output Shape Note: - Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension: input shape: [..., 2*d] -> output shape: [..., d] References: @@ -31,11 +29,9 @@ import torch.nn.functional as F from vllm.logger import init_logger -from vllm.utils.torch_utils import direct_register_custom_op from vllm.model_executor.layers.activation import SiluAndMul -from functools import lru_cache -from .utils import convert, register_layer, get_layer, _fake_impl +from .utils import convert logger = init_logger(__name__) @@ -44,134 +40,61 @@ class SpyreSiluAndMul(SiluAndMul): """Out-of-tree (OOT) SiluAndMul implementation for IBM's Spyre device. - This replaces the upstream vLLM SiluAndMul (vllm/model_executor/layers/activation.py) - when instantiated, providing Spyre-specific optimizations and device handling. - Computes: x -> silu(x[..., :d]) * x[..., d:] where d = x.shape[-1] // 2 + + Fully transparent to the outer torch.compile graph — no opaque custom-op + boundary. Uses .contiguous() to ensure zero storage offset before device + transfer, as Spyre's Flex backend rejects non-zero offsets + (aten.slice.Tensor unsupported in compiled Spyre graphs). """ - _dynamic_arg_dims = {"x1": [], "x2": []} + _dynamic_arg_dims = {"x": []} def __init__(self, *args, **kwargs): """Initialize SpyreSiluAndMul layer. - Compiles the Spyre kernel and registers this instance in static_forward_context. + Sets up the target device (Spyre) and dtype (float16) for computation. + The simplified implementation computes directly in forward_oot() without + requiring layer registry or custom op registration. """ super().__init__(*args, **kwargs) - logger.debug("Building custom SiluAndMul") - self._target_device = torch.device("spyre") self._target_dtype = torch.float16 - self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) - - self._layer_name = register_layer(self, "spyre_siluandmul") - - logger.debug_once( - "SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", - self.enabled(), - self._forward_method.__name__, - self.maybe_compiled_forward_spyre is not self.forward_spyre, - ) def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - """OOT forward pass using custom op to bypass torch.compile. - - Delegates to torch.ops.vllm.spyre_siluandmul which retrieves this layer - from the layer registry and calls _forward_spyre_impl outside - the compilation graph. - - Args: - x: Input tensor [..., 2*d] - - Returns: - Activated output tensor [..., d] - """ - d = x.shape[-1] // 2 - output = torch.empty(x.shape[:-1] + (d,), dtype=x.dtype, device=x.device) - - # Custom op call - executes outside torch.compile graph - torch.ops.vllm.spyre_siluandmul(x, output, self._layer_name) - - return output + """Spyre-optimized SiLU and multiply activation (SwiGLU). - @staticmethod - def forward_spyre(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """Spyre-optimized silu+multiply kernel compiled via torch.compile. + Computes silu(x[..., :d]) * x[..., d:] where d = x.shape[-1] // 2. + The input tensor is split into two halves, with .contiguous() ensuring + zero storage offset (Spyre's Flex backend rejects non-zero offsets). - Computes silu(x1) * x2 on the Spyre device, relying on torch-spyre's - registered aten::silu.out kernel. The two halves are passed in as - separate tensors because the Spyre device does not yet support tensor - slicing (strided views); the split is therefore performed on CPU before - this method is called (see _forward_spyre_impl). + The convert() utility handles device/dtype transfers efficiently. Args: - x1: First half of the gated input, shape [..., d], on Spyre device - (float16). silu is applied to this half. - x2: Second half of the gated input, shape [..., d], on Spyre device - (float16). Acts as the multiplicative gate. + x: Input tensor of shape [..., 2*d] containing concatenated gate halves. Returns: - Output tensor of shape [..., d] on the Spyre device (float16). - """ - return F.silu(x1) * x2 - - def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor: - """Spyre device execution: CPU slicing workaround, device transfer, kernel call. - - The Spyre device does not currently support strided tensor views (slicing), - so the input is split into its two halves on the CPU before being - transferred to the device. Once tensor slicing is supported this method - should revert to the simpler single-tensor path (see commented-out block). - - Execution steps: - 1. Slice on CPU: split x into x1 = x[..., :d] and x2 = x[..., d:] - 2. Device transfer: convert x1 and x2 independently to Spyre (float16) - via convert_for_spyre - 3. Kernel execution: call compiled maybe_compiled_forward_spyre(x1_spyre, x2_spyre) - 4. Result transfer: Spyre -> original device, restore original dtype - - Args: - x: Input tensor of shape [..., 2*d] on CPU with arbitrary float dtype. - - Returns: - Activated output tensor of shape [..., d] on the original device with - the original dtype. + Activated output tensor of shape [..., d] on the original device + with the original dtype. """ + logger.debug_once( + "SpyreSiluAndMul: enabled=%s", + self.enabled(), + ) x_dtype = x.dtype x_device = x.device - - # Note: Workaround with tensor slicing on CPU d = x.shape[-1] // 2 - x1 = x[..., :d] - x2 = x[..., d:] - out = self.maybe_compiled_forward_spyre( - convert(x1, self._target_device, self._target_dtype), - convert(x2, self._target_device, self._target_dtype), - ) - - # Transfer back to original device and restore original dtype - return convert(out, x_device, x_dtype) + # Call .contiguous() to ensure zero storage offset (Spyre's Flex backend + # rejects non-zero offsets). convert() then transfers to Spyre device/dtype. + x1 = convert(x[..., :d].contiguous(), self._target_device, self._target_dtype) + x2 = convert(x[..., d:].contiguous(), self._target_device, self._target_dtype) + return convert(F.silu(x1) * x2, x_device, x_dtype) -def _op_func( - x: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - """Custom op implementation — runs outside torch.compile graph.""" - layer = get_layer(layer_name) - result = layer._forward_spyre_impl(x) - output.copy_(result) - - -@lru_cache(maxsize=1) def register(): - """Register the spyre_siluandmul custom op with vLLM.""" - direct_register_custom_op( - op_name="spyre_siluandmul", - op_func=_op_func, - mutates_args=["output"], - fake_impl=_fake_impl, - ) - logger.info("Registered custom op: SpyreSiluAndMul") + """No-op: the custom-op barrier has been removed. + + Retained so register_all() in custom_ops/__init__.py needs no changes. + """ + logger.debug("SpyreSiluAndMul: no custom op to register (barrier removed)")