Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions openfold3/core/model/latent/base_stacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

from openfold3.core.utils.checkpointing import checkpoint_blocks
from openfold3.core.utils.chunk_utils import (
CUEQ_MAX_CHUNK_SIZE,
DEFAULT_MAX_CHUNK_SIZE,
FLASH_MAX_CHUNK_SIZE,
ChunkSizeTuner,
)

Expand Down Expand Up @@ -126,10 +126,13 @@ def block_with_cache_clear(block, *args, **kwargs):

if chunk_size is not None and self.chunk_size_tuner is not None:
assert not self.training
use_flash_kernels = (
use_cueq_triangle_kernels
or use_triton_triangle_kernels
or use_deepspeed_evo_attention
)
max_chunk_size = (
CUEQ_MAX_CHUNK_SIZE
if use_cueq_triangle_kernels
else DEFAULT_MAX_CHUNK_SIZE
FLASH_MAX_CHUNK_SIZE if use_flash_kernels else DEFAULT_MAX_CHUNK_SIZE
)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
Expand All @@ -144,9 +147,7 @@ def block_with_cache_clear(block, *args, **kwargs):
max_chunk_size=max_chunk_size,
)
attn_chunk = (
tuned_chunk_size
if use_cueq_triangle_kernels
else (tuned_chunk_size // 4)
tuned_chunk_size if use_flash_kernels else (tuned_chunk_size // 4)
)
blocks = [
partial(
Expand Down
15 changes: 8 additions & 7 deletions openfold3/core/model/latent/pairformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from openfold3.core.model.layers.transition import SwiGLUTransition
from openfold3.core.utils.checkpointing import checkpoint_blocks
from openfold3.core.utils.chunk_utils import (
CUEQ_MAX_CHUNK_SIZE,
DEFAULT_MAX_CHUNK_SIZE,
FLASH_MAX_CHUNK_SIZE,
ChunkSizeTuner,
)
from openfold3.core.utils.tensor_utils import add
Expand Down Expand Up @@ -371,10 +371,13 @@ def block_with_cache_clear(block, *args, **kwargs):

if chunk_size is not None and self.chunk_size_tuner is not None:
assert not self.training
use_flash_kernels = (
use_cueq_triangle_kernels
or use_triton_triangle_kernels
or use_deepspeed_evo_attention
)
max_chunk_size = (
CUEQ_MAX_CHUNK_SIZE
if use_cueq_triangle_kernels
else DEFAULT_MAX_CHUNK_SIZE
FLASH_MAX_CHUNK_SIZE if use_flash_kernels else DEFAULT_MAX_CHUNK_SIZE
)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
Expand All @@ -387,9 +390,7 @@ def block_with_cache_clear(block, *args, **kwargs):
max_chunk_size=max_chunk_size,
)
attn_chunk = (
tuned_chunk_size
if use_cueq_triangle_kernels
else (tuned_chunk_size // 4)
tuned_chunk_size if use_flash_kernels else (tuned_chunk_size // 4)
)
blocks = [
partial(
Expand Down
15 changes: 8 additions & 7 deletions openfold3/core/model/latent/template_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from openfold3.core.model.primitives import LayerNorm, Linear
from openfold3.core.utils.checkpointing import checkpoint_blocks, checkpoint_section
from openfold3.core.utils.chunk_utils import (
CUEQ_MAX_CHUNK_SIZE,
DEFAULT_MAX_CHUNK_SIZE,
FLASH_MAX_CHUNK_SIZE,
ChunkSizeTuner,
)
from openfold3.core.utils.tensor_utils import add
Expand Down Expand Up @@ -387,10 +387,13 @@ def _prep_blocks(

if chunk_size is not None and self.chunk_size_tuner is not None:
assert not self.training
use_flash_kernels = (
use_cueq_triangle_kernels
or use_triton_triangle_kernels
or use_deepspeed_evo_attention
)
max_chunk_size = (
CUEQ_MAX_CHUNK_SIZE
if use_cueq_triangle_kernels
else DEFAULT_MAX_CHUNK_SIZE
FLASH_MAX_CHUNK_SIZE if use_flash_kernels else DEFAULT_MAX_CHUNK_SIZE
)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
Expand All @@ -399,9 +402,7 @@ def _prep_blocks(
max_chunk_size=max_chunk_size,
)
attn_chunk = (
tuned_chunk_size
if use_cueq_triangle_kernels
else (tuned_chunk_size // 4)
tuned_chunk_size if use_flash_kernels else (tuned_chunk_size // 4)
)
blocks = [
partial(
Expand Down
2 changes: 1 addition & 1 deletion openfold3/core/utils/chunk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

DEFAULT_MAX_CHUNK_SIZE = 512
CUEQ_MAX_CHUNK_SIZE = 1024
FLASH_MAX_CHUNK_SIZE = 1024


def _fetch_dims(tree):
Expand Down
Loading