-
Notifications
You must be signed in to change notification settings - Fork 456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] pooling and aggregation operations #2257
Comments
The pooling is done in FBGEMM_GPU, not in torchrec itself. For example: https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html |
The backward key sortion and gradient reduction are also done inside fbgemm_gpu. |
It's written in cpp, specifically, cuda source code.
This template file defines the autograd function that embraces forward and backward entry. A generated file example from the template file. |
sorry, a followup question for the forward pass communication and pooling.
I profiled the forward pass with both nsight and torch profiler. I see the
all2all calls at python level, and the low level NCCL calls as SendReceive
for all2all, but I didn't see any calls map to the pooling. Does the
pooling actually happen within SendReceive? or it is after SendReceive and
within all2all? or it is after all2all, but just not appropriate profiled.
…On Mon, Jul 29, 2024 at 10:25 PM Junzhang ***@***.***> wrote:
in the forward pass, in the table wise sharding, when pooling is executed?
The pooling is done in FBGEMM_GPU, not in torchrec itself. For example:
https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html
—
Reply to this email directly, view it on GitHub
<#2257 (comment)>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQGDI2B4CJ5KDLAKGJ5S5U3ZO4PUBBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
|
another question is: in the forward pass, before all2all, how to remove the
duplicated embedding indices per table, and then run all2all with input
(unique indices) and output (the corresponding embedding vectors). please
point out the code. thanks.
…On Tue, Jul 30, 2024 at 10:00 PM Bing Xie ***@***.***> wrote:
sorry, a followup question for the forward pass communication and pooling.
I profiled the forward pass with both nsight and torch profiler. I see the
all2all calls at python level, and the low level NCCL calls as SendReceive
for all2all, but I didn't see any calls map to the pooling. Does the
pooling actually happen within SendReceive? or it is after SendReceive and
within all2all? or it is after all2all, but just not appropriate profiled.
On Mon, Jul 29, 2024 at 10:25 PM Junzhang ***@***.***>
wrote:
> in the forward pass, in the table wise sharding, when pooling is executed?
>
> The pooling is done in FBGEMM_GPU, not in torchrec itself. For example:
> https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html
>
> —
> Reply to this email directly, view it on GitHub
> <#2257 (comment)>
> or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/AQGDI2B4CJ5KDLAKGJ5S5U3ZO4PUBBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI>
> .
> You are receiving this email because you authored the thread.
>
> Triage notifications on the go with GitHub Mobile for iOS
> <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
> or Android
> <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
> .
>
>
|
Thanks for the explanation. It helps a lot.
Let me dig deeper to some details:
1, can I assume the first all2all actually communicate about the KJT per
batch? so after this all2all, each GPU with a shard of an embedding table
actually will launch a fbgemm kernel to do the lookup and pooling. and then
use the 2nd all2all to send the pooled embedding vectors back to the KJT
batch owner. If this assumption is correct, for row wise parallelism,
all2all is still used in forward pass, if a sample has the indices across
multiple shards (means the embedding vectors stored on different GPUs), how
to do pooling?
2. if the above assumption is correct. why not have only one all2all to
let each KJT batch owner have the individual embedding vectors (based on
their keys) and let them launch cuda kernel to do pooling locally.
…On Tue, Jul 30, 2024 at 11:23 PM Junzhang ***@***.***> wrote:
sorry, a followup question for the forward pass communication and pooling.
I profiled the forward pass with both nsight and torch profiler. I see the
all2all calls at python level, and the low level NCCL calls as SendReceive
for all2all, but I didn't see any calls map to the pooling. Does the
pooling actually happen within SendReceive? or it is after SendReceive and
within all2all? or it is after all2all, but just not appropriate profiled.
The pooling is done before all2all. Maybe this figure is clearer
image.png (view on web)
<https://github.com/user-attachments/assets/ed784d80-0b2f-4067-b68a-fa62a29cd891>
The first all2all is for lookup keys which is fed into fbgemm and fbgemm
will do pooling for you. The second all2all performs all2all on pooled
embedding.
So the pooling happens actually in CUDA kernels.
image.png (view on web)
<https://github.com/user-attachments/assets/80c28ad6-7199-4cf6-aa27-bf30734b7cb9>
—
Reply to this email directly, view it on GitHub
<#2257 (comment)>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQGDI2DIFK34EIS3PZ5XQKLZPB7HRBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
|
One correction: For RW, there is no the 2nd all2all, instead, it should be a ReduceScatter. (Table-wise should contain all2all)
AFAIK, Yes. But be noted that a KJT all2all usually is composed of several tensor all2all.
There should be a collective. For rw sharding, it could be a reduce scatter.
Initially, each rank has its dp input. The first all2all is used for sending keys to the corresponding sharding for lookup. I don't think we can skip it. As I clarified, the 2nd should be a RS for RW. |
For |
thanks for the clarification! very helpful. |
sorry, forget to ask some details on the collective communication operators. Just for confirmation: in the forward pass, for the communication about embedding indices, there are multiple all2all calls, each call for a jagged tensor (e.g. one call for values, one call for lengths, one call for lengths per key). In both forward and backward pass, for the communications for embedding vectors, there is only one call for all batches across all keys and tables. if the 2nd statement is correct, when the system scales up, the data size of of the call will increase n^2 (n is the number of ranks in the system). Why not cut the single call to multiple calls with each call communicate about 64MB (for best network bandwidth efficiency). |
in the forward pass, in the table wise sharding, when pooling is executed? is it after alltoall communication? and executed on trainer local? where can I see the exact code in torchrec code base?
in the backward pass, in the table wise sharding, when sorting and aggregation is executed? can you please point the lines in code base.
The text was updated successfully, but these errors were encountered: