diff --git a/openfold3/core/kernels/triton/evoformer.py b/openfold3/core/kernels/triton/evoformer.py index 03a40ff1..297795e6 100644 --- a/openfold3/core/kernels/triton/evoformer.py +++ b/openfold3/core/kernels/triton/evoformer.py @@ -904,6 +904,22 @@ def forward(ctx, Q, K, V, res_mask, pair_bias, has_pair_bias=True): ).contiguous() # (BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM) BATCH_SIZE, N_SEQ, HEAD, SEQ_LEN, DIM = Q.shape + + assert res_mask.shape == ( + BATCH_SIZE, + N_SEQ, + 1, + 1, + SEQ_LEN, + ), f"{tuple(res_mask.shape)} != {(BATCH_SIZE, N_SEQ, 1, 1, SEQ_LEN)}" + assert pair_bias.shape == ( + BATCH_SIZE, + 1, + HEAD, + SEQ_LEN, + SEQ_LEN, + ), f"{tuple(pair_bias.shape)} != {(BATCH_SIZE, 1, HEAD, SEQ_LEN, SEQ_LEN)}" + softmax_scale = DIM**-0.5 BLOCK_DIM = max(triton.next_power_of_2(DIM), 32) diff --git a/openfold3/core/model/heads/prediction_heads.py b/openfold3/core/model/heads/prediction_heads.py index f43cede3..c446f280 100644 --- a/openfold3/core/model/heads/prediction_heads.py +++ b/openfold3/core/model/heads/prediction_heads.py @@ -207,10 +207,9 @@ def reshape_outputs(x: torch.Tensor, feat_dims: list): single_mask = reshape_inputs(x=single_mask, feat_dims=single_mask.shape[-1:]) pair_mask = reshape_inputs(x=pair_mask, feat_dims=pair_mask.shape[-2:]) - # Using the DS kernel with chunk tuning and multiple samples causes shape issues - # in the DS kernel. To avoid this, chunk tuning is disabled in this case. - # TODO: cuEq seems to fail comparison unit tests with the same settings, - # disable for now and verify behavior + # The optimized kernels all require that pair bias have size 1 in the + # second dimension and cross-sample chunking has to combine the batch + # dimensions and expand it. use_kernels = ( use_deepspeed_evo_attention or use_cueq_triangle_kernels diff --git a/openfold3/tests/test_kernels.py b/openfold3/tests/test_kernels.py index ed5a0111..2c3486e8 100644 --- a/openfold3/tests/test_kernels.py +++ b/openfold3/tests/test_kernels.py @@ -438,8 +438,7 @@ def _compare_pairformer( if chunk_size is not None and ( use_deepspeed_evo_attention or use_triton_triangle_kernels ): - # Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel - # Triton kernel works with these shapes but has some numerical differences + # Chunk tuning is not supported with batch size > 1 for these kernels batch_size = 1 n_res = 200 # Avoid cuEq seq len constraints @@ -719,13 +718,15 @@ def _compare_template_stack( chunk_size=None, ): """ - Compare Template Stack output with and without using DeepSpeed Evoformer - attention kernel. Kernel can be used for Triangle Attention in the Template Pair - Stack. + Compare Template Stack output with and without using different optimized + attention kernels. Kernel can be used for Triangle Attention in the + Template Pair Stack. """ batch_size = consts.batch_size - if chunk_size is not None and use_deepspeed_evo_attention: - # Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel + if chunk_size is not None and ( + use_deepspeed_evo_attention or use_triton_triangle_kernels + ): + # Chunk tuning is not supported with batch size > 1 for these kernels batch_size = 1 n_templ = 3 @@ -793,7 +794,6 @@ def to_device(t): def test_compare_template_stack_dsk_fp32(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.float32, ) @@ -801,7 +801,6 @@ def test_compare_template_stack_dsk_fp32(self): def test_compare_template_stack_dsk_bf16(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.bfloat16, ) @@ -809,7 +808,6 @@ def test_compare_template_stack_dsk_bf16(self): def test_compare_template_stack_dsk_fp32_chunk(self): self._compare_template_stack( use_deepspeed_evo_attention=True, - use_cueq_triangle_kernels=False, dtype=torch.float32, chunk_size=4, ) @@ -817,7 +815,6 @@ def test_compare_template_stack_dsk_fp32_chunk(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_fp32(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.float32, ) @@ -825,7 +822,6 @@ def test_compare_template_stack_cueq_fp32(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_bf16(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.bfloat16, ) @@ -833,7 +829,6 @@ def test_compare_template_stack_cueq_bf16(self): @compare_utils.skip_unless_cueq_installed() def test_compare_template_stack_cueq_fp32_chunk(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, use_cueq_triangle_kernels=True, dtype=torch.float32, chunk_size=4, @@ -842,8 +837,6 @@ def test_compare_template_stack_cueq_fp32_chunk(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_fp32_chunk(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.float32, chunk_size=4, @@ -852,8 +845,6 @@ def test_compare_template_stack_triton_fp32_chunk(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_fp32(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.float32, ) @@ -861,8 +852,6 @@ def test_compare_template_stack_triton_fp32(self): @compare_utils.skip_unless_triton_installed() def test_compare_template_stack_triton_bf16(self): self._compare_template_stack( - use_deepspeed_evo_attention=False, - use_cueq_triangle_kernels=False, use_triton_triangle_kernels=True, dtype=torch.bfloat16, )