Skip to content

[Spyre-Next] Reworked rms_norm#873

Merged
tjohnson31415 merged 2 commits intotorch-spyre:mainfrom
bohnstingl:rms_rework
Mar 27, 2026
Merged

[Spyre-Next] Reworked rms_norm#873
tjohnson31415 merged 2 commits intotorch-spyre:mainfrom
bohnstingl:rms_rework

Conversation

@bohnstingl
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl commented Mar 27, 2026

Description

This PR removes the transpose + .contiguous() operation from rms_norm, making it even closer to the native upstream implementation. However, I currently observe small numerical differences and I prepared a repro script for that.

import torch

x = torch.randn(1024, 4096, device="cpu", dtype=torch.float16)
hidden_size = 4096
eps = 1e-05
weight = torch.randn(4096, device="cpu", dtype=torch.float16)

def rms_old(x, hidden_size, variance_epsilon, weight):
    if x.shape[-1] != hidden_size:
        raise ValueError(f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}")

    x = x.transpose(-1, -2).contiguous()

    variance_epsilon = torch.full(
        x.shape, variance_epsilon, dtype=torch.float16, device=x.device
    )

    x_var = x

    # After transpose, hidden dim is now dim=0
    variance = x_var.pow(2).mean(dim=0, keepdim=True)
    # variance = x_var.pow(2).mean(dim=-1, keepdim=True)

    x = x * torch.rsqrt(variance + variance_epsilon)
    x = x.transpose(-1, -2).contiguous()

    if weight is not None:
        x = x * weight
    return x

def rms_new(x, hidden_size, variance_epsilon, weight):
    if x.shape[-1] != hidden_size:
        raise ValueError(f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}")


    variance_epsilon = torch.full(
        x.shape, variance_epsilon, dtype=torch.float16, device=x.device
    )

    x_var = x

    variance = x_var.pow(2).mean(dim=-1, keepdim=True)

    x = x * torch.rsqrt(variance + variance_epsilon)

    if weight is not None:
        x = x * weight
    return x

x = x.to("spyre")
weight = weight.to("spyre")

out1 = torch.compile(rms_old, dynamic=False)(x, hidden_size, eps, weight).cpu()
out2 = torch.compile(rms_new, dynamic=False)(x, hidden_size, eps, weight).cpu()

torch.testing.assert_close(out1, out2, atol=0.001, rtol=0.001)

print('Tensors are close')

cc @romitjain

Related Issues

Corresponding change in torch-spyre:
torch-spyre/torch-spyre#1236.

Test Plan

Change is non user-facing and all existing tests should pass

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl bohnstingl self-assigned this Mar 27, 2026
@bohnstingl bohnstingl marked this pull request as ready for review March 27, 2026 10:52
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@github-actions github-actions bot changed the title Reworked rms_norm [Spyre-Next] Reworked rms_norm Mar 27, 2026
@rafvasq
Copy link
Copy Markdown
Collaborator

rafvasq commented Mar 27, 2026

bot:next-test

Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py Outdated
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl bohnstingl requested a review from yannicks1 March 27, 2026 16:29
@tjohnson31415
Copy link
Copy Markdown
Collaborator

tjohnson31415 commented Mar 27, 2026

How significant are the numerical differences observed, and which method is closer to a baseline non-spyre implementation?

The simplification seems good but maybe the transpose().contiguous() serves a purpose; I've seen another case where a double transpose was needed for Spyre (REF)

Copy link
Copy Markdown
Collaborator

@yannicks1 yannicks1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm code wise - fair point of @tjohnson31415 above ^^

@bohnstingl
Copy link
Copy Markdown
Collaborator Author

@tjohnson31415 yes, indeed. It is a fair point and maybe @romitjain can comment also here.

The case that you reference is a bit specific for attention I believe. There was a view operation done before and then the tensor needed to be re-stickified via this trick. We also currently use it in our attention implementation, see https://github.com/jvlunteren/vllm-spyre/blob/80c7cc9e0c261059375780fc24cb3cd9861d3030/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py#L698-L701.
For our case, this problem shouldn't exist and torch-spyre has also removed the transposition + contiguous() operation, see torch-spyre/torch-spyre#1236.

However, one does observe some numerical difference when comparing the old and the new approach. In particular, I created this small repro script, which will fail for the 0.001 testcase:

import torch

x = torch.randn(1024, 4096, device="cpu", dtype=torch.float16)
hidden_size = 4096
eps = 1e-05
weight = torch.randn(4096, device="cpu", dtype=torch.float16)

def rms_old(x, hidden_size, variance_epsilon, weight):
    x = x.transpose(-1, -2).contiguous()

    variance_epsilon = torch.full(
        x.shape, variance_epsilon, dtype=torch.float16, device=x.device
    )

    # After transpose, hidden dim is now dim=0
    variance = x.pow(2).mean(dim=0, keepdim=True)

    x = x * torch.rsqrt(variance + variance_epsilon)
    x = x.transpose(-1, -2).contiguous()

    if weight is not None:
        x = x * weight
    return x

def rms_new(x, hidden_size, variance_epsilon, weight):
    variance_epsilon = torch.full(
        x.shape, variance_epsilon, dtype=torch.float16, device=x.device
    )

    variance = x.pow(2).mean(dim=-1, keepdim=True)

    x = x * torch.rsqrt(variance + variance_epsilon)

    if weight is not None:
        x = x * weight
    return x

x = x.to("spyre")
weight = weight.to("spyre")

out1 = torch.compile(rms_old, dynamic=False)(x, hidden_size, eps, weight).cpu()
out2 = torch.compile(rms_new, dynamic=False)(x, hidden_size, eps, weight).cpu()

torch.testing.assert_close(out1, out2, atol=0.1, rtol=0.1, msg="FALED with atol/rtol 0.1")
torch.testing.assert_close(out1, out2, atol=0.01, rtol=0.01, msg="FALED with atol/rtol 0.01")
torch.testing.assert_close(out1, out2, atol=0.001, rtol=0.001, msg="FALED with atol/rtol 0.001")

I had an offline communication with @romitjain and I think this was verified.

@tjohnson31415
Copy link
Copy Markdown
Collaborator

Ah, if we are aligining with changes in torch-spyre then I think a little numerical change is ok (When initially running in my dev env, I was getting big changes, but something was misconfigured).

I added a link to torch-spyre/torch-spyre#1236 to the PR description.

Copy link
Copy Markdown
Collaborator

@tjohnson31415 tjohnson31415 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjohnson31415 tjohnson31415 merged commit 3a0460d into torch-spyre:main Mar 27, 2026
14 checks passed
yannicks1 added a commit that referenced this pull request Apr 9, 2026
<!-- markdownlint-disable -->

## Description

Fix docstring inaccuracies, typos and typing.

Changes: 
- cleans up docstrings after #873 
- cleans up comment after #754 
- typos and typing 


## Test Plan
Documentation-only changes, no functional impact.

## Checklist

- [x] I have read the [contributing
guidelines](https://docs.vllm.ai/projects/spyre/en/latest/contributing)
- [x] My code follows the project's code style (run `bash format.sh`)
- [ ] I have added tests for my changes (if applicable)
- [ ] I have updated the documentation (if applicable)
- [x] My commits include a `Signed-off-by:` line (DCO compliance)

---------

Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants