You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In an multi-GPU setup, I'm trying to follow the standard protocol and convert a EmbeddingBagCollection(device="meta", ...) to quantized version with quant_dynamic and then Shard it to different gpus with DistributedModelParallel(or _shard_module) wrappers, but noticed that it always ends up allocating all shards on to every rank. The script I attached prints out a comparison between the original vs quantized embedding tables. The original shared embedding bag collection only keeps one local shard of weight tensor but the quantized counterpart seems to hold all pieces locally and creates massive memory footprint.
Another observation is that the quantized table loses all values during sharding, I cannot seem to recover the quantized weight/weight_qscale/weight_qbias without manually reloading from some state_dict, which doesn't feel right and may suggest some issue with the sharding process in the first place
Looking quickly, I think key gap here is how inference model and training work wrt to devices / processes. In Training DMP works with assumption of one python process per (cuda) device, and devices implicitly communicate though collectives. While in Inference, its a single process for all devices, and data moves around though intra device alls (ie. tensor.to(...)).
Hi Torchrec Team,
I'm following the protocol in https://github.com/pytorch/torchrec/blob/main/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb to setup sharded quantization tables, but have encountered the following issues:
In an multi-GPU setup, I'm trying to follow the standard protocol and convert a EmbeddingBagCollection(device="meta", ...) to quantized version with quant_dynamic and then Shard it to different gpus with DistributedModelParallel(or _shard_module) wrappers, but noticed that it always ends up allocating all shards on to every rank. The script I attached prints out a comparison between the original vs quantized embedding tables. The original shared embedding bag collection only keeps one local shard of weight tensor but the quantized counterpart seems to hold all pieces locally and creates massive memory footprint.
Another observation is that the quantized table loses all values during sharding, I cannot seem to recover the quantized weight/weight_qscale/weight_qbias without manually reloading from some state_dict, which doesn't feel right and may suggest some issue with the sharding process in the first place
Let me know if you have any suggestions, thanks!
I'm using the following versions
A minimal sample code to reproduce the issue
The text was updated successfully, but these errors were encountered: