Skip to content

Conversation

swong3-sc
Copy link
Collaborator

Scope of work done

  • Adds two DMP-related tests to models_test.py just to make sure we can use DMP to wrap a model. One test for forward, one for gradient flow.
  • Slight fix, to deal with tensors being cast as awaitable objects, which didn't allow for DMP wrapping

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Yes, added two unit tests.

Updated Changelog.md? NO

Ready for code review?: YES

Copy link
Collaborator

@kmontemayor2-sc kmontemayor2-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Sam!

"""
Test that DMP-wrapped LightGCN produces the same output as non-wrapped model. Note: We only test with a single process for unit test.
"""
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. import at the top of the file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, changed.

Comment on lines 294 to 300
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method="tcp://localhost:29500",
rank=0,
world_size=1, # Single process for unit test
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets cleanup the process group after every test? Like we do here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way we can get rid of the try/catch here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added a tear down method

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow the pattern here to test against world size > 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test to do so. I think we should have some discussion when you get back about the nature of this test, ie. use CPU vs CUDA. I went ahead and did a world size of 2, but with CPU, so we weren't really testing the sharding here, just that it works with a larger world size.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is not urgent, however, so we can discuss late.

Comment on lines +297 to +299
# When using DMP, EmbeddingBagCollection returns Awaitable that needs to be resolved
if isinstance(embeddings_0, Awaitable):
embeddings_0 = embeddings_0.wait()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem unfortunate/surprising that the rest of our code and/or pyg code doesn't support this type of tensor.

Can we add a TODO to look into this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a TODO. This seems like an expected result of TorchRec sharding, as introduced by TorchRec itself.

https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html#gpu-training-with-lazyawaitable

if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method="tcp://localhost:29500",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also use get_process_group_init_method 1 so we can always have a free port?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants