Skip to content

[Spyre-Next] Simplify SiluAndMul by removing custom op barrier#906

Open
GOavi101 wants to merge 1 commit intotorch-spyre:mainfrom
GOavi101:remove-custom-op-barrier-siluandmul
Open

[Spyre-Next] Simplify SiluAndMul by removing custom op barrier#906
GOavi101 wants to merge 1 commit intotorch-spyre:mainfrom
GOavi101:remove-custom-op-barrier-siluandmul

Conversation

@GOavi101
Copy link
Copy Markdown
Collaborator

@GOavi101 GOavi101 commented Apr 9, 2026

Description

This PR simplifies the SpyreSiluAndMul implementation by removing the custom op barrier pattern.

Changes

  • Remove opaque custom-op boundary in SpyreSiluAndMul
  • Simplify forward_oot() to compute directly without layer registry
  • Add .contiguous() calls to ensure zero storage offset for Spyre

Technical Details

Background:
The split cannot happen on Spyre inside a compiled graph — Spyre's Flex backend rejects non-zero storage offsets (aten.slice.Tensor unsupported), confirmed with a standalone repro script (torch-spyre/torch-spyre#1333)

Implementation:

  • Split input tensor with .contiguous() to ensure zero storage offset
  • Compute silu(x1) * x2 on Spyre device
  • Convert result back to original device/dtype

Related Issues

Related to #887

Test Plan

Tested on ibm-granite/granite-3.3-8b-instruct:

  • Warmup completes successfully
  • Inference produces correct output

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 9, 2026

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@github-actions github-actions bot changed the title Simplify SiluAndMul by removing custom op barrier [Spyre-Next] Simplify SiluAndMul by removing custom op barrier Apr 9, 2026
@GOavi101 GOavi101 force-pushed the remove-custom-op-barrier-siluandmul branch 2 times, most recently from c7a190f to bc06b5e Compare April 9, 2026 11:30
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me in principle.

Comment on lines -66 to -67
self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need this anymore and this will be triggered by the upstream implementation through enforce_eager, right? Can you check the difference when using enforce_eager=True and without?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I will run a test comparing the two modes

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested both modes to verify the behavior:

With enforce_eager=True (eager execution):

  • Compilation mode: NONE
  • Custom ops: 'all'
  • Direct eager execution through forward_oot()

With enforce_eager=False (compilation enabled):

  • Compilation mode: DYNAMO_TRACE_ONCE
  • Custom ops: 'none'
  • torch.compile wraps the entire model including SpyreSiluAndMul
  • Longer warmup (130s vs 68s) due to compilation overhead
  • ~2x faster inference (92s vs 180s)

Comment on lines -70 to -75
logger.debug_once(
"SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s",
self.enabled(),
self._forward_method.__name__,
self.maybe_compiled_forward_spyre is not self.forward_spyre,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a logging like this not make sense anymore? Like, the CustomOp can still be enabled or not, right?

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! The old logging referenced attributes that no longer exist (_forward_method,maybe_compiled_forward_spyre ). In the simplified implementation, we could add simpler logging like:

logger.debug_once(
    "SpyreSiluAndMul: enabled=%s",
    self.enabled(),
)

something like this? It won't show compilation status since that's now handled by the outer graph, but it would still indicate if the CustomOp is active.

@GOavi101 GOavi101 force-pushed the remove-custom-op-barrier-siluandmul branch from bc06b5e to 2734cec Compare April 13, 2026 11:53
Copy link
Copy Markdown
Collaborator

@bringlein bringlein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @GOavi101 , looks good.

I have two questions. But I think if there is no dependency, we should be good to merge.

this method is called (see _forward_spyre_impl).
The convert() utility handles device/dtype transfers efficiently,
skipping data movement when the tensor is already on the correct
device/dtype (vllm-spyre#863 optimization).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we have a dependency on #863? Or can we merge it independently?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convert()utility function already exists and works correctly. PR #863 added an optimization to convert() that skips redundant data movement when a tensor is already on the correct device/dtype, but it's not a hard dependency.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, make sense. Can we then just remove the comment that links it to PR 863?

# Call .contiguous() to ensure zero storage offset (Spyre's Flex backend
# rejects non-zero offsets). convert() then transfers to Spyre device/dtype,
# skipping transfer if already correct (vllm-spyre#863 optimization).
x1 = convert(x[..., :d].contiguous(), self._target_device, self._target_dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like an elegant solution. But out of curiosity, do we know how these slicing is handled by torch-spyre?

Copy link
Copy Markdown
Collaborator Author

@GOavi101 GOavi101 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The slicing happens on the CPU before transfer to Spyre.

The split cannot happen on Spyre inside a compiled graph — Spyre's Flex backend rejects non-zero storage offsets (aten.slice.Tensor unsupported)(torch-spyre/torch-spyre#1333)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but if the preceding op is also executed on spyre, how should this work? (but maybe out-of-scope for this PR)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spyre tensor → CPU (convert detects device mismatch) → slice on CPU → contiguous → back to Spyre

it isn't optimal, but it's the current workaround for Spyre's aten.slice.Tensor limitation.

- Remove opaque custom-op boundary in SpyreSiluAndMul
- Simplify forward_oot() to compute directly without layer registry
- Add .contiguous() calls to ensure zero storage offset for Spyre

Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
@GOavi101 GOavi101 force-pushed the remove-custom-op-barrier-siluandmul branch from 2734cec to a4c4a28 Compare April 13, 2026 15:34
Copy link
Copy Markdown
Collaborator

@bringlein bringlein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants