Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions openfold3/core/utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ def test_chunk_size(chunk_size):
return candidates[lo]

def _compare_arg_caches(self, ac1, ac2):
# When recursing this tests that tensors have the same rank
if len(ac1) != len(ac2):
return False
consistent = True
for a1, a2 in zip(ac1, ac2, strict=True):
assert type(a1) is type(a2)
Expand All @@ -412,12 +415,11 @@ def tune_chunk_size(
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
) -> int:
def remove_tensors(a):
return a.shape if type(a) is torch.Tensor else a
return (a.shape, a.dtype.itemsize) if type(a) is torch.Tensor else a

arg_data = tree_map(remove_tensors, args, object)
if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data)
consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
else:
# Otherwise, we can reuse the precomputed value
Expand Down
75 changes: 75 additions & 0 deletions openfold3/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,78 @@ def fn(arg, chunk_size, _max=max_viable):
fn, args=(None,), min_chunk_size=4, max_chunk_size=1024
)
self.assertEqual(result, expected)

def test_chunk_size_tuner_handles_arg_rank_change(self):
tuner = ChunkSizeTuner()

def fn(t, chunk_size):
if chunk_size > 2 ** t.dim() * t.dtype.itemsize:
raise RuntimeError("Chunk size too large")
return t

first = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5),),
min_chunk_size=4,
max_chunk_size=256,
)
second = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5, 6),),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertNotEqual(
first, second, "Chunk size should have been re-tuned for new arg rank"
)

def test_chunk_size_tuner_handles_dtype_bytes_change(self):
tuner = ChunkSizeTuner()

def fn(t, chunk_size):
if chunk_size > 2 ** t.dim() * t.dtype.itemsize:
raise RuntimeError("Chunk size too large")
return t

first = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5, dtype=torch.float32),),
min_chunk_size=4,
max_chunk_size=256,
)
second = tuner.tune_chunk_size(
representative_fn=fn,
args=(torch.zeros(2, 3, 4, 5, dtype=torch.bfloat16),),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertNotEqual(
first, second, "Chunk size should have been re-tuned for new dtype bytes"
)

def test_chunk_size_tuner_handles_arg_count_change(self):
tuner = ChunkSizeTuner()

def fn(*args, chunk_size):
if chunk_size > 2 ** len(args):
raise RuntimeError("Chunk size too large")
return args

first = tuner.tune_chunk_size(
representative_fn=fn,
args=(1, 2, 3, 4, 5),
min_chunk_size=4,
max_chunk_size=256,
)
second = tuner.tune_chunk_size(
representative_fn=fn,
args=(1, 2, 3, 4, 5, 6),
min_chunk_size=4,
max_chunk_size=256,
)

self.assertNotEqual(
first, second, "Chunk size should have been re-tuned for new arg count"
)