Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added distributed tests with Horovod and XLA for early_stopping #2165

Merged
merged 12 commits into from
Aug 19, 2021
60 changes: 52 additions & 8 deletions tests/ignite/handlers/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def evaluation(engine):

def _test_distrib_with_engine_early_stopping(device):

import torch.distributed as dist
if device is None:
device = idist.device()
if isinstance(device, str):
device = torch.device(device)

torch.manual_seed(12)

Expand All @@ -264,8 +267,8 @@ def __init__(self, count=0):
def score_function(engine):
i = trainer.state.epoch - 1
v = scores[i]
dist.all_reduce(v)
v /= dist.get_world_size()
idist.all_reduce(v)
v /= idist.get_world_size()
return v.item()

trainer = Engine(do_nothing_update_fn)
Expand All @@ -285,12 +288,18 @@ def evaluation(engine):

def _test_distrib_integration_engine_early_stopping(device):

import torch.distributed as dist

from ignite.metrics import Accuracy

rank = dist.get_rank()
ws = dist.get_world_size()
if device is None:
device = idist.device()
if isinstance(device, str):
device = torch.device(device)
metric_device = device
if device.type == "xla":
metric_device = "cpu"

rank = idist.get_rank()
ws = idist.get_world_size()
torch.manual_seed(12)

n_epochs = 10
Expand All @@ -314,7 +323,7 @@ def update(engine, _):
return y_preds[e][i, rank], y_true[e][i, rank]

evaluator = Engine(update)
acc = Accuracy(device=device)
acc = Accuracy(device=metric_device)
acc.attach(evaluator, "acc")

def score_function(engine):
Expand Down Expand Up @@ -352,6 +361,18 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
_test_distrib_integration_engine_early_stopping(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_with_engine_early_stopping, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_integration_engine_early_stopping, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
Expand All @@ -370,3 +391,26 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
device = idist.device()
_test_distrib_with_engine_early_stopping(device)
_test_distrib_integration_engine_early_stopping(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_with_engine_early_stopping(device)
_test_distrib_integration_engine_early_stopping(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_with_engine_early_stopping(device)
_test_distrib_integration_engine_early_stopping(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)