From 1f703d5f8c0bb8f06f4ff42d8e30b2efc05856e1 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Fri, 1 Sep 2023 13:11:43 -0700 Subject: [PATCH] GPU UT - enable for torchrec example (#527) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/527 Replacing the way we launch distributed UTs for torchrec example, as well as getting rid of a pyre-fixme annotation Reviewed By: anshulverma Differential Revision: D48912943 fbshipit-source-id: 96955f1aa64891fdf9ffc671a99bc5aa472034d7 --- .../torchrec/tests/torchrec_example_test.py | 29 ++++--------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/examples/torchrec/tests/torchrec_example_test.py b/examples/torchrec/tests/torchrec_example_test.py index a1335be4c6..352e6fc958 100644 --- a/examples/torchrec/tests/torchrec_example_test.py +++ b/examples/torchrec/tests/torchrec_example_test.py @@ -6,38 +6,21 @@ # LICENSE file in the root directory of this source tree. import unittest -import uuid import torch -from torch.distributed import launcher -from torchtnt.utils.test_utils import skip_if_asan +from torchtnt.utils.test_utils import skip_if_asan, spawn_multi_process from ..main import main -MIN_NODES = 1 -MAX_NODES = 1 -PROC_PER_NODE = 2 +class TorchrecExampleTest(unittest.TestCase): + cuda_available: bool = torch.cuda.is_available() -class TorchrecExampleTest(unittest.TestCase): @skip_if_asan - # pyre-fixme[56]: Pyre was not able to infer the type of argument `not - # torch.cuda.is_available()` to decorator factory `unittest.skipIf`. - @unittest.skipIf( - not torch.cuda.is_available(), + @unittest.skipUnless( + cuda_available, "Skip when CUDA is not available", ) def test_torchrec_example(self) -> None: - lc = launcher.LaunchConfig( - min_nodes=MIN_NODES, - max_nodes=MAX_NODES, - nproc_per_node=PROC_PER_NODE, - run_id=str(uuid.uuid4()), - rdzv_backend="c10d", - rdzv_endpoint="localhost:0", - max_restarts=0, - monitor_interval=1, - ) - - launcher.elastic_launch(config=lc, entrypoint=main)([]) + spawn_multi_process(2, "nccl", main, [])