Skip to content

[Spyre-Next] Reworked forward call chain#872

Merged
tjohnson31415 merged 12 commits into
torch-spyre:mainfrom
bohnstingl:forward_oot_rework
Mar 30, 2026
Merged

[Spyre-Next] Reworked forward call chain#872
tjohnson31415 merged 12 commits into
torch-spyre:mainfrom
bohnstingl:forward_oot_rework

Conversation

@bohnstingl
Copy link
Copy Markdown
Collaborator

Description

This PR reworks the call chain of the forward functions for the layer wrappings for torch-spyre. In particular, it avoids overriding forward_native in favor of using forward_oot.

cc @romitjain

Related Issues

This PR was triggered by #863

Test Plan

This change is non user-facing and thus all the existing tests should work. Moreover, it should enable the upstream vLLM tests to be supported as well.

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@github-actions github-actions Bot changed the title Reworked forward call chain [Spyre-Next] Reworked forward call chain Mar 27, 2026
Comment thread vllm_spyre_next/tests/test_rms_norm.py Outdated
if x.shape[-1] != hidden_size:
raise ValueError(f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}")

x = x.transpose(-1, -2).contiguous()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bohnstingl

I tried running rmsnorm repo tests on this branch it failed for me. I am not able to pinpoint why that might be. On main, they are passing, and apart from the function name changes, the only change I see is this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@romitjain you need a very recent commit of torch-spyre for this to work. I reverted this rework from this PR and I'll move it into a separate PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bohnstingl
My bad - the tests are passing iff I check one of forward_oot, forward, or forward_native. But if we check all 3, I am getting a test failed (tensors are not close) in the residual path. I think it's because we mutate the residual tensor: https://github.com/vllm-project/vllm-spyre/blob/main/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py#L234

So this change: 4b6b2e9 might not be needed. I was able to pass tests locally with and without this change. I have made a small PR here from my fork: bohnstingl#1.

PR is for demonstration - feel free to merge/close and make changes directly here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I merged it in.

I will still leave the rms_norm rework for a separate PR, in order not to dilute the intention of this PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rms_norm rework PR: #873

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl bohnstingl requested a review from romitjain March 27, 2026 09:22
Copy link
Copy Markdown
Collaborator

@romitjain romitjain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bohnstingl bohnstingl marked this pull request as ready for review March 27, 2026 10:41
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

@joerunde or @tjohnson31415 could you also take a look at it to see if the rework makes sense to you?

@yannicks1
Copy link
Copy Markdown
Collaborator

@romitjain fyi DCO fails with: Commit sha: 8fc82cc, Author: romit, Committer: romit; The sign-off is missing.

Copy link
Copy Markdown
Collaborator

@tjohnson31415 tjohnson31415 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of NITs, but otherwise I like this change! Much cleaner to not override the upstream forward() and forward_native() functions.

Comment thread vllm_spyre_next/tests/test_rms_norm.py Outdated
Comment thread vllm_spyre_next/tests/test_rms_norm.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this PR does mainly renaming: could we also rename self._fwd or is that fixed? It was not obvious at first glance that self._fwd is the compiled self.forward_spyre. I don't have the perfect name, but sth along self._fwd_spyre_comp, or just self._fwd_spyre or self._forward_spyre/ self._forward_spyre_comp

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it to maybe_compiled_forward_spyre to indicate that 1) it is maybe compiled and 2) that it is the forward_spyre. I did the same for silu as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency: why not calling this forward_spyre as in the rsm op?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

@@ -64,11 +66,11 @@ def __init__(self, *args, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same argument as above: consider renaming self._fwd to sth more descriptive if possible

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, see above.

bohnstingl added a commit to bohnstingl/vllm-spyre that referenced this pull request Mar 27, 2026
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py Outdated
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Copy link
Copy Markdown
Collaborator

@tjohnson31415 tjohnson31415 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjohnson31415
Copy link
Copy Markdown
Collaborator

bot:next-test

@bohnstingl
Copy link
Copy Markdown
Collaborator Author

@dilipgb, @R3hankhan123, @coderfornow and @GOavi101 please feel free to also comment on this rework PR, as you are also working on layer wrappings. I am happy to incorporate any feedback you may have.

Copy link
Copy Markdown
Collaborator

@yannicks1 yannicks1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. please keep in mind to merge in the changes from #873.

Comment thread vllm_spyre_next/tests/test_rms_norm.py Outdated
- forward(): custom op dispatch (no-compile path via torch.ops.vllm.spyre_rmsnorm)
- forward_native(): direct Spyre device execution
- forward_oot(): OOT dispatch via custom op (torch.ops.vllm.spyre_rmsnorm)
- reference_rms_norm: golden reference, similar to vLLM upstream pure PyTorch (ground truth)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: reference_rms_norm -> reference_rms_norm()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

@bohnstingl
Copy link
Copy Markdown
Collaborator Author

bot:next-test

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@tjohnson31415
Copy link
Copy Markdown
Collaborator

bot test is broken for an unrelated reason, but tests are passing in dev pods. Merging!

@tjohnson31415 tjohnson31415 enabled auto-merge (squash) March 30, 2026 19:53
@github-actions github-actions Bot added the ready Runs the full CI test suite. Only add to PRs once ready to merge to limit public GHA usage label Mar 30, 2026
@tjohnson31415 tjohnson31415 merged commit 74fa89e into torch-spyre:main Mar 30, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready Runs the full CI test suite. Only add to PRs once ready to merge to limit public GHA usage

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants