diff --git a/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb new file mode 100644 index 000000000000..7302e39a56ee --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb @@ -0,0 +1,95 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +# Updated by: Pavel Tománek (INUITS) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.6.2' + +homepage = 'https://jax.readthedocs.io/' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'gfbf', 'version': '2024a'} + +builddependencies = [ + ('Bazel', '7.4.1', '-Java-21'), +] + +dependencies = [ + ('Python', '3.12.3'), + ('SciPy-bundle', '2024.05'), + ('absl-py', '2.1.0'), + ('flatbuffers-python', '24.3.25'), + ('ml_dtypes', '0.5.0'), + ('hypothesis', '6.103.1'), + ('zlib', '1.3.1'), +] + +# downloading xla tarball to avoid that Bazel downloads it during the build +# note: following commits *must* be the exact same onces used upstream +# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl +_xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc' + +# Use sources downloaded by EasyBuild +_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" ' +# create wheels for jaxlib and install it +_jaxlib_buildopts += '--wheels=jaxlib ' +# use GCC instead of default Clang +_jaxlib_buildopts += '--use_clang=false --gcc_path=$CC ' +# fix jaxlib version +_jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" ' + +components = [ + ('jaxlib', version, { + 'sources': [ + { + 'source_urls': ['https://github.com/google/jax/archive/'], + 'filename': 'jax-v%(version)s.tar.gz', + }, + { + 'source_urls': ['https://github.com/openxla/xla/archive'], + 'download_filename': f'{_xla_commit}.tar.gz', + 'filename': f'xla-{_xla_commit[:8]}.tar.gz', + 'extract_cmd': 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives', + }, + ], + 'checksums': [ + {'jax-v0.6.2.tar.gz': + 'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'}, + {'xla-3d5ece64.tar.gz': + 'fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a'}, + ], + 'start_dir': f'jax-jax-v{version}', + # fix jaxlib version - removes .dev suffix + 'prebuildopts': 'export JAXLIB_RELEASE=%(version)s && ', + 'buildopts': _jaxlib_buildopts, + }), +] + +# BackendsTest is failing on systems with GPU detected when cpu version is built +_failing_backend_test = 'tests/api_test.py::BackendsTest::test_no_backend_warning_on_cpu_if_platform_specified' +# Failing after fix jaxlib version by JAXLIB_RELEASE +_failing_version_tests = 'tests/version_test.py' +_runtest_cmd1 = f"pytest -n %(parallel)s tests --deselect={_failing_backend_test} --deselect={_failing_version_tests}" +# retry randomly failing tests +_runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed']) + +exts_list = [ + (name, version, { + "patches": ["jax-0.6.2_jax-version-fix.patch"], + "testinstall": True, + "runtest": _runtest_cmd, + "sources": [ + {"source_urls": ["https://github.com/google/jax/archive/"], "filename": "%(name)s-v%(version)s.tar.gz"} + ], + "checksums": [ + {"jax-v0.6.2.tar.gz": "d46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17"}, + {"jax-0.6.2_jax-version-fix.patch": "e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4"}, + ], + }), +] + +moduleclass = 'ai' diff --git a/easybuild/easyconfigs/j/jax/jax-0.6.2_jax-version-fix.patch b/easybuild/easyconfigs/j/jax/jax-0.6.2_jax-version-fix.patch new file mode 100644 index 000000000000..4dcd40824a76 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.6.2_jax-version-fix.patch @@ -0,0 +1,20 @@ +Fix jax version to not show .dev +diff -ru jax-jax-v0.6.2/jax/version.py jax-jax-v0.4.35_version/jax/version.py +--- jax-jax-v0.6.2/jax/version.py.orig 2024-10-22 21:00:23.000000000 +0200 ++++ jax-jax-v0.6.2/jax/version.py 2024-11-28 13:10:52.508536023 +0100 +@@ -33,6 +33,7 @@ + def _get_version_string() -> str: + # The build/source distribution for jax & jaxlib overwrites _release_version. + # In this case we return it directly. ++ return _version + if _release_version is not None: + return _release_version + if os.getenv("WHEEL_VERSION_SUFFIX"): +@@ -82,6 +83,7 @@ + - if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906" + - if none are set: version looks like "0.4.16.dev20230906+ge58560fdc + """ ++ return _version + if _release_version is not None: + return _release_version + if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'): \ No newline at end of file