diff --git a/easybuild/easyblocks/j/jaxlib.py b/easybuild/easyblocks/j/jaxlib.py index 6192f236300..72d6b3b07b2 100644 --- a/easybuild/easyblocks/j/jaxlib.py +++ b/easybuild/easyblocks/j/jaxlib.py @@ -77,9 +77,14 @@ def configure_step(self): # Collect options for the build script # Used only by the build script + options = [] + + # update build command for jaxlib-0.6 to build.py build + if LooseVersion(self.version) >= LooseVersion('0.6.0'): + options.append('build') # C++ flags are set through copt below - options = ['--target_cpu_features=default'] + options.append('--target_cpu_features=default') # Passed directly to bazel bazel_startup_options = [ @@ -125,13 +130,19 @@ def configure_step(self): options.append('--noenable_nccl') config_env_vars['GCC_HOST_COMPILER_PATH'] = which(os.getenv('CC')) - else: + elif LooseVersion(self.version) <= LooseVersion('0.6.0'): options.append('--noenable_cuda') if self.cfg['use_mkl_dnn']: - options.append('--enable_mkl_dnn') - else: + # --enable_mkl_dnn option was removed in jax(lib) v0.4.36, + # see https://github.com/jax-ml/jax/commit/676151265859f8b0dd8baf6f6ae50c3367ed0509 + if LooseVersion(self.version) < LooseVersion('0.4.36'): + options.append('--enable_mkl_dnn') + # if use_mkl_dnn is not enabled, use correct flag to disable use of MKL DNN + elif LooseVersion(self.version) < LooseVersion('0.4.36'): options.append('--noenable_mkl_dnn') + else: + options.append('--disable_mkl_dnn') # Prepend to buildopts so users can overwrite this self.cfg['buildopts'] = ' '.join(