Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9164c61
adding easyconfigs: jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb and patches:…
ThomasHoffmann77 Nov 28, 2024
179493f
add ml_dtypes v0.5.0
ThomasHoffmann77 Nov 28, 2024
d9668f2
fix style
ThomasHoffmann77 Nov 28, 2024
794c15d
checksums; add missing patch
ThomasHoffmann77 Nov 28, 2024
5724407
borrow pybind11/2.13.6 from PR #21864
ThomasHoffmann77 Nov 28, 2024
b910ecb
temporarily add pytest-xdist from #21879
ThomasHoffmann77 Nov 29, 2024
aa1ab42
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Nov 29, 2024
44ebc27
Update easyconfigs.py
ThomasHoffmann77 Nov 29, 2024
31eeb75
Update easyconfigs.py
ThomasHoffmann77 Nov 29, 2024
39e03a3
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Nov 30, 2024
a9ad131
Merge branch 'easybuilders:develop' into 20241128144208_new_pr_jax0435
ThomasHoffmann77 Dec 1, 2024
e6bb0e0
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 2, 2024
086c5ec
temporarily add SciPy-bundle with pybind11 builddependency
ThomasHoffmann77 Dec 2, 2024
54916df
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 2, 2024
fc5b969
Update jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 2, 2024
849f9fb
test v0.4.34 with pybind11/2.12.0 builddep
ThomasHoffmann77 Dec 20, 2024
b0afceb
Delete easybuild/easyconfigs/s/SciPy-bundle/SciPy-bundle-2024.05-gfbf…
ThomasHoffmann77 Dec 20, 2024
2e75e9f
Delete easybuild/easyconfigs/p/pybind11/pybind11-2.13.6-GCC-13.3.0.eb
ThomasHoffmann77 Dec 20, 2024
b988776
Delete easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 20, 2024
6db6d36
Update jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Dec 20, 2024
f3fc230
revert SciPy-bundle-2024.05-gfbf-2024a.eb
ThomasHoffmann77 Dec 20, 2024
45153cf
Update jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Jan 16, 2025
259590c
Update jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 Jan 16, 2025
ef36815
also build jax_cuda12_plugin and jax_cuda12_pjrt
ThomasHoffmann77 Jan 17, 2025
26adb0e
set XLA_FLAGS xla_gpu_cuda_data_dir to $CUDA_HOME
ThomasHoffmann77 Jan 22, 2025
3695aff
fix style
ThomasHoffmann77 Jan 22, 2025
8fc8f35
add EC for Bazel v6.5.0
ThomasHoffmann77 Jan 24, 2025
65dd435
Merge branch 'easybuilders:develop' into 20241128144208_new_pr_jax0435
ThomasHoffmann77 Feb 14, 2025
f95379c
Merge branch 'develop' of https://github.com/easybuilders/easybuild-e…
boegel Feb 24, 2025
96874a4
Delete easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.5.0-gfbf-2024a.eb
ThomasHoffmann77 Mar 6, 2025
caed690
Merge branch 'easybuilders:develop' into 20241128144208_new_pr_jax0435
ThomasHoffmann77 Mar 13, 2025
8acb285
switch to 0.4.35; read xla_gpu_cuda_data_dir from
ThomasHoffmann77 May 20, 2025
2ac7d2e
Delete easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
ThomasHoffmann77 May 20, 2025
65356d7
Delete easybuild/easyconfigs/j/jax/jax-0.4.35_fix-pybind11-systemlib_…
ThomasHoffmann77 May 20, 2025
a2ec79c
whitespaces
ThomasHoffmann77 May 20, 2025
2172e30
whitepace
ThomasHoffmann77 May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions easybuild/easyconfigs/b/Bazel/Bazel-6.5.0-GCCcore-13.3.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name = 'Bazel'
version = '6.5.0'

homepage = 'https://bazel.io/'
description = """Bazel is a build tool that builds code quickly and reliably.
It is used to build the majority of Google's software."""

toolchain = {'name': 'GCCcore', 'version': '13.3.0'}

source_urls = ['https://github.com/bazelbuild/%(namelower)s/releases/download/%(version)s']
sources = ['%(namelower)s-%(version)s-dist.zip']
patches = ['Bazel-6.5.0_py3.12_pytest_assertEqual.patch']
checksums = [
{'bazel-6.5.0-dist.zip': 'fc89da919415289f29e4ff18a5e01270ece9a6fe83cb60967218bac4a3bb3ed2'},
{'Bazel-6.5.0_py3.12_pytest_assertEqual.patch': '2670dd5c393970ba20db2c98cf0208df7190ff339ccb66fee9a6d48aaaf3ede6'},
]

builddependencies = [
('binutils', '2.42'),
('Python', '3.12.3'),
('Zip', '3.0'),
]

dependencies = [
('Java', '11.0.20', '', SYSTEM),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Java/21 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@boegel Build of Bazel 6.5.0 fails for both, Java/21 and Java/21.0.2.

]

runtest = True
testopts = "-- //examples/cpp:hello-success_test //examples/py/... //examples/py_native:test //examples/shell/..."

