Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions testing/python/components/test_tilelang_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tilelang
import os


def test_env_var():
# test default value
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
# test forced value
os.environ["TILELANG_PRINT_ON_COMPILATION"] = "0"
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "0"
# test forced value with class method
tilelang.env.TILELANG_PRINT_ON_COMPILATION = "1"
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"


if __name__ == "__main__":
test_env_var()
6 changes: 3 additions & 3 deletions tilelang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def _init_logger():

logger = logging.getLogger(__name__)

from .env import SKIP_LOADING_TILELANG_SO
from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
from .env import env as env # noqa: F401

import tvm
import tvm.base
Expand All @@ -76,12 +76,12 @@ def _load_tile_lang_lib():


# only load once here
if SKIP_LOADING_TILELANG_SO == "0":
if env.SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib()

from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401
from .cache import cached # noqa: F401
from .cache import clear_cache # noqa: F401

from .utils import (
TensorSupplyType, # noqa: F401
Expand Down
20 changes: 7 additions & 13 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
import traceback
from pathlib import Path

from tilelang.env import (
TILELANG_CACHE_DIR,
TILELANG_AUTO_TUNING_CPU_UTILITIES,
TILELANG_AUTO_TUNING_CPU_COUNTS,
TILELANG_AUTO_TUNING_MAX_CPU_COUNT,
is_cache_enabled,
)
from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.jit.param import _P, _RProg
Expand Down Expand Up @@ -111,7 +105,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"

def __init__(self, fn: Callable, configs):
self.fn = fn
Expand Down Expand Up @@ -285,7 +279,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
key = self.generate_cache_key(parameters)

with self._lock:
if is_cache_enabled():
if env.is_cache_enabled():
# First check in-memory cache
if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \
Expand Down Expand Up @@ -437,9 +431,9 @@ def shape_equal(a, b):
return autotuner_result
# get the cpu count
available_cpu_count = get_available_cpu_count()
cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count)
logger.info(
Expand Down Expand Up @@ -543,7 +537,7 @@ def device_wrapper(func, device, **config_arg):
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
if env.is_cache_enabled():
self._save_result_to_disk(key, autotuner_result)

self._memory_cache[key] = autotuner_result
Expand Down
4 changes: 2 additions & 2 deletions tilelang/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang import env
from .kernel_cache import KernelCache
from tilelang.env import TILELANG_CLEAR_CACHE

# Create singleton instance of KernelCache
_kernel_cache_instance = KernelCache()
Expand Down Expand Up @@ -44,5 +44,5 @@ def clear_cache():
_kernel_cache_instance.clear_cache()


if TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ensure proper string comparison for environment variable

The .lower() call will fail if env.TILELANG_CLEAR_CACHE returns None. Consider adding a safety check.

-if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
+clear_cache_value = env.TILELANG_CLEAR_CACHE
+if clear_cache_value and clear_cache_value.lower() in ("1", "true", "yes", "on"):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
-clear_cache_value = env.TILELANG_CLEAR_CACHE
clear_cache_value = env.TILELANG_CLEAR_CACHE
if clear_cache_value and clear_cache_value.lower() in ("1", "true", "yes", "on"):
🤖 Prompt for AI Agents
In tilelang/cache/__init__.py around line 47, the code calls .lower() on
env.TILELANG_CLEAR_CACHE which will raise if that value is None; change the
check to guard against None by first normalizing or verifying the value (e.g.
use a fallback empty string or test truthiness before calling .lower()), then
perform the lowercased membership test against ("1","true","yes","on") so the
environment missing case does not cause an exception.

clear_cache()
16 changes: 8 additions & 8 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tvm.tir import PrimFunc

from tilelang.engine.param import KernelParam
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TMP_DIR, is_cache_enabled
from tilelang import env
from tilelang.jit import JITKernel
from tilelang.version import __version__

Expand Down Expand Up @@ -61,8 +61,8 @@ def __new__(cls):

@staticmethod
def _create_dirs():
os.makedirs(TILELANG_CACHE_DIR, exist_ok=True)
os.makedirs(TILELANG_TMP_DIR, exist_ok=True)
os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True)
os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True)

def _generate_key(
self,
Expand Down Expand Up @@ -132,7 +132,7 @@ def cached(
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
if not is_cache_enabled():
if not env.is_cache_enabled():
return JITKernel(
func,
out_idx=out_idx,
Expand Down Expand Up @@ -190,7 +190,7 @@ def cached(
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
if env.is_cache_enabled():
self._save_kernel_to_disk(key, kernel, func, verbose)

# Store in memory cache after compilation
Expand All @@ -215,7 +215,7 @@ def _get_cache_path(self, key: str) -> str:
Returns:
str: Absolute path to the cache directory for this kernel.
"""
return os.path.join(TILELANG_CACHE_DIR, key)
return os.path.join(env.TILELANG_CACHE_DIR, key)

@staticmethod
def _load_binary(path: str):
Expand All @@ -226,7 +226,7 @@ def _load_binary(path: str):
@staticmethod
def _safe_write_file(path: str, mode: str, operation: Callable):
# Random a temporary file within the same FS as the cache directory
temp_path = os.path.join(TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
with open(temp_path, mode) as temp_file:
operation(temp_file)

Expand Down Expand Up @@ -396,7 +396,7 @@ def _clear_disk_cache(self):
"""
try:
# Delete the entire cache directory
shutil.rmtree(TILELANG_CACHE_DIR)
shutil.rmtree(env.TILELANG_CACHE_DIR)

# Re-create the cache directory
KernelCache._create_dirs()
Expand Down
2 changes: 1 addition & 1 deletion tilelang/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import subprocess
import warnings
from ..env import CUDA_HOME
from tilelang.env import CUDA_HOME

import tvm.ffi
from tvm.target import Target
Expand Down
Loading
Loading