Skip to content

Commit

Permalink
Added distributed tests with Horovod and XLA for early_stopping (#2165)
Browse files Browse the repository at this point in the history
* added hvd tests for early stopping

* added xla tests for early stopping

* added hvd device

* autopep8 fix

* changed dist to idist

* replace dist with idist

* convert dist to idist

* changed device to CPU for metrics

* autopep8 fix

Co-authored-by: Ishan-Kumar2 <[email protected]>
Co-authored-by: vfdev <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2021
1 parent 0e0200d commit 5d4f869
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions tests/ignite/handlers/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def evaluation(engine):

def _test_distrib_with_engine_early_stopping(device):

import torch.distributed as dist
if device is None:
device = idist.device()
if isinstance(device, str):
device = torch.device(device)

torch.manual_seed(12)

Expand All @@ -264,8 +267,8 @@ def __init__(self, count=0):
def score_function(engine):
i = trainer.state.epoch - 1
v = scores[i]
dist.all_reduce(v)
v /= dist.get_world_size()
idist.all_reduce(v)
v /= idist.get_world_size()
return v.item()

trainer = Engine(do_nothing_update_fn)
Expand All @@ -285,12 +288,18 @@ def evaluation(engine):

def _test_distrib_integration_engine_early_stopping(device):

import torch.distributed as dist

from ignite.metrics import Accuracy

rank = dist.get_rank()
ws = dist.get_world_size()
if device is None:
device = idist.device()
if isinstance(device, str):
device = torch.device(device)
metric_device = device
if device.type == "xla":
metric_device = "cpu"

rank = idist.get_rank()
ws = idist.get_world_size()
torch.manual_seed(12)

n_epochs = 10
Expand All @@ -314,7 +323,7 @@ def update(engine, _):
return y_preds[e][i, rank], y_true[e][i, rank]

evaluator = Engine(update)
acc = Accuracy(device=device)
acc = Accuracy(device=metric_device)
acc.attach(evaluator, "acc")

def score_function(engine):
Expand Down Expand Up @@ -352,6 +361,18 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
_test_distrib_integration_engine_early_stopping(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_with_engine_early_stopping, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_integration_engine_early_stopping, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
Expand All @@ -370,3 +391,26 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
device = idist.device()
_test_distrib_with_engine_early_stopping(device)
_test_distrib_integration_engine_early_stopping(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_with_engine_early_stopping(device)
_test_distrib_integration_engine_early_stopping(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_with_engine_early_stopping(device)
_test_distrib_integration_engine_early_stopping(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)

0 comments on commit 5d4f869

Please sign in to comment.