You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Effectively I'm doing DDP and have a setup where the {query | key | value} projection is of shape (40, 1024, 1024) and the batch size of my data is 512. So it seems that for attention projections, its broadcasting the the weights over the data dimension for some reason, which leads to the OOM.
Strangely enough, that's not happening on other layers - such as embedding which if broadcasted would also lead to an OOM.
I can verify I'm sharding my data on the 0th axis, and replicating my model correctly so it's strange.
Any pointers how I can approach debugging? I've been poking around sharding but it seems alright and my attention is definitely under vmap, so any broadcasting happening is done by XLA behind the scenes.
What could be the possible points of failure I should check?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Effectively I'm doing DDP and have a setup where the
{query | key | value}
projection is of shape(40, 1024, 1024)
and the batch size of my data is512
. So it seems that for attention projections, its broadcasting the the weights over the data dimension for some reason, which leads to the OOM.Strangely enough, that's not happening on other layers - such as embedding which if broadcasted would also lead to an OOM.
I can verify I'm sharding my data on the
0
th axis, and replicating my model correctly so it's strange.Any pointers how I can approach debugging? I've been poking around sharding but it seems alright and my attention is definitely under
vmap
, so any broadcasting happening is done by XLA behind the scenes.What could be the possible points of failure I should check?
Full traceback
Model `PyTree`
Beta Was this translation helpful? Give feedback.
All reactions