[Bug][Dynamic Embedding] improper optimizier state_dict momentum2
key while constructing PSCollection
#2177
Labels
bug
Something isn't working
Describe the bug
A PSCollection should contain optimizer states besides weights. The optimizer states tensors are obtained directly from EmbeddingCollection Module.
However, the sharded_module.fused_optimizer.state_dict()['state'] does not contain key
{table_name}.momentum2
becausePSCollection
will not return key likexxx.momentum1
orxxx.momentum2
. They are customized by TBE.xxx.momentum1
while the left keys are copied from above retrived results.See the below illustration where optimizer is Adam. The expected number of state tensors should be 2, but the it eventually gives
momentum1
and leavesmomentum2
(which is synonymouslyexp_avg_sq
) out.It will pose impact on all kinds of optimizer that contains momentum2.
The text was updated successfully, but these errors were encountered: