diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 62ffd935e..0a70ec4a4 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -27,6 +27,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); #define TIR_DEFINE_TL_BUILTIN(OpName) \ const Op &OpName() { \ diff --git a/src/op/builtin.h b/src/op/builtin.h index 4234b6f4d..f2f5741a5 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -30,6 +30,9 @@ static constexpr const char *kDisableWarpSpecialized = static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; static constexpr const char *kEnableAggressiveSharedMemoryMerge = "tl.enable_aggressive_shared_memory_merge"; +static constexpr const char *kDisableFastMath = "tl.disable_fast_math"; +static constexpr const char *kEnablePTXASVerboseOutput = + "tl.enable_ptxas_verbose_output"; /*! * \brief Whether to disable dynamic tail split diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py index 5aa358f9f..37adf1ff1 100644 --- a/tilelang/jit/adapter/ctypes/adapter.py +++ b/tilelang/jit/adapter/ctypes/adapter.py @@ -90,6 +90,7 @@ def __init__(self, self.verbose = verbose self.wrapper = TLWrapper(self.target) self.lib_generator = LibraryGenerator(self.target) + self.lib_generator.assign_pass_configs(pass_configs) self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_pass_configs(pass_configs) @@ -145,6 +146,7 @@ def from_database(cls, adapter.target = Target.canon_target(determine_target(target)) adapter.verbose = verbose adapter.lib_generator = LibraryGenerator(adapter.target) + adapter.lib_generator.assign_pass_configs(pass_configs) adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib.init() diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index d0bd53030..9861b7af5 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -246,6 +246,7 @@ def __init__(self, self.verbose = verbose self.wrapper = TLWrapper(self.target) self.lib_generator = LibraryGenerator(self.target) + self.lib_generator.assign_pass_configs(pass_configs) self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_pass_configs(pass_configs) @@ -305,6 +306,7 @@ def from_database(cls, adapter.verbose = verbose adapter.lib_generator = LibraryGenerator(adapter.target) + adapter.lib_generator.assign_pass_configs(pass_configs) adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib.get_last_error.restype = ctypes.c_char_p diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 9778dc51d..d63118a09 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -7,11 +7,12 @@ import os.path as osp import subprocess import tempfile -from typing import Optional +from typing import Any, Dict, Optional from tvm.target import Target from tilelang import tvm as tvm +from tilelang.transform import PassConfigKey from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch from tilelang.env import TILELANG_TEMPLATE_PATH @@ -36,10 +37,14 @@ class LibraryGenerator(object): srcpath: Optional[str] = None libpath: Optional[str] = None lib_code: Optional[str] = None + pass_configs: Optional[Dict[str, Any]] = None def __init__(self, target: Target): self.target = target + def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None): + self.pass_configs = pass_configs + def update_lib_code(self, lib_code: str): self.lib_code = lib_code @@ -61,6 +66,10 @@ def compile_lib(self, timeout: float = None): compute_version = "90a" libpath = src.name.replace(".cu", ".so") + disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False) + verbose_ptxas_output = self.pass_configs.get( + PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False) + command = [ get_nvcc_compiler(), "-std=c++17", @@ -76,6 +85,10 @@ def compile_lib(self, timeout: float = None): "-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}", ] + if not disable_fast_math: + command += ["--use_fast_math"] + if verbose_ptxas_output: + command += ["--ptxas_options", "-v"] command += [ "-I" + CUTLASS_INCLUDE_DIR, ] diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index cf426c6d2..6279c3630 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -20,6 +20,12 @@ class PassConfigKey(str, Enum): TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized" """Disable warp specialization optimization. Default: False""" + TL_DISABLE_FAST_MATH = "tl.disable_fast_math" + """Disable fast math optimization. Default: False""" + + TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output" + """Enable ptxas verbose output. Default: False""" + TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth" """Bitwidth for configuration indices. Default: 32"""