From 2fb7c099dd2c360608648a0a08cef7c5200e6241 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 16 Sep 2021 14:37:07 -0700 Subject: [PATCH] [tensor comparisons] support pt-1.8, add torch_assert_close --- megatron/testing_utils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index 801d74d9b..bfcece1b4 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -208,9 +208,20 @@ def get_gpu_count(): else: return 0 -def torch_assert_equal(actual, expected): - """ emulates the removed torch.testing.assert_equal """ - torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0) +def torch_assert_equal(actual, expected, **kwargs): + # assert_equal was added around pt-1.9, it does better checks - e.g will check dimensions match + if hasattr(torch.testing, "assert_equal"): + return torch.testing.assert_equal(actual, expected, **kwargs) + else: + return torch.allclose(actual, expected, rtol=0.0, atol=0.0) + +def torch_assert_close(actual, expected, **kwargs): + # assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match + if hasattr(torch.testing, "assert_close"): + return torch.testing.assert_close(actual, expected, **kwargs) + else: + kwargs.pop("msg", None) # doesn't have msg arg + return torch.allclose(actual, expected, **kwargs) def is_torch_bf16_available():