Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
# Updated by: Pavel Tománek (INUITS)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.6.2'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}
cuda_compute_capabilities = ["7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
('Bazel', '7.4.1', '-Java-21'),
('Clang', '18.1.8', versionsuffix),
]

dependencies = [
('CUDA', '12.6.0', '', SYSTEM),
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
('NCCL', '2.22.3', versionsuffix),
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'),
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('hypothesis', '6.103.1'),
('zlib', '1.3.1'),
]

# downloading xla tarball to avoid that Bazel downloads it during the build
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
_xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc'
# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# create wheels and install them
_jaxlib_buildopts += '--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt '
# fix jaxlib version
_jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" '

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': f'{_xla_commit}.tar.gz',
'filename': f'xla-{_xla_commit[:8]}.tar.gz',
'extract_cmd': 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives',
},
],
'checksums': [
{'jax-v0.6.2.tar.gz':
'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
{'xla-3d5ece64.tar.gz':
'fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a'},
],
'start_dir': f'jax-jax-v{version}',
# fix jaxlib version - removes .dev suffix
'prebuildopts': 'export JAXLIB_RELEASE=%(version)s && ',
'buildopts': _jaxlib_buildopts,
}),
]

# JAX otherwise preallocates ~75% of GPU memory per process; disabling avoids OOM/fragmentation.
_runtest_cmd1 = "export XLA_PYTHON_CLIENT_PREALLOCATE=false && "
# force CUDA backend to be considered first (avoid accidental ROCm/TPU plugins or CPU fallback)
_runtest_cmd1 += "export JAX_PLATFORMS=cuda,cpu && "
# pin tests to a single GPU to avoid multiple workers competing for the same VRAM
_runtest_cmd1 += 'export CUDA_VISIBLE_DEVICES=0 && '
# make libdevice (nvvm/libdevice/libdevice.10.bc) discoverable at runtime for MLIR→PTX lowering during compilation
_runtest_cmd1 += 'export XLA_FLAGS="--xla_gpu_cuda_data_dir=$EBROOTCUDA" && '
# keep matmul/conv accuracy comparable to strict FP32 (no TF32)
_runtest_cmd1 += 'export NVIDIA_TF32_OVERRIDE=0 && '
# keep -n 1 for gpu tests - more xdist workers just pile multiple JAX/XLA processes onto the same device
_runtest_cmd1 += "pytest -n 1 tests "
# pallas GPU tests include kernels requiring >100 KiB shared memory per block
# TPU tests are irrelevant here
_runtest_cmd1 += '-k "not pallas and not tpu" '
# expect the literal string "INFO" from Python logging; with PJRT CUDA the logs often come via absl/glog ("I ...")
# or at DEBUG, so these are brittle and not a CUDA correctness check
_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_debug_logging "
_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_info_logging "
# exercises cuDNN Graph/Runtime fusion for a small/odd shape; engine availability is GPU+cuDNN-version dependent
# failing on sm_86
_runtest_cmd1 += "--deselect=tests/cudnn_fusion_test.py::CudnnFusionTest::test_cudnn_fusion0 "
# Checks a byte-for-byte serialized cuDNN RNN descriptor embedded in StableHLO;
# this legitimately changes across cuDNN minors. Skip for external builds.
_runtest_cmd1 += "--deselect=tests/experimental_rnn_test.py::RnnTest::test_struct_encoding_determinism "
# failing after fix jaxlib version by JAXLIB_RELEASE
_runtest_cmd1 += "--deselect=tests/version_test.py "
# retry randomly failing tests
_runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed'])

exts_list = [
(name, version, {
'patches': [
'jax-0.6.2_jax-version-fix.patch',
'jax-0.6.2_fix-slurm-vars.patch',
],
'runtest': _runtest_cmd,
'sources': [{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': '%(name)s-v%(version)s.tar.gz'
}],
'testinstall': True,
'checksums': [
{'jax-v0.6.2.tar.gz': 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
{'jax-0.6.2_jax-version-fix.patch': 'e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4'},
{'jax-0.6.2_fix-slurm-vars.patch': 'a551deb1723c091ad9c29ada7b9e52dcfaaff4098e03d4e1c7a3d1e75f129f42'},
],
}),
]

moduleclass = 'ai'
26 changes: 26 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.6.2_fix-slurm-vars.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
From d5c78b3abbfbccb9919be8600741409f16d0846e Mon Sep 17 00:00:00 2001
From: Alexander Grund <Flamefire@users.noreply.github.com>
Date: Wed, 22 Oct 2025 09:03:16 +0200
Subject: [PATCH] Check all environment variables for Slurm environment

In some environments only the SLURM_JOB_ID might be set, e.g. when using hooks for SSH to a node with an existing allocation

This causes a false positive in the detection and later `KeyError` on e.g. `SLURM_LOCALID`
---
jax/_src/clusters/slurm_cluster.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py
index 8cec07601094..a8bb1b8a287f 100644
--- a/jax/_src/clusters/slurm_cluster.py
+++ b/jax/_src/clusters/slurm_cluster.py
@@ -30,7 +30,8 @@ class SlurmCluster(clusters.ClusterEnv):

@classmethod
def is_env_present(cls) -> bool:
- return _JOBID_PARAM in os.environ
+ return all(var in os.environ for var in
+ (_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID))

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str: