Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc
)

# Include CUDA source files if CUDA is enabled
Expand Down
66 changes: 32 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def get_cplus_compiler():
return None


@functools.lru_cache(maxsize=None)
def get_cython_compiler() -> Optional[str]:
"""Return the path to the Cython compiler.

Expand Down Expand Up @@ -238,6 +239,17 @@ def get_cython_compiler() -> Optional[str]:
return None


@functools.lru_cache(maxsize=None)
def get_cmake_path() -> str:
"""Return the path to the CMake compiler.
"""
# found which cmake is used
cmake_path = shutil.which("cmake")
if not os.path.exists(cmake_path):
raise Exception("CMake is not installed, please install it first.")
return cmake_path

Comment on lines +242 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden get_cmake_path: handle None from shutil.which, support env override, raise RuntimeError

Current code calls os.path.exists on a possible None and throws TypeError; also only checks cmake, not cmake3, and uses generic Exception.

-@functools.lru_cache(maxsize=None)
-def get_cmake_path() -> str:
-    """Return the path to the CMake compiler.
-    """
-    # found which cmake is used
-    cmake_path = shutil.which("cmake")
-    if not os.path.exists(cmake_path):
-        raise Exception("CMake is not installed, please install it first.")
-    return cmake_path
+@functools.lru_cache(maxsize=None)
+def get_cmake_path() -> str:
+    """Return the path to the CMake executable."""
+    explicit = os.environ.get("CMAKE") or os.environ.get("CMAKE_PATH")
+    candidates = [explicit, shutil.which("cmake"), shutil.which("cmake3")]
+    for p in candidates:
+        if p and os.path.exists(p) and os.access(p, os.X_OK):
+            return p
+    raise RuntimeError(
+        "CMake not found. Install CMake (>=3.22), ensure it is on PATH, or set CMAKE/CMAKE_PATH."
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@functools.lru_cache(maxsize=None)
def get_cmake_path() -> str:
"""Return the path to the CMake compiler.
"""
# found which cmake is used
cmake_path = shutil.which("cmake")
if not os.path.exists(cmake_path):
raise Exception("CMake is not installed, please install it first.")
return cmake_path
@functools.lru_cache(maxsize=None)
def get_cmake_path() -> str:
"""Return the path to the CMake executable."""
explicit = os.environ.get("CMAKE") or os.environ.get("CMAKE_PATH")
candidates = [explicit, shutil.which("cmake"), shutil.which("cmake3")]
for p in candidates:
if p and os.path.exists(p) and os.access(p, os.X_OK):
return p
raise RuntimeError(
"CMake not found. Install CMake (>=3.22), ensure it is on PATH, or set CMAKE/CMAKE_PATH."
)
🧰 Tools
🪛 Ruff (0.12.2)

249-249: Create your own exception

(TRY002)


249-249: Avoid specifying long messages outside the exception class

(TRY003)


def get_system_info():
system = platform.system().lower()
if system == "linux":
Expand Down Expand Up @@ -338,33 +350,6 @@ def is_git_repo():
raise RuntimeError("Failed to update submodules") from error


def build_csrc(llvm_config_path):
"""Configures and builds TVM."""

if not os.path.exists("build"):
os.makedirs("build")
os.chdir("build")
# Copy the config.cmake as a baseline
if not os.path.exists("config.cmake"):
shutil.copy("../3rdparty/tvm/cmake/config.cmake", "config.cmake")
# Set LLVM path and enable CUDA or ROCM in config.cmake
with open("config.cmake", "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
if USE_ROCM:
config_file.write(f"set(USE_ROCM {ROCM_HOME})\n")
config_file.write("set(USE_CUDA OFF)\n")
else:
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
config_file.write("set(USE_ROCM OFF)\n")
# Run CMake and make
try:
subprocess.check_call(["cmake", ".."])
num_jobs = max(1, int(multiprocessing.cpu_count() * 0.75))
subprocess.check_call(["make", f"-j{num_jobs}"])
except subprocess.CalledProcessError as error:
raise RuntimeError("Failed to build TileLang C Source") from error


def setup_llvm_for_tvm():
"""Downloads and extracts LLVM, then configures TVM to use it."""
# Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script
Expand Down Expand Up @@ -627,7 +612,10 @@ class TilelangExtensionBuild(build_ext):
def run(self):
# Check if CMake is installed and accessible by attempting to run 'cmake --version'.
try:
subprocess.check_output(["cmake", "--version"])
cmake_path = get_cmake_path()
if not cmake_path:
raise Exception("CMake is not installed, please install it first.")
subprocess.check_output([cmake_path, "--version"])
except OSError as error:
# If CMake is not found, raise an error.
raise RuntimeError(
Expand Down Expand Up @@ -830,15 +818,25 @@ def build_cmake(self, ext):
else:
print(f"[Config] No changes: {dst_config}")

cmake_path = get_cmake_path()
# Run CMake to configure the project with the given arguments.
if not os.path.exists(build_temp + "/build.ninja"):
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
if not os.path.exists(os.path.join(build_temp, "build.ninja")):
logger.info(
f"[CMake] Generating build.ninja: {cmake_path} {ext.sourcedir} {' '.join(cmake_args)}"
)
subprocess.check_call([cmake_path, ext.sourcedir] + cmake_args, cwd=build_temp)
else:
logger.info(f"[CMake] build.ninja already exists in {build_temp}")

# Build the project in "Release" mode with all available CPU cores ("-j").
num_jobs = max(1, int(multiprocessing.cpu_count() * 0.75))
subprocess.check_call(["cmake", "--build", ".", "--config", "Release", "-j",
str(num_jobs)],
cwd=build_temp)
logger.info(
f"[Build] Using {num_jobs} jobs | cmake: {cmake_path} (exists: {os.path.exists(cmake_path)}) | build dir: {build_temp}"
)

subprocess.check_call(
[cmake_path, "--build", ".", "--config", "Release", "-j",
str(num_jobs)],
cwd=build_temp)


setup(
Expand Down
138 changes: 138 additions & 0 deletions src/target/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*!
* \file intrin_rule_cuda.cc
* \brief CUDA intrinsic rules.
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>

#include "target/intrin_rule.h"

namespace tvm {
namespace codegen {
namespace intrin {
// Add float suffix to the intrinsics, CUDA fast math.
using tir::FLowerIntrinsic;

struct CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
case 64:
return name;
case 32:
return name + 'f';
case 16: {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
}
default:
return "";
}
} else if (t.is_bfloat16()) {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
} else if (t.is_int() || t.is_uint()) {
switch (t.bits()) {
case 32:
return "__" + name;
case 64:
return "__" + name + "ll";
default:
return "";
}
}
return "";
}
};

struct CUDAFastMath : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return CUDAMath::operator()(t, name);
}
return "";
}
};

struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
case 64:
return name;
// `__tanf` seems to produce some values too deviant from numpy tan
// version. So, let's use just `tanf` instead.
case 32:
return name + 'f';
case 16:
return 'h' + name;
default:
return "";
}
}
return "";
}
};

struct CUDAPopcount {
std::string operator()(DataType t, std::string name) const {
if (t.is_uint()) {
switch (t.bits()) {
case 32:
return "__popc";
case 64:
return "__popcll";
default:
return "";
}
}
return "";
}
};

struct CUDAWarpIntrinsic {
const Op operator()(DataType t, const Op &orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.cuda.__shfl_sync");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.cuda.__shfl_up_sync");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.cuda.__shfl_down_sync");
}
}
};

static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) {
const CallNode *call = e.as<CallNode>();
return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
}

template <typename T> static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) {
const CallNode *call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{
{call->args[0], call->args[1], call->args[2], call->args[3]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
}

TVM_REGISTER_OP("tir.rsqrt")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);

} // namespace intrin
} // namespace codegen
} // namespace tvm
5 changes: 5 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
return half_t(__habs(x.to_half()));
}

// hrsqrt function for half_t
TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
return half_t(hrsqrt(x.to_half()));
}

Comment on lines +58 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Fix infinite recursion in hrsqrt(half_t) and use device rsqrtf

hrsqrt(x.to_half()) can resolve back to this same overload via implicit conversion (half → half_t), causing infinite recursion at runtime. Also, relying on a non-existent hrsqrt(half) intrinsic is risky. Convert to float, use CUDA’s rsqrtf, then convert back.

-// hrsqrt function for half_t
-TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
-  return half_t(hrsqrt(x.to_half()));
-}
+// hrsqrt function for half_t
+TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
+  float xf = static_cast<float>(x);
+  float rf = rsqrtf(xf);
+  return half_t(rf);
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// hrsqrt function for half_t
TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
return half_t(hrsqrt(x.to_half()));
}
// hrsqrt function for half_t
TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
float xf = static_cast<float>(x);
float rf = rsqrtf(xf);
return half_t(rf);
}
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 58 to 62, the hrsqrt(const half_t
x) overload currently calls hrsqrt(x.to_half()) which can re-resolve to the same
overload and recurse infinitely; instead convert the half_t to float, call
CUDA’s device rsqrtf on the float, and convert the result back to half_t for
return. Update the function to obtain a float from x (e.g., x.to_float() or
equivalent), compute rsqrtf(f), then construct and return a half_t from that
float so the intrinsic rsqrtf is used and recursion is avoided.

💡 Verification agent

🧩 Analysis chain

Add bfloat16_t overload (rsqrt) to match lowering for BF16

CUDAMath maps BF16 rsqrt to a symbol prefixed with h as well; without an overload, BF16 tir.rsqrt may fail to link. Provide a BF16 wrapper mirroring the FP16 path.

Proposed addition (outside the shown hunk):

TL_PATCH TL_DEVICE bfloat16_t hrsqrt(const bfloat16_t x) {
  float xf = static_cast<float>(x);
  float rf = rsqrtf(xf);
  return bfloat16_t(rf);
}

If you prefer distinct naming (e.g., brsqrt), update the CUDA lowering accordingly instead of reusing hrsqrt.


🏁 Script executed:

#!/bin/bash
# Search for existing bfloat16_t overloads or mentions in cuda common.h
rg -n "bfloat16_t" -C3 src/tl_templates/cuda/common.h
# Search specifically for TL_DEVICE functions taking bfloat16_t
rg -nP --type=cpp "TL_DEVICE.*bfloat16_t" -C3 src/tl_templates/cuda/common.h

Length of output: 2301


Add bfloat16_t hrsqrt overload
In src/tl_templates/cuda/common.h there’s a half-precision hrsqrt but no matching BF16 wrapper—calling tir.rsqrt on bfloat16_t will fail to link. Insert after the existing half_t overload:

TL_PATCH TL_DEVICE bfloat16_t hrsqrt(const bfloat16_t x) {
  float xf = static_cast<float>(x);
  float rf = rsqrtf(xf);
  return bfloat16_t(rf);
}
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 58 to 62, there is a half_t
hrsqrt overload but no bfloat16_t overload, so calls using bfloat16_t fail to
link; add a TL_PATCH TL_DEVICE overload for bfloat16_t named hrsqrt that
converts the bfloat16 to float, calls rsqrtf on the float, and converts the
result back to bfloat16_t, matching the style and placement of the existing
half_t wrapper.

// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
Expand Down