Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions openfold3/core/model/primitives/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,10 +773,12 @@ def _cueq_triangle_attn(q, k, v, biases, scale):
triangle_bias = triangle_bias.view(batch * n_tmpl, *triangle_bias.shape[2:])

# 4D → 5D: chunk_layer flattens batch dims and slices into chunks.
# chunk_layer skips expanding bias when all its batch dims are 1,
# so bias may have B=1 while q has B=chunk. In this case, we're good - otherwise:
# Promote to 5D with N=1 so each chunk entry is an independent batch item.
# cuequivariance >=0.8 requires bias shape (B, 1, H, Q, K) with exact
# batch match — no implicit broadcasting.
is_chunked_input = len(q.shape) == 4
is_chunked_input = len(q.shape) == 4 and triangle_bias.shape[0] > 1
if is_chunked_input:
# q: (chunk, H, S, D) → (chunk, 1, H, S, D)
q = q.unsqueeze(1)
Expand All @@ -787,8 +789,7 @@ def _cueq_triangle_attn(q, k, v, biases, scale):
# triangle_bias: (chunk, H, S, S) → (chunk, 1, H, S, S)
# or: (1, H, S, S) → (1, 1, H, S, S) when chunk_layer kept B=1
triangle_bias = triangle_bias.unsqueeze(1)
# chunk_layer skips expanding bias when all its batch dims are 1,
# so bias may have B=1 while q has B=chunk. Expand to match.
# This should not happen. Just in case. Expand to match.
if triangle_bias.shape[0] != q.shape[0]:
# (1, 1, H, S, S) → (chunk, 1, H, S, S)
triangle_bias = triangle_bias.expand(q.shape[0], *triangle_bias.shape[1:])
Expand All @@ -803,11 +804,14 @@ def _cueq_triangle_attn(q, k, v, biases, scale):
o = triangle_attention(q, k, v, bias=triangle_bias, mask=mask_bias, scale=scale)

# Undo the promotions in reverse order.
if is_chunked_input:
if len(q.shape) == 4:
##VS: There's a bug in cueq where if the input is missing the batch dim
## the outputs adds it in and so we need to remove it here
o = o.squeeze(0)
elif is_chunked_input:
# (chunk, 1, H, S, D) → (chunk, H, S, D)
o = o.squeeze(1)

if is_batched_input:
elif is_batched_input:
# (batch*n_tmpl, N, H, S, D) → (batch, n_tmpl, N, H, S, D)
o = o.view(batch, n_tmpl, *o.shape[1:])

Expand Down
5 changes: 2 additions & 3 deletions openfold3/tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,10 @@ def _compare_pairformer(
"""
batch_size = consts.batch_size
if chunk_size is not None and (
use_deepspeed_evo_attention
or use_cueq_triangle_kernels
or use_triton_triangle_kernels
use_deepspeed_evo_attention or use_triton_triangle_kernels
):
# Chunk tuning is not supported with batch size > 1 for DeepSpeed kernel
# Triton kernel works with these shapes but has some numerical differences
batch_size = 1

n_res = 200 # Avoid cuEq seq len constraints
Expand Down
Loading