-
Notifications
You must be signed in to change notification settings - Fork 8
[asynctp] Optimize agmm lastdim via addmm_ #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
89626c4 to
0ce9c0c
Compare
stack-info: PR: #190, branch: IvanKobzarev/stack/7
0ce9c0c to
357dd7e
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have all the context on this file yet, but changes LGTM in general.
| outputs[idx] += output_partials[idx] | ||
| out = outputs[idx] | ||
| if first: | ||
| torch.ops.aten.mm.out(shard, B_shards[idx][rank], **kwargs, out=out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we prefer using the torch.mm version instead of the torch.ops.aten.mm version? I'm not sure there is effectively a difference, but maybe for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think there should not be much difference, we can use torch.mm.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test?
Oh, yeah, I want to add e2e test but on torchtitan/autoparallel with asynctp/bucketing/overlap configs once configs are landed pytorch/torchtitan#1838 |
|
Yea - thought this was in pytorch repro at first / more stand alone.. less easy here. |
Stacked PRs:
[asynctp] Optimize agmm lastdim via addmm_