diff --git a/openfold3/core/kernels/triton/evoformer.py b/openfold3/core/kernels/triton/evoformer.py index 03a40ff1..297795e6 100644 --- a/openfold3/core/kernels/triton/evoformer.py +++ b/openfold3/core/kernels/triton/evoformer.py @@ -904,6 +904,22 @@ def forward(ctx, Q, K, V, res_mask, pair_bias, has_pair_bias=True): ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = Q.shape + + assert res_mask.shape == ( + BATCH_SIZE, + N_SEQ, + 1, + 1, + SEQ_LEN, + ), f"{tuple(res_mask.shape)} != {(BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN)}" + assert pair_bias.shape == ( + BATCH_SIZE, + 1, + HEAD, + SEQ_LEN, + SEQ_LEN, + ), f"{tuple(pair_bias.shape)} != {(BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN)}" + softmax_scale = DIM**-0.5 BLOCK_DIM = max(triton.next_power_of_2(DIM), 32) diff --git a/openfold3/core/model/heads/head_modules.py b/openfold3/core/model/heads/head_modules.py index 46a132ac..92830d48 100644 --- a/openfold3/core/model/heads/head_modules.py +++ b/openfold3/core/model/heads/head_modules.py @@ -187,8 +187,16 @@ def forward( apply_per_sample = ( not torch.is_grad_enabled() and num_samples > 1 - and self.per_sample_token_cutoff is not None - and repr_x_pred.shape[-2] > self.per_sample_token_cutoff + and ( + (self.per_sample_token_cutoff is not None + and repr_x_pred.shape[-2] > self.per_sample_token_cutoff) + # The optimized attention kernels do not support cross-sample + # chunking because it requires expanding the pair bias. For now + # we just always apply per sample if these kernels are in use. + or use_deepspeed_evo_attention + or use_cueq_triangle_kernels + or use_triton_triangle_kernels + ) ) out_device = atom_positions_predicted.device diff --git a/openfold3/core/model/heads/prediction_heads.py b/openfold3/core/model/heads/prediction_heads.py index e957bc68..d89de0fd 100644 --- a/openfold3/core/model/heads/prediction_heads.py +++ b/openfold3/core/model/heads/prediction_heads.py @@ -209,10 +209,11 @@ def reshape_outputs(x: torch.Tensor, feat_dims: list): single_mask = reshape_inputs(x=single_mask, feat_dims=single_mask.shape[-1:]) pair_mask = reshape_inputs(x=pair_mask, feat_dims=pair_mask.shape[-2:]) - # Using the DS kernel with chunk tuning and multiple samples causes shape issues - # in the DS kernel. To avoid this, chunk tuning is disabled in this case. - # TODO: cuEq seems to fail comparison unit tests with the same settings, - # disable for now and verify behavior + # The optimized kernels all require that pair bias have size 1 in the + # second dimension and cross-sample chunking has to combine the batch + # dimensions and expand it. We mostly avoid this path entirely by + # splitting per-sample when using the optimized kernels, but this avoids + # a potential correctness issue here. use_kernels = ( use_deepspeed_evo_attention or use_cueq_triangle_kernels diff --git a/openfold3/core/utils/chunk_utils.py b/openfold3/core/utils/chunk_utils.py index c89e1d1c..af870cc7 100644 --- a/openfold3/core/utils/chunk_utils.py +++ b/openfold3/core/utils/chunk_utils.py @@ -363,7 +363,6 @@ def _determine_favorable_chunk_size(fn, args, min_chunk_size, max_chunk_size): candidates = [2**l for l in range(int(math.log(max_chunk_size, 2)) + 1)] candidates = [c for c in candidates if c > min_chunk_size] candidates = [min_chunk_size] + candidates - candidates[-1] += 4 def test_chunk_size(chunk_size): try: diff --git a/openfold3/tests/test_kernels.py b/openfold3/tests/test_kernels.py index ed5a0111..490068fc 100644 --- a/openfold3/tests/test_kernels.py +++ b/openfold3/tests/test_kernels.py @@ -719,13 +719,13 @@ def _compare_template_stack( chunk_size=None, ): """ - Compare Template Stack output with and without using DeepSpeed Evoformer - attention kernel. Kernel can be used for Triangle Attention in the Template Pair - Stack. + Compare Template Stack output with and without using different optimized + attention kernels. Kernel can be used for Triangle Attention in the + Template Pair Stack. """ batch_size = consts.batch_size - if chunk_size is not None and use_deepspeed_evo_attention: - # Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel + if chunk_size is not None: + # Chunking is not supported with batch size > 1 for optimized kernels batch_size = 1 n_templ = 3 @@ -793,7 +793,6 @@ def to_device(t): def test_compare_template_stack_dsk_fp32(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.float32, ) @@ -801,7 +800,6 @@ def test_compare_template_stack_dsk_fp32(self): def test_compare_template_stack_dsk_bf16(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.bfloat16, ) @@ -809,7 +807,6 @@ def test_compare_template_stack_dsk_bf16(self): def test_compare_template_stack_dsk_fp32_chunk(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.float32, chunk_size=4, ) @@ -817,7 +814,6 @@ def test_compare_template_stack_dsk_fp32_chunk(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_fp32(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.float32, ) @@ -825,7 +821,6 @@ def test_compare_template_stack_cueq_fp32(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_bf16(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.bfloat16, ) @@ -833,7 +828,6 @@ def test_compare_template_stack_cueq_bf16(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_fp32_chunk(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.float32, chunk_size=4, @@ -842,8 +836,6 @@ def test_compare_template_stack_cueq_fp32_chunk(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_fp32_chunk(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.float32, chunk_size=4, @@ -852,8 +844,6 @@ def test_compare_template_stack_triton_fp32_chunk(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_fp32(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.float32, ) @@ -861,8 +851,6 @@ def test_compare_template_stack_triton_fp32(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_bf16(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.bfloat16, )