[Spyre-Next] RMSNorm tests and upstream tests framework#837
[Spyre-Next] RMSNorm tests and upstream tests framework#837joerunde merged 24 commits intotorch-spyre:mainfrom
Conversation
There was a problem hiding this comment.
This is largely for demonstration purposes. Can go back based on the feedback
|
|
||
| @staticmethod | ||
| def forward_static( | ||
| def forward_spyre( |
There was a problem hiding this comment.
This is renamed to keep one upstream method as a reference.
I propose forward_static as the upstream golden reference implementation against which our custom implementation can be compared.
| return | ||
|
|
||
|
|
||
| _REGISTERED = False |
There was a problem hiding this comment.
While running tests, we need to register these ops multiple times and run into issues with multiple registration.
There was a problem hiding this comment.
@functools.lru_cache(maxsize=1) would be a clean way to do this that prevents accidentally mutating this global variable and re-registering plugins.
Alternatively, a closure in a decorator would also work without global variables, and be easy to re-use across all the custom ops that we'll need to do this for.
def run_once(f):
@wraps(f)
def wrapper(*args, **kwargs):
if not wrapper.has_run:
wrapper.has_run = True
return f(*args, **kwargs)
wrapper.has_run = False
return wrapper
@run_once
def register():
There was a problem hiding this comment.
Makes sense @joerunde , ill update it to use lru cache instead
| add_residual: [true] | ||
| strided_input: [true] | ||
| override: | ||
| num_tokens: [1, 16] |
There was a problem hiding this comment.
Only testing for known pass conditions for now
| raise NotImplementedError("TODO: variance_size_override not yet implemented") | ||
|
|
||
| batch_padding = x.shape[0] | ||
| orig_batch_size = x.shape[0] |
There was a problem hiding this comment.
Padding was being done incorrectly here
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
| dev = [ | ||
| "pytest" | ||
| "pytest", | ||
| "pyyaml", |
There was a problem hiding this comment.
what's this for?
edit- I always thought yaml was a builtin package because it's almost always installed by some other dependency 😆
| @@ -0,0 +1,76 @@ | |||
| """Data models for the vllm-spyre-next test infrastructure.""" | |||
There was a problem hiding this comment.
I really like this encapsulation within a plugin!
I don't have any experience writing pytest plugin, but I do think that we'd want to keep it in a separate package by itself instead of packaging it within vllm-spyre-next
@romitjain I imagine we'll run into this situation a lot as we edit upstream tests :/ Would it be possible to have the pytest plugin clone vllm into cache like how contest.py does currently, and also allow a CLI arg like |
| # Load plugins early to register custom ops before test modules import RMSNorm | ||
| from vllm.plugins import load_general_plugins | ||
|
|
||
| load_general_plugins() |
| dependencies = [ | ||
| "torch-spyre", | ||
| "vllm==0.15.1+cpu", | ||
| # "vllm==0.15.1+cpu", |
There was a problem hiding this comment.
we for sure depend on vllm!
Why did this need to be removed for your testing?
There was a problem hiding this comment.
I was testing using a custom fork and I was trying to install that custom fork in an editable manner
cd vllm-custom
pip install -e .
While installing vllm-spyre-next plug-in, I did not want to override my custom fork.
But this change would be reversed once I use the method you used to run upstream tests (cloning vllm into cache and using that)
| [tool.uv.sources] | ||
| # This is the unreleased v0.16.0 tag with 2.10 support | ||
| vllm = { git = "https://github.com/vllm-project/vllm", rev = "2d5be1dd5ce2e44dfea53ea03ff61143da5137eb" } | ||
| # vllm = { git = "https://github.com/vllm-project/vllm", rev = "2d5be1dd5ce2e44dfea53ea03ff61143da5137eb" } |
There was a problem hiding this comment.
Just FYI that this is here because we require building vllm from source so that we can compile cpu kernels which run the ops that aren't enabled on spyre yet
| allow_list: | ||
| - test: "test_rms_norm" | ||
| mode: mandatory_pass | ||
| tags: [rmsnorm, llama, granite] |
There was a problem hiding this comment.
It would be nice to have markers to use to select certain tests, since pretty quickly the full set of tests here will take far too long to run in one go and for local development we'll want to be able to run a set of upstream tests that cover what we're changing in vllm-spyre-next.
Maybe the right thing to do is to add markers to the upstream tests in vllm so that everybody benefits? But if we could mark tests with these tags that would allow us to run
pytest -m rmsnorm
to select the tests that we want
There was a problem hiding this comment.
Agreed will make the change
| def reference_rms_norm( | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor | None, | ||
| eps: float, | ||
| ) -> torch.Tensor: | ||
| """Golden reference: standard RMSNorm in PyTorch.""" | ||
| x_float = x.float() | ||
| variance = x_float.pow(2).mean(dim=-1, keepdim=True) | ||
| x_normed = x_float * torch.rsqrt(variance + eps) | ||
| if weight is not None: | ||
| x_normed = x_normed * weight.float() | ||
| return x_normed |
There was a problem hiding this comment.
Should we use directly a reference method instead of reimplementing it: such as RMSNorm.forward_static()
There was a problem hiding this comment.
Agreed,
I had added it to make sure we are not blocked by upstream vLLM merge, but let me check if we can replace this without the merge too.
| return x_normed | ||
|
|
||
|
|
||
| @pytest.mark.cpu |
There was a problem hiding this comment.
why is this marked as cpu test? It should also run on spyre from my understanding
There was a problem hiding this comment.
Yes, this is a spyre test. This is an artifact of my previous work where SpyreRMSNorm had a CPU path as well. I will merge it with the test below
|
|
||
| @pytest.mark.cpu | ||
| @pytest.mark.parametrize("batch_size", [1]) | ||
| @pytest.mark.parametrize("hidden_size", [128, 512]) |
There was a problem hiding this comment.
I might be missing something, but I am also wondering why we are not testing the shapes that require padding here. Why not put the same shapes as in line 44
| # Mock forward_native (called by forward_oot) with a known transform | ||
| if residual is not None: | ||
| monkeypatch.setattr(layer, "forward_native", mock_forward_native_with_residual) | ||
| out_x, out_residual = layer.forward_oot(dummy_tensor, residual) | ||
|
|
||
| assert torch.allclose(out_x, 2 * dummy_tensor) | ||
| assert torch.allclose(out_residual, 2 * residual) | ||
| else: | ||
| monkeypatch.setattr(layer, "forward_native", mock_forward_native_no_residual) | ||
| out_x = layer.forward_oot(dummy_tensor, residual) | ||
|
|
||
| assert torch.allclose(out_x, dummy_tensor + 1) |
There was a problem hiding this comment.
I love that, it's a really nice and interesting method to test that forward_oot is using forward_native
| @pytest.mark.cpu | ||
| @pytest.mark.parametrize("batch_size", [1]) | ||
| @pytest.mark.parametrize("hidden_size", [128, 512]) | ||
| def test_spyre_rmsnorm_matches_reference(default_vllm_config, batch_size, hidden_size): |
There was a problem hiding this comment.
Can we also test the use of residual in these two tests, with your fixes to SpyreRMSNorm, my test that didn't use residual is passing, but the testing residual is failing
|
thanks for the fix and the tests implementation @romitjain. I only reviewed the parts I was familiar with for now but it looks like really good work to me, I closed my PR#830 as I think it has nothing to add |
|
@joerunde My next commit will merge the cache+single pytest command that was already implemented (#800) with the one that I am proposing. |
67b3212 to
8cf03fe
Compare
8cf03fe to
923910d
Compare
|
I have updated my PR to address your comments
This means we can run both upstream and local tests from I have updated the description, too. |
|
bot:next-test |
|
bot:next-test |
|
Digging in on final things to fix before merging:
|
|
Ah, additionally the extra failure of the basic model load test only seems to happen when it is run alongside other tests: and this fails with a device busy error: I think this is likely because vllm will load the model from worker process(es), which will fail if the main pytest process is still holding references to memory on the spyre devices that hans't been cleaned up. I'd guess we need some global cleanup fixtures to ensure that we don't hold onto device-side data between tests |
Signed-off-by: Joe Runde <joe@joerun.de>
Signed-off-by: Joe Runde <joe@joerun.de>
|
bot:next-test |
|
@romit I'm not sure what's up with the size errors on those test cases, I tried running on dev images from the last week but they all hit the same error. If you have the full environment including spyre runtime stack component versions used to run those tests, it'd be good to figure out how to get them to pass. In the meantime I've pushed changes here to:
This currently passes a simple WDYT about merging? |
joerunde
left a comment
There was a problem hiding this comment.
Tests are passing, we can consider merging this
Signed-off-by: Joe Runde <joe@joerun.de>
|
Update: Added an allowlist for test params, this will be useful so that we can select out only a set of parametrizations to run in cases where there are more things to block than to allow. We also think this will be required to avoid pulling in new test cases without vetting them first. (See the huge diff in the yaml for the laguage/generation tests on 0e9c0be0e9c0be) |
|
bot:next-test |
|
@tjohnson31415 has some staged changes to move the testing plugin out of |
|
bot:next-test |
1 similar comment
|
bot:next-test |
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Joe Runde <joe@joerun.de>
|
bot:next-test |
tjohnson31415
left a comment
There was a problem hiding this comment.
LGTM.
I'll do a follow up to move the pytest plugin out of the vllm_spyre_next module.
|
@joerunde @tjohnson31415. @joerunde re:
These were failing for me, too. Let me debug this. |
Description
This PR does 2 things:
vllm-spyre/vllm_spyre_nextThere are 2 tests added - a unit test verifying the correctness of the layer on CPU/Spyre and an integration test to ensure
forward_ootgets called when it is installed as a vLLM plugin.While writing down these tests, I saw a couple of issues in the SpyreRMSNorm implementation - which I have attempted to fix, but please correct me if I am wrong.
Building on: vllm-project/vllm#36246, this PR also adds a framework that can be used to filter and update upstream tests and run them from the
vllmrepo.Related Issues
#805
Test Plan
To test both the features of this PR:
vllm-spyre/vllm_spyre_nextThis is expected to produce,
The tests are failing at boundaries of hidden dim, which is expected as of now, since hidden dim is not being padded to 64. (I can raise a separate PR to fix that, but I did not want to overload this PR)
This makes use of my PR on vLLM: vllm-project/vllm#36246, which enables the RMSNorm test to run for OOT devices
This is expected to produce
We can see that our YAML is being respected and:
param skipped)Checklist
bash format.sh)Signed-off-by:line (DCO compliance)