From 82dce6d5aeea47491966c7650ba07870deca64e3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Jun 2024 19:18:43 -0700 Subject: [PATCH] fix flaky test --- tests/test_sharded_state_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 022fb36b346f..de79c3b945d4 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -39,7 +39,8 @@ def test_filter_subtensors(): filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") for key, tensor in filtered_state_dict.items(): - assert tensor.equal(state_dict[key]) + # NOTE: don't use `euqal` here, as the tensor might contain NaNs + assert tensor is state_dict[key] @pytest.fixture(scope="module")