diff --git a/openfold3/core/utils/chunk_utils.py b/openfold3/core/utils/chunk_utils.py index 4c9b850c..4037d519 100644 --- a/openfold3/core/utils/chunk_utils.py +++ b/openfold3/core/utils/chunk_utils.py @@ -372,17 +372,20 @@ def test_chunk_size(chunk_size): except RuntimeError: return False - min_viable_chunk_size_index = 0 + # Binary search for largest viable chunk size (min_chunk_size is assumed + # to be viable). + lo = 0 + hi = len(candidates) i = len(candidates) - 1 - while i > min_viable_chunk_size_index: + while lo < i < hi: viable = test_chunk_size(candidates[i]) - if not viable: - i = (min_viable_chunk_size_index + i) // 2 + if viable: + lo = i else: - min_viable_chunk_size_index = i - i = (i + len(candidates) - 1) // 2 + hi = i + i = (lo + hi) // 2 - return candidates[min_viable_chunk_size_index] + return candidates[lo] def _compare_arg_caches(self, ac1, ac2): consistent = True diff --git a/openfold3/tests/utils/test_utils.py b/openfold3/tests/utils/test_utils.py index 63a33b28..ed3907db 100644 --- a/openfold3/tests/utils/test_utils.py +++ b/openfold3/tests/utils/test_utils.py @@ -18,7 +18,7 @@ import torch from openfold3.core.model.primitives import Linear -from openfold3.core.utils.chunk_utils import _chunk_slice, chunk_layer +from openfold3.core.utils.chunk_utils import ChunkSizeTuner, _chunk_slice, chunk_layer from openfold3.core.utils.rigid_utils import ( Rigid, Rotation, @@ -196,3 +196,51 @@ def test_chunk_slice_dict(self): chunked_flattened = x_flat[i:j] self.assertTrue(torch.all(chunked == chunked_flattened)) + + def test_chunk_size_tuner_does_not_retest_candidates(self): + # Based on previous bug: the binary search forgot which candidates it + # had already proven non-viable and re-tested them. + for max_viable in (128, 64, 256, 512): + with self.subTest(max_viable=max_viable): + tested = [] + + def fn(arg, chunk_size, _max=max_viable, tested=tested): + tested.append(chunk_size) + if chunk_size > _max: + raise RuntimeError("simulated OOM") + + ChunkSizeTuner._determine_favorable_chunk_size( + fn, args=(None,), min_chunk_size=4, max_chunk_size=1024 + ) + + self.assertEqual( + len(tested), + len(set(tested)), + f"Some candidate was tested more than once: {tested}", + ) + + def test_chunk_size_tuner_picks_largest_viable(self): + # When the cutoff sits between two power-of-2 candidates, the tuner + # should pick the largest viable power of 2. + cases = [ + # (max_viable, expected_chunk_size) + (1024, 1024), + (512, 512), + (511, 256), + (256, 256), + (255, 128), + (128, 128), + (4, 4), + (3, 4), # nothing viable above min, returns min + ] + for max_viable, expected in cases: + with self.subTest(max_viable=max_viable): + + def fn(arg, chunk_size, _max=max_viable): + if chunk_size > _max: + raise RuntimeError("simulated OOM") + + result = ChunkSizeTuner._determine_favorable_chunk_size( + fn, args=(None,), min_chunk_size=4, max_chunk_size=1024 + ) + self.assertEqual(result, expected)