-
Notifications
You must be signed in to change notification settings - Fork 8
Add Distributed Model Parallel Tests for models_test.py #363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Sam!
""" | ||
Test that DMP-wrapped LightGCN produces the same output as non-wrapped model. Note: We only test with a single process for unit test. | ||
""" | ||
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. import at the top of the file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, changed.
if not dist.is_initialized(): | ||
dist.init_process_group( | ||
backend="gloo", | ||
init_method="tcp://localhost:29500", | ||
rank=0, | ||
world_size=1, # Single process for unit test | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets cleanup the process group after every test? Like we do here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way we can get rid of the try/catch
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I added a tear down method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we follow the pattern here to test against world size > 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test to do so. I think we should have some discussion when you get back about the nature of this test, ie. use CPU vs CUDA. I went ahead and did a world size of 2, but with CPU, so we weren't really testing the sharding here, just that it works with a larger world size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is not urgent, however, so we can discuss late.
# When using DMP, EmbeddingBagCollection returns Awaitable that needs to be resolved | ||
if isinstance(embeddings_0, Awaitable): | ||
embeddings_0 = embeddings_0.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does seem unfortunate/surprising that the rest of our code and/or pyg code doesn't support this type of tensor.
Can we add a TODO to look into this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add a TODO. This seems like an expected result of TorchRec sharding, as introduced by TorchRec itself.
if not dist.is_initialized(): | ||
dist.init_process_group( | ||
backend="gloo", | ||
init_method="tcp://localhost:29500", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also use get_process_group_init_method
1 so we can always have a free port?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to this.
Scope of work done
models_test.py
just to make sure we can use DMP to wrap a model. One test for forward, one for gradient flow.Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Yes, added two unit tests.
Updated Changelog.md? NO
Ready for code review?: YES