-
Notifications
You must be signed in to change notification settings - Fork 54
[Spyre-Next] Simplify SiluAndMul by removing custom op barrier #906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,19 +8,17 @@ | |
|
|
||
| Architecture: | ||
| - OOT Registration: @SiluAndMul.register_oot() replaces upstream at instantiation | ||
| - forward_oot(): Entry point for OOT dispatch, calls custom op for | ||
| torch.compile opacity | ||
| - Custom Op Boundary: torch.ops.vllm.spyre_siluandmul is opaque to torch.compile, | ||
| so _forward_spyre_impl runs eagerly outside the compiled graph | ||
| - Separate Compilation: forward_spyre is compiled independently via maybe_compile | ||
| - forward_oot(): Entry point for OOT dispatch, fully transparent to the outer | ||
| torch.compile graph (no opaque custom-op boundary). | ||
| - Halves use .contiguous() to ensure zero storage offset before device transfer. | ||
| - convert() utility handles device/dtype transfers efficiently. | ||
|
|
||
| Spyre Device Constraints: | ||
| - Computations performed in torch.float16: | ||
| Input (dtype defined by model / user) converted to torch.float16 for | ||
| operations on spyre and then converted back to original dtype for cpu. | ||
| - Splitting (aten.slice.Tensor) inside compiled Spyre graphs is unsupported — | ||
| non-zero storage offsets are rejected by the Flex backend. | ||
| - The .contiguous() call ensures zero storage offset before transfer to Spyre. | ||
|
|
||
| Output Shape Note: | ||
| Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension: | ||
| input shape: [..., 2*d] -> output shape: [..., d] | ||
|
|
||
| References: | ||
|
|
@@ -31,11 +29,9 @@ | |
| import torch.nn.functional as F | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
| from functools import lru_cache | ||
|
|
||
| from .utils import convert, register_layer, get_layer, _fake_impl | ||
| from .utils import convert | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -44,134 +40,61 @@ | |
| class SpyreSiluAndMul(SiluAndMul): | ||
| """Out-of-tree (OOT) SiluAndMul implementation for IBM's Spyre device. | ||
|
|
||
| This replaces the upstream vLLM SiluAndMul (vllm/model_executor/layers/activation.py) | ||
| when instantiated, providing Spyre-specific optimizations and device handling. | ||
|
|
||
| Computes: x -> silu(x[..., :d]) * x[..., d:] where d = x.shape[-1] // 2 | ||
|
|
||
| Fully transparent to the outer torch.compile graph — no opaque custom-op | ||
| boundary. Uses .contiguous() to ensure zero storage offset before device | ||
| transfer, as Spyre's Flex backend rejects non-zero offsets | ||
| (aten.slice.Tensor unsupported in compiled Spyre graphs). | ||
| """ | ||
|
|
||
| _dynamic_arg_dims = {"x1": [], "x2": []} | ||
| _dynamic_arg_dims = {"x": []} | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| """Initialize SpyreSiluAndMul layer. | ||
|
|
||
| Compiles the Spyre kernel and registers this instance in static_forward_context. | ||
| Sets up the target device (Spyre) and dtype (float16) for computation. | ||
| The simplified implementation computes directly in forward_oot() without | ||
| requiring layer registry or custom op registration. | ||
| """ | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| logger.debug("Building custom SiluAndMul") | ||
|
|
||
| self._target_device = torch.device("spyre") | ||
| self._target_dtype = torch.float16 | ||
| self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) | ||
|
|
||
| self._layer_name = register_layer(self, "spyre_siluandmul") | ||
|
|
||
| logger.debug_once( | ||
| "SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", | ||
| self.enabled(), | ||
| self._forward_method.__name__, | ||
| self.maybe_compiled_forward_spyre is not self.forward_spyre, | ||
| ) | ||
|
Comment on lines
-70
to
-75
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would a logging like this not make sense anymore? Like, the CustomOp can still be
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! The old logging referenced attributes that no longer exist (_forward_method,maybe_compiled_forward_spyre ). In the simplified implementation, we could add simpler logging like: something like this? It won't show compilation status since that's now handled by the outer graph, but it would still indicate if the CustomOp is active. |
||
|
|
||
| def forward_oot(self, x: torch.Tensor) -> torch.Tensor: | ||
| """OOT forward pass using custom op to bypass torch.compile. | ||
|
|
||
| Delegates to torch.ops.vllm.spyre_siluandmul which retrieves this layer | ||
| from the layer registry and calls _forward_spyre_impl outside | ||
| the compilation graph. | ||
|
|
||
| Args: | ||
| x: Input tensor [..., 2*d] | ||
|
|
||
| Returns: | ||
| Activated output tensor [..., d] | ||
| """ | ||
| d = x.shape[-1] // 2 | ||
| output = torch.empty(x.shape[:-1] + (d,), dtype=x.dtype, device=x.device) | ||
|
|
||
| # Custom op call - executes outside torch.compile graph | ||
| torch.ops.vllm.spyre_siluandmul(x, output, self._layer_name) | ||
|
|
||
| return output | ||
| """Spyre-optimized SiLU and multiply activation (SwiGLU). | ||
|
|
||
| @staticmethod | ||
| def forward_spyre(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | ||
| """Spyre-optimized silu+multiply kernel compiled via torch.compile. | ||
| Computes silu(x[..., :d]) * x[..., d:] where d = x.shape[-1] // 2. | ||
| The input tensor is split into two halves, with .contiguous() ensuring | ||
| zero storage offset (Spyre's Flex backend rejects non-zero offsets). | ||
|
|
||
| Computes silu(x1) * x2 on the Spyre device, relying on torch-spyre's | ||
| registered aten::silu.out kernel. The two halves are passed in as | ||
| separate tensors because the Spyre device does not yet support tensor | ||
| slicing (strided views); the split is therefore performed on CPU before | ||
| this method is called (see _forward_spyre_impl). | ||
| The convert() utility handles device/dtype transfers efficiently. | ||
|
|
||
| Args: | ||
| x1: First half of the gated input, shape [..., d], on Spyre device | ||
| (float16). silu is applied to this half. | ||
| x2: Second half of the gated input, shape [..., d], on Spyre device | ||
| (float16). Acts as the multiplicative gate. | ||
| x: Input tensor of shape [..., 2*d] containing concatenated gate halves. | ||
|
|
||
| Returns: | ||
| Output tensor of shape [..., d] on the Spyre device (float16). | ||
| """ | ||
| return F.silu(x1) * x2 | ||
|
|
||
| def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Spyre device execution: CPU slicing workaround, device transfer, kernel call. | ||
|
|
||
| The Spyre device does not currently support strided tensor views (slicing), | ||
| so the input is split into its two halves on the CPU before being | ||
| transferred to the device. Once tensor slicing is supported this method | ||
| should revert to the simpler single-tensor path (see commented-out block). | ||
|
|
||
| Execution steps: | ||
| 1. Slice on CPU: split x into x1 = x[..., :d] and x2 = x[..., d:] | ||
| 2. Device transfer: convert x1 and x2 independently to Spyre (float16) | ||
| via convert_for_spyre | ||
| 3. Kernel execution: call compiled maybe_compiled_forward_spyre(x1_spyre, x2_spyre) | ||
| 4. Result transfer: Spyre -> original device, restore original dtype | ||
|
|
||
| Args: | ||
| x: Input tensor of shape [..., 2*d] on CPU with arbitrary float dtype. | ||
|
|
||
| Returns: | ||
| Activated output tensor of shape [..., d] on the original device with | ||
| the original dtype. | ||
| Activated output tensor of shape [..., d] on the original device | ||
| with the original dtype. | ||
| """ | ||
| logger.debug_once( | ||
| "SpyreSiluAndMul: enabled=%s", | ||
| self.enabled(), | ||
| ) | ||
| x_dtype = x.dtype | ||
| x_device = x.device | ||
|
|
||
| # Note: Workaround with tensor slicing on CPU | ||
| d = x.shape[-1] // 2 | ||
| x1 = x[..., :d] | ||
| x2 = x[..., d:] | ||
| out = self.maybe_compiled_forward_spyre( | ||
| convert(x1, self._target_device, self._target_dtype), | ||
| convert(x2, self._target_device, self._target_dtype), | ||
| ) | ||
|
|
||
| # Transfer back to original device and restore original dtype | ||
| return convert(out, x_device, x_dtype) | ||
| # Call .contiguous() to ensure zero storage offset (Spyre's Flex backend | ||
| # rejects non-zero offsets). convert() then transfers to Spyre device/dtype. | ||
| x1 = convert(x[..., :d].contiguous(), self._target_device, self._target_dtype) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like an elegant solution. But out of curiosity, do we know how these slicing is handled by torch-spyre?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The slicing happens on the CPU before transfer to Spyre. The split cannot happen on Spyre inside a compiled graph — Spyre's Flex backend rejects non-zero storage offsets (aten.slice.Tensor unsupported)(torch-spyre/torch-spyre#1333)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, but if the preceding op is also executed on spyre, how should this work? (but maybe out-of-scope for this PR)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spyre tensor → CPU (convert detects device mismatch) → slice on CPU → contiguous → back to Spyre it isn't optimal, but it's the current workaround for Spyre's aten.slice.Tensor limitation. |
||
| x2 = convert(x[..., d:].contiguous(), self._target_device, self._target_dtype) | ||
| return convert(F.silu(x1) * x2, x_device, x_dtype) | ||
|
|
||
|
|
||
| def _op_func( | ||
| x: torch.Tensor, | ||
| output: torch.Tensor, | ||
| layer_name: str, | ||
| ) -> None: | ||
| """Custom op implementation — runs outside torch.compile graph.""" | ||
| layer = get_layer(layer_name) | ||
| result = layer._forward_spyre_impl(x) | ||
| output.copy_(result) | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
| def register(): | ||
| """Register the spyre_siluandmul custom op with vLLM.""" | ||
| direct_register_custom_op( | ||
| op_name="spyre_siluandmul", | ||
| op_func=_op_func, | ||
| mutates_args=["output"], | ||
| fake_impl=_fake_impl, | ||
| ) | ||
| logger.info("Registered custom op: SpyreSiluAndMul") | ||
| """No-op: the custom-op barrier has been removed. | ||
|
|
||
| Retained so register_all() in custom_ops/__init__.py needs no changes. | ||
| """ | ||
| logger.debug("SpyreSiluAndMul: no custom op to register (barrier removed)") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we don't need this anymore and this will be triggered by the upstream implementation through
enforce_eager, right? Can you check the difference when usingenforce_eager=Trueand without?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure I will run a test comparing the two modes
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested both modes to verify the behavior:
With enforce_eager=True (eager execution):
With enforce_eager=False (compilation enabled):