Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 39 additions & 116 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,17 @@

Architecture:
- OOT Registration: @SiluAndMul.register_oot() replaces upstream at instantiation
- forward_oot(): Entry point for OOT dispatch, calls custom op for
torch.compile opacity
- Custom Op Boundary: torch.ops.vllm.spyre_siluandmul is opaque to torch.compile,
so _forward_spyre_impl runs eagerly outside the compiled graph
- Separate Compilation: forward_spyre is compiled independently via maybe_compile
- forward_oot(): Entry point for OOT dispatch, fully transparent to the outer
torch.compile graph (no opaque custom-op boundary).
- Halves use .contiguous() to ensure zero storage offset before device transfer.
- convert() utility handles device/dtype transfers efficiently.

Spyre Device Constraints:
- Computations performed in torch.float16:
Input (dtype defined by model / user) converted to torch.float16 for
operations on spyre and then converted back to original dtype for cpu.
- Splitting (aten.slice.Tensor) inside compiled Spyre graphs is unsupported —
non-zero storage offsets are rejected by the Flex backend.
- The .contiguous() call ensures zero storage offset before transfer to Spyre.

Output Shape Note:
Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension:
input shape: [..., 2*d] -> output shape: [..., d]

References:
Expand All @@ -31,11 +29,9 @@
import torch.nn.functional as F

from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.model_executor.layers.activation import SiluAndMul
from functools import lru_cache

from .utils import convert, register_layer, get_layer, _fake_impl
from .utils import convert

logger = init_logger(__name__)

Expand All @@ -44,134 +40,61 @@
class SpyreSiluAndMul(SiluAndMul):
"""Out-of-tree (OOT) SiluAndMul implementation for IBM's Spyre device.

This replaces the upstream vLLM SiluAndMul (vllm/model_executor/layers/activation.py)
when instantiated, providing Spyre-specific optimizations and device handling.

Computes: x -> silu(x[..., :d]) * x[..., d:] where d = x.shape[-1] // 2

Fully transparent to the outer torch.compile graph — no opaque custom-op
boundary. Uses .contiguous() to ensure zero storage offset before device
transfer, as Spyre's Flex backend rejects non-zero offsets
(aten.slice.Tensor unsupported in compiled Spyre graphs).
"""

_dynamic_arg_dims = {"x1": [], "x2": []}
_dynamic_arg_dims = {"x": []}

def __init__(self, *args, **kwargs):
"""Initialize SpyreSiluAndMul layer.

Compiles the Spyre kernel and registers this instance in static_forward_context.
Sets up the target device (Spyre) and dtype (float16) for computation.
The simplified implementation computes directly in forward_oot() without
requiring layer registry or custom op registration.
"""
super().__init__(*args, **kwargs)

logger.debug("Building custom SiluAndMul")

self._target_device = torch.device("spyre")
self._target_dtype = torch.float16
self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre)

Comment on lines -66 to -67
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need this anymore and this will be triggered by the upstream implementation through enforce_eager, right? Can you check the difference when using enforce_eager=True and without?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I will run a test comparing the two modes

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested both modes to verify the behavior:

With enforce_eager=True (eager execution):

  • Compilation mode: NONE
  • Custom ops: 'all'
  • Direct eager execution through forward_oot()

With enforce_eager=False (compilation enabled):

  • Compilation mode: DYNAMO_TRACE_ONCE
  • Custom ops: 'none'
  • torch.compile wraps the entire model including SpyreSiluAndMul
  • Longer warmup (130s vs 68s) due to compilation overhead
  • ~2x faster inference (92s vs 180s)

self._layer_name = register_layer(self, "spyre_siluandmul")

logger.debug_once(
"SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s",
self.enabled(),
self._forward_method.__name__,
self.maybe_compiled_forward_spyre is not self.forward_spyre,
)
Comment on lines -70 to -75
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a logging like this not make sense anymore? Like, the CustomOp can still be enabled or not, right?

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! The old logging referenced attributes that no longer exist (_forward_method,maybe_compiled_forward_spyre ). In the simplified implementation, we could add simpler logging like:

logger.debug_once(
    "SpyreSiluAndMul: enabled=%s",
    self.enabled(),
)

something like this? It won't show compilation status since that's now handled by the outer graph, but it would still indicate if the CustomOp is active.


def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
"""OOT forward pass using custom op to bypass torch.compile.

Delegates to torch.ops.vllm.spyre_siluandmul which retrieves this layer
from the layer registry and calls _forward_spyre_impl outside
the compilation graph.

Args:
x: Input tensor [..., 2*d]

Returns:
Activated output tensor [..., d]
"""
d = x.shape[-1] // 2
output = torch.empty(x.shape[:-1] + (d,), dtype=x.dtype, device=x.device)

# Custom op call - executes outside torch.compile graph
torch.ops.vllm.spyre_siluandmul(x, output, self._layer_name)

return output
"""Spyre-optimized SiLU and multiply activation (SwiGLU).

@staticmethod
def forward_spyre(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""Spyre-optimized silu+multiply kernel compiled via torch.compile.
Computes silu(x[..., :d]) * x[..., d:] where d = x.shape[-1] // 2.
The input tensor is split into two halves, with .contiguous() ensuring
zero storage offset (Spyre's Flex backend rejects non-zero offsets).

Computes silu(x1) * x2 on the Spyre device, relying on torch-spyre's
registered aten::silu.out kernel. The two halves are passed in as
separate tensors because the Spyre device does not yet support tensor
slicing (strided views); the split is therefore performed on CPU before
this method is called (see _forward_spyre_impl).
The convert() utility handles device/dtype transfers efficiently.

Args:
x1: First half of the gated input, shape [..., d], on Spyre device
(float16). silu is applied to this half.
x2: Second half of the gated input, shape [..., d], on Spyre device
(float16). Acts as the multiplicative gate.
x: Input tensor of shape [..., 2*d] containing concatenated gate halves.

Returns:
Output tensor of shape [..., d] on the Spyre device (float16).
"""
return F.silu(x1) * x2

def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor:
"""Spyre device execution: CPU slicing workaround, device transfer, kernel call.

The Spyre device does not currently support strided tensor views (slicing),
so the input is split into its two halves on the CPU before being
transferred to the device. Once tensor slicing is supported this method
should revert to the simpler single-tensor path (see commented-out block).

Execution steps:
1. Slice on CPU: split x into x1 = x[..., :d] and x2 = x[..., d:]
2. Device transfer: convert x1 and x2 independently to Spyre (float16)
via convert_for_spyre
3. Kernel execution: call compiled maybe_compiled_forward_spyre(x1_spyre, x2_spyre)
4. Result transfer: Spyre -> original device, restore original dtype

Args:
x: Input tensor of shape [..., 2*d] on CPU with arbitrary float dtype.

Returns:
Activated output tensor of shape [..., d] on the original device with
the original dtype.
Activated output tensor of shape [..., d] on the original device
with the original dtype.
"""
logger.debug_once(
"SpyreSiluAndMul: enabled=%s",
self.enabled(),
)
x_dtype = x.dtype
x_device = x.device

# Note: Workaround with tensor slicing on CPU
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
out = self.maybe_compiled_forward_spyre(
convert(x1, self._target_device, self._target_dtype),
convert(x2, self._target_device, self._target_dtype),
)

# Transfer back to original device and restore original dtype
return convert(out, x_device, x_dtype)
# Call .contiguous() to ensure zero storage offset (Spyre's Flex backend
# rejects non-zero offsets). convert() then transfers to Spyre device/dtype.
x1 = convert(x[..., :d].contiguous(), self._target_device, self._target_dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like an elegant solution. But out of curiosity, do we know how these slicing is handled by torch-spyre?

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The slicing happens on the CPU before transfer to Spyre.

The split cannot happen on Spyre inside a compiled graph — Spyre's Flex backend rejects non-zero storage offsets (aten.slice.Tensor unsupported)(torch-spyre/torch-spyre#1333)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but if the preceding op is also executed on spyre, how should this work? (but maybe out-of-scope for this PR)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spyre tensor → CPU (convert detects device mismatch) → slice on CPU → contiguous → back to Spyre

it isn't optimal, but it's the current workaround for Spyre's aten.slice.Tensor limitation.

x2 = convert(x[..., d:].contiguous(), self._target_device, self._target_dtype)
return convert(F.silu(x1) * x2, x_device, x_dtype)


def _op_func(
x: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
"""Custom op implementation — runs outside torch.compile graph."""
layer = get_layer(layer_name)
result = layer._forward_spyre_impl(x)
output.copy_(result)


@lru_cache(maxsize=1)
def register():
"""Register the spyre_siluandmul custom op with vLLM."""
direct_register_custom_op(
op_name="spyre_siluandmul",
op_func=_op_func,
mutates_args=["output"],
fake_impl=_fake_impl,
)
logger.info("Registered custom op: SpyreSiluAndMul")
"""No-op: the custom-op barrier has been removed.

Retained so register_all() in custom_ops/__init__.py needs no changes.
"""
logger.debug("SpyreSiluAndMul: no custom op to register (barrier removed)")
Loading