moduleclass = 'devel'
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2025/01
# replace assertEquals by assertEqual
# https://docs.python.org/3/whatsnew/3.12.html#id3
diff -ru bazel-6.5.0/examples/py_native/fail.py bazel-6.5.0_pytest_assertEqual/examples/py_native/fail.py
--- bazel-6.5.0/examples/py_native/fail.py 1980-01-01 00:00:00.000000000 +0100
+++ bazel-6.5.0_pytest_assertEqual/examples/py_native/fail.py 2025-01-24 14:27:22.973336188 +0100
@@ -6,7 +6,7 @@
class TestGetNumber(unittest.TestCase):

def test_fail(self):
- self.assertEquals(GetNumber(), 0)
+ self.assertEqual(GetNumber(), 0)


if __name__ == '__main__':
diff -ru bazel-6.5.0/examples/py_native/test.py bazel-6.5.0_pytest_assertEqual/examples/py_native/test.py
--- bazel-6.5.0/examples/py_native/test.py 1980-01-01 00:00:00.000000000 +0100
+++ bazel-6.5.0_pytest_assertEqual/examples/py_native/test.py 2025-01-24 14:27:22.973336188 +0100
@@ -8,10 +8,10 @@
class TestGetNumber(unittest.TestCase):

def test_ok(self):
- self.assertEquals(GetNumber(), 42)
+ self.assertEqual(GetNumber(), 42)

def test_fib(self):
- self.assertEquals(Fib(5), 8)
+ self.assertEqual(Fib(5), 8)

if __name__ == '__main__':
unittest.main()
216 changes: 216 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Pavel Tománek (INUITS)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.35'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
# ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime:
# Error in fail: interpreter_path must be an absolute path
# Bazel 6.5.0 (download) works.
('pybind11', '2.13.6'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0.
('pytest-xdist', '3.6.1'),
('git', '2.45.1'), # bazel uses git to fetch repositories
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py
('poetry', '1.8.3'),
('Clang', '18.1.8', versionsuffix)
]

dependencies = [
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ?
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
('NCCL', '2.22.3', versionsuffix),
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'), # 2024.11 ?
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('zlib', '1.3.1'),
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c'
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required?
# TODO: add other downloads

# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# Use dependencies from EasyBuild
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
# Avoid warning (treated as error) in upb/table.c
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required?
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?)
_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """

_plugins_buildopts = """--enable_cuda """
_plugins_buildopts += """--build_gpu_plugin """
# _plugins_buildopts +="""--gpu_plugin_cuda_version=12 """
_plugins_buildopts += """--build_gpu_pjrt_plugin """
_plugins_buildopts += """--build_gpu_kernel_plugin=cuda """

# get rid of .devDate versionsuffix: TODO: find a better way
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?)
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required?

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'patches': [
'jax-0.4.35_easyblock_compat.patch',
'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch',
'jax-0.4.35_version.patch',
],
'checksums': [
{'jax-v0.4.35.tar.gz':
'65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
{'xla-76da7301.tar.gz':
'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.35_easyblock_compat.patch':
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
{'jax-0.4.35_fix-pybind11-systemlib_cupti_CUDA_HOME.patch':
'fa5273d31651579590f7291fc151836f43024f74f4c89243dc4c6a417284e7ce'},
{'jax-0.4.35_version.patch':
'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts,
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
_no_devtag
}),
# build jaxlib first and then plugins in 2nd interation:
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'checksums': [
{'jax-v0.4.35.tar.gz':
'65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
{'xla-76da7301.tar.gz':
'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts + _plugins_buildopts,
'prebuildopts': _no_devtag
}),
]
# failing:
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 FAILED [ 98%]
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 FAILED [ 98%]
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 FAILED [ 99%]
# tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 FAILED [ 99%]
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex128 - AssertionError:
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_expm1_complex64 - AssertionError:
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex128 - AssertionError:
# FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_tan_complex64 - AssertionError:
# tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 FAILED [ 10%]
# FAILED tests/nn_test.py::NNFunctionsTest::testDotProductAttentionMask7 - AssertionError:
#


# Some tests require an isolated run: TODO: still required?
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])


exts_list = [
(name, version, {
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'patches': ['jax-0.4.35_version.patch'],
'checksums': [
{'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
{'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
'runtest': False, # tmp
'preinstallopts': _no_devtag
}),
]
sanity_check_commands = [
"""python -c "import jax_cuda"$(echo $EBVERSIONCUDA|awk -F '.' '{print $1}')"_plugin" """
]


moduleclass = 'ai'
21 changes: 21 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35_easyblock_compat.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/11
# add dummy parameters to build/build.py for cudnn_path and cuda_path, which are set by default by the jaxlib easyblock.
diff -ru jax-jax-v0.4.35/build/build.py jax-jax-v0.4.35_easyblockcompat/build/build.py
--- jax-jax-v0.4.35/build/build.py 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_easyblockcompat/build/build.py 2024-11-19 12:35:46.524479324 +0100
@@ -549,6 +549,15 @@
help_str="Same as update_requirements, but will consider dev, nightly "
"and pre-release versions of packages.")

+ parser.add_argument(
+ "--cuda_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+ parser.add_argument(
+ "--cudnn_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+
args = parser.parse_args()

logging.basicConfig()
Loading