From 56159c95a00c70357dbf093f1d6528db3175470f Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 3 Aug 2023 18:34:37 -0700 Subject: [PATCH] [Bugfix][CUTLASS] CUTLASS path finding (#15476) This PR fixes the path finding for `3rdparty/cutlass` by trying out different combinations. It is helpful when TVM is packaged differently. --- python/tvm/_ffi/libinfo.py | 9 +++++++-- python/tvm/contrib/cutlass/build.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 19b8b50bbf29..3ec78fe278ae 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Library information.""" -import sys import os +import sys def split_env_var(env_var, split): @@ -169,7 +169,12 @@ def find_include_path(name=None, search_path=None, optional=False): source_dir = os.environ["TVM_HOME"] else: ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - source_dir = os.path.join(ffi_dir, "..", "..", "..") + for source_dir in ["..", "../..", "../../.."]: + source_dir = os.path.join(ffi_dir, source_dir) + if os.path.isdir(os.path.join(source_dir, "include")): + break + else: + raise AssertionError("Cannot find the source directory given ffi_dir: {ffi_dir}") third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 7b5a6b97b907..cd990d0eeae1 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -37,13 +37,14 @@ def has_cutlass(): def _get_cutlass_path(): - tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../../") - cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass") - assert os.path.exists(cutlass_path), ( - f"The CUTLASS root directory not found in {cutlass_path}. Currently, using CUTLASS " - f"requires building TVM from source." - ) - return cutlass_path + invalid_paths = [] + for rel in ["../../../../", "../../../", "../../"]: + tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), rel) + cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass") + if os.path.exists(cutlass_path): + return cutlass_path + invalid_paths.append(cutlass_path) + raise AssertionError(f"The CUTLASS root directory not found in: {invalid_paths}") def _get_cutlass_compile_options(sm, threads, use_fast_math=False):