Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added: Improved tests using pytest parametrize #2543

Merged
merged 12 commits into from
Apr 14, 2022
Merged
68 changes: 31 additions & 37 deletions tests/ignite/metrics/test_root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,40 @@ def test_zero_sample():
rmse.compute()


def test_compute():
@pytest.fixture(params=[0, 1, 2, 3])
def test_data(request):
return [
(torch.empty(10).uniform_(0, 10), torch.empty(10).uniform_(0, 10), 1),
(torch.empty(10, 1).uniform_(-10, 10), torch.empty(10, 1).uniform_(-10, 10), 1),
# updated batches
(torch.empty(50).uniform_(0, 10), torch.empty(50).uniform_(0, 10), 16),
(torch.empty(50, 1).uniform_(-10, 10), torch.empty(50, 1).uniform_(-10, 10), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(3))
def test_compute(n_times, test_data):

rmse = RootMeanSquaredError()

def _test(y_pred, y, batch_size):
rmse.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
rmse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
rmse.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

np_res = np.sqrt(np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0])
res = rmse.compute()

assert isinstance(res, float)
assert pytest.approx(res) == np_res

def get_test_cases():

test_cases = [
(torch.empty(10).uniform_(0, 10), torch.empty(10).uniform_(0, 10), 1),
(torch.empty(10, 1).uniform_(-10, 10), torch.empty(10, 1).uniform_(-10, 10), 1),
# updated batches
(torch.empty(50).uniform_(0, 10), torch.empty(50).uniform_(0, 10), 16),
(torch.empty(50, 1).uniform_(-10, 10), torch.empty(50, 1).uniform_(-10, 10), 16),
]

return test_cases

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
y_pred, y, batch_size = test_data
rmse.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
rmse.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
rmse.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

np_res = np.sqrt(np.power((np_y - np_y_pred), 2.0).sum() / np_y.shape[0])
res = rmse.compute()

assert isinstance(res, float)
assert pytest.approx(res) == np_res


def _test_distrib_integration(device, tol=1e-6):
Expand Down