diff --git a/openfold3/core/model/latent/base_stacks.py b/openfold3/core/model/latent/base_stacks.py index 87a37f37..22896844 100644 --- a/openfold3/core/model/latent/base_stacks.py +++ b/openfold3/core/model/latent/base_stacks.py @@ -29,8 +29,8 @@ from openfold3.core.utils.checkpointing import checkpoint_blocks from openfold3.core.utils.chunk_utils import ( - CUEQ_MAX_CHUNK_SIZE, DEFAULT_MAX_CHUNK_SIZE, + FLASH_MAX_CHUNK_SIZE, ChunkSizeTuner, ) @@ -126,10 +126,13 @@ def block_with_cache_clear(block, *args, **kwargs): if chunk_size is not None and self.chunk_size_tuner is not None: assert not self.training + use_flash_kernels = ( + use_cueq_triangle_kernels + or use_triton_triangle_kernels + or use_deepspeed_evo_attention + ) max_chunk_size = ( - CUEQ_MAX_CHUNK_SIZE - if use_cueq_triangle_kernels - else DEFAULT_MAX_CHUNK_SIZE + FLASH_MAX_CHUNK_SIZE if use_flash_kernels else DEFAULT_MAX_CHUNK_SIZE ) tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( representative_fn=blocks[0], @@ -144,9 +147,7 @@ def block_with_cache_clear(block, *args, **kwargs): max_chunk_size=max_chunk_size, ) attn_chunk = ( - tuned_chunk_size - if use_cueq_triangle_kernels - else (tuned_chunk_size // 4) + tuned_chunk_size if use_flash_kernels else (tuned_chunk_size // 4) ) blocks = [ partial( diff --git a/openfold3/core/model/latent/pairformer.py b/openfold3/core/model/latent/pairformer.py index 6f75923d..8ef7b354 100644 --- a/openfold3/core/model/latent/pairformer.py +++ b/openfold3/core/model/latent/pairformer.py @@ -29,8 +29,8 @@ from openfold3.core.model.layers.transition import SwiGLUTransition from openfold3.core.utils.checkpointing import checkpoint_blocks from openfold3.core.utils.chunk_utils import ( - CUEQ_MAX_CHUNK_SIZE, DEFAULT_MAX_CHUNK_SIZE, + FLASH_MAX_CHUNK_SIZE, ChunkSizeTuner, ) from openfold3.core.utils.tensor_utils import add @@ -371,10 +371,13 @@ def block_with_cache_clear(block, *args, **kwargs): if chunk_size is not None and self.chunk_size_tuner is not None: assert not self.training + use_flash_kernels = ( + use_cueq_triangle_kernels + or use_triton_triangle_kernels + or use_deepspeed_evo_attention + ) max_chunk_size = ( - CUEQ_MAX_CHUNK_SIZE - if use_cueq_triangle_kernels - else DEFAULT_MAX_CHUNK_SIZE + FLASH_MAX_CHUNK_SIZE if use_flash_kernels else DEFAULT_MAX_CHUNK_SIZE ) tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( representative_fn=blocks[0], @@ -387,9 +390,7 @@ def block_with_cache_clear(block, *args, **kwargs): max_chunk_size=max_chunk_size, ) attn_chunk = ( - tuned_chunk_size - if use_cueq_triangle_kernels - else (tuned_chunk_size // 4) + tuned_chunk_size if use_flash_kernels else (tuned_chunk_size // 4) ) blocks = [ partial( diff --git a/openfold3/core/model/latent/template_module.py b/openfold3/core/model/latent/template_module.py index cbd00ef8..d0d040be 100644 --- a/openfold3/core/model/latent/template_module.py +++ b/openfold3/core/model/latent/template_module.py @@ -35,8 +35,8 @@ from openfold3.core.model.primitives import LayerNorm, Linear from openfold3.core.utils.checkpointing import checkpoint_blocks, checkpoint_section from openfold3.core.utils.chunk_utils import ( - CUEQ_MAX_CHUNK_SIZE, DEFAULT_MAX_CHUNK_SIZE, + FLASH_MAX_CHUNK_SIZE, ChunkSizeTuner, ) from openfold3.core.utils.tensor_utils import add @@ -387,10 +387,13 @@ def _prep_blocks( if chunk_size is not None and self.chunk_size_tuner is not None: assert not self.training + use_flash_kernels = ( + use_cueq_triangle_kernels + or use_triton_triangle_kernels + or use_deepspeed_evo_attention + ) max_chunk_size = ( - CUEQ_MAX_CHUNK_SIZE - if use_cueq_triangle_kernels - else DEFAULT_MAX_CHUNK_SIZE + FLASH_MAX_CHUNK_SIZE if use_flash_kernels else DEFAULT_MAX_CHUNK_SIZE ) tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( representative_fn=blocks[0], @@ -399,9 +402,7 @@ def _prep_blocks( max_chunk_size=max_chunk_size, ) attn_chunk = ( - tuned_chunk_size - if use_cueq_triangle_kernels - else (tuned_chunk_size // 4) + tuned_chunk_size if use_flash_kernels else (tuned_chunk_size // 4) ) blocks = [ partial( diff --git a/openfold3/core/utils/chunk_utils.py b/openfold3/core/utils/chunk_utils.py index c89e1d1c..461b56f6 100644 --- a/openfold3/core/utils/chunk_utils.py +++ b/openfold3/core/utils/chunk_utils.py @@ -26,7 +26,7 @@ ) DEFAULT_MAX_CHUNK_SIZE = 512 -CUEQ_MAX_CHUNK_SIZE = 1024 +FLASH_MAX_CHUNK_SIZE = 1024 def _fetch_dims(tree):