Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/target/rt_mod_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -66,7 +67,10 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
std::string ptx;
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).cast<std::string>();
// Fetch current pass context config and pass into the compile callback
tvm::transform::PassContext pass_ctx =
tvm::transform::PassContext::Current();
ptx = (*f)(code, target, pass_ctx->config).cast<std::string>();
if (ptx[0] != '/')
fmt = "cubin";
} else {
Expand Down
9 changes: 1 addition & 8 deletions tilelang/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def compile_cuda(code,
out_file.write(code)

file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd = [get_nvcc_compiler()]
cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None:
cmd += ["-lineinfo"]
Expand Down Expand Up @@ -332,13 +332,6 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")


@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx


@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
Expand Down
42 changes: 30 additions & 12 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from tvm.ir import CallingConv
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
from tilelang.transform import PassConfigKey
from tilelang.utils.deprecated import deprecated_warning
from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target
from tilelang.engine.phase import (
Expand Down Expand Up @@ -54,7 +56,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:


@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target):
def tilelang_callback_cuda_compile(code, target, pass_config=None):
project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
Expand All @@ -69,21 +71,37 @@ def tilelang_callback_cuda_compile(code, target):
target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))

arch = [f"-arch=sm_{target_arch}"]
format = "cubin"
compile_format = "cubin"

# Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib
cfg = pass_config or {}
if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, False):
deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7")
disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, True))
enable_fast_math = not disable_fast_math
else:
enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH.value, False))

ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL.value, None)
verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT.value, False))

options = [
"-std=c++17",
"-I" + tl_template_path,
"-I" + cutlass_path,
]
if enable_fast_math:
options.append("--use_fast_math")
if ptxas_usage_level is not None:
options.append(f"--ptxas-options=--register-usage-level={ptxas_usage_level}")
if verbose_ptxas_output:
options.append("--ptxas-options=--verbose")

# printing out number of registers
debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
ptx = nvcc.compile_cuda(
code,
format,
compile_format,
arch,
options=[
"-std=c++17",
debug_option,
"--use_fast_math",
"-I" + tl_template_path,
"-I" + cutlass_path,
],
options=options,
verbose=False,
)

Expand Down
Loading