diff --git a/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb new file mode 100644 index 000000000000..8283ce1ac89e --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a-CUDA-12.6.0.eb @@ -0,0 +1,125 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +# Updated by: Pavel Tománek (INUITS) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.6.2' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://jax.readthedocs.io/' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'gfbf', 'version': '2024a'} +cuda_compute_capabilities = ["7.0", "7.5", "8.0", "8.6", "9.0"] + +builddependencies = [ + ('Bazel', '7.4.1', '-Java-21'), + ('Clang', '18.1.8', versionsuffix), +] + +dependencies = [ + ('CUDA', '12.6.0', '', SYSTEM), + ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM), + ('NCCL', '2.22.3', versionsuffix), + ('Python', '3.12.3'), + ('SciPy-bundle', '2024.05'), + ('absl-py', '2.1.0'), + ('flatbuffers-python', '24.3.25'), + ('ml_dtypes', '0.5.0'), + ('hypothesis', '6.103.1'), + ('zlib', '1.3.1'), +] + +# downloading xla tarball to avoid that Bazel downloads it during the build +# note: following commits *must* be the exact same onces used upstream +# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl +_xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc' +# Use sources downloaded by EasyBuild +_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" ' +# create wheels and install them +_jaxlib_buildopts += '--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt ' +# fix jaxlib version +_jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" ' + +components = [ + ('jaxlib', version, { + 'sources': [ + { + 'source_urls': ['https://github.com/google/jax/archive/'], + 'filename': 'jax-v%(version)s.tar.gz', + }, + { + 'source_urls': ['https://github.com/openxla/xla/archive'], + 'download_filename': f'{_xla_commit}.tar.gz', + 'filename': f'xla-{_xla_commit[:8]}.tar.gz', + 'extract_cmd': 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives', + }, + ], + 'checksums': [ + {'jax-v0.6.2.tar.gz': + 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'}, + {'xla-3d5ece64.tar.gz': + 'fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a'}, + ], + 'start_dir': f'jax-jax-v{version}', + # fix jaxlib version - removes .dev suffix + 'prebuildopts': 'export JAXLIB_RELEASE=%(version)s && ', + 'buildopts': _jaxlib_buildopts, + }), +] + +# JAX otherwise preallocates ~75% of GPU memory per process; disabling avoids OOM/fragmentation. +_runtest_cmd1 = "export XLA_PYTHON_CLIENT_PREALLOCATE=false && " +# force CUDA backend to be considered first (avoid accidental ROCm/TPU plugins or CPU fallback) +_runtest_cmd1 += "export JAX_PLATFORMS=cuda,cpu && " +# pin tests to a single GPU to avoid multiple workers competing for the same VRAM +_runtest_cmd1 += 'export CUDA_VISIBLE_DEVICES=0 && ' +# make libdevice (nvvm/libdevice/libdevice.10.bc) discoverable at runtime for MLIR→PTX lowering during compilation +_runtest_cmd1 += 'export XLA_FLAGS="--xla_gpu_cuda_data_dir=$EBROOTCUDA" && ' +# keep matmul/conv accuracy comparable to strict FP32 (no TF32) +_runtest_cmd1 += 'export NVIDIA_TF32_OVERRIDE=0 && ' +# keep -n 1 for gpu tests - more xdist workers just pile multiple JAX/XLA processes onto the same device +_runtest_cmd1 += "pytest -n 1 tests " +# pallas GPU tests include kernels requiring >100 KiB shared memory per block +# TPU tests are irrelevant here +_runtest_cmd1 += '-k "not pallas and not tpu" ' +# expect the literal string "INFO" from Python logging; with PJRT CUDA the logs often come via absl/glog ("I ...") +# or at DEBUG, so these are brittle and not a CUDA correctness check +_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_debug_logging " +_runtest_cmd1 += "--deselect=tests/logging_test.py::LoggingTest::test_subprocess_stderr_info_logging " +# exercises cuDNN Graph/Runtime fusion for a small/odd shape; engine availability is GPU+cuDNN-version dependent +# failing on sm_86 +_runtest_cmd1 += "--deselect=tests/cudnn_fusion_test.py::CudnnFusionTest::test_cudnn_fusion0 " +# Checks a byte-for-byte serialized cuDNN RNN descriptor embedded in StableHLO; +# this legitimately changes across cuDNN minors. Skip for external builds. +_runtest_cmd1 += "--deselect=tests/experimental_rnn_test.py::RnnTest::test_struct_encoding_determinism " +# failing after fix jaxlib version by JAXLIB_RELEASE +_runtest_cmd1 += "--deselect=tests/version_test.py " +# retry randomly failing tests +_runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed']) + +exts_list = [ + (name, version, { + 'patches': [ + 'jax-0.6.2_jax-version-fix.patch', + 'jax-0.6.2_fix-slurm-vars.patch', + ], + 'runtest': _runtest_cmd, + 'sources': [{ + 'source_urls': ['https://github.com/google/jax/archive/'], + 'filename': '%(name)s-v%(version)s.tar.gz' + }], + 'testinstall': True, + 'checksums': [ + {'jax-v0.6.2.tar.gz': 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'}, + {'jax-0.6.2_jax-version-fix.patch': 'e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4'}, + {'jax-0.6.2_fix-slurm-vars.patch': 'a551deb1723c091ad9c29ada7b9e52dcfaaff4098e03d4e1c7a3d1e75f129f42'}, + ], + }), +] + +moduleclass = 'ai' diff --git a/easybuild/easyconfigs/j/jax/jax-0.6.2_fix-slurm-vars.patch b/easybuild/easyconfigs/j/jax/jax-0.6.2_fix-slurm-vars.patch new file mode 100644 index 000000000000..8b72ad65fd3a --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.6.2_fix-slurm-vars.patch @@ -0,0 +1,26 @@ +From d5c78b3abbfbccb9919be8600741409f16d0846e Mon Sep 17 00:00:00 2001 +From: Alexander Grund +Date: Wed, 22 Oct 2025 09:03:16 +0200 +Subject: [PATCH] Check all environment variables for Slurm environment + +In some environments only the SLURM_JOB_ID might be set, e.g. when using hooks for SSH to a node with an existing allocation + +This causes a false positive in the detection and later `KeyError` on e.g. `SLURM_LOCALID` +--- + jax/_src/clusters/slurm_cluster.py | 3 ++- + 1 file changed, 2 insertions(+), 1 deletion(-) + +diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py +index 8cec07601094..a8bb1b8a287f 100644 +--- a/jax/_src/clusters/slurm_cluster.py ++++ b/jax/_src/clusters/slurm_cluster.py +@@ -30,7 +30,8 @@ class SlurmCluster(clusters.ClusterEnv): + + @classmethod + def is_env_present(cls) -> bool: +- return _JOBID_PARAM in os.environ ++ return all(var in os.environ for var in ++ (_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID)) + + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: \ No newline at end of file