Skip to content

[Spyre-Next] [Feature] rms_norm rework with vLLM IR#877

Open
bohnstingl wants to merge 22 commits intotorch-spyre:mainfrom
bohnstingl:vllm_IR_integration
Open

[Spyre-Next] [Feature] rms_norm rework with vLLM IR#877
bohnstingl wants to merge 22 commits intotorch-spyre:mainfrom
bohnstingl:vllm_IR_integration

Conversation

@bohnstingl
Copy link
Copy Markdown
Collaborator

Description

This PR reworks the custom wrapping of torch-spyre ops to use the upstream vLLM IR system.

Key changes

  • New Spyre IR provider (custom_ops/kernels/rms_norm.py): A standalone function spyre_rms_norm that implements RMSNorm for spyre, registered with the vLLM IR system via ir.ops.rms_norm.register_impl("spyre", ...), with spyre-specific workarounds.

  • Simplified OOT wrapper (custom_ops/rms_norm.py): SpyreRMSNorm._forward_spyre_impl() now calls ir.ops.rms_norm(...) directly and is significantly simplified. The wrapper is reduced to: device sandwich (CPU → Spyre → CPU), batch padding, and residual handling. Note: This is only temporarily needed and can be removed with a spyre-specific model runner as well as with the upstream rework to a PluggableLayer.

  • Platform Changes (platform.py): New get_default_ir_op_priority() classmethod sets rms_norm=["spyre", "native"], telling the IR system to prefer our spyre provider and only fall back to the pure-PyTorch native implementation when needed. check_and_update_config() forces custom_ops=["all"] and ir_enable_torch_wrap=False.

How it works

The dispatch chain is:

model forward
  → forward_oot()                          # OOT entry point (replaces upstream RMSNorm)
    → torch.ops.vllm.spyre_rmsnorm(...)    # custom op boundary, opaque to torch.compile
      → _op_func()                         # runs eagerly outside any compiled graph
        → _forward_spyre_impl()
          → convert tensors to Spyre fp16
          → ir.ops.rms_norm(...)           # IR direct dispatch
            → dispatch() selects "spyre" provider (x.device.type == "spyre")
            → spyre_rms_norm() runs on Spyre
          → convert result back to CPU

Main benefits

  • Reuses the upstream vLLM kernel dispatch system. Instead of our own standalone kernel function called directly from the OOT wrapper (maybe_compile), the spyre kernel is discovered and selected through the same priority-based dispatch that CUDA, ROCm, and XPU use.

  • Simpler OOT wrapper. The SpyreRMSNorm class no longer needs _dynamic_arg_dims, maybe_compile, etc. Moreover, the OOT wrapper can be further simplified when moving to a PluggableLayer or with a spyre-specific model_runner, keeping the inputs and outputs on spyre -> no device transfer.

  • Decoupled provider. The spyre rms_norm kernel (kernels/rms_norm.py) is a standalone function with no dependency on the OOT class. It can be tested in isolation via ir.ops.rms_norm() and evolves independently of the wrapping logic.

Limitations

  • The approach relies on the current status of the vLLM IR PR (#33825). With this, limitations arise with respect to the switch to a PluggableLayer and the residual path handling. In general, this approach is bottle-necked by the availability of the vLLM IR for other operations, such as SiluAndMul, etc., but this is expected to evolve fast.
  • The opaque custom op boundary (torch.ops.vllm.spyre_rmsnorm) is still required. Dynamo cannot yet trace into the OOT wrapper, because it contains D2H and H2D tensor transfers and because of a necessary padding operation, that is not yet supported in torch-spyre. However, as mentioned above, these points can be addressed in the future and the overhead removed.

cc @tdoublep @bringlein

Related Issues

#733

Test Plan

The change is non user-facing and thus in principle all current tests should pass. However, there are some rms_norm-specific tests that currently use mocking. If we decide to keep them and not reuse the tests from upstream vLLM, we might need to adjust them. For example: https://github.com/vllm-project/vllm-spyre/blob/2d816a744f479025462bd97521c823c67b403173/vllm_spyre_next/tests/test_rms_norm.py#L97

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
…R_integration

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@github-actions github-actions bot changed the title [Feature] rms_norm rework with vLLM IR [Spyre-Next] [Feature] rms_norm rework with vLLM IR Mar 28, 2026
@romitjain
Copy link
Copy Markdown
Collaborator

@bohnstingl
Maybe an unrelated question:
You mentioned:

→ _op_func()                         # runs eagerly outside any compiled graph

Shouldn't these methods run inside the compiled graph? Is that the current limitation where we do CPU->Spyre transfer inside the method?

Copy link
Copy Markdown
Collaborator

@bringlein bringlein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort @bohnstingl ! Good starting point to leverage vLLM IR.

Comment on lines +21 to +25
Spyre-specific implementation details:
- Epsilon as tensor: scalar broadcast limited, expand via torch.full()
- No dtype promotion: torch-spyre limitation, stays in input dtype
- variance_size: parameter currently not used
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand and agree with these constraints and I think we should proceed with this PR as is.
However, if we want to reduce the custom code for spyre even further, couldn't we reuse the upstream RMS norm if we automate the first two steps?

  1. to convert epsilon to a tensor could be a custom vllm compile pass? (or done in torch.spyre?)
  2. same for remove of the dtype cast?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider this once we fully utilize torch.compile for our provider, see my comment below. In a sense, currently our provider function is not traced, and thus no graph is created and thus we can't do any of the custom passes.

@bohnstingl
Copy link
Copy Markdown
Collaborator Author

@romitjain The problem is that we need to perform the D2H and H2D and those can happen inside the graph. That's why we have this opaque layer. This layer then triggers layer._forward_spyre_impl, which does the transfers and directly dispatches the ir.ops.rms_norm.

…R_integration

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
…R_integration

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

I updated the implementation and simplified the code infrastructure quite a bit. Currently, we only need to rewrite the provider function and don't require any of the CusomOp wrapping anymore.

# torch.ops.vllm_ir.rms_norm (a registered custom op). This creates
# an opaque boundary that Dynamo captures without tracing inside.
# Inductor. The provider therefore runs eagerly at each forward call.
compilation_config.ir_enable_torch_wrap = True
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use the ir_enable_torch_wrap at the moment, because of our .to(device="spyre") calls and also because an upstream issue in torch-spyre that needs to be investigated.
So, for the moment our provider function won't be compiled and runs in eager-mode.

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

Padding is no longer needed, since torch-spyre/torch-spyre#878 landed in torch-spyre. Therefore, we no longer need the D2H and H2D transfers in the IR and together with #878, we can avoid using ir_enable_torch_wrap

…R_integration

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl bohnstingl marked this pull request as ready for review April 9, 2026 16:07
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

I updated this PR and further simplified our code. At this stage, our provider is only different in three ways:

  1. It does D2H and H2D transfers -> will be resolved with the SpyreModelRunner
  2. It does not do upcasts to float32 -> known limitation in torch-spyre
  3. It uses an epsilon tensor, instead of a single element. -> Minor issue

@tdoublep @bringlein please have a look

@bohnstingl bohnstingl requested a review from bringlein April 9, 2026 16:09
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

Merged-in the latest main. Please have a look and provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants