Skip to content

{ai}[gfbf/2024a] jax v0.4.35, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP#21924

Draft
ThomasHoffmann77 wants to merge 36 commits intoeasybuilders:developfrom
ThomasHoffmann77:20241128144208_new_pr_jax0435
Draft

{ai}[gfbf/2024a] jax v0.4.35, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP#21924
ThomasHoffmann77 wants to merge 36 commits intoeasybuilders:developfrom
ThomasHoffmann77:20241128144208_new_pr_jax0435

Conversation

@ThomasHoffmann77
Copy link
Contributor

@ThomasHoffmann77 ThomasHoffmann77 commented Nov 28, 2024

(created using eb --new-pr)
requires:

TODO:

  • disable Bazel downloads
  • Use Bazel 7.4.1
  • Tests
  • Plugins: python -c "import jax; jax.device()" does not list GPU devices. Probably because of missing jaxlib_cuda12_plugin -> use additional build parameter --build_gpu_plugin, --gpu_plugin_cuda_version=12, --build_gpu_pjrt_plugin, and --build_gpu_kernel_plugin=cuda
  • fix: E jaxlib.xla_extension.XlaRuntimeError: INTERNAL: libdevice not found at ./libdevice.10.bc
    -> export XLA_FLAGS=--xla_gpu_cuda_data_dir=$CUDA_HOME

… jax-0.4.35_easyblock_compat.patch, jax-0.4.35_fix-pybind11-systemlib_cupti.patch
@github-actions
Copy link

github-actions bot commented Nov 28, 2024

Updated software Bazel-6.5.0-GCCcore-13.3.0.eb

Diff against Bazel-7.4.1-GCCcore-13.3.0-Java-21.eb

easybuild/easyconfigs/b/Bazel/Bazel-7.4.1-GCCcore-13.3.0-Java-21.eb

diff --git a/easybuild/easyconfigs/b/Bazel/Bazel-7.4.1-GCCcore-13.3.0-Java-21.eb b/easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
index 5455effbb3..c4c77a8a7d 100644
--- a/easybuild/easyconfigs/b/Bazel/Bazel-7.4.1-GCCcore-13.3.0-Java-21.eb
+++ b/easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
@@ -1,6 +1,5 @@
 name = 'Bazel'
-version = '7.4.1'
-versionsuffix = '-Java-%(javaver)s'
+version = '6.5.0'
 
 homepage = 'https://bazel.io/'
 description = """Bazel is a build tool that builds code quickly and reliably.
@@ -10,19 +9,23 @@ toolchain = {'name': 'GCCcore', 'version': '13.3.0'}
 
 source_urls = ['https://github.com/bazelbuild/%(namelower)s/releases/download/%(version)s']
 sources = ['%(namelower)s-%(version)s-dist.zip']
-checksums = ['83386618bc489f4da36266ef2620ec64a526c686cf07041332caff7c953afaf5']
+patches = ['Bazel-6.5.0_py3.12_pytest_assertEqual.patch']
+checksums = [
+    {'bazel-6.5.0-dist.zip': 'fc89da919415289f29e4ff18a5e01270ece9a6fe83cb60967218bac4a3bb3ed2'},
+    {'Bazel-6.5.0_py3.12_pytest_assertEqual.patch': '2670dd5c393970ba20db2c98cf0208df7190ff339ccb66fee9a6d48aaaf3ede6'},
+]
 
 builddependencies = [
     ('binutils', '2.42'),
     ('Python', '3.12.3'),
     ('Zip', '3.0'),
 ]
+
 dependencies = [
-    ('Java', '21', '', SYSTEM),
+    ('Java', '11.0.20', '', SYSTEM),
 ]
 
 runtest = True
-testopts = "--sandbox_add_mount_pair=$TMPDIR "
-testopts += "-- //examples/cpp:hello-success_test //examples/py/... //examples/py_native:test //examples/shell/..."
+testopts = "-- //examples/cpp:hello-success_test //examples/py/... //examples/py_native:test //examples/shell/..."
 
 moduleclass = 'devel'
Diff against Bazel-6.1.0-GCCcore-12.3.0.eb

easybuild/easyconfigs/b/Bazel/Bazel-6.1.0-GCCcore-12.3.0.eb

diff --git a/easybuild/easyconfigs/b/Bazel/Bazel-6.1.0-GCCcore-12.3.0.eb b/easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
index 1bacc7b936..c4c77a8a7d 100644
--- a/easybuild/easyconfigs/b/Bazel/Bazel-6.1.0-GCCcore-12.3.0.eb
+++ b/easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
@@ -1,27 +1,29 @@
 name = 'Bazel'
-version = '6.1.0'
+version = '6.5.0'
 
 homepage = 'https://bazel.io/'
 description = """Bazel is a build tool that builds code quickly and reliably.
 It is used to build the majority of Google's software."""
 
-toolchain = {'name': 'GCCcore', 'version': '12.3.0'}
+toolchain = {'name': 'GCCcore', 'version': '13.3.0'}
 
 source_urls = ['https://github.com/bazelbuild/%(namelower)s/releases/download/%(version)s']
 sources = ['%(namelower)s-%(version)s-dist.zip']
-patches = ['Bazel-6.3.1_add-symlinks-in-runfiles.patch']
+patches = ['Bazel-6.5.0_py3.12_pytest_assertEqual.patch']
 checksums = [
-    {'bazel-6.1.0-dist.zip': 'c4b85675541cf66ee7cb71514097fdd6c5fc0e02527243617a4f20ca6b4f2932'},
-    {'Bazel-6.3.1_add-symlinks-in-runfiles.patch': '81db53aa87229557480b6f719c99a0f1af9c69dfec12185451e520b0128c3ae2'},
+    {'bazel-6.5.0-dist.zip': 'fc89da919415289f29e4ff18a5e01270ece9a6fe83cb60967218bac4a3bb3ed2'},
+    {'Bazel-6.5.0_py3.12_pytest_assertEqual.patch': '2670dd5c393970ba20db2c98cf0208df7190ff339ccb66fee9a6d48aaaf3ede6'},
 ]
 
 builddependencies = [
-    ('binutils', '2.40'),
-    ('Python', '3.11.3'),
+    ('binutils', '2.42'),
+    ('Python', '3.12.3'),
     ('Zip', '3.0'),
 ]
 
-dependencies = [('Java', '11', '', SYSTEM)]
+dependencies = [
+    ('Java', '11.0.20', '', SYSTEM),
+]
 
 runtest = True
 testopts = "-- //examples/cpp:hello-success_test //examples/py/... //examples/py_native:test //examples/shell/..."
Diff against Bazel-6.3.1-GCCcore-12.2.0.eb

easybuild/easyconfigs/b/Bazel/Bazel-6.3.1-GCCcore-12.2.0.eb

diff --git a/easybuild/easyconfigs/b/Bazel/Bazel-6.3.1-GCCcore-12.2.0.eb b/easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
index 8c284f50a4..c4c77a8a7d 100644
--- a/easybuild/easyconfigs/b/Bazel/Bazel-6.3.1-GCCcore-12.2.0.eb
+++ b/easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
@@ -1,27 +1,29 @@
 name = 'Bazel'
-version = '6.3.1'
+version = '6.5.0'
 
 homepage = 'https://bazel.io/'
 description = """Bazel is a build tool that builds code quickly and reliably.
 It is used to build the majority of Google's software."""
 
-toolchain = {'name': 'GCCcore', 'version': '12.2.0'}
+toolchain = {'name': 'GCCcore', 'version': '13.3.0'}
 
 source_urls = ['https://github.com/bazelbuild/%(namelower)s/releases/download/%(version)s']
 sources = ['%(namelower)s-%(version)s-dist.zip']
-patches = ['Bazel-6.3.1_add-symlinks-in-runfiles.patch']
+patches = ['Bazel-6.5.0_py3.12_pytest_assertEqual.patch']
 checksums = [
-    {'bazel-6.3.1-dist.zip': '2676319e86c5aeab142dccd42434364a33aa330a091c13562b7de87a10e68775'},
-    {'Bazel-6.3.1_add-symlinks-in-runfiles.patch': '81db53aa87229557480b6f719c99a0f1af9c69dfec12185451e520b0128c3ae2'},
+    {'bazel-6.5.0-dist.zip': 'fc89da919415289f29e4ff18a5e01270ece9a6fe83cb60967218bac4a3bb3ed2'},
+    {'Bazel-6.5.0_py3.12_pytest_assertEqual.patch': '2670dd5c393970ba20db2c98cf0208df7190ff339ccb66fee9a6d48aaaf3ede6'},
 ]
 
 builddependencies = [
-    ('binutils', '2.39'),
-    ('Python', '3.10.8'),
+    ('binutils', '2.42'),
+    ('Python', '3.12.3'),
     ('Zip', '3.0'),
 ]
 
-dependencies = [('Java', '11', '', SYSTEM)]
+dependencies = [
+    ('Java', '11.0.20', '', SYSTEM),
+]
 
 runtest = True
 testopts = "-- //examples/cpp:hello-success_test //examples/py/... //examples/py_native:test //examples/shell/..."

Updated software jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb

Diff against jax-0.4.25-gfbf-2023a.eb

easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a.eb

diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a.eb b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
index e61bc4719d..5a431a6807 100644
--- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a.eb
+++ b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
@@ -6,54 +6,83 @@
 easyblock = 'PythonBundle'
 
 name = 'jax'
-version = '0.4.25'
+version = '0.4.35'
+versionsuffix = '-CUDA-%(cudaver)s'
 
 homepage = 'https://jax.readthedocs.io/'
 description = """Composable transformations of Python+NumPy programs:
 differentiate, vectorize, JIT to GPU/TPU, and more"""
 
-toolchain = {'name': 'gfbf', 'version': '2023a'}
+toolchain = {'name': 'gfbf', 'version': '2024a'}
+cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]
 
 builddependencies = [
-    ('Bazel', '6.3.1'),
-    ('pytest-xdist', '3.3.1'),
-    ('git', '2.41.0', '-nodocs'),  # bazel uses git to fetch repositories
-    ('matplotlib', '3.7.2'),  # required for tests/lobpcg_test.py
-    ('poetry', '1.5.1'),
-    ('pybind11', '2.11.1'),
+    # ('Bazel', '7.4.1'),  TODO: problems with @@local_config_python//:py3_runtime:
+    # Error in fail: interpreter_path must be an absolute path
+    # Bazel 6.5.0 (download) works.
+    ('pybind11', '2.13.6'),  # 2.12.0 ? SciPy-bundle has pybind/2.12.0.
+    ('pytest-xdist', '3.6.1'),
+    ('git', '2.45.1'),  # bazel uses git to fetch repositories
+    ('matplotlib', '3.9.2'),  # required for tests/lobpcg_test.py
+    ('poetry', '1.8.3'),
+    ('Clang', '18.1.8', versionsuffix)
 ]
 
 dependencies = [
-    ('Python', '3.11.3'),
-    ('SciPy-bundle', '2023.07'),
+    ('CUDA', '12.6.0', '', SYSTEM),  # 12.6.2 ?
+    ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
+    ('NCCL', '2.22.3', versionsuffix),
+    ('Python', '3.12.3'),
+    ('SciPy-bundle', '2024.05'),  # 2024.11 ?
     ('absl-py', '2.1.0'),
-    ('flatbuffers-python', '23.5.26'),
-    ('ml_dtypes', '0.3.2'),
-    ('zlib', '1.2.13'),
+    ('flatbuffers-python', '24.3.25'),
+    ('ml_dtypes', '0.5.0'),
+    ('zlib', '1.3.1'),
 ]
 
 # downloading xla and other tarballs to avoid that Bazel downloads it during the build
 local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
 # note: following commits *must* be the exact same onces used upstream
 # XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
-local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4'
+local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c'
 # TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
-local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'
+local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'  # TODO: still required?
+# TODO: add other downloads
 
 # Use sources downloaded by EasyBuild
 _jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
 # Use dependencies from EasyBuild
 _jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
-_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '
+_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
 # Avoid warning (treated as error) in upb/table.c
-_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" '
+_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" '  # TODO: still required?
+# _jaxlib_buildopts += '--nouse_clang '  #TODO: avoid clang (?)
+_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
+_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
+# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
+# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
+_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """
+
+_plugins_buildopts = """--enable_cuda """
+_plugins_buildopts += """--build_gpu_plugin """
+# _plugins_buildopts +="""--gpu_plugin_cuda_version=12 """
+_plugins_buildopts += """--build_gpu_pjrt_plugin """
+_plugins_buildopts += """--build_gpu_kernel_plugin=cuda """
+
+# get rid of .devDate versionsuffix:  TODO: find a better way
+# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """  does not work (?)
+_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
+_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """  # required?
 
 components = [
     ('jaxlib', version, {
         'sources': [
             {
                 'source_urls': ['https://github.com/google/jax/archive/'],
-                'filename': '%(name)s-v%(version)s.tar.gz',
+                'filename': 'jax-v%(version)s.tar.gz',
             },
             {
                 'source_urls': ['https://github.com/openxla/xla/archive'],
@@ -68,38 +97,120 @@ components = [
                 'extract_cmd': local_extract_cmd,
             },
         ],
-        'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'],
+        'patches': [
+            'jax-0.4.35_easyblock_compat.patch',
+            'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch',
+            'jax-0.4.35_version.patch',
+        ],
         'checksums': [
-            {'jaxlib-v0.4.25.tar.gz':
-             'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'},
-            {'xla-4ccfe33c.tar.gz':
-             '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'},
+            {'jax-v0.4.35.tar.gz':
+             '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'xla-76da7301.tar.gz':
+             'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
             {'tf_runtime-0aeefb16.tar.gz':
              'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
-            {'jax-0.4.25_fix-pybind11-systemlib.patch':
-             'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'},
+            {'jax-0.4.35_easyblock_compat.patch':
+             'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
+            {'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch':
+             'fa5273d31651579590f7291fc151836f43024f74f4c89243dc4c6a417284e7ce'},
+            {'jax-0.4.35_version.patch':
+             'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
         ],
-        'start_dir': 'jax-jaxlib-v%(version)s',
-        'buildopts': _jaxlib_buildopts
+        'start_dir': 'jax-jax-v%(version)s',
+        'buildopts': _jaxlib_buildopts,
+        'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
+                        'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
+                        _no_devtag
     }),
-]
-
-exts_list = [
-    (name, version, {
+    # build jaxlib first and then plugins in 2nd interation:
+    ('jaxlib', version, {
         'sources': [
             {
                 'source_urls': ['https://github.com/google/jax/archive/'],
-                'filename': '%(name)s-v%(version)s.tar.gz',
+                'filename': 'jax-v%(version)s.tar.gz',
+            },
+            {
+                'source_urls': ['https://github.com/openxla/xla/archive'],
+                'download_filename': '%s.tar.gz' % local_xla_commit,
+                'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
+            {
+                'source_urls': ['https://github.com/tensorflow/runtime/archive'],
+                'download_filename': '%s.tar.gz' % local_tfrt_commit,
+                'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
+                'extract_cmd': local_extract_cmd,
             },
         ],
-        'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'],
         'checksums': [
-            {'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'},
-            {'jax-0.4.25_fix_env_test_no_log_spam.patch':
-             'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'},
+            {'jax-v0.4.35.tar.gz':
+             '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'xla-76da7301.tar.gz':
+             'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
+            {'tf_runtime-0aeefb16.tar.gz':
+             'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
+        ],
+        'start_dir': 'jax-jax-v%(version)s',
+        'buildopts': _jaxlib_buildopts + _plugins_buildopts,
+        'prebuildopts': _no_devtag
+    }),
+]
+# failing:
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 FAILED [ 98%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 FAILED [ 98%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 FAILED [ 99%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 FAILED [ 99%]
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 - AssertionError:
+# tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 FAILED   [ 10%]
+# FAILED tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 - AssertionError:
+#
+
+
+# Some tests require an isolated run:  TODO: still required?
+local_isolated_tests = [
+    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
+    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
+    'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
+    '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
+]
+# deliberately not testing in parallel, as that results in (additional) failing tests;
+# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
+# see https://github.com/google/jax/issues/7323 and
+# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
+# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
+# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
+local_test_exports = [
+    "NVIDIA_TF32_OVERRIDE=0",
+    "CUDA_VISIBLE_DEVICES=0",
+    "XLA_PYTHON_CLIENT_ALLOCATOR=platform",
+    "JAX_ENABLE_X64=true",
+]
+local_test = ''.join(['export %s;' % x for x in local_test_exports])
+# run all tests at once except for local_isolated_tests:
+local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
+# run remaining local_isolated_tests separately:
+local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])
+
+
+exts_list = [
+    (name, version, {
+        'source_tmpl': '%(name)s-v%(version)s.tar.gz',
+        'source_urls': ['https://github.com/google/jax/archive/'],
+        'patches': ['jax-0.4.35_version.patch'],
+        'checksums': [
+            {'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
         ],
-        'runtest': "pytest -n %(parallel)s tests",
+        'runtest': False,  # tmp
+        'preinstallopts': _no_devtag
     }),
 ]
+sanity_check_commands = [
+    """python -c "import jax_cuda"$(echo $EBVERSIONCUDA|awk -F '.' '{print $1}')"_plugin" """
+]
+
 
 moduleclass = 'ai'
Diff against jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb

easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb

diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
index e0c599039d..5a431a6807 100644
--- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb
+++ b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
@@ -6,84 +6,128 @@
 easyblock = 'PythonBundle'
 
 name = 'jax'
-version = '0.4.25'
+version = '0.4.35'
 versionsuffix = '-CUDA-%(cudaver)s'
 
 homepage = 'https://jax.readthedocs.io/'
 description = """Composable transformations of Python+NumPy programs:
 differentiate, vectorize, JIT to GPU/TPU, and more"""
 
-toolchain = {'name': 'gfbf', 'version': '2023a'}
+toolchain = {'name': 'gfbf', 'version': '2024a'}
 cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]
 
 builddependencies = [
-    ('Bazel', '6.3.1'),
-    ('pytest-xdist', '3.3.1'),
-    ('git', '2.41.0', '-nodocs'),  # bazel uses git to fetch repositories
-    ('matplotlib', '3.7.2'),  # required for tests/lobpcg_test.py
-    ('poetry', '1.5.1'),
-    ('pybind11', '2.11.1'),
+    # ('Bazel', '7.4.1'),  TODO: problems with @@local_config_python//:py3_runtime:
+    # Error in fail: interpreter_path must be an absolute path
+    # Bazel 6.5.0 (download) works.
+    ('pybind11', '2.13.6'),  # 2.12.0 ? SciPy-bundle has pybind/2.12.0.
+    ('pytest-xdist', '3.6.1'),
+    ('git', '2.45.1'),  # bazel uses git to fetch repositories
+    ('matplotlib', '3.9.2'),  # required for tests/lobpcg_test.py
+    ('poetry', '1.8.3'),
+    ('Clang', '18.1.8', versionsuffix)
 ]
 
 dependencies = [
-    ('CUDA', '12.1.1', '', SYSTEM),
-    ('cuDNN', '8.9.2.26', versionsuffix, SYSTEM),
-    ('NCCL', '2.18.3', versionsuffix),
-    ('Python', '3.11.3'),
-    ('SciPy-bundle', '2023.07'),
+    ('CUDA', '12.6.0', '', SYSTEM),  # 12.6.2 ?
+    ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
+    ('NCCL', '2.22.3', versionsuffix),
+    ('Python', '3.12.3'),
+    ('SciPy-bundle', '2024.05'),  # 2024.11 ?
     ('absl-py', '2.1.0'),
-    ('flatbuffers-python', '23.5.26'),
-    ('ml_dtypes', '0.3.2'),
-    ('zlib', '1.2.13'),
+    ('flatbuffers-python', '24.3.25'),
+    ('ml_dtypes', '0.5.0'),
+    ('zlib', '1.3.1'),
 ]
 
 # downloading xla and other tarballs to avoid that Bazel downloads it during the build
 local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
 # note: following commits *must* be the exact same onces used upstream
 # XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
-local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4'
+local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c'
 # TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
-local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'
+local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'  # TODO: still required?
+# TODO: add other downloads
 
 # Use sources downloaded by EasyBuild
 _jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
 # Use dependencies from EasyBuild
 _jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
-_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '
+_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
 # Avoid warning (treated as error) in upb/table.c
-_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" '
+_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" '  # TODO: still required?
+# _jaxlib_buildopts += '--nouse_clang '  #TODO: avoid clang (?)
+_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
+_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
+# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
+# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
+_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """
 
-# Some tests require an isolated run:
-local_isolated_tests = [
-    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
-    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
-    'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
-    '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
-]
-# deliberately not testing in parallel, as that results in (additional) failing tests;
-# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
-# see https://github.com/google/jax/issues/7323 and
-# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
-# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
-# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
-local_test_exports = [
-    "NVIDIA_TF32_OVERRIDE=0",
-    "CUDA_VISIBLE_DEVICES=0",
-    "XLA_PYTHON_CLIENT_ALLOCATOR=platform",
-    "JAX_ENABLE_X64=true",
-]
-local_test = ''.join(['export %s;' % x for x in local_test_exports])
-# run all tests at once except for local_isolated_tests:
-local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
-# run remaining local_isolated_tests separately:
-local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])
+_plugins_buildopts = """--enable_cuda """
+_plugins_buildopts += """--build_gpu_plugin """
+# _plugins_buildopts +="""--gpu_plugin_cuda_version=12 """
+_plugins_buildopts += """--build_gpu_pjrt_plugin """
+_plugins_buildopts += """--build_gpu_kernel_plugin=cuda """
+
+# get rid of .devDate versionsuffix:  TODO: find a better way
+# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """  does not work (?)
+_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
+_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """  # required?
 
 components = [
     ('jaxlib', version, {
         'sources': [
             {
                 'source_urls': ['https://github.com/google/jax/archive/'],
-                'filename': '%(name)s-v%(version)s.tar.gz',
+                'filename': 'jax-v%(version)s.tar.gz',
+            },
+            {
+                'source_urls': ['https://github.com/openxla/xla/archive'],
+                'download_filename': '%s.tar.gz' % local_xla_commit,
+                'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
+            {
+                'source_urls': ['https://github.com/tensorflow/runtime/archive'],
+                'download_filename': '%s.tar.gz' % local_tfrt_commit,
+                'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
+        ],
+        'patches': [
+            'jax-0.4.35_easyblock_compat.patch',
+            'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch',
+            'jax-0.4.35_version.patch',
+        ],
+        'checksums': [
+            {'jax-v0.4.35.tar.gz':
+             '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'xla-76da7301.tar.gz':
+             'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
+            {'tf_runtime-0aeefb16.tar.gz':
+             'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
+            {'jax-0.4.35_easyblock_compat.patch':
+             'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
+            {'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch':
+             'fa5273d31651579590f7291fc151836f43024f74f4c89243dc4c6a417284e7ce'},
+            {'jax-0.4.35_version.patch':
+             'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
+        ],
+        'start_dir': 'jax-jax-v%(version)s',
+        'buildopts': _jaxlib_buildopts,
+        'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
+                        'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
+                        _no_devtag
+    }),
+    # build jaxlib first and then plugins in 2nd interation:
+    ('jaxlib', version, {
+        'sources': [
+            {
+                'source_urls': ['https://github.com/google/jax/archive/'],
+                'filename': 'jax-v%(version)s.tar.gz',
             },
             {
                 'source_urls': ['https://github.com/openxla/xla/archive'],
@@ -98,34 +142,75 @@ components = [
                 'extract_cmd': local_extract_cmd,
             },
         ],
-        'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'],
         'checksums': [
-            {'jaxlib-v0.4.25.tar.gz':
-             'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'},
-            {'xla-4ccfe33c.tar.gz':
-             '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'},
+            {'jax-v0.4.35.tar.gz':
+             '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'xla-76da7301.tar.gz':
+             'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
             {'tf_runtime-0aeefb16.tar.gz':
              'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
-            {'jax-0.4.25_fix-pybind11-systemlib.patch':
-             'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'},
         ],
-        'start_dir': 'jax-jaxlib-v%(version)s',
-        'buildopts': _jaxlib_buildopts
+        'start_dir': 'jax-jax-v%(version)s',
+        'buildopts': _jaxlib_buildopts + _plugins_buildopts,
+        'prebuildopts': _no_devtag
     }),
 ]
+# failing:
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 FAILED [ 98%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 FAILED [ 98%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 FAILED [ 99%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 FAILED [ 99%]
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 - AssertionError:
+# tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 FAILED   [ 10%]
+# FAILED tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 - AssertionError:
+#
+
+
+# Some tests require an isolated run:  TODO: still required?
+local_isolated_tests = [
+    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
+    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
+    'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
+    '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
+]
+# deliberately not testing in parallel, as that results in (additional) failing tests;
+# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
+# see https://github.com/google/jax/issues/7323 and
+# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
+# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
+# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
+local_test_exports = [
+    "NVIDIA_TF32_OVERRIDE=0",
+    "CUDA_VISIBLE_DEVICES=0",
+    "XLA_PYTHON_CLIENT_ALLOCATOR=platform",
+    "JAX_ENABLE_X64=true",
+]
+local_test = ''.join(['export %s;' % x for x in local_test_exports])
+# run all tests at once except for local_isolated_tests:
+local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
+# run remaining local_isolated_tests separately:
+local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])
+
 
 exts_list = [
     (name, version, {
         'source_tmpl': '%(name)s-v%(version)s.tar.gz',
         'source_urls': ['https://github.com/google/jax/archive/'],
-        'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'],
+        'patches': ['jax-0.4.35_version.patch'],
         'checksums': [
-            {'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'},
-            {'jax-0.4.25_fix_env_test_no_log_spam.patch':
-             'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'},
+            {'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
         ],
-        'runtest': local_test,
+        'runtest': False,  # tmp
+        'preinstallopts': _no_devtag
     }),
 ]
+sanity_check_commands = [
+    """python -c "import jax_cuda"$(echo $EBVERSIONCUDA|awk -F '.' '{print $1}')"_plugin" """
+]
+
 
 moduleclass = 'ai'
Diff against jax-0.3.25-foss-2022a.eb

easybuild/easyconfigs/j/jax/jax-0.3.25-foss-2022a.eb

diff --git a/easybuild/easyconfigs/j/jax/jax-0.3.25-foss-2022a.eb b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
index 27458b716a..5a431a6807 100644
--- a/easybuild/easyconfigs/j/jax/jax-0.3.25-foss-2022a.eb
+++ b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
@@ -1,109 +1,216 @@
 # This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
 # Author: Denis Kristak
 # Updated by: Alex Domingo (Vrije Universiteit Brussel)
+# Updated by: Pavel Tománek (INUITS)
+# Updated by: Thomas Hoffmann (EMBL Heidelberg)
 easyblock = 'PythonBundle'
 
 name = 'jax'
-version = '0.3.25'
+version = '0.4.35'
+versionsuffix = '-CUDA-%(cudaver)s'
 
-homepage = 'https://pypi.python.org/pypi/jax'
+homepage = 'https://jax.readthedocs.io/'
 description = """Composable transformations of Python+NumPy programs:
 differentiate, vectorize, JIT to GPU/TPU, and more"""
 
-toolchain = {'name': 'foss', 'version': '2022a'}
+toolchain = {'name': 'gfbf', 'version': '2024a'}
+cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]
 
 builddependencies = [
-    ('Bazel', '5.1.1'),
-    ('pytest-xdist', '2.5.0'),
-    # git 2.x required to fetch repository 'io_bazel_rules_docker'
-    ('git', '2.36.0', '-nodocs'),
-    ('matplotlib', '3.5.2'),  # required for tests/lobpcg_test.py
+    # ('Bazel', '7.4.1'),  TODO: problems with @@local_config_python//:py3_runtime:
+    # Error in fail: interpreter_path must be an absolute path
+    # Bazel 6.5.0 (download) works.
+    ('pybind11', '2.13.6'),  # 2.12.0 ? SciPy-bundle has pybind/2.12.0.
+    ('pytest-xdist', '3.6.1'),
+    ('git', '2.45.1'),  # bazel uses git to fetch repositories
+    ('matplotlib', '3.9.2'),  # required for tests/lobpcg_test.py
+    ('poetry', '1.8.3'),
+    ('Clang', '18.1.8', versionsuffix)
 ]
 
 dependencies = [
-    ('Python', '3.10.4'),
-    ('SciPy-bundle', '2022.05'),
-    ('flatbuffers-python', '2.0'),
+    ('CUDA', '12.6.0', '', SYSTEM),  # 12.6.2 ?
+    ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
+    ('NCCL', '2.22.3', versionsuffix),
+    ('Python', '3.12.3'),
+    ('SciPy-bundle', '2024.05'),  # 2024.11 ?
+    ('absl-py', '2.1.0'),
+    ('flatbuffers-python', '24.3.25'),
+    ('ml_dtypes', '0.5.0'),
+    ('zlib', '1.3.1'),
 ]
 
-# downloading TensorFlow tarball to avoid that Bazel downloads it during the build
-# note: this *must* be the exact same commit as used in WORKSPACE
-local_tf_commit = 'f0fe8d4c04fab1f157854a1aa3c136377901cdef'
-local_tf_dir = 'tensorflow-%s' % local_tf_commit
-local_tf_builddir = '%(builddir)s/' + local_tf_dir
+# downloading xla and other tarballs to avoid that Bazel downloads it during the build
+local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
+# note: following commits *must* be the exact same onces used upstream
+# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
+local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c'
+# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
+local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25'  # TODO: still required?
+# TODO: add other downloads
 
-# replace remote TensorFlow repository with the local one from EB
-local_jax_prebuildopts = "sed -i -f jaxlib_local-tensorflow-repo.sed WORKSPACE && "
-local_jax_prebuildopts += "sed -i 's|EB_TF_REPOPATH|%s|' WORKSPACE && " % local_tf_builddir
+# Use sources downloaded by EasyBuild
+_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
+# Use dependencies from EasyBuild
+_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
+_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
+# Avoid warning (treated as error) in upb/table.c
+_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" '  # TODO: still required?
+# _jaxlib_buildopts += '--nouse_clang '  #TODO: avoid clang (?)
+_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
+_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
+# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
+# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
+_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
+_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """
 
-default_easyblock = 'PythonPackage'
-default_component_specs = {
-    'sources': [SOURCE_TAR_GZ],
-    'source_urls': [PYPI_SOURCE],
-    'start_dir': '%(name)s-%(version)s',
-}
+_plugins_buildopts = """--enable_cuda """
+_plugins_buildopts += """--build_gpu_plugin """
+# _plugins_buildopts +="""--gpu_plugin_cuda_version=12 """
+_plugins_buildopts += """--build_gpu_pjrt_plugin """
+_plugins_buildopts += """--build_gpu_kernel_plugin=cuda """
+
+# get rid of .devDate versionsuffix:  TODO: find a better way
+# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """  does not work (?)
+_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
+_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """  # required?
 
 components = [
-    ('absl-py', '1.3.0', {
-        'options': {'modulename': 'absl'},
-        'checksums': ['463c38a08d2e4cef6c498b76ba5bd4858e4c6ef51da1a5a1f27139a022e20248'],
-    }),
     ('jaxlib', version, {
         'sources': [
-            '%(name)s-v%(version)s.tar.gz',
             {
-                'download_filename': '%s.tar.gz' % local_tf_commit,
-                'filename': 'tensorflow-%s.tar.gz' % local_tf_commit,
-            }
-        ],
-        'source_urls': [
-            'https://github.com/google/jax/archive/',
-            'https://github.com/tensorflow/tensorflow/archive/'
+                'source_urls': ['https://github.com/google/jax/archive/'],
+                'filename': 'jax-v%(version)s.tar.gz',
+            },
+            {
+                'source_urls': ['https://github.com/openxla/xla/archive'],
+                'download_filename': '%s.tar.gz' % local_xla_commit,
+                'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
+            {
+                'source_urls': ['https://github.com/tensorflow/runtime/archive'],
+                'download_filename': '%s.tar.gz' % local_tfrt_commit,
+                'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
         ],
         'patches': [
-            ('jaxlib_local-tensorflow-repo.sed', '.'),
-            ('TensorFlow-2.7.0_cuda-noncanonical-include-paths.patch', '../' + local_tf_dir),
+            'jax-0.4.35_easyblock_compat.patch',
+            'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch',
+            'jax-0.4.35_version.patch',
         ],
         'checksums': [
-            # jaxlib-v0.3.25.tar.gz
-            '73ebc7868631cd9d520385557bbd7f08762d748a5a6a1bebef0f3b8d7ba748ef',
-            # tensorflow-f0fe8d4c04fab1f157854a1aa3c136377901cdef.tar.gz
-            '9ebba3031e8a81993682e4b9e43891ebb8480b6287e635df8e7efaa45ab5ede7',
-            # jaxlib_local-tensorflow-repo.sed
-            'abb5c3b97f4e317bce9f22ed3eeea3b9715365818d8b50720d937e2d41d5c4e5',
-            # TensorFlow-2.7.0_cuda-noncanonical-include-paths.patch
-            '0a759010c253d49755955cd5f028e75de4a4c447dcc8f5a0d9f47cce6881a9db',
+            {'jax-v0.4.35.tar.gz':
+             '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'xla-76da7301.tar.gz':
+             'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
+            {'tf_runtime-0aeefb16.tar.gz':
+             'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
+            {'jax-0.4.35_easyblock_compat.patch':
+             'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
+            {'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch':
+             'fa5273d31651579590f7291fc151836f43024f74f4c89243dc4c6a417284e7ce'},
+            {'jax-0.4.35_version.patch':
+             'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
         ],
-        'start_dir': 'jax-jaxlib-v%(version)s',
-        'prebuildopts': local_jax_prebuildopts,
-        # Avoid warning (treated as error) in upb/table.c
-        'buildopts': '--bazel_options="--copt=-Wno-maybe-uninitialized"',
+        'start_dir': 'jax-jax-v%(version)s',
+        'buildopts': _jaxlib_buildopts,
+        'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
+                        'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
+                        _no_devtag
     }),
+    # build jaxlib first and then plugins in 2nd interation:
+    ('jaxlib', version, {
+        'sources': [
+            {
+                'source_urls': ['https://github.com/google/jax/archive/'],
+                'filename': 'jax-v%(version)s.tar.gz',
+            },
+            {
+                'source_urls': ['https://github.com/openxla/xla/archive'],
+                'download_filename': '%s.tar.gz' % local_xla_commit,
+                'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
+            {
+                'source_urls': ['https://github.com/tensorflow/runtime/archive'],
+                'download_filename': '%s.tar.gz' % local_tfrt_commit,
+                'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
+                'extract_cmd': local_extract_cmd,
+            },
+        ],
+        'checksums': [
+            {'jax-v0.4.35.tar.gz':
+             '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'xla-76da7301.tar.gz':
+             'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
+            {'tf_runtime-0aeefb16.tar.gz':
+             'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
+        ],
+        'start_dir': 'jax-jax-v%(version)s',
+        'buildopts': _jaxlib_buildopts + _plugins_buildopts,
+        'prebuildopts': _no_devtag
+    }),
+]
+# failing:
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 FAILED [ 98%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 FAILED [ 98%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 FAILED [ 99%]
+# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 FAILED [ 99%]
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 - AssertionError:
+# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 - AssertionError:
+# tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 FAILED   [ 10%]
+# FAILED tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 - AssertionError:
+#
+
+
+# Some tests require an isolated run:  TODO: still required?
+local_isolated_tests = [
+    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
+    'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
+    'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
+    '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
+]
+# deliberately not testing in parallel, as that results in (additional) failing tests;
+# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
+# see https://github.com/google/jax/issues/7323 and
+# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
+# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
+# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
+local_test_exports = [
+    "NVIDIA_TF32_OVERRIDE=0",
+    "CUDA_VISIBLE_DEVICES=0",
+    "XLA_PYTHON_CLIENT_ALLOCATOR=platform",
+    "JAX_ENABLE_X64=true",
 ]
+local_test = ''.join(['export %s;' % x for x in local_test_exports])
+# run all tests at once except for local_isolated_tests:
+local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
+# run remaining local_isolated_tests separately:
+local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])
+
 
 exts_list = [
-    ('opt_einsum', '3.3.0', {
-        'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'],
-    }),
-    ('etils', '0.8.0', {
-        'checksums': ['d1d5af7bd9c784a273c4e1eccfaa8feaca5e0481a08717b5313fa231da22a903'],
-    }),
     (name, version, {
         'source_tmpl': '%(name)s-v%(version)s.tar.gz',
         'source_urls': ['https://github.com/google/jax/archive/'],
-        'patches': [
-            'jax-0.3.23_relax-testPoly5-tolerance.patch',
-            'jax-0.3.25_skip-qdwh-test-rank-deficient-deficient.patch',
-        ],
+        'patches': ['jax-0.4.35_version.patch'],
         'checksums': [
-            {'jax-v0.3.25.tar.gz': '49e8ce88ddd7dd0de86116c9d75d98a577a9061377ec423493fbac5ea29f79f0'},
-            {'jax-0.3.23_relax-testPoly5-tolerance.patch':
-             'be64bf36dde4884a97b6c8bb22c6b14ab5b24033cd40bfe7ce18363c55c30e87'},
-            {'jax-0.3.25_skip-qdwh-test-rank-deficient-deficient.patch':
-             '70f16f2dba03ab162ce6e13ea61774524b485e9630209bbd4bec81fd16c8812f'},
+            {'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
+            {'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
         ],
-        'runtest': "pytest -n %(parallel)s tests",
+        'runtest': False,  # tmp
+        'preinstallopts': _no_devtag
     }),
 ]
+sanity_check_commands = [
+    """python -c "import jax_cuda"$(echo $EBVERSIONCUDA|awk -F '.' '{print $1}')"_plugin" """
+]
+
 
-moduleclass = 'tools'
+moduleclass = 'ai'

@ThomasHoffmann77 ThomasHoffmann77 changed the title {ai}[gfbf/2024a] jax v0.4.35 w/ CUDA 12.6.0 {ai}[gfbf/2024a] jax v0.4.35 w/ CUDA 12.6.0 WIP Nov 28, 2024
@ThomasHoffmann77 ThomasHoffmann77 marked this pull request as draft November 28, 2024 13:43
@ThomasHoffmann77 ThomasHoffmann77 changed the title {ai}[gfbf/2024a] jax v0.4.35 w/ CUDA 12.6.0 WIP {ai}[gfbf/2024a] jax v0.4.35, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP Nov 28, 2024
@github-actions github-actions bot added the change label Dec 2, 2024
@ThomasHoffmann77 ThomasHoffmann77 marked this pull request as draft January 24, 2025 13:59
@fizwit
Copy link
Contributor

fizwit commented Feb 12, 2025

where is the easyconfig for Clang-18.1.8-gfbf-2024a-CUDA-12.6.0.eb

@ThomasHoffmann77
Copy link
Contributor Author

ThomasHoffmann77 commented Feb 13, 2025

where is the easyconfig for Clang-18.1.8-gfbf-2024a-CUDA-12.6.0.eb

]

dependencies = [
('Java', '11.0.20', '', SYSTEM),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Java/21 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@boegel Build of Bazel 6.5.0 fails for both, Java/21 and Java/21.0.2.

cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
# ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it's unfortunate we can't use Bazel 7.4.1...

Do we fully understand what's going on here, is it a fundamental incompatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@boegel: I am not an bazel expert, but I think it is rather a fundamental incompatibility as the latest jax 0.5.2 still is using Bazel v6.5.0: https://github.com/jax-ml/jax/blob/ce224293b1a7d9b39b5d9194d429b54f38faf6fe/.bazelversion#L1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@boegel Jax 0.6.0 is using Bazel 7.4.1. Since AF3 is not merged yet, it might be worth to drop jax 0.4.34 and use 0.6.0 instead.

Copy link
Collaborator

@pavelToman pavelToman Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have PR for v0.6.2 cpu version: #23526
Either v0.7.0 cpu: #23530

@VRehnberg
Copy link
Contributor

VRehnberg commented Mar 14, 2025

I got the following error (when running AlphaFold3 test-suite):

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: libdevice not found at ./libdevice.10.bc

libdevice.10.bc is located at $EBROOTCUDA/nvvm/libdevice/libdevice.10.bc. Relevent part of changelog:
https://github.com/jax-ml/jax/blob/main/CHANGELOG.md#jax-0431-july-29-2024

I suspect this function https://github.com/jax-ml/jax/blob/jax-v0.4.34/jax/_src/lib/__init__.py#L130-L138 . ptxas is found at $EBROOTCUDA/bin/ptxas.

@VRehnberg
Copy link
Contributor

This patch got rid of the libdevice error for me (for some reason couldn't find Thomas' repo listed to do a PR against (commit)):

Find cuda path with EBROOTCUDA. See
https://github.com/easybuilders/easybuild-easyconfigs/pull/21924#issuecomment-2723932545

Author: Viktor Rehnberg (Chalmers University of Technology)

--- /apps/Test/software/jax/0.4.34-gfbf-2024a-CUDA-12.6.0/lib/python3.12/site-packages/jax/_src/lib/__init__.py 2025-02-20 10:12:23.704399435 +0000
+++ -   2025-03-14 08:22:49.679847748 +0000
@@ -18,6 +18,7 @@
 from __future__ import annotations

 import gc
+import os
 import pathlib
 import re
 from typing import Any
@@ -132,7 +133,7 @@
   # If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have
   # both of the things XLA looks for in the cuda path, namely bin/ptxas and
   # nvvm/libdevice/libdevice.10.bc
-  path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
+  path = pathlib.Path(os.getenv("EBROOTCUDA", _jaxlib_path.parent / "nvidia" / "cuda_nvcc"))
   if path.is_dir():
     return str(path)
   return None

@ThomasHoffmann77
Copy link
Contributor Author

This patch got rid of the libdevice error for me (for some reason couldn't find Thomas' repo listed to do a PR against (commit)):

Find cuda path with EBROOTCUDA. See
https://github.com/easybuilders/easybuild-easyconfigs/pull/21924#issuecomment-2723932545

Author: Viktor Rehnberg (Chalmers University of Technology)

--- /apps/Test/software/jax/0.4.34-gfbf-2024a-CUDA-12.6.0/lib/python3.12/site-packages/jax/_src/lib/__init__.py 2025-02-20 10:12:23.704399435 +0000
+++ -   2025-03-14 08:22:49.679847748 +0000
@@ -18,6 +18,7 @@
 from __future__ import annotations

 import gc
+import os
 import pathlib
 import re
 from typing import Any
@@ -132,7 +133,7 @@
   # If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have
   # both of the things XLA looks for in the cuda path, namely bin/ptxas and
   # nvvm/libdevice/libdevice.10.bc
-  path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
+  path = pathlib.Path(os.getenv("EBROOTCUDA", _jaxlib_path.parent / "nvidia" / "cuda_nvcc"))
   if path.is_dir():
     return str(path)
   return None

@VRehnberg is it sufficient to patch the python code only? I have some other patch, which I did not upload yet. It modifies third_party/tsl/tsl/platform/default/cuda_root_path.cc allowing to either set CUDA_HOME or XLA_FLAGS=xla_gpu_cuda_data_dir.

@VRehnberg
Copy link
Contributor

is it sufficient to patch the python code only? I have some other patch, which I did not upload yet. It modifies third_party/tsl/tsl/platform/default/cuda_root_path.cc allowing to either set CUDA_HOME or XLA_FLAGS=xla_gpu_cuda_data_dir.

Sorry, I don't understand what you're asking. Root issue is that path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" is not a path that exists as is. Either fix that path or use something else than that function to find CUDA stuff. Those are the two classes of solutions I can imagine. If you've found something else that works, go for it.

@ThomasHoffmann77 ThomasHoffmann77 changed the title {ai}[gfbf/2024a] jax v0.4.34, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP {ai}[gfbf/2024a] jax v0.4.35, ml_dtypes v0.5.0 w/ CUDA 12.6.0 WIP May 20, 2025
@ThomasHoffmann77
Copy link
Contributor Author

is it sufficient to patch the python code only? I have some other patch, which I did not upload yet. It modifies third_party/tsl/tsl/platform/default/cuda_root_path.cc allowing to either set CUDA_HOME or XLA_FLAGS=xla_gpu_cuda_data_dir.

Sorry, I don't understand what you're asking. Root issue is that path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" is not a path that exists as is. Either fix that path or use something else than that function to find CUDA stuff. Those are the two classes of solutions I can imagine. If you've found something else that works, go for it.

@VRehnberg I switched to 0.4.35 and added a patch to find libdevice.10.bc relative to $CUDA_HOME

@boegel
Copy link
Member

boegel commented Jul 4, 2025

@ThomasHoffmann77 Any updates on this?

@ThomasHoffmann77
Copy link
Contributor Author

@ThomasHoffmann77 Any updates on this?

@boegel We have this jax 0.4.35 running with AF3 at EMBL. I did not further work on an update to jax 0.6.0 yet. The current PR still downloads lots of Bazel packages at build time. Some more critical review and testing would be beneficial.

@pavelToman
Copy link
Collaborator

pavelToman commented Jul 11, 2025

For me it looks as the same problem as with TensorFlow and cupti and XLA: easybuilders/easybuild-easyblocks#3765
What do you think @Flamefire ?

@Flamefire
Copy link
Contributor

Flamefire commented Jul 20, 2025

@Thyre Thyre added the 2024a issues & PRs related to 2024a common toolchains label Aug 18, 2025
@pavelToman
Copy link
Collaborator

I just create PR for jax-0.6.2 with CUDA-12.6.0:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2024a issues & PRs related to 2024a common toolchains update

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants