Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions easybuild/easyblocks/p/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def add_enable_option(name, enabled):
elif pytorch_version >= '1.9.0' and get_software_root('BLIS'):
options.append('BLAS=BLIS')
options.append('BLIS_HOME=' + get_software_root('BLIS'))
options.append('USE_MKLDNN=ON')
options.append('USE_MKLDNN_CBLAS=ON')
elif get_software_root('OpenBLAS'):
# This is what PyTorch defaults to if no MKL is found.
Expand All @@ -463,6 +464,13 @@ def add_enable_option(name, enabled):
else:
raise EasyBuildError("Did not find a supported BLAS in dependencies. Don't know which BLAS lib to use")

if pytorch_version >= '1.10':
acl_root = get_software_root('ArmComputeLibrary')
if acl_root:
options.append('USE_MKLDNN=ON')
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you also want options.append('USE_MKLDNN_CBLAS=ON') here (just from the comment in easybuilders/easybuild-easyconfigs#21309 (comment) by @migueldiascosta )?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this being set in the PyTorch source. It defaults to OFF, and CI doesn't set it:
https://github.com/pytorch/pytorch/blob/1d16a0978458457dc5c6b50bc19a37359a4bd822/.ci/pytorch/build.sh#L78-L82

Hence I left it disabled here too. It can be set using custom_opts e.g. via hooks

options.append('USE_MKLDNN_ACL=ON')
env.setvar('ACL_ROOT_DIR', acl_root)

available_dependency_options = EB_PyTorch.get_dependency_options_for_version(self.version)
dependency_names = self.cfg.dependency_names()
not_used_dep_names = []
Expand Down Expand Up @@ -534,7 +542,7 @@ def add_enable_option(name, enabled):

build_type = self.cfg.get('build_type')
if build_type is None:
build_type = 'Debug' if self.toolchain.options.get('debug', None) else 'Release'
build_type = 'Debug' if self.toolchain.options.get('debug') else 'Release'
else:
for name in ('prebuildopts', 'preinstallopts', 'custom_opts'):
if '-DCMAKE_BUILD_TYPE=' in self.cfg[name]:
Expand All @@ -551,7 +559,7 @@ def add_enable_option(name, enabled):

unique_options = self.cfg['custom_opts']
for option in options:
name = option.split('=')[0] + '=' # Include the equals sign to avoid partial matches
name = option.split('=', maxsplit=1)[0] + '=' # Include the equals sign to avoid partial matches
if not any(opt.startswith(name) for opt in unique_options):
unique_options.append(option)

Expand Down Expand Up @@ -627,6 +635,7 @@ def test_step(self):
env.setvar('SANDCASTLE', '1')
# Skip this test(s) which is very flaky
env.setvar('SKIP_TEST_BOTTLENECK', '1')
env.setvar('MAX_JOBS', str(self.cfg.parallel))
if self.has_xml_test_reports:
env.setvar(self.GENERATE_TEST_REPORT_VAR_NAME, '1')
# Parse excluded_tests and flatten into space separated string
Expand Down
Loading