diff --git a/easybuild/easyblocks/p/pytorch.py b/easybuild/easyblocks/p/pytorch.py index 862afbf61de..7bb1b7a313b 100755 --- a/easybuild/easyblocks/p/pytorch.py +++ b/easybuild/easyblocks/p/pytorch.py @@ -448,6 +448,7 @@ def add_enable_option(name, enabled): elif pytorch_version >= '1.9.0' and get_software_root('BLIS'): options.append('BLAS=BLIS') options.append('BLIS_HOME=' + get_software_root('BLIS')) + options.append('USE_MKLDNN=ON') options.append('USE_MKLDNN_CBLAS=ON') elif get_software_root('OpenBLAS'): # This is what PyTorch defaults to if no MKL is found. @@ -463,6 +464,13 @@ def add_enable_option(name, enabled): else: raise EasyBuildError("Did not find a supported BLAS in dependencies. Don't know which BLAS lib to use") + if pytorch_version >= '1.10': + acl_root = get_software_root('ArmComputeLibrary') + if acl_root: + options.append('USE_MKLDNN=ON') + options.append('USE_MKLDNN_ACL=ON') + env.setvar('ACL_ROOT_DIR', acl_root) + available_dependency_options = EB_PyTorch.get_dependency_options_for_version(self.version) dependency_names = self.cfg.dependency_names() not_used_dep_names = [] @@ -534,7 +542,7 @@ def add_enable_option(name, enabled): build_type = self.cfg.get('build_type') if build_type is None: - build_type = 'Debug' if self.toolchain.options.get('debug', None) else 'Release' + build_type = 'Debug' if self.toolchain.options.get('debug') else 'Release' else: for name in ('prebuildopts', 'preinstallopts', 'custom_opts'): if '-DCMAKE_BUILD_TYPE=' in self.cfg[name]: @@ -551,7 +559,7 @@ def add_enable_option(name, enabled): unique_options = self.cfg['custom_opts'] for option in options: - name = option.split('=')[0] + '=' # Include the equals sign to avoid partial matches + name = option.split('=', maxsplit=1)[0] + '=' # Include the equals sign to avoid partial matches if not any(opt.startswith(name) for opt in unique_options): unique_options.append(option) @@ -627,6 +635,7 @@ def test_step(self): env.setvar('SANDCASTLE', '1') # Skip this test(s) which is very flaky env.setvar('SKIP_TEST_BOTTLENECK', '1') + env.setvar('MAX_JOBS', str(self.cfg.parallel)) if self.has_xml_test_reports: env.setvar(self.GENERATE_TEST_REPORT_VAR_NAME, '1') # Parse excluded_tests and flatten into space separated string