Skip to content

{ai}[gfbf/2024a] jax v0.6.2 w/ CUDA 12.6.0#24141

Merged
boegel merged 3 commits intoeasybuilders:developfrom
pavelToman:20251007135115_new_pr_jax062
Dec 14, 2025
Merged

{ai}[gfbf/2024a] jax v0.6.2 w/ CUDA 12.6.0#24141
boegel merged 3 commits intoeasybuilders:developfrom
pavelToman:20251007135115_new_pr_jax062

Conversation

@pavelToman
Copy link
Collaborator

@pavelToman pavelToman commented Oct 7, 2025

@github-actions
Copy link

github-actions bot commented Oct 7, 2025

Updated software jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb

Diff against jax-0.6.2-gfbf-2024a.eb

easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb

diff --git a/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
index 7302e39a56..8283ce1ac8 100644
--- a/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb
+++ b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
@@ -7,18 +7,24 @@ easyblock = 'PythonBundle'
 
 name = 'jax'
 version = '0.6.2'
+versionsuffix = '-CUDA-%(cudaver)s'
 
 homepage = 'https://jax.readthedocs.io/'
 description = """Composable transformations of Python+NumPy programs:
 differentiate, vectorize, JIT to GPU/TPU, and more"""
 
 toolchain = {'name': 'gfbf', 'version': '2024a'}
+cuda_compute_capabilities = ["7.0", "7.5", "8.0", "8.6", "9.0"]
 
 builddependencies = [
     ('Bazel', '7.4.1', '-Java-21'),
+    ('Clang', '18.1.8', versionsuffix),
 ]
 
 dependencies = [
+    ('CUDA', '12.6.0', '', SYSTEM),
+    ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
+    ('NCCL', '2.22.3', versionsuffix),
     ('Python', '3.12.3'),
     ('SciPy-bundle', '2024.05'),
     ('absl-py', '2.1.0'),
@@ -32,13 +38,10 @@ dependencies = [
 # note: following commits *must* be the exact same onces used upstream
 # XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
 _xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc'
-
 # Use sources downloaded by EasyBuild
 _jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
-# create wheels for jaxlib and install it
-_jaxlib_buildopts += '--wheels=jaxlib '
-# use GCC instead of default Clang
-_jaxlib_buildopts += '--use_clang=false --gcc_path=$CC '
+# create wheels and install them
+_jaxlib_buildopts += '--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt '
 # fix jaxlib version
 _jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" '
 
@@ -69,25 +72,52 @@ components = [
     }),
 ]
 
-# BackendsTest is failing on systems with GPU detected when cpu version is built
-_failing_backend_test = 'tests/api_test.py::BackendsTest::test_no_backend_warning_on_cpu_if_platform_specified'
-# Failing after fix jaxlib version by JAXLIB_RELEASE
-_failing_version_tests = 'tests/version_test.py'
-_runtest_cmd1 = f"pytest -n %(parallel)s tests --deselect={_failing_backend_test} --deselect={_failing_version_tests}"
+# JAX otherwise preallocates ~75% of GPU memory per process; disabling avoids OOM/fragmentation.
+_runtest_cmd1 = "export XLA_PYTHON_CLIENT_PREALLOCATE=false && "
+# force CUDA backend to be considered first (avoid accidental ROCm/TPU plugins or CPU fallback)
+_runtest_cmd1 += "export JAX_PLATFORMS=cuda,cpu && "
+# pin tests to a single GPU to avoid multiple workers competing for the same VRAM
+_runtest_cmd1 += 'export CUDA_VISIBLE_DEVICES=0 && '
+# make libdevice (nvvm/libdevice/libdevice.10.bc) discoverable at runtime for MLIR→PTX lowering during compilation
+_runtest_cmd1 += 'export XLA_FLAGS="--xla_gpu_cuda_data_dir=$EBROOTCUDA" && '
+# keep matmul/conv accuracy comparable to strict FP32 (no TF32)
+_runtest_cmd1 += 'export NVIDIA_TF32_OVERRIDE=0 && '
+# keep -n 1 for gpu tests - more xdist workers just pile multiple JAX/XLA processes onto the same device
+_runtest_cmd1 += "pytest -n 1 tests "
+# pallas GPU tests include kernels requiring >100 KiB shared memory per block
+# TPU tests are irrelevant here
+_runtest_cmd1 += '-k "not pallas and not tpu" '
+# expect the literal string "INFO" from Python logging; with PJRT CUDA the logs often come via absl/glog ("I ...")
+# or at DEBUG, so these are brittle and not a CUDA correctness check
+_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_debug_logging "
+_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_info_logging "
+# exercises cuDNN Graph/Runtime fusion for a small/odd shape; engine availability is GPU+cuDNN-version dependent
+# failing on sm_86
+_runtest_cmd1 += "--deselect=tests/cudnn_fusion_test.py::CudnnFusionTest::test_cudnn_fusion0 "
+# Checks a byte-for-byte serialized cuDNN RNN descriptor embedded in StableHLO;
+# this legitimately changes across cuDNN minors. Skip for external builds.
+_runtest_cmd1 += "--deselect=tests/experimental_rnn_test.py::RnnTest::test_struct_encoding_determinism "
+# failing after fix jaxlib version by JAXLIB_RELEASE
+_runtest_cmd1 += "--deselect=tests/version_test.py "
 # retry randomly failing tests
 _runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed'])
 
 exts_list = [
     (name, version, {
-        "patches": ["jax-0.6.2_jax-version-fix.patch"],
-        "testinstall": True,
-        "runtest": _runtest_cmd,
-        "sources": [
-            {"source_urls": ["https://github.com/google/jax/archive/"], "filename": "%(name)s-v%(version)s.tar.gz"}
+        'patches': [
+            'jax-0.6.2_jax-version-fix.patch',
+            'jax-0.6.2_fix-slurm-vars.patch',
         ],
-        "checksums": [
-            {"jax-v0.6.2.tar.gz": "d46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17"},
-            {"jax-0.6.2_jax-version-fix.patch": "e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4"},
+        'runtest': _runtest_cmd,
+        'sources': [{
+            'source_urls': ['https://github.com/google/jax/archive/'],
+            'filename': '%(name)s-v%(version)s.tar.gz'
+        }],
+        'testinstall': True,
+        'checksums': [
+            {'jax-v0.6.2.tar.gz': 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
+            {'jax-0.6.2_jax-version-fix.patch': 'e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4'},
+            {'jax-0.6.2_fix-slurm-vars.patch': 'a551deb1723c091ad9c29ada7b9e52dcfaaff4098e03d4e1c7a3d1e75f129f42'},
         ],
     }),
 ]
Diff against jax-0.7.0-gfbf-2025a.eb

easybuild/easyconfigs/j/jax/jax-0.7.0-gfbf-2025a.eb

diff --git a/easybuild/easyconfigs/j/jax/jax-0.7.0-gfbf-2025a.eb b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
index 3dc61c2ca2..8283ce1ac8 100644
--- a/easybuild/easyconfigs/j/jax/jax-0.7.0-gfbf-2025a.eb
+++ b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
@@ -3,45 +3,47 @@
 # Updated by: Alex Domingo (Vrije Universiteit Brussel)
 # Updated by: Thomas Hoffmann (EMBL Heidelberg)
 # Updated by: Pavel Tománek (INUITS)
-
 easyblock = 'PythonBundle'
 
 name = 'jax'
-version = '0.7.0'
+version = '0.6.2'
+versionsuffix = '-CUDA-%(cudaver)s'
 
 homepage = 'https://jax.readthedocs.io/'
 description = """Composable transformations of Python+NumPy programs:
 differentiate, vectorize, JIT to GPU/TPU, and more"""
 
-toolchain = {'name': 'gfbf', 'version': '2025a'}
+toolchain = {'name': 'gfbf', 'version': '2024a'}
+cuda_compute_capabilities = ["7.0", "7.5", "8.0", "8.6", "9.0"]
 
 builddependencies = [
     ('Bazel', '7.4.1', '-Java-21'),
+    ('Clang', '18.1.8', versionsuffix),
 ]
 
 dependencies = [
-    ('Python', '3.13.1'),
-    ('SciPy-bundle', '2025.06'),
-    ('absl-py', '2.3.1'),
-    ('flatbuffers-python', '25.2.10'),
-    ('ml_dtypes', '0.5.1'),
-    ('hypothesis', '6.133.2'),
+    ('CUDA', '12.6.0', '', SYSTEM),
+    ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
+    ('NCCL', '2.22.3', versionsuffix),
+    ('Python', '3.12.3'),
+    ('SciPy-bundle', '2024.05'),
+    ('absl-py', '2.1.0'),
+    ('flatbuffers-python', '24.3.25'),
+    ('ml_dtypes', '0.5.0'),
+    ('hypothesis', '6.103.1'),
     ('zlib', '1.3.1'),
 ]
 
 # downloading xla  tarball to avoid that Bazel downloads it during the build
 # note: following commits *must* be the exact same onces used upstream
 # XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
-_xla_commit = '3d25fed5571304e446903bc00e4f457b2b0f73dc'
-
+_xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc'
 # Use sources downloaded by EasyBuild
 _jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
-# create wheels for jaxlib and install it
-_jaxlib_buildopts += '--wheels=jaxlib '
-# use GCC instead of default Clang
-_jaxlib_buildopts += '--use_clang=false --gcc_path=$CC '
+# create wheels and install them
+_jaxlib_buildopts += '--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt '
 # fix jaxlib version
-_jaxlib_buildopts += "--bazel_options='--action_env=JAXLIB_RELEASE' "
+_jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" '
 
 components = [
     ('jaxlib', version, {
@@ -57,14 +59,11 @@ components = [
                 'extract_cmd': 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives',
             },
         ],
-        'patches': ['jax-0.7.0_fix-mosaic.patch'],
         'checksums': [
-            {'jax-v0.7.0.tar.gz':
-             '518966801e4402667e77915c2dc7cf1a178a80e22ff253204a837f207a87fcde'},
-            {'xla-3d25fed5.tar.gz':
-             '9efd7d303edab24fd8552d602045722f462870d17b888fa607e9b7143b9e0515'},
-            {'jax-0.7.0_fix-mosaic.patch':
-             '8ce4bcfd1cc6a0e42288c58ee28af5d29363ca72dbde3436ee60c336f1efc575'}
+            {'jax-v0.6.2.tar.gz':
+             'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
+            {'xla-3d5ece64.tar.gz':
+             'fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a'},
         ],
         'start_dir': f'jax-jax-v{version}',
         # fix jaxlib version - removes .dev suffix
@@ -73,25 +72,52 @@ components = [
     }),
 ]
 
-# BackendsTest is failing on systems with GPU detected when cpu version is built
-_failing_backend_test = 'tests/api_test.py::BackendsTest::test_no_backend_warning_on_cpu_if_platform_specified'
-# Failing after fix jaxlib version by JAXLIB_RELEASE
-_failing_version_tests = 'tests/version_test.py'
-_runtest_cmd1 = f"pytest -n %(parallel)s tests --deselect={_failing_backend_test} --deselect={_failing_version_tests}"
+# JAX otherwise preallocates ~75% of GPU memory per process; disabling avoids OOM/fragmentation.
+_runtest_cmd1 = "export XLA_PYTHON_CLIENT_PREALLOCATE=false && "
+# force CUDA backend to be considered first (avoid accidental ROCm/TPU plugins or CPU fallback)
+_runtest_cmd1 += "export JAX_PLATFORMS=cuda,cpu && "
+# pin tests to a single GPU to avoid multiple workers competing for the same VRAM
+_runtest_cmd1 += 'export CUDA_VISIBLE_DEVICES=0 && '
+# make libdevice (nvvm/libdevice/libdevice.10.bc) discoverable at runtime for MLIR→PTX lowering during compilation
+_runtest_cmd1 += 'export XLA_FLAGS="--xla_gpu_cuda_data_dir=$EBROOTCUDA" && '
+# keep matmul/conv accuracy comparable to strict FP32 (no TF32)
+_runtest_cmd1 += 'export NVIDIA_TF32_OVERRIDE=0 && '
+# keep -n 1 for gpu tests - more xdist workers just pile multiple JAX/XLA processes onto the same device
+_runtest_cmd1 += "pytest -n 1 tests "
+# pallas GPU tests include kernels requiring >100 KiB shared memory per block
+# TPU tests are irrelevant here
+_runtest_cmd1 += '-k "not pallas and not tpu" '
+# expect the literal string "INFO" from Python logging; with PJRT CUDA the logs often come via absl/glog ("I ...")
+# or at DEBUG, so these are brittle and not a CUDA correctness check
+_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_debug_logging "
+_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_info_logging "
+# exercises cuDNN Graph/Runtime fusion for a small/odd shape; engine availability is GPU+cuDNN-version dependent
+# failing on sm_86
+_runtest_cmd1 += "--deselect=tests/cudnn_fusion_test.py::CudnnFusionTest::test_cudnn_fusion0 "
+# Checks a byte-for-byte serialized cuDNN RNN descriptor embedded in StableHLO;
+# this legitimately changes across cuDNN minors. Skip for external builds.
+_runtest_cmd1 += "--deselect=tests/experimental_rnn_test.py::RnnTest::test_struct_encoding_determinism "
+# failing after fix jaxlib version by JAXLIB_RELEASE
+_runtest_cmd1 += "--deselect=tests/version_test.py "
 # retry randomly failing tests
 _runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed'])
 
 exts_list = [
     (name, version, {
-        'patches': ['jax-0.7.0_jax-version-fix.patch'],
-        'testinstall': True,
-        "runtest": _runtest_cmd,
-        'sources': [
-            {'source_urls': ['https://github.com/google/jax/archive/'], 'filename': '%(name)s-v%(version)s.tar.gz'}
+        'patches': [
+            'jax-0.6.2_jax-version-fix.patch',
+            'jax-0.6.2_fix-slurm-vars.patch',
         ],
+        'runtest': _runtest_cmd,
+        'sources': [{
+            'source_urls': ['https://github.com/google/jax/archive/'],
+            'filename': '%(name)s-v%(version)s.tar.gz'
+        }],
+        'testinstall': True,
         'checksums': [
-            {'jax-v0.7.0.tar.gz': '518966801e4402667e77915c2dc7cf1a178a80e22ff253204a837f207a87fcde'},
-            {'jax-0.7.0_jax-version-fix.patch': 'c09ec85af1faa78146fa57ad2ec2324b643c9059458f86781f062ab4889a7e90'},
+            {'jax-v0.6.2.tar.gz': 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
+            {'jax-0.6.2_jax-version-fix.patch': 'e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4'},
+            {'jax-0.6.2_fix-slurm-vars.patch': 'a551deb1723c091ad9c29ada7b9e52dcfaaff4098e03d4e1c7a3d1e75f129f42'},
         ],
     }),
 ]
Diff against jax-0.4.25-gfbf-2023a.eb

easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a.eb

diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a.eb b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
index e61bc4719d..8283ce1ac8 100644
--- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a.eb
+++ b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb
@@ -1,104 +1,124 @@
 # This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
 # Author: Denis Kristak
 # Updated by: Alex Domingo (Vrije Universiteit Brussel)
-# Updated by: Pavel Tománek (INUITS)
 # Updated by: Thomas Hoffmann (EMBL Heidelberg)
+# Updated by: Pavel Tománek (INUITS)
 easyblock = 'PythonBundle'
 
 name = 'jax'
-version = '0.4.25'
+version = '0.6.2'
+versionsuffix = '-CUDA-%(cudaver)s'
 
 homepage = 'https://jax.readthedocs.io/'
 description = """Composable transformations of Python+NumPy programs:
 differentiate, vectorize, JIT to GPU/TPU, and more"""
 
-toolchain = {'name': 'gfbf', 'version': '2023a'}
+toolchain = {'name': 'gfbf', 'version': '2024a'}
+cuda_compute_capabilities = ["7.0", "7.5", "8.0", "8.6", "9.0"]
 
 builddependencies = [
-    ('Bazel', '6.3.1'),
-    ('pytest-xdist', '3.3.1'),
-    ('git', '2.41.0', '-nodocs'),  # bazel uses git to fetch repositories
-    ('matplotlib', '3.7.2'),  # required for tests/lobpcg_test.py
-    ('poetry', '1.5.1'),
-    ('pybind11', '2.11.1'),
+    ('Bazel', '7.4.1', '-Java-21'),
+    ('Clang', '18.1.8', versionsuffix),
 ]
 
 dependencies = [
-    ('Python', '3.11.3'),
-    ('SciPy-bundle', '2023.07'),
+    ('CUDA', '12.6.0', '', SYSTEM),
+    ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
+    ('NCCL', '2.22.3', versionsuffix),
+    ('Python', '3.12.3'),
+    ('SciPy-bundle', '2024.05'),
     ('absl-py', '2.1.0'),
-    ('flatbuffers-python', '23.5.26'),
-    ('ml_dtypes', '0.3.2'),
-    ('zlib', '1.2.13'),
+    ('flatbuffers-python', '24.3.25'),
+    ('ml_dtypes', '0.5.0'),
+    ('hypothesis', '6.103.1'),
+    ('zlib', '1.3.1'),
 ]
 
-# downloading xla and other tarballs to avoid that Bazel downloads it during the build
-local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
+# downloading xla  tarball to avoid that Bazel downloads it during the build
 # note: following commits *must* be the exact same onces used upstream
 # XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
-local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4'
-# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
-local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'
-
+_xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc'
 # Use sources downloaded by EasyBuild
 _jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
-# Use dependencies from EasyBuild
-_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
-_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '
-# Avoid warning (treated as error) in upb/table.c
-_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" '
+# create wheels and install them
+_jaxlib_buildopts += '--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt '
+# fix jaxlib version
+_jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" '
 
 components = [
     ('jaxlib', version, {
         'sources': [
             {
                 'source_urls': ['https://github.com/google/jax/archive/'],
-                'filename': '%(name)s-v%(version)s.tar.gz',
+                'filename': 'jax-v%(version)s.tar.gz',
             },
             {
                 'source_urls': ['https://github.com/openxla/xla/archive'],
-                'download_filename': '%s.tar.gz' % local_xla_commit,
-                'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
-                'extract_cmd': local_extract_cmd,
-            },
-            {
-                'source_urls': ['https://github.com/tensorflow/runtime/archive'],
-                'download_filename': '%s.tar.gz' % local_tfrt_commit,
-                'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
-                'extract_cmd': local_extract_cmd,
+                'download_filename': f'{_xla_commit}.tar.gz',
+                'filename': f'xla-{_xla_commit[:8]}.tar.gz',
+                'extract_cmd': 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives',
             },
         ],
-        'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'],
         'checksums': [
-            {'jaxlib-v0.4.25.tar.gz':
-             'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'},
-            {'xla-4ccfe33c.tar.gz':
-             '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'},
-            {'tf_runtime-0aeefb16.tar.gz':
-             'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
-            {'jax-0.4.25_fix-pybind11-systemlib.patch':
-             'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'},
+            {'jax-v0.6.2.tar.gz':
+             'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
+            {'xla-3d5ece64.tar.gz':
+             'fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a'},
         ],
-        'start_dir': 'jax-jaxlib-v%(version)s',
-        'buildopts': _jaxlib_buildopts
+        'start_dir': f'jax-jax-v{version}',
+        # fix jaxlib version - removes .dev suffix
+        'prebuildopts': 'export JAXLIB_RELEASE=%(version)s && ',
+        'buildopts': _jaxlib_buildopts,
     }),
 ]
 
+# JAX otherwise preallocates ~75% of GPU memory per process; disabling avoids OOM/fragmentation.
+_runtest_cmd1 = "export XLA_PYTHON_CLIENT_PREALLOCATE=false && "
+# force CUDA backend to be considered first (avoid accidental ROCm/TPU plugins or CPU fallback)
+_runtest_cmd1 += "export JAX_PLATFORMS=cuda,cpu && "
+# pin tests to a single GPU to avoid multiple workers competing for the same VRAM
+_runtest_cmd1 += 'export CUDA_VISIBLE_DEVICES=0 && '
+# make libdevice (nvvm/libdevice/libdevice.10.bc) discoverable at runtime for MLIR→PTX lowering during compilation
+_runtest_cmd1 += 'export XLA_FLAGS="--xla_gpu_cuda_data_dir=$EBROOTCUDA" && '
+# keep matmul/conv accuracy comparable to strict FP32 (no TF32)
+_runtest_cmd1 += 'export NVIDIA_TF32_OVERRIDE=0 && '
+# keep -n 1 for gpu tests - more xdist workers just pile multiple JAX/XLA processes onto the same device
+_runtest_cmd1 += "pytest -n 1 tests "
+# pallas GPU tests include kernels requiring >100 KiB shared memory per block
+# TPU tests are irrelevant here
+_runtest_cmd1 += '-k "not pallas and not tpu" '
+# expect the literal string "INFO" from Python logging; with PJRT CUDA the logs often come via absl/glog ("I ...")
+# or at DEBUG, so these are brittle and not a CUDA correctness check
+_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_debug_logging "
+_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_info_logging "
+# exercises cuDNN Graph/Runtime fusion for a small/odd shape; engine availability is GPU+cuDNN-version dependent
+# failing on sm_86
+_runtest_cmd1 += "--deselect=tests/cudnn_fusion_test.py::CudnnFusionTest::test_cudnn_fusion0 "
+# Checks a byte-for-byte serialized cuDNN RNN descriptor embedded in StableHLO;
+# this legitimately changes across cuDNN minors. Skip for external builds.
+_runtest_cmd1 += "--deselect=tests/experimental_rnn_test.py::RnnTest::test_struct_encoding_determinism "
+# failing after fix jaxlib version by JAXLIB_RELEASE
+_runtest_cmd1 += "--deselect=tests/version_test.py "
+# retry randomly failing tests
+_runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed'])
+
 exts_list = [
     (name, version, {
-        'sources': [
-            {
-                'source_urls': ['https://github.com/google/jax/archive/'],
-                'filename': '%(name)s-v%(version)s.tar.gz',
-            },
+        'patches': [
+            'jax-0.6.2_jax-version-fix.patch',
+            'jax-0.6.2_fix-slurm-vars.patch',
         ],
-        'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'],
+        'runtest': _runtest_cmd,
+        'sources': [{
+            'source_urls': ['https://github.com/google/jax/archive/'],
+            'filename': '%(name)s-v%(version)s.tar.gz'
+        }],
+        'testinstall': True,
         'checksums': [
-            {'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'},
-            {'jax-0.4.25_fix_env_test_no_log_spam.patch':
-             'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'},
+            {'jax-v0.6.2.tar.gz': 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
+            {'jax-0.6.2_jax-version-fix.patch': 'e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4'},
+            {'jax-0.6.2_fix-slurm-vars.patch': 'a551deb1723c091ad9c29ada7b9e52dcfaaff4098e03d4e1c7a3d1e75f129f42'},
         ],
-        'runtest': "pytest -n %(parallel)s tests",
     }),
 ]
 

@pavelToman
Copy link
Collaborator Author

@boegelbot please test @ jsc-zen3-a100
EB_ARGS="--include-easyblocks-from-pr 3951"
CORE_CNT=16

@boegelbot
Copy link
Collaborator

@pavelToman: Request for testing this PR well received on jsczen3l1.int.jsc-zen3.fz-juelich.de

PR test command 'if [[ develop != 'develop' ]]; then EB_BRANCH=develop ./easybuild_develop.sh 2> /dev/null 1>&2; EB_PREFIX=/home/boegelbot/easybuild/develop source init_env_easybuild_develop.sh; fi; EB_PR=24141 EB_ARGS="--include-easyblocks-from-pr 3951" EB_CONTAINER= EB_REPO=easybuild-easyconfigs EB_BRANCH=develop /opt/software/slurm/bin/sbatch --job-name test_PR_24141 --ntasks="16" --partition=jsczen3g --gres=gpu:1 ~/boegelbot/eb_from_pr_upload_jsc-zen3.sh' executed!

  • exit code: 0
  • output:
Submitted batch job 8187

Test results coming soon (I hope)...

Details

- notification for comment with ID 3376569345 processed

Message to humans: this is just bookkeeping information for me,
it is of no use to you (unless you think I have a bug, which I don't).

@pavelToman pavelToman added the 2024a issues & PRs related to 2024a common toolchains label Oct 7, 2025
@boegelbot
Copy link
Collaborator

Test report by @boegelbot
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 1 out of 1 (1 easyconfigs in total)
jsczen3g1.int.jsc-zen3.fz-juelich.de - Linux Rocky Linux 9.6, x86_64, AMD EPYC-Milan Processor (zen3), 1 x NVIDIA NVIDIA A100 80GB PCIe, 580.82.07, Python 3.9.21
See https://gist.github.com/boegelbot/1a7a61a6ba06289f3070422aa11ee825 for a full test report.

@Flamefire
Copy link
Contributor

Test report by @Flamefire
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
FAILED
Build succeeded for 1 out of 2 (1 easyconfigs in total)
i8022 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 7352 24-Core Processor (zen2), 8 x NVIDIA NVIDIA A100-SXM4-40GB, 580.65.06, Python 3.9.21
See https://gist.github.com/Flamefire/669280b42a97cdac50d9ec2701354516 for a full test report.

@pavelToman
Copy link
Collaborator Author

Test report by @pavelToman
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 1 out of 1 (1 easyconfigs in total)
node4006.donphan.os - Linux RHEL 9.6, x86_64, Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz, 1 x NVIDIA NVIDIA A2, 580.82.07, Python 3.9.21
See https://gist.github.com/pavelToman/5518ad35e321458a6dc43c393be43bc5 for a full test report.

@pavelToman
Copy link
Collaborator Author

pavelToman commented Oct 8, 2025

Test report by @Flamefire Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951 FAILED Build succeeded for 1 out of 2 (1 easyconfigs in total) i8022 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 7352 24-Core Processor (zen2), 8 x NVIDIA NVIDIA A100-SXM4-40GB, 580.65.06, Python 3.9.21 See https://gist.github.com/Flamefire/669280b42a97cdac50d9ec2701354516 for a full test report.

One test error - missing SLURM_LOCALID

@Flamefire
Copy link
Contributor

Hm, it detects SLURM_JOB_ID and hence assumes a SLURM job. Shall we unset that (or all SLURM vars) in pretestopts?

@pavelToman
Copy link
Collaborator Author

pavelToman commented Oct 8, 2025

Hm, it detects SLURM_JOB_ID and hence assumes a SLURM job. Shall we unset that (or all SLURM vars) in pretestopts?

The same problem had @boegel, what was the solution of "missing SLURM_LOCALID"?

EDIT:
I found it: #23530 (comment)
The solutions seems to be to unset SLURM_JOB_ID as you suggest, but other slurm envs could stay?

@Flamefire
Copy link
Contributor

The solutions seems to be to unset SLURM_JOB_ID as you suggest, but other slurm envs could stay?

I think so

@pavelToman
Copy link
Collaborator Author

Test report by @pavelToman
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
FAILED
Build succeeded for 1 out of 2 (1 easyconfigs in total)
node4305.litleo.os - Linux RHEL 9.6, x86_64, AMD EPYC 9454P 48-Core Processor, 1 x NVIDIA NVIDIA H100 NVL, 580.82.07, Python 3.9.21
See https://gist.github.com/pavelToman/ada540027d6e6b00c65612b2cdb846e6 for a full test report.

@pavelToman
Copy link
Collaborator Author

Test report by @pavelToman
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
FAILED
Build succeeded for 1 out of 2 (1 easyconfigs in total)
node3902.accelgor.os - Linux RHEL 9.6, x86_64, AMD EPYC 7413 24-Core Processor, 1 x NVIDIA NVIDIA A100-SXM4-80GB, 580.82.07, Python 3.9.21
See https://gist.github.com/pavelToman/cbf7890cf96da837537349f1715f9d89 for a full test report.

@pavelToman
Copy link
Collaborator Author

pavelToman commented Oct 9, 2025

Both failures on accelgor and litleo comes from 4 tests:

------------------------------------------------ Captured stderr call -------------------------------------------------
E1008 20:27:59.487242 3316452 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: INTERNAL: NCCL operation ncclCommInitRankConfig(&comm, clique_key.num_devices(), nccl_unique_id, ranks[i].rank.value(), &comm_config) failed: unhandled cuda error (run with NCCL_DEBUG=INFO for details). Last NCCL warning(error) log entry (may be unrelated) 'Cuda failure 'named symbol not found''.
=============================================== short test summary info ================================================
FAILED tests/pmap_test.py::PythonPmapTest::testCollectivePermuteGrad - jaxlib._jax.XlaRuntimeError: INTERNAL: NCCL op...
FAILED tests/pmap_test.py::CppPmapTest::testCollectivePermuteGrad - jaxlib._jax.XlaRuntimeError: INTERNAL: NCCL opera...
FAILED tests/pmap_test.py::PythonPmapEagerTest::testCollectivePermuteGrad - jaxlib._jax.XlaRuntimeError: INTERNAL: NC...
FAILED tests/pmap_test.py::CppPmapEagerTest::testCollectivePermuteGrad - jaxlib._jax.XlaRuntimeError: INTERNAL: NCCL ...
================================================== 4 failed in 23.85s ==================================================

It seems as a problem with NCCL/CUDA on the clusters - will investigate what is going on there

EDIT: NCCL was build with cuda-compute-capabilities=8.6 - going to rebuild NCCL with right cuda-compute-capabilities for each cluster (8.0 and 9.0 for accelgor and litleo).

@pavelToman
Copy link
Collaborator Author

pavelToman commented Oct 9, 2025

Test report by @pavelToman
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 1 out of 1 (1 easyconfigs in total)
node3904.accelgor.os - Linux RHEL 9.6, x86_64, AMD EPYC 7413 24-Core Processor, 1 x NVIDIA NVIDIA A100-SXM4-80GB, 580.82.07, Python 3.9.21
See https://gist.github.com/pavelToman/5d32afe09d85bfbf8c0bcc86e917131d for a full test report.

Also on litleo all tests passed but installation failed just after COMPLETED: Installation ended successfully (took 3 hours 56 mins 22 secs) with:

EasyBuild crashed! Please consider reporting a bug, this should not happen...

Traceback (most recent call last):
  File "/usr/lib64/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib64/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/main.py", line 833, in <module>
    main_with_hooks()
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/main.py", line 819, in main_with_hooks
    main(args=args, prepared_cfg_data=(init_session_state, eb_go, cfg_settings))
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/main.py", line 774, in main
    do_cleanup = process_eb_args(orig_paths, eb_go, cfg_settings, modtool, testing, init_session_state,
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/main.py", line 610, in process_eb_args
    test_report_msg = overall_test_report(ecs_with_res, len(paths), overall_success, success_msg, init_session_state)
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/tools/testing.py", line 455, in overall_test_report 
    txt = post_pr_test_report(pr_nrs, GITHUB_EASYCONFIGS_REPO, test_report, msg, init_session_state,
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/tools/testing.py", line 358, in post_pr_test_report 
    gist_url = upload_test_report_as_gist(test_report['full'], descr=descr, fn=fn)
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/tools/testing.py", line 331, in upload_test_report_as_gist
    gist_url = create_gist(test_report, descr=descr, fn=fn, github_user=github_user)
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/tools/github.py", line 837, in create_gist
    status, data = g.gists.post(body=body)
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/base/rest.py", line 140, in post
    return self.request(self.POST, url, json.dumps(body), headers, content_type='application/json')
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/base/rest.py", line 177, in request
    conn = self.get_connection(method, url, body, headers)
  File "/user/gent/470/vsc47063/easybuild/easybuild-framework/easybuild/base/rest.py", line 216, in get_connection
    connection = self.opener.open(request)
  File "/usr/lib64/python3.9/urllib/request.py", line 523, in open
    response = meth(req, response)
  File "/usr/lib64/python3.9/urllib/request.py", line 632, in http_response
    response = self.parent.error(
  File "/usr/lib64/python3.9/urllib/request.py", line 561, in error
    return self._call_chain(*args)
  File "/usr/lib64/python3.9/urllib/request.py", line 494, in _call_chain
    result = func(*args)
  File "/usr/lib64/python3.9/urllib/request.py", line 641, in http_error_default
    raise HTTPError(req.full_url, code, msg, hdrs, fp)   
urllib.error.HTTPError: HTTP Error 500: Internal Server Error

@pavelToman
Copy link
Collaborator Author

Test report by @pavelToman
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 1 out of 1 (1 easyconfigs in total)
node4303.litleo.os - Linux RHEL 9.6, x86_64, AMD EPYC 9454P 48-Core Processor, 1 x NVIDIA NVIDIA H100 NVL, 580.82.07, Python 3.9.21
See https://gist.github.com/pavelToman/8b89a1d60bb5353949eec70976526359 for a full test report.

@smoors smoors assigned smoors and unassigned smoors Oct 13, 2025
@Flamefire
Copy link
Contributor

Test report by @Flamefire
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 1 out of 1 (1 easyconfigs in total)
c108 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 9334 32-Core Processor (zen4), 4 x NVIDIA NVIDIA H100, 580.65.06, Python 3.9.21
See https://gist.github.com/Flamefire/fea8810a1cc9da93c751cc2f6f54c8a6 for a full test report.

@Flamefire
Copy link
Contributor

Test report by @Flamefire
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
FAILED
Build succeeded for 0 out of 2 (1 easyconfigs in total)
i8020 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 7352 24-Core Processor (zen2), 8 x NVIDIA NVIDIA A100-SXM4-40GB, 580.65.06, Python 3.9.21
See https://gist.github.com/Flamefire/d20339513e6d412f7df1c56d1ae5f17b for a full test report.

@boegel
Copy link
Member

boegel commented Oct 21, 2025

@Flamefire Can you try again after fixing the missing Clang/18.1.8-GCCcore-13.3.0-CUDA-12.6.0 ?

@Flamefire
Copy link
Contributor

Test report by @Flamefire
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
FAILED
Build succeeded for 0 out of 1 (1 easyconfigs in total)
i8016 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 7352 24-Core Processor (zen2), 8 x NVIDIA NVIDIA A100-SXM4-40GB, 580.65.06, Python 3.9.21
See https://gist.github.com/Flamefire/c8b332e781555ea7c7716bd7ef2e4267 for a full test report.

@Flamefire
Copy link
Contributor

Flamefire commented Oct 22, 2025

Test report by @Flamefire Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951 FAILED Build succeeded for 1 out of 2 (1 easyconfigs in total) i8022 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 7352 24-Core Processor (zen2), 8 x NVIDIA NVIDIA A100-SXM4-40GB, 580.65.06, Python 3.9.21 See https://gist.github.com/Flamefire/669280b42a97cdac50d9ec2701354516 for a full test report.

One test error - missing SLURM_LOCALID

Same error as before with the latest test.
Maybe we can just patch https://github.com/jax-ml/jax/blob/7963bcc31968aec712e27807ac25b397fdb7a3ee/jax/_src/clusters/slurm_cluster.py#L33 to check for _LOCAL_PROCESS_ID too?

Asking upstream: jax-ml/jax#32799

@pavelToman
Copy link
Collaborator Author

Test report by @Flamefire Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951 FAILED Build succeeded for 1 out of 2 (1 easyconfigs in total) i8022 - Linux Rocky Linux 9.6, x86_64, AMD EPYC 7352 24-Core Processor (zen2), 8 x NVIDIA NVIDIA A100-SXM4-40GB, 580.65.06, Python 3.9.21 See https://gist.github.com/Flamefire/669280b42a97cdac50d9ec2701354516 for a full test report.

One test error - missing SLURM_LOCALID

Same error as before with the latest test. Maybe we can just patch https://github.com/jax-ml/jax/blob/7963bcc31968aec712e27807ac25b397fdb7a3ee/jax/_src/clusters/slurm_cluster.py#L33 to check for _LOCAL_PROCESS_ID too?

Asking upstream: jax-ml/jax#32799

Should I made a patch with this change in jax/_src/clusters/slurm_cluster.py ?

@Flamefire
Copy link
Contributor

Should I made a patch with this change in jax/_src/clusters/slurm_cluster.py ?

I think that's best, yes, as other might run into the same issue

@pavelToman
Copy link
Collaborator Author

@boegelbot please test @ jsc-zen3-a100
EB_ARGS="--include-easyblocks-from-pr 3951"
CORE_CNT=16

@boegelbot
Copy link
Collaborator

@pavelToman: Request for testing this PR well received on jsczen3l1.int.jsc-zen3.fz-juelich.de

PR test command 'if [[ develop != 'develop' ]]; then EB_BRANCH=develop ./easybuild_develop.sh 2> /dev/null 1>&2; EB_PREFIX=/home/boegelbot/easybuild/develop source init_env_easybuild_develop.sh; fi; EB_PR=24141 EB_ARGS="--include-easyblocks-from-pr 3951" EB_CONTAINER= EB_REPO=easybuild-easyconfigs EB_BRANCH=develop /opt/software/slurm/bin/sbatch --job-name test_PR_24141 --ntasks="16" --partition=jsczen3g --gres=gpu:1 ~/boegelbot/eb_from_pr_upload_jsc-zen3.sh' executed!

  • exit code: 0
  • output:
Submitted batch job 8631

Test results coming soon (I hope)...

Details

- notification for comment with ID 3484937147 processed

Message to humans: this is just bookkeeping information for me,
it is of no use to you (unless you think I have a bug, which I don't).

@boegelbot
Copy link
Collaborator

Test report by @boegelbot
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 1 out of 1 (total: 6 hours 46 mins 21 secs) (1 easyconfigs in total)
jsczen3g1.int.jsc-zen3.fz-juelich.de - Linux Rocky Linux 9.6, x86_64, AMD EPYC-Milan Processor (zen3), 1 x NVIDIA NVIDIA A100 80GB PCIe, 580.95.05, Python 3.9.21
See https://gist.github.com/boegelbot/65744bcd0846d181cfcd84b2773def5b for a full test report.

@boegel
Copy link
Member

boegel commented Nov 20, 2025

Test report by @boegel
Using easyblocks from PR(s) easybuilders/easybuild-easyblocks#3951
SUCCESS
Build succeeded for 3 out of 3 (1 easyconfigs in total)
node3908.accelgor.os - Linux RHEL 9.6, x86_64, AMD EPYC 7413 24-Core Processor (zen3), 1 x NVIDIA NVIDIA A100-SXM4-80GB, 580.95.05, Python 3.9.21
See https://gist.github.com/boegel/b09484a7ccd26b9413156ddd860fe4fd for a full test report.

@boegel boegel added this to the next release (5.2.0) milestone Dec 14, 2025
Copy link
Member

@boegel boegel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@boegel
Copy link
Member

boegel commented Dec 14, 2025

Going in, thanks @pavelToman!

@boegel boegel merged commit 6e6abce into easybuilders:develop Dec 14, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2024a issues & PRs related to 2024a common toolchains update

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AlphaFold 3

5 participants