Skip to content

Update rmse, ssim, top_k_categorical_accuracy in test for generating data with different rank #2673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 2, 2022
2 changes: 1 addition & 1 deletion tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def _test_distrib_integration(device, tol=1e-4):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)
n_iters = 100
batch_size = 10
offset = n_iters * batch_size

def _test(metric_device):
y_pred = torch.rand(n_iters * batch_size, 3, 28, 28, dtype=torch.float, device=device)
Expand Down
5 changes: 2 additions & 3 deletions tests/ignite/metrics/test_top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,12 @@ def _test_distrib_integration(device):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12)

def _test(n_epochs, metric_device):
n_iters = 100
batch_size = 16
n_classes = 10

offset = n_iters * batch_size
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)

Expand Down Expand Up @@ -100,7 +98,8 @@ def update(engine, i):
metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(3):
for i in range(3):
torch.manual_seed(12 + rank + i)
for metric_device in metric_devices:
_test(n_epochs=1, metric_device=metric_device)
_test(n_epochs=2, metric_device=metric_device)
Expand Down