Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions openfold3/core/utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,20 @@ def test_chunk_size(chunk_size):
except RuntimeError:
return False

min_viable_chunk_size_index = 0
# Binary search for largest viable chunk size (min_chunk_size is assumed
# to be viable).
lo = 0
hi = len(candidates)
i = len(candidates) - 1
while i > min_viable_chunk_size_index:
while lo < i < hi:
viable = test_chunk_size(candidates[i])
if not viable:
i = (min_viable_chunk_size_index + i) // 2
if viable:
lo = i
else:
min_viable_chunk_size_index = i
i = (i + len(candidates) - 1) // 2
hi = i
i = (lo + hi) // 2

return candidates[min_viable_chunk_size_index]
return candidates[lo]

def _compare_arg_caches(self, ac1, ac2):
consistent = True
Expand Down
50 changes: 49 additions & 1 deletion openfold3/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from openfold3.core.model.primitives import Linear
from openfold3.core.utils.chunk_utils import _chunk_slice, chunk_layer
from openfold3.core.utils.chunk_utils import ChunkSizeTuner, _chunk_slice, chunk_layer
from openfold3.core.utils.rigid_utils import (
Rigid,
Rotation,
Expand Down Expand Up @@ -196,3 +196,51 @@ def test_chunk_slice_dict(self):
chunked_flattened = x_flat[i:j]

self.assertTrue(torch.all(chunked == chunked_flattened))

def test_chunk_size_tuner_does_not_retest_candidates(self):
# Based on previous bug: the binary search forgot which candidates it
# had already proven non-viable and re-tested them.
for max_viable in (128, 64, 256, 512):
with self.subTest(max_viable=max_viable):
tested = []

def fn(arg, chunk_size, _max=max_viable, tested=tested):
tested.append(chunk_size)
if chunk_size > _max:
raise RuntimeError("simulated OOM")

ChunkSizeTuner._determine_favorable_chunk_size(
fn, args=(None,), min_chunk_size=4, max_chunk_size=1024
)

self.assertEqual(
len(tested),
len(set(tested)),
f"Some candidate was tested more than once: {tested}",
)

def test_chunk_size_tuner_picks_largest_viable(self):
# When the cutoff sits between two power-of-2 candidates, the tuner
# should pick the largest viable power of 2.
cases = [
# (max_viable, expected_chunk_size)
(1024, 1024),
(512, 512),
(511, 256),
(256, 256),
(255, 128),
(128, 128),
(4, 4),
(3, 4), # nothing viable above min, returns min
]
for max_viable, expected in cases:
with self.subTest(max_viable=max_viable):

def fn(arg, chunk_size, _max=max_viable):
if chunk_size > _max:
raise RuntimeError("simulated OOM")

result = ChunkSizeTuner._determine_favorable_chunk_size(
fn, args=(None,), min_chunk_size=4, max_chunk_size=1024
)
self.assertEqual(result, expected)