[Spyre-Next] Add RowParallelLinear and ColumnParallelLinear(MLP) wrappers#869
[Spyre-Next] Add RowParallelLinear and ColumnParallelLinear(MLP) wrappers#869bohnstingl merged 3 commits intotorch-spyre:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
bohnstingl
left a comment
There was a problem hiding this comment.
Thank you @R3hankhan123 for the PR.
In principle looks good to me. I am just wondering whether we should simplify the code for the moment by de-duplicating identical functions until we really have a need to specialize them.
Also, could you run an end-to-end test and see whether the Granite3.3-8B model works and produces tokens?
f762452 to
86b50fb
Compare
|
Also @bohnstingl i ran a test on Granite3.3-8B model and here is the output |
39874e1 to
b17f7d0
Compare
bohnstingl
left a comment
There was a problem hiding this comment.
LGTM. Can you please try the Granite3.3-8B model and check whether the token generation works?
It has been observed though that the current way of wrapping for torch-spyre interferes with the enablement of upstream vLLM tests, see #863. To address this, I've opened a PR (#872) that reworks the forward call chain a bit and uses forward_oot instead of forward_native. Maybe we could hold off the merge a bit and get #872 merged first and then apply the rework directly here as well?
@R3hankhan123 what do you think?
Sure @bohnstingl |
|
@R3hankhan123 #872 has landed. Could you please overtake the modified forward call structure? I will then push for a quick merge |
b17f7d0 to
e098fcb
Compare
|
|
||
| class _SpyreLinear: | ||
| """Shared implementation for Spyre linear layers at TP=1.""" | ||
|
|
There was a problem hiding this comment.
for all other oot impl. (e.g rms, silu ... ) I see this line:
_dynamic_arg_dims = {"x": [], "residual": []}
is it not needed here? could it be removed in the other classes too? @bohnstingl
There was a problem hiding this comment.
okay, I see we do not have a residual here...
is not specifying anything the same as putting _dynamic_arg_dims = {"x": []} ?
There was a problem hiding this comment.
is it not needed here? could it be removed in the other classes too? @bohnstingl
No, it can't be removed and in fact we need it here as well. The _dynamic_arg_dims = {"x": [], "residual": []} ensures that maybe_compile compiles with dynamic=False. Here it should be _dynamic_arg_dims = {"x": [], "weight": [], "bias": []}, I think.
@R3hankhan123 could you please confirm that:
There was a problem hiding this comment.
The weight and bias tensors are internal to the layer implementation, they're accessed inside _forward_spyre_impl. Since they're not direct arguments to the custom op, i think they don't need to be in _dynamic_arg_dims. I think only {"x": [], "output": []}, are sufficient
There was a problem hiding this comment.
On a second thought, I have to take my comment above back. MergedColumnParallelLinear and RowParallelLinear are PluggableLayer, not CustomOp. Thus, the compilation path is different and there is no maybe_compile. Probably we need to simply invoke torch.compile directly for at the moment:
self.maybe_compiled_forward_spyre = torch.compile(self.forward_spyre, dynamic=False)
Please leave a note though that this should be changed in the future.
This means you also don't need to define _dynamic_arg_dims.
e098fcb to
c554470
Compare
bohnstingl
left a comment
There was a problem hiding this comment.
In general looks good to me. @R3hankhan123 could you take a look at the small comments we had?
|
|
||
| class _SpyreLinear: | ||
| """Shared implementation for Spyre linear layers at TP=1.""" | ||
|
|
There was a problem hiding this comment.
is it not needed here? could it be removed in the other classes too? @bohnstingl
No, it can't be removed and in fact we need it here as well. The _dynamic_arg_dims = {"x": [], "residual": []} ensures that maybe_compile compiles with dynamic=False. Here it should be _dynamic_arg_dims = {"x": [], "weight": [], "bias": []}, I think.
@R3hankhan123 could you please confirm that:
c554470 to
d4cdd1f
Compare
|
after running a quick test |
Add RowParallelLinear and ColumnParallelLinear wrappers for torch-spyre which will act as Up projection and Down Projection in MLP layer Co-authored-by: Rehan Khan <Rehan.Khan7@ibm.com> Co-authored-by: nikheal2 <suryawanshin74@gmail.com> Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com> Signed-off-by: nikheal2 <suryawanshin74@gmail.com>
d4cdd1f to
5fc7429
Compare
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
I changed the I tested the latest commit E2E and I see that it runs on |
|
bot:next-test |
|
@joerunde or @tjohnson31415 I looked at the CI results and they don't seem to be related. Could you confirm? |
yannicks1
left a comment
There was a problem hiding this comment.
lgtm, thanks for addressing all the feedback
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
tjohnson31415
left a comment
There was a problem hiding this comment.
@tjohnson31415 I looked at the CI results and they don't seem to be related. Could you confirm?
Can confirm. The spyre-ci tests should not block the PR because they are not working.
Description
Add RowParallelLinear and ColumnParallelLinear wrappers for torch-spyre which will act as Up projection and Down Projection in MLP layer
Related Issues
Contributes towards #736
Test Plan
Test Result
examples/torch_spyre_inference.pyChecklist
bash format.sh)Signed-off-by:line (DCO compliance)