Skip to content

Commit

Permalink
add general testing
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Sep 3, 2024
1 parent 7e9dd08 commit 3d3d7b5
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions tests/unittests/_helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,25 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O
"""Recursively assert that two results are within a certain tolerance."""
# single output compare
if isinstance(tm_result, Tensor):
assert np.allclose(tm_result.detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
assert np.allclose(
tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result,
ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result,
atol=atol,
equal_nan=True,
)
# multi output compare
elif isinstance(tm_result, Sequence):
for pl_res, ref_res in zip(tm_result, ref_result):
_assert_allclose(pl_res, ref_res, atol=atol)
elif isinstance(tm_result, Dict):
if key is None:
raise KeyError("Provide Key for Dict based metric results.")
assert np.allclose(tm_result[key].detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
assert np.allclose(
tm_result[key].detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result[key],
ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result,
atol=atol,
equal_nan=True,
)
else:
raise ValueError("Unknown format for comparison")

Expand Down Expand Up @@ -147,13 +157,24 @@ def _class_test(
# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)
metric_clone = deepcopy(metric)

for i in range(rank, num_batches, world_size):
batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}

# compute batch stats and aggregate for global stats
batch_result = metric(preds[i], target[i], **batch_kwargs_update)

if rank == 0 and world_size == 1 and i == 0: # check only in non-ddp mode and first batch
# dummy check to make sure that forward/update works as expected
metric_clone.update(preds[i], target[i], **batch_kwargs_update)
update_result = metric_clone.compute()
if isinstance(batch_result, dict):
for key in batch_result:
_assert_allclose(batch_result, update_result[key], key=key)
else:
_assert_allclose(batch_result, update_result)

if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0:
if isinstance(preds, Tensor):
ddp_preds = torch.cat([preds[i + r] for r in range(world_size)]).cpu()
Expand Down

0 comments on commit 3d3d7b5

Please sign in to comment.