-
Notifications
You must be signed in to change notification settings - Fork 456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question/Bug] DP sharding parameters are inconsistent with others. #2563
Comments
sorry for the late response. @JacoCheung it looks like an invalid access to the weights. could you please try using the weights in the state_dict? |
@JacoCheung reguarding quesiont 1 "TableBatchedEmbeddingSlice is not a leaf tensor"
If you want the leaf weight, you can try using the parameters from the
or
Hope that answered your question. |
@JacoCheung as for question 2, it's kind of similar to question 1.
My suggestion is do the bfloat16 casting before calling the DMP, or more directly, setting the embedding config with dtype=bfloat16. |
Description:
Hi ,torchrec team,
I'm using EmbeddingCollection and constrain the sharding type as DATA_PARALLEL. Subsequently, I should be able to get the parameters and pass it to my optimizers. However, there are serveral problems I encoutered.
Looking forward to any input . Thanks!
TableBatchedEmbeddingSlice is not a leaf tensor
.parameters()
orembedding_collection._dmp_wrapped_module.embeddings['0'].weight
returns a TableBatchedEmbeddingSlice, which is unexpectedly not aleaf
tensor.Unfortunately, my app requires such a flag to perform some operations.
model.bfloat16() detachs the weight storage
When I convert the whole model into lower precision, say bf16, the underlying DP tables storage are not affected, however, the weight / params accessor are converted as expected. Then the weight and storage seems to be untied. The optimizer would take no effect on the original storage. A reproducible script is as below:
The
params
gets updated while the underlying storagesplit_embedding_weights
remains the same ( And the next lookup does see the old storage ).params:
storage:
grad_fn AsStride when
TableBatchedEmbeddingSlice
as an operandBesides, I find out the next_functions is started with an
AsStridedBackward0
ahead ofAccumulateGrad
after if TableBatchedEmbeddingSlice object is as an operand, for example:I would like to know why there is an
AsStridedBackward0
. One lib that I depend on requires accessing theAccumulateGrad
in only one jump.The text was updated successfully, but these errors were encountered: