Skip to content

cuequivariance support in pixi (including aarch64)#181

Merged
jandom merged 31 commits into
mainfrom
pixi-beta-cuequivariance
Apr 28, 2026
Merged

cuequivariance support in pixi (including aarch64)#181
jandom merged 31 commits into
mainfrom
pixi-beta-cuequivariance

Conversation

@jandom

@jandom jandom commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds cuequivariance support for linux-aarch64 (e.g. DGX Spark / GB10) and upgrades the version pin from <0.8 to >=0.8 across all platforms. This required adapting _cueq_triangle_attn for a breaking API change in cuequivariance 0.8+.

Changes

pixi.toml / pixi.lock

  • Added linux-aarch64 pypi-dependency targets for both cuequivariance-cuda12 and cuequivariance-cuda13 features
  • Unified all cuequivariance pins to >=0.8 (previously <0.8 on linux-64)
    • <0.8 was originally pinned because of a triangle_attention bias shape API change — now fixed in the attention code
    • 0.7.0 crashes on import on GB10/aarch64 due to an upstream pynvml bug (fixed in 0.8+)
  • All platforms now resolve to cuequivariance 0.9.1

openfold3/core/model/primitives/attention.py

  • Updated _cueq_triangle_attn for cuequivariance 0.8+ API: triangle_attention now requires bias shape (B, 1, H, Q, K) with exact batch dim match (no implicit broadcasting)
  • 4D chunked inputs: chunk_layer flattens batch dims, producing 4D tensors. These are promoted to 5D with N=1 so each chunk entry is an independent batch item. Bias batch dim is expanded when chunk_layer kept it at B=1.
  • 6D template inputs: unchanged (already worked — the view preserves the 1 in bias)

openfold3/core/kernels/cueq_utils.py

  • Added is_cuequivariance_installed() — package-only check (no CUDA requirement)
  • is_cuequivariance_available() refactored to call it internally (no behavior change)

openfold3/tests/compare_utils.py

  • skip_unless_cueq_installed() now produces distinct skip messages:
    • "Requires cuequivariance to be installed" (package missing)
    • "Requires CUDA (cuequivariance is installed but no GPU available)" (no GPU)
  • Fixes typo "Equivaraince" → "cuequivariance"

Testing

Known issue

N/A

sdvillal and others added 27 commits March 23, 2026 12:33
* Add initial pixi environment

all tests pass, predictions seem to be correct
corresponds to a modernized conda environment following best practices

* Reorder dependencies for easier read

* Add openfold3 as an editable dependency

* Sync cuda-python pin between pypi package and the conda environment

* Comments

Comments

Overcommenting issues

* Add explicitly a conda yml version of the pixi environment

* Improve some wordings

* Update pixi lockfile

* Vendoring pieces of deepspeed

incomplete, we might not need the native sources
from upstream commit df59f203f40c8a292dd019ae68c9e6c88f107026

* Swap ninja verification with pytorch's

* Vendoring pieces of deepspeed

incomplete, we might not need the native sources
from upstream commit df59f203f40c8a292dd019ae68c9e6c88f107026

* Use vendored deepspeed evoformer builder

Use vendored deepspeed in the attention primitives

* Add symlink to vendored deepspeed as in upstream

* Vendor also op_builder.__init__ from deepspeed

* Import explicitly EvoformerAttnBuilder, avoiding broken introspection magic

* Add a ignore mechanism for cutlass detection in vendored deepspeed

* Apply cutlass detection workaround and remove all nvidia-cutlass tricks from pixi environment

* Remove nvidia-cutlass from openfold-3 dependencies (fix later)

* Remove pypi ninja dependency in pixi workspace

* No need for cutlass hacks

* Add pixi config to .gitattributes

* Remove deepspeed hacks for good

* Update pixi lockfile

* Update pixi conda environment

* Remove MKL from pypi dependencies, as it is unused

* Remove aria2 from pypi dependencies, unused and not so much of a convenience

* Update lockfile

Update lockfile

* Re-enable pure PyPI install

* Disable hack when conda is active

* More comments on cutlass python API deprecation and pytorch

* Make pixi environments (CPU, CUDA12, CUDA13, for all major platforms)

* Increase LMDB map size to make test pass in osx-arm64

* Better comments of TODOs in pixi.toml

Better comments of TODOs in pixi.toml

Better comments of TODOs in pixi.toml

* Pin cuequivariance until test failure is investigated

* Move deepspeed to optional dependency also in pyproject

* Pyproject: extend python version support

* Pyproject: move dependencies table together with optional-dependencies

* Pyproject: document future decision on dependency-groups

* Pyproject: reformat to consolidate indent to 4 spaces

* Pyproject: reorder dependencies for easier read

* Pixi: add scipy

* Pixi: add comment on CUDA13

* Pixi: make cuequivariance CUDA generic for its conda packages

* Pixi: add reminder about devel install

* Pyproject: fix and improve readability, add URLs

* pixi.toml: make more readable by showing first envs, then base, then variants

* pixi.toml: pin deepspeed to 0.18.3, first one with ninja detection fixed

* pixi.toml: fully enable aarch64 and cuda13, revamp docs

* pixi.lock: update

* pixi.toml: add triton to cuequivariance dependencies for CUDA13

* pixi.lock: update

* pixi.toml: include pip to allow users to play

* pixi.toml: formatting for better readability

* pixi.toml: restrict cuequivariance-cu13 to linux-64 until we unpin to >=0.8

* pixi.toml: formatting for better readability

* pixi.toml: make pytorch-gpu an isolated environment feature

in this way we can more easily express when a package is not ready yet in CF

* pixi.toml: add environments that combine mostly pypi-based deps with CUDA from conda

* pixi.toml: add openfold3-editable-full and account for lack of cuequivariance for python=3.14

* pixi.toml: brief documentation of the pypi-dominant environments

* pixi.toml: add also the dev optional dependency group to openfold3-full

* pyproject.toml: pin cuequivariance to <0.8 until we adapt tests

* pixi.toml: add kalign to required non-pypi dependencies

* pixi.toml: add more bioinformatics tools to non-pypi

* pixi.toml: make env setup be part of the deepspeed-build feature

* pixi.toml: simplify management of pypi features

* pixi.lock: update, all tests pass A100,B300 x CUDA12,CUDA13

* pixi.toml: add table of what works and what needs test

* pixi.toml: add tasks for exporting to regular conda environment yamls

* conda environments: delete outdated modernized conda env, use new tasks instead

* pixi.toml: bump min pixi version

* pixi.toml: remove unnecessary comments

* pixi.toml: remove unnecessary envvar definition for isolating extension builds

* pixi.toml: better definition of maintenance environment

pixi.toml: better definition of maintenance environment

pixi.toml: better definition of maintenance environment

* pixi.toml: add simple task to run test and save rsults to an environment-specific dir

* of3: enable pickling regardless of forking strategy and platform

* of3: enable multiple data loader workers in osx mps backed

* Vendor improved deepspeed builder from upstream PR

See: deepspeedai/DeepSpeed#7760

* pixi.lock: update

* pixi.toml: remove some comment noise

* of3: fix multiprocessing configuration corner case in osx

* docker: move outdated example dockerfiles to docker/pixi-examples

* examples: add example runner for osx inference

* pixi.toml: ensure we get the right pytorch from pypi

something smilar should actually be supported in pyproject.toml

* pixi.lock: update, fixed torch cuda missmatch in pypi environments

* pixi.toml: fix lock export + make default environment be maintenance

* pixi.toml: use a more consitent name for environment arg

* pixi.lock: update

* pixi.toml: workaround for no-default-feature breaking the test task (pixi bug)

* pixi.toml: issue with pixi pypi resolution seems solved

* Revert "pixi.toml: issue with pixi pypi resolution seems solved"

This reverts commit ded3482.

* pixi.toml: better document problem and workaround

* pixi.toml: make the test task present in all relevant environments

this I feel makes less surprising its use, as opposed to passing the environment as an arg to a dependent task

* pixi.toml: let CUDA13 flow freely

* pixi.lock: update for initial pytorch 2.10, cuda 13.1 support

* pixi.toml: add safe cuda environments (no accelerators)

* of3: remove deepspeed hacks

note that there are still some in __init__.py

* of3: unvendor deepspeed

* pixi.toml: simplify deepspeed dependency after our changes made it to CF/pypi

* pixi.toml: remove safe environments as we are not maintaining them

* pixi.toml: enable pytorch-coda in cuda 13 env after 2.10 release

* pyproject.toml: pin deepspeed to >0.18.5, improved evoformer compilation

* Add awscrt to dependencies, missing from recent PR

* pixi.toml: setup correctly path to PTXAS_BLACKWELL for triton >=3.6.0

* pixi.toml: add -safe environments, at the moment just without cuequivariance

these are also conda-pure environments

* pixi.lock: update after consolidation (no vendor, pytorch 2.10 + CF cuda13)

* pixi.toml: update outdated comments

* updates with GB10 tests (#2)

* updates with GB10 tests

* cleanup

* harmonize

* linting data_module.py

* speculative changes

* pixi.toml: remove safe environments

* pixi.lock: update after removal of safe environments

* Remove pixi docker examples, to rework

* Comment-out workaround for hard to reproduce ABI mismatch problem

* pixi.toml: bump pixi, improve conda export by including all env variables

* pixi.toml: unpin biotite

* pixi.toml: python has its own feature

* pixi.toml: bump deepspeed

* pyproject.toml: bump deepspeed to version without Evoformer build bug

* pixi.toml: detail on workaround

* pixi.lock: update

* pixi.toml: add example task to update safely the lockfile

* pixi.toml: remove kalign2

* tests: fix test depending on unspecified glob return order

* pixi.toml: better metadata

* docs: wip

* pixi.lock: update

* Allow to configure multiprocessing start and set safe defaults

We would still need to document this for users

* Fix capitalization error

* Fix capitalization error

* Fix typo

* pixi.lock: update

---------

Co-authored-by: Tim Adler <tim.adler@bayer.com>
Co-authored-by: Jan Domański <jan.domanski@omsf.io>
@jandom jandom self-assigned this Apr 15, 2026
@jandom jandom requested a review from jnwei April 15, 2026 13:57
@jandom jandom added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label Apr 15, 2026
@jandom jandom mentioned this pull request Apr 15, 2026
4 tasks
@jandom jandom changed the title cuequivariance support in pixi cuequivariance support in pixi (including aarch64) Apr 15, 2026
Base automatically changed from pixi-beta to main April 23, 2026 07:27
@jandom jandom added safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. and removed safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. labels Apr 23, 2026
@jandom jandom requested a review from christinaflo April 23, 2026 08:13
Comment on lines +775 to +796
# 4D → 5D: chunk_layer flattens batch dims and slices into chunks.
# Promote to 5D with N=1 so each chunk entry is an independent batch item.
# cuequivariance >=0.8 requires bias shape (B, 1, H, Q, K) with exact
# batch match — no implicit broadcasting.
is_chunked_input = len(q.shape) == 4
if is_chunked_input:
# q: (chunk, H, S, D) → (chunk, 1, H, S, D)
q = q.unsqueeze(1)
k = k.unsqueeze(1)
v = v.unsqueeze(1)
# mask_bias: (chunk, 1, 1, S) → (chunk, 1, 1, 1, S)
mask_bias = mask_bias.unsqueeze(1)
# triangle_bias: (chunk, H, S, S) → (chunk, 1, H, S, S)
# or: (1, H, S, S) → (1, 1, H, S, S) when chunk_layer kept B=1
triangle_bias = triangle_bias.unsqueeze(1)
# chunk_layer skips expanding bias when all its batch dims are 1,
# so bias may have B=1 while q has B=chunk. Expand to match.
if triangle_bias.shape[0] != q.shape[0]:
# (1, 1, H, S, S) → (chunk, 1, H, S, S)
triangle_bias = triangle_bias.expand(q.shape[0], *triangle_bias.shape[1:])
if mask_bias.shape[0] != q.shape[0]:
mask_bias = mask_bias.expand(q.shape[0], *mask_bias.shape[1:])

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christinaflo this migrates how we call a recent cueq - does it look sane? (it obviously doesn't :D)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK. IMO we're not going to get much cleaner without having separate functions depending on the number of dimensions in the input.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well actually it's the right way to fix chunking when BS>1. But we do want to avoid expand() on bias for BS=1 (it becomes huge and we do allocate real memory for it on contiguous() call), so I suggest adding a few conditionals to process BS=1 case just like before :

diff --git a/openfold3/core/model/primitives/attention.py b/openfold3/core/model/primitives/attention.py
index 46c924aa..dbe74e87 100644
--- a/openfold3/core/model/primitives/attention.py
+++ b/openfold3/core/model/primitives/attention.py
@@ -773,10 +773,12 @@ def _cueq_triangle_attn(q, k, v, biases, scale):
         triangle_bias = triangle_bias.view(batch * n_tmpl, *triangle_bias.shape[2:])

     # 4D → 5D: chunk_layer flattens batch dims and slices into chunks.
+    # chunk_layer skips expanding bias when all its batch dims are 1,
+    # so bias may have B=1 while q has B=chunk. In this case, we're good - otherwise:
     # Promote to 5D with N=1 so each chunk entry is an independent batch item.
     # cuequivariance >=0.8 requires bias shape (B, 1, H, Q, K) with exact
     # batch match — no implicit broadcasting.
-    is_chunked_input = len(q.shape) == 4
+    is_chunked_input = len(q.shape) == 4 and triangle_bias.shape[0] > 1
     if is_chunked_input:
         # q: (chunk, H, S, D) → (chunk, 1, H, S, D)
         q = q.unsqueeze(1)
@@ -787,8 +789,7 @@ def _cueq_triangle_attn(q, k, v, biases, scale):
         # triangle_bias: (chunk, H, S, S) → (chunk, 1, H, S, S)
         #   or: (1, H, S, S) → (1, 1, H, S, S) when chunk_layer kept B=1
         triangle_bias = triangle_bias.unsqueeze(1)
-        # chunk_layer skips expanding bias when all its batch dims are 1,
-        # so bias may have B=1 while q has B=chunk. Expand to match.
+        # This should not happen. Just in case. Expand to match.
         if triangle_bias.shape[0] != q.shape[0]:
             # (1, 1, H, S, S) → (chunk, 1, H, S, S)
             triangle_bias = triangle_bias.expand(q.shape[0], *triangle_bias.shape[1:])
@@ -803,11 +804,14 @@ def _cueq_triangle_attn(q, k, v, biases, scale):
     o = triangle_attention(q, k, v, bias=triangle_bias, mask=mask_bias, scale=scale)

     # Undo the promotions in reverse order.
-    if is_chunked_input:
+    if len(q.shape) == 4:
+        ##VS: There's a bug in cueq where if the input is missing the batch dim
+        ## the outputs adds it in and so we need to remove it here
+        o = o.squeeze(0)
+    elif is_chunked_input:
         # (chunk, 1, H, S, D) → (chunk, H, S, D)
         o = o.squeeze(1)
-
-    if is_batched_input:
+    elif is_batched_input:
         # (batch*n_tmpl, N, H, S, D) → (batch, n_tmpl, N, H, S, D)
         o = o.view(batch, n_tmpl, *o.shape[1:])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christinaflo : I have tested that code with BS=1 and BS=2 using test_kernels, and it worked for me:

--- a/openfold3/tests/test_kernels.py
+++ b/openfold3/tests/test_kernels.py
@@ -437,8 +437,6 @@ class TestKernels(unittest.TestCase):
         batch_size = consts.batch_size
         if chunk_size is not None and (
             use_deepspeed_evo_attention
-            or use_cueq_triangle_kernels
-            or use_triton_triangle_kernels
         ):

(actually, triton_kernels also worked with BS>1, with minor accuracy error).

@jnwei jnwei left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Tested the following on A100:

  • openfold3-cuda12-pypi :
    • all kernel tests
    • all unit tests. Everything passes except for 3 attention regression tests, which are locked to the GB10 instance
    • sample ubiquitin prediction
  • openfold3-cuda13-pypi : all kernel tests pass

I added a few edits to the documentation for cuequivariance usage, namely use pixi run -e openfold3-cuda12-pypi to run with cuequivariance, feel free to edit

Comment on lines +775 to +796
# 4D → 5D: chunk_layer flattens batch dims and slices into chunks.
# Promote to 5D with N=1 so each chunk entry is an independent batch item.
# cuequivariance >=0.8 requires bias shape (B, 1, H, Q, K) with exact
# batch match — no implicit broadcasting.
is_chunked_input = len(q.shape) == 4
if is_chunked_input:
# q: (chunk, H, S, D) → (chunk, 1, H, S, D)
q = q.unsqueeze(1)
k = k.unsqueeze(1)
v = v.unsqueeze(1)
# mask_bias: (chunk, 1, 1, S) → (chunk, 1, 1, 1, S)
mask_bias = mask_bias.unsqueeze(1)
# triangle_bias: (chunk, H, S, S) → (chunk, 1, H, S, S)
# or: (1, H, S, S) → (1, 1, H, S, S) when chunk_layer kept B=1
triangle_bias = triangle_bias.unsqueeze(1)
# chunk_layer skips expanding bias when all its batch dims are 1,
# so bias may have B=1 while q has B=chunk. Expand to match.
if triangle_bias.shape[0] != q.shape[0]:
# (1, 1, H, S, S) → (chunk, 1, H, S, S)
triangle_bias = triangle_bias.expand(q.shape[0], *triangle_bias.shape[1:])
if mask_bias.shape[0] != q.shape[0]:
mask_bias = mask_bias.expand(q.shape[0], *mask_bias.shape[1:])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK. IMO we're not going to get much cleaner without having separate functions depending on the number of dimensions in the input.

@jandom

jandom commented Apr 28, 2026

Copy link
Copy Markdown
Collaborator Author

LGTM! Tested the following on A100:

This is incredible, thank you so much!

@jandom jandom merged commit 8f1e988 into main Apr 28, 2026
4 checks passed
@jandom jandom deleted the pixi-beta-cuequivariance branch April 28, 2026 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants