Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions openfold3/core/utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ def test_chunk_size(chunk_size):
return candidates[lo]

def _compare_arg_caches(self, ac1, ac2):
# When recursing this tests that tensors have the same rank
if len(ac1) != len(ac2):
return False
consistent = True
for a1, a2 in zip(ac1, ac2, strict=True):
assert type(a1) is type(a2)
Expand All @@ -412,12 +415,11 @@ def tune_chunk_size(
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
) -> int:
def remove_tensors(a):
return a.shape if type(a) is torch.Tensor else a
return (a.shape, a.dtype.itemsize) if type(a) is torch.Tensor else a

arg_data = tree_map(remove_tensors, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
else:
# Otherwise, we can reuse the precomputed value
Expand Down
111 changes: 111 additions & 0 deletions openfold3/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,42 @@ def test_chunk_slice_dict(self):

self.assertTrue(torch.all(chunked == chunked_flattened))

def test_chunk_size_tuner_caches(self):
tuner = ChunkSizeTuner()

def fn(t, chunk_size):
if chunk_size > 2 ** t.dim() * t.dtype.itemsize:
raise RuntimeError("Chunk size too large")
return t

spy_fn = unittest.mock.Mock(side_effect=fn)

first = tuner.tune_chunk_size(
representative_fn=spy_fn,
args=(torch.randn(2, 3, 4, 5),),
min_chunk_size=4,
max_chunk_size=256,
)

first_call_count = spy_fn.call_count
second = tuner.tune_chunk_size(
representative_fn=spy_fn,
args=(torch.randn(2, 3, 4, 5),),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertEqual(
first,
second,
"Chunk size should have been cached for identical arg shapes and dtypes",
)
self.assertEqual(
first_call_count,
spy_fn.call_count,
"Representative function should not have been called again for identical arg shapes and dtypes",
)

def test_chunk_size_tuner_does_not_retest_candidates(self):
# Based on previous bug: the binary search forgot which candidates it
# had already proven non-viable and re-tested them.
Expand Down Expand Up @@ -244,3 +280,78 @@ def fn(arg, chunk_size, _max=max_viable):
fn, args=(None,), min_chunk_size=4, max_chunk_size=1024
)
self.assertEqual(result, expected)

def test_chunk_size_tuner_handles_arg_rank_change(self):
tuner = ChunkSizeTuner()

def fn(t, chunk_size):
if chunk_size > 2 ** t.dim() * t.dtype.itemsize:
raise RuntimeError("Chunk size too large")
return t

first = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5),),
min_chunk_size=4,
max_chunk_size=256,
)
second = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5, 6),),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertNotEqual(
first, second, "Chunk size should have been re-tuned for new arg rank"
)

def test_chunk_size_tuner_handles_dtype_bytes_change(self):
tuner = ChunkSizeTuner()

def fn(t, chunk_size):
if chunk_size > 2 ** t.dim() * t.dtype.itemsize:
raise RuntimeError("Chunk size too large")
return t

first = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5, dtype=torch.float32),),
min_chunk_size=4,
max_chunk_size=256,
)
second = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5, dtype=torch.bfloat16),),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertNotEqual(
first, second, "Chunk size should have been re-tuned for new dtype bytes"
)

def test_chunk_size_tuner_handles_arg_count_change(self):
tuner = ChunkSizeTuner()

def fn(*args, chunk_size):
if chunk_size > 2 ** len(args):
raise RuntimeError("Chunk size too large")
return args

first = tuner.tune_chunk_size(
representative_fn=fn,
args=(1, 2, 3, 4, 5),
min_chunk_size=4,
max_chunk_size=256,
)
second = tuner.tune_chunk_size(
representative_fn=fn,
args=(1, 2, 3, 4, 5, 6),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertNotEqual(
first, second, "Chunk size should have been re-tuned for new arg count"
)