Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rmse, ssim, top_k_categorical_accuracy in test for generating data with different rank #2673

Merged
merged 9 commits into from
Sep 2, 2022
28 changes: 14 additions & 14 deletions tests/ignite/metrics/test_root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,19 @@ def _test_distrib_integration(device, tol=1e-6):
from ignite.engine import Engine

rank = idist.get_rank()
n_iters = 100
s = 10
offset = n_iters * s

y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device)
y_preds = (rank + 1) * torch.ones(offset, dtype=torch.float).to(device)
def _test(metric_device):
n_iters = 2
batch_size = 3

def update(engine, i):
return y_preds[i * s : (i + 1) * s], y_true[i * s + offset * rank : (i + 1) * s + offset * rank]
torch.manual_seed(12 + rank)

y_true = torch.arange(0, n_iters * batch_size, dtype=torch.float).to(device)
y_preds = (rank + 1) * torch.ones(n_iters * batch_size, dtype=torch.float).to(device)

def update(engine, i):
return y_preds[i * batch_size : (i + 1) * batch_size], y_true[i * batch_size : (i + 1) * batch_size]

def _test(metric_device):
engine = Engine(update)

m = RootMeanSquaredError(device=metric_device)
Expand All @@ -77,15 +79,13 @@ def _test(metric_device):
data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "rmse" in engine.state.metrics
res = engine.state.metrics["rmse"]

y_preds_full = []
for i in range(idist.get_world_size()):
y_preds_full.append((i + 1) * torch.ones(offset))
y_preds_full = torch.stack(y_preds_full).to(device).flatten()

true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy())))
true_res = np.sqrt(np.mean(np.square((y_true - y_preds).cpu().numpy())))

assert pytest.approx(res, rel=tol) == true_res

Expand Down
13 changes: 8 additions & 5 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ def _test_distrib_integration(device, tol=1e-4):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)
n_iters = 100
s = 10
offset = n_iters * s
batch_size = 10

def _test(metric_device):
y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, dtype=torch.float, device=device)
y_pred = torch.rand(n_iters * batch_size, 3, 28, 28, dtype=torch.float, device=device)
y = y_pred * 0.65

def update(engine, i):
return (
y_pred[i * s + offset * rank : (i + 1) * s + offset * rank],
y[i * s + offset * rank : (i + 1) * s + offset * rank],
y_pred[i * batch_size : (i + 1) * batch_size],
y[i * batch_size : (i + 1) * batch_size],
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
)

engine = Engine(update)
Expand All @@ -150,6 +150,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

assert "ssim" in engine.state.metrics
res = engine.state.metrics["ssim"]

Expand Down
21 changes: 11 additions & 10 deletions tests/ignite/metrics/test_top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,18 @@ def top_k_accuracy(y_true, y_pred, k=5, normalize=True):
def _test_distrib_integration(device):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12)

def _test(n_epochs, metric_device):
n_iters = 100
s = 16
batch_size = 16
n_classes = 10

offset = n_iters * s
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size],
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
Expand All @@ -85,6 +81,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "acc" in engine.state.metrics
res = engine.state.metrics["acc"]
if isinstance(res, torch.Tensor):
Expand All @@ -97,7 +96,9 @@ def update(engine, i):
metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(3):
rank = idist.get_rank()
for i in range(3):
puhuk marked this conversation as resolved.
Show resolved Hide resolved
torch.manual_seed(12 + rank + i)
for metric_device in metric_devices:
_test(n_epochs=1, metric_device=metric_device)
_test(n_epochs=2, metric_device=metric_device)
Expand Down