Skip to content

Commit b66f9aa

Browse files
authored
[Math] Dispatch T.rsqrt(x) into cuda intrin instead of 1 / T.sqrt(x) (#781)
* Fix type hint for target_host parameter in compile function to allow None value * Refactor target handling in compile function to utilize determine_target for improved clarity and consistency * Update PrintConst function in codegen_cuda.cc to use hexfloat format for bfloat16 and float8/float4 types, while adding scientific notation comments for clarity. This change enhances the representation of floating-point constants in the generated code. * Refactor PrintType function in codegen_cuda.cc to remove unnecessary failure conditions for floating-point types with lane counts greater than 4. This change simplifies the logic and improves code clarity. * Enhance benchmark_matmul.py to conditionally print Reference TFlops only if ref_latency is not None. Update param.py to ensure target is converted to string for consistency. Refactor tuner.py to utilize determine_target for improved clarity in target handling. * Remove automatic commit and push step from AMD and NVIDIA CI workflows to streamline the process and avoid unnecessary commits. * Add intrin_rule source files to CMakeLists.txt and implement hrsqrt function for half_t in common.h * lint fix * remove cmake dep in pyproject as it may lead to different cmake paths in diff stages * lint fix * Add cmake dependency to pyproject.toml and improve build logging in setup.py
1 parent 021e44e commit b66f9aa

File tree

4 files changed

+177
-34
lines changed

4 files changed

+177
-34
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
124124
src/target/rt_mod_cpp.cc
125125
# webgpu doesn't have system dependency
126126
src/target/codegen_webgpu.cc
127+
# intrin_rule doesn't have system dependency
128+
src/target/intrin_rule*.cc
127129
)
128130

129131
# Include CUDA source files if CUDA is enabled

setup.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def get_cplus_compiler():
203203
return None
204204

205205

