diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 6a1c6356b1c1..135f71f6cd63 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -251,7 +251,7 @@ def convert_type_repr(x): return x -def make_hash(fn, **kwargs): +def make_hash(fn, arch, **kwargs): if isinstance(fn, triton.runtime.JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] @@ -262,7 +262,7 @@ def make_hash(fn, **kwargs): # Get unique key for the compiled code get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) configs_key = [get_conf_key(conf) for conf in configs] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}" + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() @@ -418,7 +418,7 @@ def compile(fn, **kwargs): # cache manager so_path = make_stub(name, signature, constants) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs)) # determine name and extension type of provided function if isinstance(fn, triton.runtime.JITFunction): name, ext = fn.__name__, "ast"