diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index d402b59df..62991c044 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -207,6 +207,10 @@ def get_gpu_count(): else: return 0 +def torch_assert_equal(actual, expected): + """ emulates the removed torch.testing.assert_equal """ + torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0) + def get_tests_dir(append_path=None): """ diff --git a/tests/test_activations.py b/tests/test_activations.py index 85c949f4a..98b91d376 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from megatron.model.activations import liglu, geglu, reglu, swiglu -from megatron.testing_utils import set_seed +from megatron.testing_utils import set_seed, torch_assert_equal class TestActivations(unittest.TestCase): @@ -27,16 +27,16 @@ def test_shapes(self): def test_liglu(self): expected = self.x1 * self.x2 - torch.testing.assert_equal(liglu(self.x), expected) + torch_assert_equal(liglu(self.x), expected) def test_geglu(self): expected = self.x1 * F.gelu(self.x2) - torch.testing.assert_equal(geglu(self.x), expected) + torch_assert_equal(geglu(self.x), expected) def test_reglu(self): expected = self.x1 * F.relu(self.x2) - torch.testing.assert_equal(reglu(self.x), expected) + torch_assert_equal(reglu(self.x), expected) def test_swiglu(self): expected = self.x1 * F.silu(self.x2) - torch.testing.assert_equal(swiglu(self.x), expected) + torch_assert_equal(swiglu(self.x), expected)