206+
@functools.lru_cache(maxsize=None)
206207
def get_cython_compiler() -> Optional[str]:
207208
"""Return the path to the Cython compiler.
208209
@@ -238,6 +239,17 @@ def get_cython_compiler() -> Optional[str]:
238239
return None
239240

240241

242+
@functools.lru_cache(maxsize=None)
243+
def get_cmake_path() -> str:
244+
"""Return the path to the CMake compiler.
245+
"""
246+
# found which cmake is used
247+
cmake_path = shutil.which("cmake")
248+
if not os.path.exists(cmake_path):
249+
raise Exception("CMake is not installed, please install it first.")
250+
return cmake_path
251+
252+
241253
def get_system_info():
242254
system = platform.system().lower()
243255
if system == "linux":
@@ -338,33 +350,6 @@ def is_git_repo():
338350
raise RuntimeError("Failed to update submodules") from error
339351

340352

341-
def build_csrc(llvm_config_path):
342-
"""Configures and builds TVM."""
343-
344-
if not os.path.exists("build"):
345-
os.makedirs("build")
346-
os.chdir("build")
347-
# Copy the config.cmake as a baseline
348-
if not os.path.exists("config.cmake"):
349-
shutil.copy("../3rdparty/tvm/cmake/config.cmake", "config.cmake")
350-
# Set LLVM path and enable CUDA or ROCM in config.cmake
351-
with open("config.cmake", "a") as config_file:
352-
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
353-
if USE_ROCM:
354-
config_file.write(f"set(USE_ROCM {ROCM_HOME})\n")
355-
config_file.write("set(USE_CUDA OFF)\n")
356-
else:
357-
config_file.write(f"set(USE_CUDA {CUDA_HOME})\n")
358-
config_file.write("set(USE_ROCM OFF)\n")
359-
# Run CMake and make
360-
try:
361-
subprocess.check_call(["cmake", ".."])
362-
num_jobs = max(1, int(multiprocessing.cpu_count() * 0.75))
363-
subprocess.check_call(["make", f"-j{num_jobs}"])
364-
except subprocess.CalledProcessError as error:
365-
raise RuntimeError("Failed to build TileLang C Source") from error
366-
367-
368353
def setup_llvm_for_tvm():
369354
"""Downloads and extracts LLVM, then configures TVM to use it."""
370355
# Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script
@@ -627,7 +612,10 @@ class TilelangExtensionBuild(build_ext):
627612
def run(self):
628613
# Check if CMake is installed and accessible by attempting to run 'cmake --version'.
629614
try:
630-
subprocess.check_output(["cmake", "--version"])
615+
cmake_path = get_cmake_path()
616+
if not cmake_path:
617+
raise Exception("CMake is not installed, please install it first.")
618+
subprocess.check_output([cmake_path, "--version"])
631619
except OSError as error:
632620
# If CMake is not found, raise an error.
633621
raise RuntimeError(
@@ -830,15 +818,25 @@ def build_cmake(self, ext):
830818
else:
831819
print(f"[Config] No changes: {dst_config}")
832820

821+
cmake_path = get_cmake_path()
833822
# Run CMake to configure the project with the given arguments.
834-
if not os.path.exists(build_temp + "/build.ninja"):
835-
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
823+
if not os.path.exists(os.path.join(build_temp, "build.ninja")):
824+
logger.info(
825+
f"[CMake] Generating build.ninja: {cmake_path} {ext.sourcedir} {' '.join(cmake_args)}"
826+
)
827+
subprocess.check_call([cmake_path, ext.sourcedir] + cmake_args, cwd=build_temp)
828+
else:
829+
logger.info(f"[CMake] build.ninja already exists in {build_temp}")
836830

837-
# Build the project in "Release" mode with all available CPU cores ("-j").
838831
num_jobs = max(1, int(multiprocessing.cpu_count() * 0.75))
839-
subprocess.check_call(["cmake", "--build", ".", "--config", "Release", "-j",
840-
str(num_jobs)],
841-
cwd=build_temp)
832+
logger.info(
833+
f"[Build] Using {num_jobs} jobs | cmake: {cmake_path} (exists: {os.path.exists(cmake_path)}) | build dir: {build_temp}"
834+
)
835+
836+
subprocess.check_call(
837+
[cmake_path, "--build", ".", "--config", "Release", "-j",
838+
str(num_jobs)],
839+
cwd=build_temp)
842840

843841

844842
setup(

src/target/intrin_rule_cuda.cc

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*!
2+
* \file intrin_rule_cuda.cc
3+
* \brief CUDA intrinsic rules.
4+
*/
5+
#include <tvm/tir/builtin.h>
6+
#include <tvm/tir/op_attr_types.h>
7+
8+
#include "target/intrin_rule.h"
9+
10+
namespace tvm {
11+
namespace codegen {
12+
namespace intrin {
13+
// Add float suffix to the intrinsics, CUDA fast math.
14+
using tir::FLowerIntrinsic;
15+
16+
struct CUDAMath {
17+
std::string operator()(DataType t, std::string name) const {
18+
if (t.is_float()) {
19+
switch (t.bits()) {
20+
case 64:
21+
return name;
22+
case 32:
23+
return name + 'f';
24+
case 16: {
25+
if (name == "fabs") {
26+
return "__habs";
27+
} else if (name == "round") {
28+
return "hrint";
29+
} else {
30+
return "h" + name;
31+
}
32+
}
33+
default:
34+
return "";
35+
}
36+
} else if (t.is_bfloat16()) {
37+
if (name == "fabs") {
38+
return "__habs";
39+
} else if (name == "round") {
40+
return "hrint";
41+
} else {
42+
return "h" + name;
43+
}
44+
} else if (t.is_int() || t.is_uint()) {
45+
switch (t.bits()) {
46+
case 32:
47+
return "__" + name;
48+
case 64:
49+
return "__" + name + "ll";
50+
default:
51+
return "";
52+
}
53+
}
54+
return "";
55+
}
56+
};
57+
58+
struct CUDAFastMath : public CUDAMath {
59+
std::string operator()(DataType t, std::string name) const {
60+
if (t.is_float() && t.bits() == 32) {
61+
return "__" + name + 'f';
62+
} else {
63+
return CUDAMath::operator()(t, name);
64+
}
65+
return "";
66+
}
67+
};
68+
69+
struct CUDAFastMathTan : public CUDAMath {
70+
std::string operator()(DataType t, std::string name) const {
71+
if (t.is_float()) {
72+
switch (t.bits()) {
73+
case 64:
74+
return name;
75+
// `__tanf` seems to produce some values too deviant from numpy tan
76+
// version. So, let's use just `tanf` instead.
77+
case 32:
78+
return name + 'f';
79+
case 16:
80+
return 'h' + name;
81+
default:
82+
return "";
83+
}
84+
}
85+
return "";
86+
}
87+
};
88+
89+
struct CUDAPopcount {
90+
std::string operator()(DataType t, std::string name) const {
91+
if (t.is_uint()) {
92+
switch (t.bits()) {
93+
case 32:
94+
return "__popc";
95+
case 64:
96+
return "__popcll";
97+
default:
98+
return "";
99+
}
100+
}
101+
return "";
102+
}
103+
};
104+
105+
struct CUDAWarpIntrinsic {
106+
const Op operator()(DataType t, const Op &orig_op) const {
107+
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
108+
return Op::Get("tir.cuda.__shfl_sync");
109+
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
110+
return Op::Get("tir.cuda.__shfl_up_sync");
111+
} else {
112+
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
113+
return Op::Get("tir.cuda.__shfl_down_sync");
114+
}
115+
}
116+
};
117+
118+
static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) {
119+
const CallNode *call = e.as<CallNode>();
120+
return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
121+
}
122+
123+
template <typename T> static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) {
124+
const CallNode *call = e.as<CallNode>();
125+
ICHECK(call != nullptr);
126+
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
127+
Array<PrimExpr> cuda_args{
128+
{call->args[0], call->args[1], call->args[2], call->args[3]}};
129+
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
130+
}
131+
132+
TVM_REGISTER_OP("tir.rsqrt")
133+
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
134+
DispatchPureExtern<CUDAMath>);
135+
136+
} // namespace intrin
137+
} // namespace codegen
138+
} // namespace tvm

src/tl_templates/cuda/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
5555
return half_t(__habs(x.to_half()));
5656
}
5757

58+
// hrsqrt function for half_t
59+
TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
60+
return half_t(hrsqrt(x.to_half()));
61+
}
62+
5863
// Pack two half values.
5964
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
6065
unsigned v0 = *((unsigned short *)&x);

0 commit comments

Comments
 (0)