[Spyre-Next] Reworked forward call chain#872
Conversation
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
| if x.shape[-1] != hidden_size: | ||
| raise ValueError(f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}") | ||
|
|
||
| x = x.transpose(-1, -2).contiguous() |
There was a problem hiding this comment.
I tried running rmsnorm repo tests on this branch it failed for me. I am not able to pinpoint why that might be. On main, they are passing, and apart from the function name changes, the only change I see is this.
There was a problem hiding this comment.
@romitjain you need a very recent commit of torch-spyre for this to work. I reverted this rework from this PR and I'll move it into a separate PR.
There was a problem hiding this comment.
@bohnstingl
My bad - the tests are passing iff I check one of forward_oot, forward, or forward_native. But if we check all 3, I am getting a test failed (tensors are not close) in the residual path. I think it's because we mutate the residual tensor: https://github.com/vllm-project/vllm-spyre/blob/main/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py#L234
So this change: 4b6b2e9 might not be needed. I was able to pass tests locally with and without this change. I have made a small PR here from my fork: bohnstingl#1.
PR is for demonstration - feel free to merge/close and make changes directly here.
There was a problem hiding this comment.
Thanks, I merged it in.
I will still leave the rms_norm rework for a separate PR, in order not to dilute the intention of this PR.
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Updated rmsnorm tests
|
@joerunde or @tjohnson31415 could you also take a look at it to see if the rework makes sense to you? |
|
@romitjain fyi DCO fails with: Commit sha: 8fc82cc, Author: romit, Committer: romit; The sign-off is missing. |
There was a problem hiding this comment.
As this PR does mainly renaming: could we also rename self._fwd or is that fixed? It was not obvious at first glance that self._fwd is the compiled self.forward_spyre. I don't have the perfect name, but sth along self._fwd_spyre_comp, or just self._fwd_spyre or self._forward_spyre/ self._forward_spyre_comp
There was a problem hiding this comment.
I renamed it to maybe_compiled_forward_spyre to indicate that 1) it is maybe compiled and 2) that it is the forward_spyre. I did the same for silu as well.
There was a problem hiding this comment.
for consistency: why not calling this forward_spyre as in the rsm op?
| @@ -64,11 +66,11 @@ def __init__(self, *args, **kwargs): | |||
There was a problem hiding this comment.
same argument as above: consider renaming self._fwd to sth more descriptive if possible
There was a problem hiding this comment.
Done, see above.
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
bot:next-test |
|
@dilipgb, @R3hankhan123, @coderfornow and @GOavi101 please feel free to also comment on this rework PR, as you are also working on layer wrappings. I am happy to incorporate any feedback you may have. |
| - forward(): custom op dispatch (no-compile path via torch.ops.vllm.spyre_rmsnorm) | ||
| - forward_native(): direct Spyre device execution | ||
| - forward_oot(): OOT dispatch via custom op (torch.ops.vllm.spyre_rmsnorm) | ||
| - reference_rms_norm: golden reference, similar to vLLM upstream pure PyTorch (ground truth) |
There was a problem hiding this comment.
nit: reference_rms_norm -> reference_rms_norm()
… into forward_oot_rework
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
bot:next-test |
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
bot test is broken for an unrelated reason, but tests are passing in dev pods. Merging! |
Description
This PR reworks the call chain of the forward functions for the layer wrappings for torch-spyre. In particular, it avoids overriding
forward_nativein favor of usingforward_oot.cc @romitjain
Related Issues
This PR was triggered by #863
Test Plan
This change is non user-facing and thus all the existing tests should work. Moreover, it should enable the upstream vLLM tests to be supported as well.
Checklist
bash format.sh)Signed-off-by:line (DCO compliance)