diff --git a/aiter/ops/triton/utils/_triton/arch_info.py b/aiter/ops/triton/utils/_triton/arch_info.py index 20bab95679..9acc2599a6 100644 --- a/aiter/ops/triton/utils/_triton/arch_info.py +++ b/aiter/ops/triton/utils/_triton/arch_info.py @@ -1,4 +1,5 @@ import triton +from functools import lru_cache # For now, there is 1-to-1 correspondence between arch and device _ARCH_TO_DEVICE = { @@ -7,6 +8,7 @@ } +@lru_cache(maxsize=1) def get_arch(): try: arch = ( @@ -21,6 +23,7 @@ def get_arch(): return arch +@lru_cache(maxsize=1) def get_device(): return _ARCH_TO_DEVICE[get_arch()]