Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.6.2-gfbf-2024a.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
# Updated by: Pavel Tománek (INUITS)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.6.2'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}

builddependencies = [
('Bazel', '7.4.1', '-Java-21'),
]

dependencies = [
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'),
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('hypothesis', '6.103.1'),
('zlib', '1.3.1'),
]

# downloading xla tarball to avoid that Bazel downloads it during the build
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
_xla_commit = '3d5ece64321630dade7ff733ae1353fc3c83d9cc'

# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# create wheels for jaxlib and install it
_jaxlib_buildopts += '--wheels=jaxlib '
# use GCC instead of default Clang
_jaxlib_buildopts += '--use_clang=false --gcc_path=$CC '
# fix jaxlib version
_jaxlib_buildopts += '--bazel_options="--action_env=JAXLIB_RELEASE" '

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': f'{_xla_commit}.tar.gz',
'filename': f'xla-{_xla_commit[:8]}.tar.gz',
'extract_cmd': 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives',
},
],
'checksums': [
{'jax-v0.6.2.tar.gz':
'd46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17'},
{'xla-3d5ece64.tar.gz':
'fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a'},
],
'start_dir': f'jax-jax-v{version}',
# fix jaxlib version - removes .dev suffix
'prebuildopts': 'export JAXLIB_RELEASE=%(version)s && ',
'buildopts': _jaxlib_buildopts,
}),
]

# BackendsTest is failing on systems with GPU detected when cpu version is built
_failing_backend_test = 'tests/api_test.py::BackendsTest::test_no_backend_warning_on_cpu_if_platform_specified'
# Failing after fix jaxlib version by JAXLIB_RELEASE
_failing_version_tests = 'tests/version_test.py'
_runtest_cmd1 = f"pytest -n %(parallel)s tests --deselect={_failing_backend_test} --deselect={_failing_version_tests}"
# retry randomly failing tests
_runtest_cmd = ' '.join([_runtest_cmd1, '||', _runtest_cmd1, '--last-failed'])

exts_list = [
(name, version, {
"patches": ["jax-0.6.2_jax-version-fix.patch"],
"testinstall": True,
"runtest": _runtest_cmd,
"sources": [
{"source_urls": ["https://github.com/google/jax/archive/"], "filename": "%(name)s-v%(version)s.tar.gz"}
],
"checksums": [
{"jax-v0.6.2.tar.gz": "d46cb98795f2c1ccdf2b081e02d9d74b659063679a80beb001ad17d482a60e17"},
{"jax-0.6.2_jax-version-fix.patch": "e15615fd9f4e1698f7c5fd384f146d7b2dbfde3d4657b69bd2d044d75c9fb1d4"},
],
}),
]

moduleclass = 'ai'
20 changes: 20 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.6.2_jax-version-fix.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Fix jax version to not show .dev
diff -ru jax-jax-v0.6.2/jax/version.py jax-jax-v0.4.35_version/jax/version.py
--- jax-jax-v0.6.2/jax/version.py.orig 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.6.2/jax/version.py 2024-11-28 13:10:52.508536023 +0100
@@ -33,6 +33,7 @@
def _get_version_string() -> str:
# The build/source distribution for jax & jaxlib overwrites _release_version.
# In this case we return it directly.
+ return _version
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_VERSION_SUFFIX"):
@@ -82,6 +83,7 @@
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
+ return _version
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):