diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b12f0592d..d7abaeb0f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -370,8 +370,27 @@ jobs: pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + --ignore=./python/jit/test_tilelang_jit_cutedsl.py \ ./python + # CuTeDSL JIT tests require GEMM v1 (must be set before importing tilelang). + # Run them in a dedicated step to avoid changing the default GEMM selection + # (and to keep the rest of the CUDA tests on GEMM v2). + - name: Run CuTeDSL JIT tests (GEMM v1) with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + id: cutedsl-tests + if: contains(matrix.runner.toolkit, 'CUDA') + env: + TILELANG_USE_GEMM_V1: "1" + run: | + cd testing + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear + ) + # Avoid xdist contention on a single GPU by running this file in one worker. + "${PYTEST[@]}" --maxfail=3 --numprocesses=1 \ + ./python/jit/test_tilelang_jit_cutedsl.py + # AMD ROCm tests - name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) id: rocm-tests diff --git a/CMakeLists.txt b/CMakeLists.txt index 109f84518..7af7f854f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -215,7 +215,11 @@ elseif(USE_CUDA) src/runtime/runtime.cc src/target/ptx.cc src/target/codegen_cuda.cc + src/target/codegen_py.cc + src/target/codegen_utils.cc + src/target/codegen_cutedsl.cc src/target/rt_mod_cuda.cc + src/target/rt_mod_cutedsl.cc ) list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) diff --git a/maint/scripts/run_local_ci_test.sh b/maint/scripts/run_local_ci_test.sh index f8fe54384..ef560437a 100755 --- a/maint/scripts/run_local_ci_test.sh +++ b/maint/scripts/run_local_ci_test.sh @@ -14,7 +14,13 @@ cd examples python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear cd .. -# Run pytest in parallel (4 workers) for all tests in the testing/python directory +# Run pytest in parallel (4 workers) for all tests in the testing/python directory. +# IMPORTANT: CuTeDSL backend currently requires GEMM v1 (TILELANG_USE_GEMM_V1=1). +# Do NOT export it globally here, or you'll silently change the default GEMM selection +# for unrelated tests. Run the CuTeDSL JIT tests in a separate pytest invocation. cd testing/python -python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear +python -m pytest -n 4 . --ignore=jit/test_tilelang_jit_cutedsl.py --verbose --color=yes --durations=0 --showlocals --cache-clear + +# CuTeDSL JIT tests (isolate env + avoid xdist contention on a single GPU) +TILELANG_USE_GEMM_V1=1 python -m pytest -n 1 jit/test_tilelang_jit_cutedsl.py --verbose --color=yes --durations=0 --showlocals --cache-clear cd .. diff --git a/requirements-test-cuda.txt b/requirements-test-cuda.txt index 122320238..52a403aa9 100644 --- a/requirements-test-cuda.txt +++ b/requirements-test-cuda.txt @@ -7,3 +7,5 @@ # CUDA specific requirements flash-attn==2.5.8 cuda-python==12.9.4 +# CuTeDSL (CUTLASS Python DSL with CuTe support) +nvidia-cutlass-dsl>=4.3.1 diff --git a/src/target/codegen_cutedsl.cc b/src/target/codegen_cutedsl.cc new file mode 100644 index 000000000..8279710de --- /dev/null +++ b/src/target/codegen_cutedsl.cc @@ -0,0 +1,1355 @@ +/*! + * \file target/codegen_cutedsl.cc + */ + +#include "codegen_cutedsl.h" +#include "codegen_utils.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "arith/pattern_match.h" + +namespace tvm { +namespace codegen { +namespace { + +// The threshold of the loop extent to use cutlass.range_constexpr +// Higher values would lead to DSLOptimizationWarning: +// This static loop has 128 iterations, which may be very slow to compile, +// consider using `cutlass.range(..., unroll_full=True)` instead. +const int64_t LOOP_UNROLL_THRESHOLD = 64; + +void ReplaceAll(std::string &str, const std::string &from, + const std::string &to) { + ICHECK(!from.empty()) << "ReplaceAll(): `from` must be non-empty"; + auto pos = str.find(from); + while (pos != std::string::npos) { + str.replace(pos, from.size(), to); + pos = str.find(from, pos + to.size()); + } +} + +} // namespace + +CodeGenTileLangCuTeDSL::CodeGenTileLangCuTeDSL() { + // Read fastmath configuration from current PassContext + auto pass_ctx = tvm::transform::PassContext::Current(); + + // Read tl.enable_fast_math config, default to false + enable_fastmath_ = + pass_ctx->GetConfig(tl::kEnableFastMath, Bool(false)).value(); +} + +std::string CodeGenTileLangCuTeDSL::CanonicalizeFastmathFunctionName_( + const std::string &func_name) const { + static const std::unordered_map kFastMathMap = { + {"divf", "tl.divf"}, {"exp", "tl.exp"}, {"expf", "tl.exp"}, + {"exp2", "tl.exp2"}, {"exp2f", "tl.exp2"}, {"log", "tl.log"}, + {"logf", "tl.log"}, {"log2", "tl.log2"}, {"log2f", "tl.log2"}, + {"log10", "tl.log10"}, {"tan", "tl.tan"}, {"cos", "tl.cos"}, + {"sin", "tl.sin"}, {"sqrt", "tl.sqrt"}, {"sqrtf", "tl.sqrt"}, + }; + + auto it = kFastMathMap.find(func_name); + if (it != kFastMathMap.end()) { + return it->second; + } + return ""; +} + +void CodeGenTileLangCuTeDSL::PrintFuncDecorator_( + std::ostream &os) { // NOLINT(*) + os << "@cute.kernel\n"; +} + +void CodeGenTileLangCuTeDSL::PreFunctionBody_(const PrimFunc &f) { + PrintIndent(); + stream << "threadIdx = tl.ThreadIdx()" << "\n"; + PrintIndent(); + stream << "blockIdx = tl.BlockIdx()" << "\n"; +} + +namespace { +std::string DTypeToString(DataType t) { + ICHECK(t.is_scalar()) << "unsupported type " << t; + + if (t.is_void()) { + return "void"; + } + if (t == tl::cuTensorMapType()) { + return "CUtensorMap"; + } + + int bits = t.bits(); + std::string elem_type; + if (t.is_float()) { + if (bits == 16 || bits == 32 || bits == 64) { + elem_type = "Float" + std::to_string(bits); + } + } else if (t.is_bfloat16()) { + elem_type = "BFloat16"; + } else if (t.is_float8()) { + if (t.is_float8_e3m4()) { + // unsupported + } else if (t.is_float8_e4m3()) { + elem_type = + "Float8E4M3FN"; // Only Float8E4M3FN is supported at the moment + } else if (t.is_float8_e4m3b11fnuz()) { + // unsupported + } else if (t.is_float8_e4m3fn()) { + elem_type = "Float8E4M3FN"; + } else if (t.is_float8_e4m3fnuz()) { + // unsupported + } else if (t.is_float8_e5m2()) { + elem_type = "Float8E5M2"; + } else if (t.is_float8_e5m2fnuz()) { + // unsupported + } else if (t.is_float8_e8m0fnu()) { + elem_type = "Float8E8M0FNU"; + } + } else if (t.is_float6()) { + if (t.is_float6_e3m2fn()) { + elem_type = "Float6E3M2FN"; + } else if (t.is_float6_e2m3fn()) { + elem_type = "Float6E2M3FN"; + } + } else if (t.is_float4()) { + if (t.is_float4_e2m1fn()) { + elem_type = "Float4E2M1FN"; + } + } else if (t.is_bool()) { + elem_type = "Boolean"; + } else if (t.is_uint()) { + if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 128) { + elem_type = "Uint" + std::to_string(bits); + } + } else if (t.is_int()) { + if (bits == 4 || bits == 8 || bits == 16 || bits == 32 || bits == 64 || + bits == 128) { + elem_type = "Int" + std::to_string(bits); + } + } + + if (elem_type.empty()) { + LOG(FATAL) << "Cannot convert type " << t << " to CuTeDSL type!"; + } + + return "cutlass." + elem_type; +} +} // namespace + +void CodeGenTileLangCuTeDSL::PrintType(DataType t, + std::ostream &os) { // NOLINT(*) + CHECK(t.is_scalar()) << "Should not print a non-scalar type in CuTeDSL: " + << t; + os << DTypeToString(t); +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + os << "tl.make_filled_tensor((" << PrintExpr_(op->lanes) << ",), " + << PrintExpr_(op->value) << ").load()"; +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + switch (op->dtype.bits()) { + case 64: + case 32: + case 16: + case 8: + case 4: { + std::ostringstream temp; + if (std::isinf(op->value)) { + // For CuTeDSL, use Python's float('inf') instead of CUDA macros + PrintType(op->dtype, temp); + temp << "("; + if (op->value < 0) { + temp << "float('-inf')"; + } else { + temp << "float('inf')"; + } + temp << ")"; + } else if (std::isnan(op->value)) { + // For CuTeDSL, use Python's float('nan') + PrintType(op->dtype, temp); + temp << "(float('nan'))"; + } else { + // For CuTeDSL, use Python's float.fromhex() with hexfloat for full + // precision + PrintType(op->dtype, temp); + temp << "(float.fromhex('" << std::hexfloat << op->value << "'))"; + } + MarkConst(temp.str()); + os << temp.str(); + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const CastNode *op, + std::ostream &os) { // NOLINT(*) + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); + + if (from_ty.is_scalar()) + return CodeGenTileLangPY::VisitExpr_(op, os); + + // Emit this as vectorized unary ops. + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << target_ty.lanes() << ",), "; + PrintType(target_ty.element_of(), stream); + stream << ")\n"; + + std::string src = SSAGetID(PrintExpr_(op->value), from_ty); + + PrintIndent(); + stream << sret << ".store(" << src << ".to("; + PrintType(target_ty.element_of(), stream); + stream << "))\n"; + os << sret << ".load()"; + return; +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const DivNode *op, + std::ostream &os) { // NOLINT(*) + if (op->dtype.is_int() || op->dtype.is_uint()) { + PrintBinaryExpr_("//", op->dtype, op->a, op->b, os); + } else { + if (enable_fastmath_) { + os << "tl.divf(" << PrintExpr_(op->a) << ", " << PrintExpr_(op->b) + << ", fastmath=True)"; + } else { + PrintBinaryExpr_("tl.divf", op->dtype, op->a, op->b, os); + } + } +} +void CodeGenTileLangCuTeDSL::VisitExpr_(const MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("tl.min", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangCuTeDSL::VisitExpr_(const MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("tl.max", op->dtype, op->a, op->b, os); +} + +/** + * @brief Emit CuTeDSL-specific code for a call expression. + * + * This visitor handles CallNode intrinsics and builtins that require emitting + * CuTeDSL-specific code (inline PTX/ASM sequences, TensorLanguage runtime + * calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based + * stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The + * function writes the generated code to the provided output stream and falls + * back to the Python codegen for unrecognized calls. + * + * The method recognizes and emits code for (non-exhaustive): cp.async and its + * commit/wait variants, tma_load/store and im2col variants, ptX + * ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy + * MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX + * asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret + * paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm + * and related external calls, and other TL runtime calls. + * + * Side effects: + * - Emits to `os` and the internal codegen output stream. + * - May set internal feature flags (e.g., need_cooperative_groups_). + * - May open/close SSA scopes and mutate internal variable mappings. + * - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument + * patterns. + * + * @param op The call node to generate code for; the function inspects op->op + * and op->args to determine the appropriate emission. + * @param os Output stream to receive expression-level output when the caller + * expects an expression result (some paths write directly to the + * member stream instead). + */ +void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + auto print_extern_call_stmt = [&](std::string name, size_t start = 0, + size_t end = 0) { + // Cache context into a private ss, otherwise the let node may generate + // within the function call arguments. + std::ostringstream ss; + for (size_t i = start; i < op->args.size() - end; i++) { + if (i > start) + ss << ", "; + ss << PrintExpr_(op->args[i]); + } + + PrintIndent(); + stream << name << "("; + stream << ss.str(); + stream << ")\n"; + }; + + auto print_mbarrier_obj = [&](PrimExpr barrier_id) { + std::ostringstream ss; + if (barrier_id.as()) { + // incase the barrier_id is an integer, we need to print the barrier_id as + // an integer + ss << "(" << mbarrier_name_ << "+" << barrier_id << ")"; + } else { + // otherwise may be a T.get_mbarrier() call or BufferLoad Node + // we need to print the barrier_id as a string + ss << PrintExpr_(barrier_id); + } + return ss.str(); + }; + + if (op->op.same_as(builtin::ptx_cp_async())) { + std::string dst = PrintExpr_(op->args[0]); + std::string dst_offset = PrintExpr_(op->args[1]); + std::string src = PrintExpr_(op->args[2]); + std::string src_offset = PrintExpr_(op->args[3]); + std::string size = PrintExpr_(op->args[4]); + // use size of argument list to indicate whether or not to use predicated + // cp.async + if (op->args.size() == 5) { + PrintIndent(); + stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset + << ", " << src << ", " << src_offset << ")\n"; + } else { + std::string condition = PrintExpr_(op->args[5]); + PrintIndent(); + stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", " + << dst_offset << ", " << src << ", " << src_offset << ", " + << condition << ")\n"; + } + } else if (op->op.same_as(builtin::ptx_commit_group())) { + print_extern_call_stmt("tl.cp_async_commit"); + } else if (op->op.same_as(builtin::ptx_wait_group())) { + print_extern_call_stmt("tl.cp_async_wait"); + } else if (op->op.same_as(builtin::create_barriers())) { + PrintIndent(); + int barrier_count = Downcast(op->args[0])->value; + stream << mbarrier_name_ + << " = tl.alloc_smem(cutlass.Uint64, size_in_elems=" << barrier_count + << ")\n"; + } else if (op->op.same_as(tl::get_mbarrier())) { + ICHECK_EQ(op->args.size(), 1); + std::string barrier_id = PrintExpr_(op->args[0]); + os << "(" << mbarrier_name_ << "+" << barrier_id << ")"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + if (op->args.size() == 1) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + stream << "tl.mbarrier_arrive(" << mbarrier_obj << ")\n"; + } else if (op->args.size() == 3) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto cta_id = PrintExpr_(op->args[1]); + auto pred = PrintExpr_(op->args[2]); + stream << "tl.mbarrier_arrive(" << mbarrier_obj << ", " << cta_id << ", " + << pred << ")\n"; + } else { + LOG(FATAL) << "Invalid parameter for tl::arrive_barrier " + << op->args.size(); + } + } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { + ICHECK_EQ(op->args.size(), 2); + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto arrive_count = PrintExpr_(op->args[1]); + stream << "tl.mbarrier_init(" << mbarrier_obj << ", " << arrive_count + << ")\n"; + } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { + if (op->args.size() == 2) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = PrintExpr_(op->args[1]); + stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ")\n"; + } else if (op->args.size() == 4) { + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = PrintExpr_(op->args[1]); + auto cta_id = PrintExpr_(op->args[2]); + auto pred = PrintExpr_(op->args[3]); + stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ", " << cta_id << ", " << pred << ")\n"; + } else { + LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx " + << op->args.size(); + } + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + print_extern_call_stmt("tl.mbarrier_cp_async_arrive"); + } else if (op->op.same_as(tl::ptx_fence_barrier_init())) { + print_extern_call_stmt("tl.fence_barrier_init"); + } else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { + print_extern_call_stmt("tl.mbarrier_cp_async_arrive_noinc"); + } else if (op->op.same_as(tl::mbarrier_expect_tx())) { + ICHECK_EQ(op->args.size(), 2); + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto transaction_bytes = PrintExpr_(op->args[1]); + stream << "tl.mbarrier_expect_tx(" << mbarrier_obj << ", " + << transaction_bytes << ")\n"; + } else if (op->op.same_as(tl::mbarrier_wait_parity())) { + ICHECK_EQ(op->args.size(), 2); + PrintIndent(); + auto mbarrier_obj = print_mbarrier_obj(op->args[0]); + auto phase = PrintExpr_(op->args[1]); + stream << "tl.mbarrier_wait(" << mbarrier_obj << ", " << phase << ")\n"; + } else if (op->op.same_as(tl::ptx_init_tensor_memory())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::no_set_max_nreg())) { + // do nothing + } else if (op->op.same_as(tl::tma_load())) { + std::ostringstream ss; + ICHECK_GE(op->args.size(), 2); + auto pol = op->args[op->args.size() - 1].as(); + ICHECK(pol) << "Eviction policy must be IntImm"; + ICHECK_GE(pol->value, 0); + ICHECK_LT(static_cast(pol->value), eviction_policy_names_.size()); + auto eviction_policy = eviction_policy_names_[pol->value]; + // Simplify the code by using the default eviction policy + if (eviction_policy != "EVICT_NORMAL") { + LOG(FATAL) << "Eviction policy " << eviction_policy + << " is not supported currently"; + } else { + ss << "tl.tma_load("; + } + auto desc = op->args[0]; + ss << PrintExpr_(desc) << ", "; + ss << print_mbarrier_obj(op->args[1]) << ", "; + ss << PrintExpr_(op->args[2]) << ", ("; + for (size_t i = 3; i < op->args.size() - 1; i++) { + if (i > 3) + ss << ", "; + ss << PrintExpr_(op->args[i]); + } + ss << "))\n"; + PrintIndent(); + stream << ss.str(); + } else if (op->op.same_as(tl::tma_load_im2col())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::tma_store())) { + std::stringstream ss; + // Check minimum argument count (desc, data, at least one coord, + // need_reduce, eviction) + ICHECK_GE(op->args.size(), 4) << "tma_store requires at least 4 arguments " + "(desc, data, coords..., need_reduce, " + "eviction_policy), got " + << op->args.size(); + + // Safely extract need_reduce flag + auto need_reduce_ptr = op->args[op->args.size() - 2].as(); + ICHECK(need_reduce_ptr) + << "tma_store need_reduce flag (args[-2]) must be IntImm, got " + << op->args[op->args.size() - 2]->GetTypeKey(); + auto need_reduce = need_reduce_ptr->value; + if (need_reduce) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } + + // Safely extract and validate eviction policy index + auto eviction_idx_ptr = op->args[op->args.size() - 1].as(); + ICHECK(eviction_idx_ptr) + << "tma_store eviction policy (args[-1]) must be IntImm, got " + << op->args[op->args.size() - 1]->GetTypeKey(); + ICHECK_GE(eviction_idx_ptr->value, 0) + << "tma_store eviction policy index must be >= 0, got " + << eviction_idx_ptr->value; + ICHECK_LT(static_cast(eviction_idx_ptr->value), + eviction_policy_names_.size()) + << "tma_store eviction policy index " << eviction_idx_ptr->value + << " out of bounds (max " << eviction_policy_names_.size() - 1 << ")"; + auto eviction_policy = eviction_policy_names_[eviction_idx_ptr->value]; + + ss << "tl.tma_store("; + auto desc = op->args[0]; + ss << PrintExpr_(desc) << ", "; + ss << PrintExpr_(op->args[1]) << ", ("; + for (size_t i = 2; i < op->args.size() - 2; i++) { + if (i > 2) + ss << ", "; + ss << PrintExpr_(op->args[i]); + } + ss << ")"; + if (eviction_policy != "EVICT_NORMAL") { + ss << ", eviction_kind = nvvm.EvictKind." << eviction_policy.substr(6); + } + ss << ")\n"; + PrintIndent(); + stream << ss.str(); + } else if (op->op.same_as(tl::ptx_ldmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl.ptx_ldmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::ptx_stmatrix())) { + int trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string func_name = "tl.ptx_stmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } else if (op->op.same_as(tl::fence_proxy_async())) { + print_extern_call_stmt("tl.fence_proxy_async"); + } else if (op->op.same_as(tl::tma_store_arrive())) { + print_extern_call_stmt("tl.tma_store_arrive"); + } else if (op->op.same_as(tl::tma_store_wait())) { + PrintIndent(); + stream << "tl.tma_store_wait(0)\n"; + } else if (op->op.same_as(tl::warpgroup_arrive())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warpgroup_commit_batch())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warpgroup_wait())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warpgroup_fence_operand())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::set_max_nreg())) { + PrintIndent(); + int nreg = Downcast(op->args[0])->value; + int is_inc = Downcast(op->args[1])->value; + std::string func_name = + is_inc ? "tl.warpgroup_reg_alloc" : "tl.warpgroup_reg_dealloc"; + stream << func_name << "(" << nreg << ")\n"; + } else if (op->op.same_as(tl::wait_wgmma())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::pack_b16())) { + os << "tl.pack_half2(" << PrintExpr_(op->args[0]) << ", " + << PrintExpr_(op->args[1]) << ")"; + } else if (op->op.same_as(tl::sync_grid())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::loop_break())) { + PrintIndent(); + stream << "break\n"; + } else if (op->op.same_as(builtin::ptx_mma())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_mma_sm70())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_mma_sp())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_wgmma_ss())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_wgmma_rs())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::tcgen05_mma_arrive())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_ldmatrix())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::mma_store())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::mma_fill())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_wait_barrier())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::ptx_ldg32())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::reinterpret())) { + DataType tgt_dtype = op->dtype; + DataType src_dtype = op->args[0]->dtype; + ICHECK_EQ(tgt_dtype.lanes() * tgt_dtype.bits(), + src_dtype.lanes() * src_dtype.bits()) + << "reinterpret expects source and target to have the same number of " + "bits"; + + const BufferLoadNode *load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) + << "CodeGenTileLangCuTeDSL only supports flat memory"; + + PrimExpr index = load->indices[0]; + if (const RampNode *node = index.as(); node) { + auto *p_stride = as_const_int(node->stride); + CHECK(p_stride); + ICHECK_EQ(*p_stride, 1) << "reinterpret expects contiguous elements"; + index = node->base; + } + + auto ptr_str = GetBufferPtr_(load->buffer.get(), index); + os << "tl.make_tensor(tl.recast_ptr(" << ptr_str << ", dtype="; + PrintType(tgt_dtype.element_of(), os); + os << "), (" << tgt_dtype.lanes() << ",)).load()"; + } else if (op->op.same_as(builtin::thread_return())) { + os << "return"; + } else if (op->op.same_as(tl::tl_gemm())) { + ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " + << op->args.size(); + + auto op_instance = Downcast(op->args[0]); + PrintCallExtern_(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::tl_gemm_sp())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_lane_idx())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_warp_idx_sync())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_warp_idx())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::get_warp_group_idx())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::tl_shuffle_elect())) { + os << "tl.shuffle_elect(" << PrintExpr_(op->args[0]) << ")"; + } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::increase_descriptor_offset())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::__exp())) { + os << "tl.exp2(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__exp10())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::__log())) { + os << "tl.log(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__log2())) { + os << "tl.log2(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__log10())) { + os << "tl.log10(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__tan())) { + os << "tl.tan(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__cos())) { + os << "tl.cos(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::__sin())) { + os << "tl.sin(" << PrintExpr_(op->args[0]) << ", fastmath=True)"; + } else if (op->op.same_as(tl::ieee_add())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_sub())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_mul())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_fmaf())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_frcp())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_fsqrt())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_frsqrt())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::ieee_fdiv())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_sum())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_max())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_min())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_bitand())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(tl::warp_reduce_bitor())) { + LOG(FATAL) << "Currently unsupported op: " << op->op; + } else if (op->op.same_as(builtin::address_of())) { + const BufferLoadNode *load = op->args[0].as(); + ICHECK(op->args.size() == 1 && load); + ICHECK_EQ(load->indices.size(), 1) + << "CodeGenTileLangCuTeDSL only supports flat memory"; + os << GetBufferPtr_(load->buffer.get(), load->indices[0]); + } else { + CodeGenTileLangPY::VisitExpr_(op, os); + } +} + +void CodeGenTileLangCuTeDSL::VisitExpr_(const BufferLoadNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) + << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + const int value_lanes = value_dtype.lanes(); + if (value_lanes == element_dtype.lanes()) { + std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index); + if (ref.back() == ')') { + ref += ".load()"; + } + os << ref; + } else { + ICHECK_GE(value_lanes, element_dtype.lanes()) + << "Unsupported load/store: value lanes < buffer element lanes"; + bool is_contiguous = false; + arith::PVar base; + if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()) + .Match(index)) { + is_contiguous = true; + } + + if (is_contiguous) { + std::string ref = + GetBufferRef_(value_dtype, op->buffer.get(), base.Eval()); + if (ref.back() == ')') { + ref += ".load()"; + } + os << ref; + } else { + ICHECK(element_dtype.is_scalar()) + << "buffer element type for non-contiguous load must be scalar " + "currently"; + + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << value_lanes << ",), "; + PrintType(element_dtype, stream); + stream << ")\n"; + + std::string vid = GetVarID(buffer_var.get()); + const RampNode *ramp = index.as(); + ICHECK(ramp) + << "Expected Ramp index for vectorized non-contiguous access"; + for (int i = 0; i < value_lanes; ++i) { + auto idx_expr = + arith::Analyzer().Simplify(ramp->base + ramp->stride * i); + + PrintIndent(); + stream << sret << "[" << i << "] = " + << GetBufferRef_(element_dtype, op->buffer.get(), idx_expr) + << "\n"; + } + os << sret << ".load()"; + } + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const BufferStoreNode *op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not supported."; + + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + std::string value_str = PrintExpr_(op->value); + + int value_lanes = value_dtype.lanes(); + if (value_lanes == element_dtype.lanes()) { + std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr); + PrintIndent(); + + if (ref.back() != ')') { + stream << ref << " = " << RemoveOutermostParentheses(value_str) << "\n"; + } else { + stream << ref << ".store(" << RemoveOutermostParentheses(value_str) + << ")\n"; + } + } else { + bool is_contiguous = false; + arith::PVar base; + if (arith::ramp(base, 1, value_lanes / element_dtype.lanes()) + .Match(index_expr)) { + is_contiguous = true; + } + + if (is_contiguous) { + PrintVecStore_(op->buffer.get(), value_dtype, base.Eval(), value_str); + } else { + ICHECK(element_dtype.is_scalar()) + << "buffer element type for non-contiguous store must be scalar " + "currently"; + + // store elements separately + value_str = SSAGetID(value_str, element_dtype); + for (int i = 0; i < value_lanes; ++i) { + const RampNode *ramp = index_expr.as(); + ICHECK(ramp); + auto idx_expr = + arith::Analyzer().Simplify(ramp->base + ramp->stride * i); + + PrintIndent(); + stream << GetBufferRef_(element_dtype, op->buffer.get(), idx_expr) + << " = "; + PrintVecElemLoad_(value_str, value_dtype, i, stream); + stream << "\n"; + } + } + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + + if (scope == "local.descriptor.wgmma") { + stream << vid << " = tl.GmmaDescriptor()\n"; + } else if (scope == "local.descriptor.tcgen05_smem") { + LOG(FATAL) << "Currently unsupported scope: " << scope; + } else if (scope == "local.descriptor.tcgen05_instr") { + LOG(FATAL) << "Currently unsupported scope: " << scope; + } else if (scope == "shared.dyn") { + stream << vid << " = tl.make_tensor(tl.get_dyn_smem("; + PrintType(op->dtype, stream); + // there is no bound check for Tensor access, so just set shape to 1 + stream << ", alignment=1024), (1,))\n"; + } else { + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now, but get " + << constant_size << " for " << op->buffer_var->name_hint; + + if (scope == "shared") { + stream << vid << " = tl.make_tensor(tl.alloc_smem("; + PrintType(op->dtype, stream); + stream << ", " << constant_size << "), (" << constant_size << ",))\n"; + } else if (scope == "shared.barrier") { + ICHECK(false) << "Unsupported scope: " << scope; + } else if (scope == "local") { + stream << vid << " = tl.make_rmem_tensor((" << constant_size << "),"; + PrintType(op->dtype, stream); + stream << ")\n"; + } else if (scope == "local.var") { + PrimExpr init = tir::make_const(op->dtype, 0); + auto init_it = op->annotations.find(tl::attr::kLocalVarInit); + if (init_it != op->annotations.end()) { + PrimExpr user_init = Downcast((*init_it).second); + if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) { + user_init = tir::Cast(op->dtype, user_init); + } + init = user_init; + } + stream << vid << " = " << PrintExpr_(init) << "\n"; + } else { + ICHECK(false) << "Unsupported scope: " << scope; + } + } + + RegisterHandleType_(op->buffer_var.get(), op->dtype); + PrintStmt_(op->body); +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (!iv->thread_tag.empty()) { + if (!var_idmap_.count(iv->var.get())) { + BindThreadIndex_(iv); + } + } + VisitStmt(op->body); + } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + const IntImmNode *queue_id = op->value.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + VisitStmt(op->body); + auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); + VisitExpr(commit_group, stream); + } else if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto queue_id = wait_attrs.first.as(); + ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; + auto wait_cnt = wait_attrs.second; + auto wait_group = + Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); + VisitExpr(wait_group, stream); + auto inner = op->body.as(); + ICHECK(inner); + VisitStmt(inner->body); + } else if (op->attr_key == "threadblock_swizzle_pattern") { + this->PrintIndent(); + const StringImmNode *pattern = op->value.as(); + ICHECK(pattern); + std::string call_str = pattern->value; + // replace :: with . and replace < with ( and replace > with ) + ReplaceAll(call_str, "::", "."); + ReplaceAll(call_str, "<", "("); + ReplaceAll(call_str, ">", ")"); + this->stream << "blockIdx = " << call_str << "\n"; + this->VisitStmt(op->body); + } else if (op->attr_key == "pragma_unroll_factor") { + const IntImmNode *factor = op->value.as(); + ICHECK(factor); + unroll_factor_[op->node.as()] = Downcast(factor); + CodeGenTileLangPY::VisitStmt_(op); + } else { + CodeGenTileLangPY::VisitStmt_(op); + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const ForNode *op) { + if (op->kind != tir::ForKind::kUnrolled) { + CodeGenTileLangPY::VisitStmt_(op); + return; + } + + auto start_expr = arith::Analyzer().Simplify(op->min); + auto stop_expr = arith::Analyzer().Simplify(op->extent + op->min); + std::string unroll_factor; + if (auto it = unroll_factor_.find(op->loop_var.get()); + it != unroll_factor_.end()) { + unroll_factor = PrintExpr_(it->second); + } + bool use_range_constexpr = unroll_factor.empty() && + as_const_int(op->extent) != nullptr && + *as_const_int(op->extent) <= LOOP_UNROLL_THRESHOLD; + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + stream << "for " << vid << " in cutlass.range"; + if (use_range_constexpr) { + stream << "_constexpr"; + } + stream << "("; + if (!is_zero(start_expr)) { + PrintExpr_(start_expr, stream); + stream << ", "; + } + PrintExpr_(stop_expr, stream); + if (!unroll_factor.empty()) { + stream << ", unroll=" << unroll_factor; + } else if (!use_range_constexpr) { + stream << ", unroll_full=True"; + } + stream << "):\n"; + int for_scope = BeginScope(); + PrintStmt_(op->body); + EndScope(for_scope); +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const IfThenElseNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + stream << "if " << RemoveOutermostParentheses(cond) << ":\n"; + int then_scope = BeginScope(); + if (const CallNode *call = op->condition.as(); + call && call->op.same_as(tl::tl_shuffle_elect())) { + PrintIndent(); + stream << "with cute.arch.elect_one():\n"; + int with_scope = BeginScope(); + PrintStmt_(op->then_case); + EndScope(with_scope); + } else { + PrintStmt_(op->then_case); + } + EndScope(then_scope); + + if (op->else_case) { + PrintIndent(); + stream << "else:\n"; + int else_scope = BeginScope(); + PrintStmt_(op->else_case.value()); + EndScope(else_scope); + } +} + +void CodeGenTileLangCuTeDSL::VisitStmt_(const EvaluateNode *op) { + if (is_const_int(op->value)) + return; + const CallNode *call = op->value.as(); + if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { + LOG(FATAL) << "Currently unsupported op: " << call->op; + } + if (call && (call->op.same_as(tvm::tl::device_assert()))) { + std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0])); + PrintIndent(); + stream << "assert " << cond << "\n"; + } else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) { + std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0])); + std::string msg_expr = PrintExpr_(call->args[1]); + PrintIndent(); + stream << "assert " << cond << ", " << msg_expr << "\n"; + } else if (call && call->op.same_as(builtin::tvm_storage_sync())) { + PrintStorageSync_(call); + } else { + CodeGenTileLangPY::VisitStmt_(op); + } +} + +void CodeGenTileLangCuTeDSL::PrintVecElemLoad_(const std::string &vec, + DataType t, int i, + std::ostream &os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + os << vec << "[" << i << "]"; +} + +void CodeGenTileLangCuTeDSL::PrintVecElemStore_(const std::string &vec, + DataType t, int i, + const std::string &value) { + PrintIndent(); + stream << vec << "[" << i << "] = " << value << "\n"; +} + +void CodeGenTileLangCuTeDSL::PrintVecStore_(const BufferNode *buffer, + DataType t, PrimExpr base, + const std::string &value) { + ICHECK(!t.is_scalar()) << "PrintVecStore_() should not be used for scalar"; + + std::string ref = GetBufferRef_(t, buffer, base); + PrintIndent(); + stream << ref << ".store(" << value << ")\n"; +} + +void CodeGenTileLangCuTeDSL::PrintVecBinaryOp_(const std::string &opstr, + DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + // Declare the result. + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << dtype.lanes() << ",), "; + PrintType(dtype.element_of(), stream); + stream << ")\n"; + + std::string vlhs = SSAGetID(PrintExpr_(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr_(rhs), rhs.dtype()); + + const std::string one_char_op{"+-*%<>^|&"}; + const std::string two_char_op{"// == != <= >="}; + if ((opstr.size() == 1 && one_char_op.find(opstr) != std::string::npos) || + (opstr.size() == 2 && two_char_op.find(opstr) != std::string::npos)) { + PrintIndent(); + stream << sret << ".store(" << vlhs << " " << opstr << " " << vrhs << ")\n"; + } else { + // Unpack into individual ops. + for (int i = 0, lanes = dtype.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(opstr[0])) { + value_temp << opstr << "("; + PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp); + value_temp << opstr; + PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore_(sret, dtype, i, value_temp.str()); + } + } + os << sret << ".load()"; +} + +void CodeGenTileLangCuTeDSL::PrintBinaryExpr_(const std::string &opstr, + DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + if (dtype.is_scalar()) { + CodeGenTileLangPY::PrintBinaryExpr_(opstr, dtype, lhs, rhs, os); + } else { + PrintVecBinaryOp_(opstr, dtype, lhs, rhs, os); + } +} + +void CodeGenTileLangCuTeDSL::PrintBinaryIntrinsic_( + const CallNode *op, const char *opstr, + std::ostream &os) { // NOLINT(*) + if (op->dtype.is_scalar()) { + CodeGenTileLangPY::PrintBinaryIntrinsic_(op, opstr, os); + } else { + PrintVecBinaryOp_(opstr, op->dtype, op->args[0], op->args[1], os); + } +} + +void CodeGenTileLangCuTeDSL::PrintCallExtern_(Type ret_type, + ffi::String global_symbol, + const ffi::Array &args, + bool skip_first_arg, + std::ostream &os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + + std::string global_symbol_str = global_symbol; + ReplaceAll(global_symbol_str, "::", "."); + + std::vector sargs; + // when the template arguments occurs at the end, merge them with function + // arguments + if (global_symbol_str.back() == '>') { + auto pos = global_symbol_str.rfind('<'); + ICHECK(pos != std::string::npos); + std::string template_args = + global_symbol_str.substr(pos + 1, global_symbol_str.size() - pos - 2); + ReplaceAll(template_args, "true", "True"); + ReplaceAll(template_args, "false", "False"); + sargs.push_back(template_args); + + global_symbol_str.resize(pos); + } + const size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string sarg = PrintExpr_(args[i]); + if (ret_dtype.is_fixed_length_vector()) { + std::string val = SSAGetID(sarg, args[i].dtype()); + sargs.push_back(std::move(val)); + } else { + sargs.push_back(sarg); + } + } + + // Replace "<...>" with "(...)". Nested "<" is not supported + { + auto pos_left = global_symbol_str.find('<'); + while (pos_left != std::string::npos) { + auto pos_right = global_symbol_str.find('>', pos_left + 1); + if (pos_right != std::string::npos) { + auto args = + global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1); + ReplaceAll(args, "true", "True"); + ReplaceAll(args, "false", "False"); + global_symbol_str.replace(pos_left, args.size() + 2, "(" + args + ")"); + } + pos_left = global_symbol_str.find('<'); + } + } + + // Special cases: + // Map C math functions to Python/cutedsl equivalents + const auto canonicalized_global_symbol_str = + CanonicalizeFastmathFunctionName_(global_symbol_str); + const bool canonicalized = !canonicalized_global_symbol_str.empty(); + if (canonicalized) { + global_symbol_str = canonicalized_global_symbol_str; + } + + // Atomic Functions + if (global_symbol_str.substr(0, 6) == "Atomic") { + global_symbol_str = "tl." + global_symbol_str; + // Convert first argument (Buffer) to pointer for atomic operations + if (const BufferLoadNode *load = args[arg_begin].as()) { + ICHECK_EQ(load->indices.size(), 1) + << "CodeGenTileLangCuTeDSL only supports flat memory"; + sargs[0] = GetBufferPtr_(load->buffer.get(), load->indices[0]); + } + } + // some optional template arguments might be ommited, so add names explicitly + // for remain arguments + if (global_symbol_str == "tl.gemm_ss" || global_symbol_str == "tl.gemm_rs" || + global_symbol_str == "tl.gemm_sr" || global_symbol_str == "tl.gemm_rr") { + ICHECK(sargs.size() >= 3); + sargs[sargs.size() - 3] = "A_ptr=" + sargs[sargs.size() - 3]; + sargs[sargs.size() - 2] = "B_ptr=" + sargs[sargs.size() - 2]; + sargs[sargs.size() - 1] = "C_ptr=" + sargs[sargs.size() - 1]; + } + + if (ret_dtype.is_fixed_length_vector()) { + // maybe simplify this if TensorSSA suppports this OP + std::string sret = name_supply_->FreshName("_"); + PrintIndent(); + stream << sret << " = tl.make_rmem_tensor((" << ret_dtype.lanes() << ",), "; + PrintType(ret_dtype.element_of(), stream); + stream << ")\n"; + + // Emit a scalar call for each lane. + bool has_template_arg = (sargs.size() > args.size() - arg_begin); + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol_str << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j != 0) { + scall << ", "; + } + + if (j == 0 && has_template_arg) { + scall << sargs[j]; + } else { + PrintVecElemLoad_( + sargs[j], + args[arg_begin + j - static_cast(has_template_arg)] + .dtype(), + i, scall); + } + } + if (canonicalized && enable_fastmath_) { + if (!sargs.empty()) { + scall << ", "; + } + scall << "fastmath=True"; + } + scall << ")"; + PrintVecElemStore_(sret, ret_dtype, i, scall.str()); + } + os << sret << ".load()"; + } else { + os << global_symbol_str << "("; + for (size_t i = 0; i < sargs.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << sargs[i]; + } + if (canonicalized && enable_fastmath_) { + if (!sargs.empty()) { + os << ", "; + } + os << "fastmath=True"; + } + os << ")"; + } +} + +std::string CodeGenTileLangCuTeDSL::GetBufferPtr_(const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + const std::string vid = GetVarID(buffer_var); + + DataType buffer_element_dtype = buffer->dtype; + bool is_handle_type_match = + HandleTypeMatch_(buffer_var, buffer_element_dtype); + std::string ptr_str; + if (is_handle_type_match) { + ptr_str = vid + ".iterator"; + } else { + ptr_str = "tl.recast_ptr(" + vid + + ".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")"; + } + + std::string index_str = PrintExpr_(index); + return "(" + ptr_str + " + " + index_str + ")"; +} + +// The following forms can be returned: +// (1) vid +// (2) vid[i] +// (3) tl.make_tensor_at_offset(...)[0] +// (4) tl.make_tensor_at_offset(...) +// +// Form (4) is needed when the whole tensor is loaded or stored. +// It's the only form that ends with ")". Using this fact, BufferLoadNode will +// add ".load()" and BufferStoreNode will add ".store()". +std::string CodeGenTileLangCuTeDSL::GetBufferRef_(DataType t, + const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + std::string vid = GetVarID(buffer_var); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + if (scope == "local.var" || scope.find("local.descriptor") == 0) { + return vid; + } + + DataType buffer_element_dtype = buffer->dtype; + bool is_handle_type_match = + HandleTypeMatch_(buffer_var, buffer_element_dtype); + std::string ptr_str; + if (is_handle_type_match) { + ptr_str = vid + ".iterator"; + } else { + ptr_str = "tl.recast_ptr(" + vid + + ".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")"; + } + + const std::string index_str = PrintExpr_(index); + + if (t == buffer_element_dtype) { + if (is_handle_type_match && buffer_element_dtype.is_scalar() && + (scope == "local" || scope == "shared" || scope == "shared.dyn" || + scope == "shared.barrier")) { + // Tensors in these scopes are allocated as one-dimensional, so can be + // assessed via "[]" correctly. Other tensors may be multi-dimensional, + // and must be assessed via ptr, otherwise CuTeDSL will interpret "[]" + // access using its visiting order and layout. + return vid + "[" + index_str + "]"; + } else { + std::ostringstream os; + os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str + << ", (1,), div_by=" << buffer_element_dtype.lanes() << ")"; + // for vector data types, ".load()" (added by BufferLoadNode) is neeed + // instead of "[0]" + if (buffer_element_dtype.is_scalar()) { + os << "[0]"; + } + return os.str(); + } + } else { + const int num = t.bits() * t.lanes(); + const int den = buffer_element_dtype.bits() * buffer_element_dtype.lanes(); + ICHECK_EQ(num % den, 0) << "Cannot form view: bitwidth not divisible"; + int buffer_size = num / den; + + std::ostringstream os; + os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str << ", (" + << buffer_size << ",), div_by=" << buffer_size << ")"; + return os.str(); + } +} + +void CodeGenTileLangCuTeDSL::BindThreadIndex_(const IterVar &iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + + auto &thread_tag = iv->thread_tag; + ICHECK(thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || + thread_tag == "threadIdx.z" || thread_tag == "blockIdx.x" || + thread_tag == "blockIdx.y" || thread_tag == "blockIdx.z"); + + // cute.arch.thread_idx() and block_idx() are Int32 + DataType from_dtype = DataType::Int(32); + var_idmap_[iv->var.get()] = + CastFromTo_(thread_tag, from_dtype, iv->var.dtype()); +} + +void CodeGenTileLangCuTeDSL::PrintStorageSync_(const CallNode *op) { + auto args = op->args; + const std::string &sync = args[0].as()->value; + if (sync == "warp") { + // do nothing + } else if (sync == "shared" || sync == "shared.dyn") { + PrintIndent(); + if (args.size() == 1) { + stream << "tl.sync_threads()\n"; + } else if (args.size() == 2) { + auto barrier_id_ptr = args[1].as(); + ICHECK(barrier_id_ptr) + << "storage_sync barrier_id (args[1]) must be IntImm, got " + << args[1]->GetTypeKey(); + auto barrier_id = barrier_id_ptr->value; + stream << "tl.sync_thread_partial(" << barrier_id << ")\n"; + } else if (args.size() == 3) { + auto barrier_id_ptr = args[1].as(); + ICHECK(barrier_id_ptr) + << "storage_sync barrier_id (args[1]) must be IntImm, got " + << args[1]->GetTypeKey(); + auto thread_count_ptr = args[2].as(); + ICHECK(thread_count_ptr) + << "storage_sync thread_count (args[2]) must be IntImm, got " + << args[2]->GetTypeKey(); + auto barrier_id = barrier_id_ptr->value; + auto thread_count = thread_count_ptr->value; + stream << "tl.sync_thread_partial(" << barrier_id << ", " << thread_count + << ")\n"; + } else { + LOG(FATAL) << "Invalid number of arguments for storage sync: " + << args.size(); + } + } else if (sync == "global") { + LOG(FATAL) << "PrintStorageSync_ for global is not supported for now"; + } else { + LOG(FATAL) << "Unknown storage sync scope: " << sync; + } +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/codegen_cutedsl.h b/src/target/codegen_cutedsl.h new file mode 100644 index 000000000..1d4edc538 --- /dev/null +++ b/src/target/codegen_cutedsl.h @@ -0,0 +1,102 @@ +/*! + * \file target/codegen_cutedsl.h + * \brief Utility to generate CuTeDSL code + */ +#ifndef TVM_TL_TARGET_CODEGEN_CUTEDSL_H_ +#define TVM_TL_TARGET_CODEGEN_CUTEDSL_H_ + +#include +#include +#include + +#include +#include +#include + +#include "codegen_py.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangCuTeDSL final : public CodeGenTileLangPY { +public: + CodeGenTileLangCuTeDSL(); + +protected: + void PrintFuncDecorator_(std::ostream &os) override; // NOLINT(*) + void PreFunctionBody_(const PrimFunc &f) override; + +protected: + void PrintType(DataType t, std::ostream &os) override; // NOLINT(*) + + void VisitExpr_(const BroadcastNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode *op, + std::ostream &os) override; // NOLINT(*) + + void VisitStmt_(const BufferStoreNode *op) override; + void VisitStmt_(const AllocateNode *op) override; + void VisitStmt_(const AttrStmtNode *op) override; + void VisitStmt_(const ForNode *op) override; + void VisitStmt_(const IfThenElseNode *op) override; + void VisitStmt_(const EvaluateNode *op) override; + +protected: + virtual void PrintVecElemLoad_(const std::string &vec, DataType t, int i, + std::ostream &os); // NOLINT(*) + virtual void PrintVecElemStore_(const std::string &vec, DataType t, int i, + const std::string &value); + virtual void PrintVecStore_(const BufferNode *buffer, DataType t, + PrimExpr base, const std::string &value); + void PrintVecBinaryOp_(const std::string &opstr, DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os); // NOLINT(*) + void PrintBinaryExpr_(const std::string &opstr, DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) override; // NOLINT(*) + void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr, + std::ostream &os) override; // NOLINT(*) + + void PrintCallExtern_(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, bool skip_first_arg, + std::ostream &os) override; // NOLINT(*) + + std::string GetBufferPtr_(const BufferNode *buffer, PrimExpr index); + std::string GetBufferRef_(DataType t, const BufferNode *buffer, + PrimExpr index) override; + + /*! + * \brief Print expr representing the thread tag + * \param IterVar iv The thread index to be binded; + */ + virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*) + + virtual void PrintStorageSync_(const CallNode *op); + + std::string + CanonicalizeFastmathFunctionName_(const std::string &func_name) const; + +private: + // The name of the mbarrier array in shared memory + const std::string mbarrier_name_ = "mbarrier"; + + std::unordered_map unroll_factor_; + + std::vector eviction_policy_names_ = { + "EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"}; + + // Fastmath configuration (read from PassContext) + bool enable_fastmath_ = false; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TL_TARGET_CODEGEN_CUTEDSL_H_ diff --git a/src/target/codegen_py.cc b/src/target/codegen_py.cc new file mode 100644 index 000000000..aa12eef09 --- /dev/null +++ b/src/target/codegen_py.cc @@ -0,0 +1,715 @@ +/*! + * \file codegen_py.cc + */ +#include "codegen_py.h" +#include "codegen_utils.h" + +#include +#include + +#include + +namespace tvm { +namespace codegen { + +void CodeGenTileLangPY::AddFunction(const GlobalVar &gvar, const PrimFunc &f) { + RegisterFunction_(gvar, f); + auto function_name = GetFunctionName_(gvar); + + // clear previous generated state. + InitFuncState_(f); + + PrintFuncDecorator_(stream); + PrintFunctionSignature_(function_name, f, stream); + stream << ":\n"; + + int func_scope = BeginScope(); + PreFunctionBody_(f); + PrintStmt_(f->body); + EndScope(func_scope); +} + +std::string CodeGenTileLangPY::Finish() { + std::ostringstream code; + code << decl_stream.str(); + code << stream.str(); + return code.str(); +} + +ffi::String CodeGenTileLangPY::GetFunctionName_(const GlobalVar &gvar) { + auto it = internal_functions_.find(gvar); + ICHECK(it != internal_functions_.end()) + << "Attempted to find name of " << gvar + << ", but no function with this GlobalVar has been declared"; + return it->second; +} + +void CodeGenTileLangPY::RegisterFunction_(const GlobalVar &gvar, + const PrimFunc &func) { + if (internal_functions_.count(gvar)) { + return; + } + + auto function_name = [&]() -> ffi::String { + if (auto global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol)) { + auto name = global_symbol.value(); + ICHECK(!func_name_supply_->ContainsName(name)) + << "Function " << gvar << " must use global symbol " << name + << ", but this name has already been used."; + func_name_supply_->ReserveName(name); + return name; + } else { + ICHECK(!func_name_supply_->ContainsName(gvar->name_hint)) + << "Function " << gvar << " must use name hint " << gvar->name_hint + << ", but this name has already been used."; + func_name_supply_->ReserveName(gvar->name_hint); + return gvar->name_hint; + } + }(); + internal_functions_.insert({gvar, function_name}); +} + +void CodeGenTileLangPY::InitFuncState_(const PrimFunc &f) { + alloc_storage_scope_.clear(); + handle_data_type_.clear(); + CodeGenSourceBase::ClearFuncState(); + ReserveKeywordsAsUnique_(); +} + +void CodeGenTileLangPY::PrintFunctionSignature_( + const ffi::String &function_name, const PrimFunc &func, + std::ostream &os) { // NOLINT(*) + os << "def " << function_name << "("; + for (size_t i = 0; i < func->params.size(); ++i) { + tir::Var v = func->params[i]; + if (i > 0) { + os << ", "; + } + os << AllocVarID(v.get()); + } + os << ")"; + + // Register handle data type + for (const auto ¶m : func->params) { + if (auto *ptr = param->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType_(param.get(), prim->dtype); + } + } + } +} + +void CodeGenTileLangPY::ReserveKeywordsAsUnique_() { + // skip the first underscore, so SSA variable starts from _1 + name_supply_->ReserveName("_"); + name_supply_->ReserveName("False"); + name_supply_->ReserveName("None"); + name_supply_->ReserveName("True"); + name_supply_->ReserveName("and"); + name_supply_->ReserveName("as"); + name_supply_->ReserveName("assert"); + name_supply_->ReserveName("async"); + name_supply_->ReserveName("await"); + name_supply_->ReserveName("break"); + name_supply_->ReserveName("class"); + name_supply_->ReserveName("continue"); + name_supply_->ReserveName("def"); + name_supply_->ReserveName("del"); + name_supply_->ReserveName("elif"); + name_supply_->ReserveName("else"); + name_supply_->ReserveName("except"); + name_supply_->ReserveName("finally"); + name_supply_->ReserveName("for"); + name_supply_->ReserveName("from"); + name_supply_->ReserveName("global"); + name_supply_->ReserveName("if"); + name_supply_->ReserveName("import"); + name_supply_->ReserveName("in"); + name_supply_->ReserveName("is"); + name_supply_->ReserveName("lambda"); + name_supply_->ReserveName("nonlocal"); + name_supply_->ReserveName("not"); + name_supply_->ReserveName("or"); + name_supply_->ReserveName("pass"); + name_supply_->ReserveName("raise"); + name_supply_->ReserveName("return"); + name_supply_->ReserveName("try"); + name_supply_->ReserveName("while"); + name_supply_->ReserveName("with"); + name_supply_->ReserveName("yield"); + + name_supply_->ReserveName("void"); + name_supply_->ReserveName("int"); + name_supply_->ReserveName("float"); + name_supply_->ReserveName("double"); + name_supply_->ReserveName("char"); + name_supply_->ReserveName("unsigned"); + name_supply_->ReserveName("short"); + name_supply_->ReserveName("long"); + + name_supply_->ReserveName("cutlass"); + name_supply_->ReserveName("cute"); + name_supply_->ReserveName("tl"); +} + +void CodeGenTileLangPY::PrintSSAAssign(const std::string &target, + const std::string &src, DataType t) { + stream << target << " = " << RemoveOutermostParentheses(src) << "\n"; +} + +void CodeGenTileLangPY::PrintType(DataType type, + std::ostream &os) { // NOLINT(*) + if (type.is_float()) { + if (type.bits() == 16 || type.bits() == 32 || type.bits() == 64) { + os << "float"; + } else { + LOG(FATAL) << "Cannot convert float" << type.bits() << " to Python type"; + } + } else if (type.is_uint()) { + switch (type.bits()) { + case 8: + case 16: + case 32: + case 64: { + os << "int"; + break; + } + case 1: + os << "bool"; + break; + default: + LOG(FATAL) << "Cannot convert uint" << type.bits() << " to Python type"; + } + } else if (type.is_int()) { + switch (type.bits()) { + case 8: + case 16: + case 32: + case 64: { + os << "int"; + break; + } + case 1: + os << "bool"; + break; + default: + LOG(FATAL) << "Cannot convert int" << type.bits() << " to Python type"; + } + } else { + LOG(FATAL) << "Cannot convert type " << type << " to Python type"; + } +} + +void CodeGenTileLangPY::VisitExpr_(const VarNode *op, + std::ostream &os) { // NOLINT(*) + os << GetVarID(op); +} + +void CodeGenTileLangPY::VisitExpr_(const IntImmNode *op, + std::ostream &os) { // NOLINT(*) + if (op->dtype == DataType::Bool()) { + os << (op->value ? "True" : "False"); + } else { + std::ostringstream temp; + temp << op->value; + MarkConst(temp.str()); + os << temp.str(); + } +} + +void CodeGenTileLangPY::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + switch (op->dtype.bits()) { + case 64: + case 32: { + std::ostringstream temp; + temp << "float.fromhex('" << std::hexfloat << op->value << "')"; + MarkConst(temp.str()); + os << temp.str(); + break; + } + case 16: { + PrintType(op->dtype, os); + os << "(float.fromhex('" << std::hexfloat << op->value << "'))"; + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenTileLangPY::VisitExpr_(const StringImmNode *op, + std::ostream &os) { // NOLINT(*) + EscapeStringLiteral_(op->value, os); +} + +void CodeGenTileLangPY::VisitExpr_(const CastNode *op, + std::ostream &os) { // NOLINT(*) + std::stringstream value; + PrintExpr_(op->value, value); + os << CastFromTo_(value.str(), op->value.dtype(), op->dtype); +} + +void CodeGenTileLangPY::VisitExpr_(const AddNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("+", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const SubNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("-", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const MulNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("*", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const DivNode *op, + std::ostream &os) { // NOLINT(*) + if (op->dtype.is_int() || op->dtype.is_uint()) { + PrintBinaryExpr_("//", op->dtype, op->a, op->b, os); + } else { + PrintBinaryExpr_("/", op->dtype, op->a, op->b, os); + } +} +void CodeGenTileLangPY::VisitExpr_(const ModNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK(op->dtype.is_int() || op->dtype.is_uint() || op->dtype.is_float()) + << "Expected floating point or integer dtype in Mod, but got " + << op->dtype; + PrintBinaryExpr_("%", op->dtype, op->a, op->b, os); +} + +void CodeGenTileLangPY::VisitExpr_(const MinNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("min", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const MaxNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("max", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const EQNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("==", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const NENode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("!=", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const LTNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("<", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const LENode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("<=", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const GTNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_(">", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const GENode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_(">=", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const AndNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("and", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const OrNode *op, + std::ostream &os) { // NOLINT(*) + PrintBinaryExpr_("or", op->dtype, op->a, op->b, os); +} +void CodeGenTileLangPY::VisitExpr_(const NotNode *op, + std::ostream &os) { // NOLINT(*) + os << "(not "; + PrintExpr_(op->a, os); + os << ")"; +} + +void CodeGenTileLangPY::VisitExpr_(const SelectNode *op, + std::ostream &os) { // NOLINT(*) + os << "("; + PrintExpr_(op->true_value, os); + os << " if "; + PrintExpr_(op->condition, os); + os << " else "; + PrintExpr_(op->false_value, os); + os << ")"; +} + +void CodeGenTileLangPY::VisitExpr_(const RampNode *op, + std::ostream &os) { // NOLINT(*) + int lanes = op->dtype.lanes(); + os << "("; + for (int i = 0; i < lanes; i++) { + os << "(" << PrintExpr_(op->base) << ")" + << "+(" << PrintExpr_(op->stride) << "*" << i << ")"; + if (i != lanes - 1) + os << ", "; + } + os << ")"; +} + +void CodeGenTileLangPY::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + if (auto opt_call_op = op->op.as()) { + const auto &call_op = opt_call_op.value(); + + if (op->op.same_as(builtin::ret())) { + os << "return " << RemoveOutermostParentheses(PrintExpr_(op->args[0])); + } else if (op->op.same_as(builtin::continue_loop())) { + os << "continue"; + } else if (op->op.same_as(builtin::break_loop())) { + os << "break"; + } else if (op->op.same_as(builtin_call_extern_) || + op->op.same_as(builtin_call_pure_extern_)) { + ICHECK_GE(op->args.size(), 1U); + auto func = Downcast(op->args[0]); + PrintCallExtern_(GetType(ffi::GetRef(op)), func->value, + op->args, true, os); + } else if (op_attr_global_symbol_.count(call_op)) { + // call extern if the op itself have a global symbol. + PrintCallExtern_(GetType(ffi::GetRef(op)), + op_attr_global_symbol_[call_op], op->args, false, os); + } else if (op->op.same_as(builtin::large_uint_imm())) { + ICHECK_EQ(op->args.size(), 2U); + uint64_t low = + static_cast(Downcast(op->args[0])->value); + uint64_t high = + static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + + if (op->dtype == DataType::UInt(32)) { + std::ostringstream temp; + temp << val; + MarkConst(temp.str()); + os << temp.str(); + } else { + PrintType(op->dtype, os); + os << "(" << val << ")"; + } + } else if (op->op.same_as(builtin::bitwise_and())) { + PrintBinaryIntrinsic_(op, "&", os); + } else if (op->op.same_as(builtin::bitwise_or())) { + PrintBinaryIntrinsic_(op, "|", os); + } else if (op->op.same_as(builtin::bitwise_xor())) { + PrintBinaryIntrinsic_(op, "^", os); + } else if (op->op.same_as(builtin::bitwise_not())) { + ICHECK_EQ(op->args.size(), 1U); + os << "~"; + PrintExpr_(op->args[0], os); + } else if (op->op.same_as(builtin::shift_left())) { + PrintBinaryIntrinsic_(op, "<<", os); + } else if (op->op.same_as(builtin::shift_right())) { + PrintBinaryIntrinsic_(op, ">>", os); + } else if (op->op.same_as(builtin::if_then_else())) { + + std::string cond = PrintExpr_(op->args[0]); + std::string true_val = PrintExpr_(op->args[1]); + std::string false_val = PrintExpr_(op->args[2]); + os << "(" << true_val << " if " << cond << " else " << false_val << ")"; + } else if (op->op.same_as(builtin::isnullptr())) { + ICHECK_EQ(op->args.size(), 1U); + os << "("; + PrintExpr_(op->args[0], os); + os << " is None)"; + } else if (op->op.same_as(builtin::isnan())) { + os << "("; + PrintExpr_(op->args[0], os); + os << " != "; + PrintExpr_(op->args[0], os); + os << ")"; + } else { + LOG(FATAL) << "Unresolved call " << op->op; + } + } else if (auto opt = op->op.as()) { + const auto &gvar = opt.value(); + auto callee_name = GetFunctionName_(gvar); + PrintCallExtern_(GetType(ffi::GetRef(op)), callee_name, op->args, + false, os); + } else { + LOG(FATAL) + << "CodeGenTileLangPY: Unknown operation " << op->op + << " is neither a recognized built-in, " + << "nor a GlobalVar reference to another function in the IRModule"; + } +} + +void CodeGenTileLangPY::VisitExpr_(const BufferLoadNode *op, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->indices.size(), 1) + << "Load from non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not supported."; + + DataType value_dtype = op->dtype; + PrimExpr index = op->indices[0]; + Var buffer_var = op->buffer->data; + DataType element_dtype = op->buffer->dtype; + + ICHECK_EQ(value_dtype, element_dtype) + << "value_dtype and element_dtype must be same for a BufferLoadNode"; + std::string ref = GetBufferRef_(op->dtype, op->buffer.get(), index); + os << ref; +} + +void CodeGenTileLangPY::VisitStmt_(const BufferStoreNode *op) { + ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not supported."; + + DataType value_dtype = op->value.dtype(); + DataType element_dtype = op->buffer->dtype; + PrimExpr index_expr = op->indices[0]; + Var buffer_var = op->buffer->data; + + ICHECK_EQ(value_dtype, element_dtype) + << "value_dtype and element_dtype must be same for a BufferStoreNode"; + std::string value = PrintExpr_(op->value); + std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr); + PrintIndent(); + stream << ref << " = " << RemoveOutermostParentheses(value) << "\n"; +} + +void CodeGenTileLangPY::VisitStmt_(const DeclBufferNode *op) { + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const LetStmtNode *op) { + std::string value = PrintExpr_(op->value); + PrintIndent(); + stream << AllocVarID(op->var.get()) << " = " << value << "\n"; + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + PrintIndent(); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + + stream << vid << " = [None] * " << constant_size << "\n"; + + RegisterHandleType_(op->buffer_var.get(), op->dtype); + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const AttrStmtNode *op) { + PrintStmt_(op->body); +} + +void CodeGenTileLangPY::VisitStmt_(const ForNode *op) { + PrintIndent(); + std::string vid = AllocVarID(op->loop_var.get()); + stream << "for " << vid << " in range("; + if (is_zero(op->min)) { + PrintExpr_(op->extent, stream); + } else { + PrintExpr_(op->min, stream); + stream << ", "; + PrimExpr upper_bound = arith::Analyzer().Simplify(op->extent + op->min); + PrintExpr_(upper_bound, stream); + } + stream << "):\n"; + int for_scope = BeginScope(); + PrintStmt_(op->body); + EndScope(for_scope); +} + +void CodeGenTileLangPY::VisitStmt_(const WhileNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + stream << "while " << RemoveOutermostParentheses(cond) << ":\n"; + int while_scope = BeginScope(); + PrintStmt_(op->body); + EndScope(while_scope); +} + +void CodeGenTileLangPY::VisitStmt_(const IfThenElseNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + stream << "if " << RemoveOutermostParentheses(cond) << ":\n"; + int then_scope = BeginScope(); + PrintStmt_(op->then_case); + EndScope(then_scope); + + if (op->else_case) { + PrintIndent(); + stream << "else:\n"; + int else_scope = BeginScope(); + PrintStmt_(op->else_case.value()); + EndScope(else_scope); + } +} + +void CodeGenTileLangPY::VisitStmt_(const SeqStmtNode *op) { + for (Stmt stmt : op->seq) { + PrintStmt_(stmt); + } +} + +void CodeGenTileLangPY::VisitStmt_(const EvaluateNode *op) { + if (is_const_int(op->value)) + return; + + std::string vid = PrintExpr_(op->value); + if (!vid.empty()) { + PrintIndent(); + stream << vid << "\n"; + } +} + +void CodeGenTileLangPY::VisitStmt_(const AssertStmtNode *op) { + std::string cond = PrintExpr_(op->condition); + PrintIndent(); + if (const auto *str = op->message.as()) { + stream << "assert " << cond << ", "; + EscapeStringLiteral_(str->value, stream); + stream << "\n"; + } else { + stream << "assert " << cond << "\n"; + } + PrintStmt_(op->body); +} + +std::string CodeGenTileLangPY::CastFromTo_(const std::string &value, + DataType from, DataType target) { + if (from == target) + return value; + std::ostringstream os; + PrintType(target, os); + os << "(" << value << ")"; + return os.str(); +} + +void CodeGenTileLangPY::PrintBinaryExpr_(const std::string &opstr, + DataType dtype, PrimExpr lhs, + PrimExpr rhs, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(dtype.lanes(), 1); + if (isalpha(opstr[0]) && opstr != "and" && opstr != "or") { + os << opstr << '('; + PrintExpr_(lhs, os); + os << ", "; + PrintExpr_(rhs, os); + os << ')'; + } else { + os << '('; + PrintExpr_(lhs, os); + os << ' ' << opstr << ' '; + PrintExpr_(rhs, os); + os << ')'; + } +} + +void CodeGenTileLangPY::PrintBinaryIntrinsic_(const CallNode *op, + const char *opstr, + std::ostream &os) { // NOLINT(*) + ICHECK_EQ(op->dtype.lanes(), 1); + ICHECK_EQ(op->args.size(), 2U); + os << '('; + PrintExpr_(op->args[0], os); + os << ' ' << opstr << ' '; + PrintExpr_(op->args[1], os); + os << ')'; +} + +void CodeGenTileLangPY::PrintCallExtern_(Type ret_type, + ffi::String global_symbol, + const ffi::Array &args, + bool skip_first_arg, + std::ostream &os) { // NOLINT(*) + os << global_symbol << "("; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + PrintExpr_(args[i], os); + if (i < args.size() - 1) { + os << ", "; + } + } + os << ")"; +} + +// Print a reference expression to a buffer. +std::string CodeGenTileLangPY::GetBufferRef_(DataType t, + const BufferNode *buffer, + PrimExpr index) { + const VarNode *buffer_var = buffer->data.get(); + std::string vid = GetVarID(buffer_var); + DataType buffer_element_dtype = buffer->dtype; + + ICHECK(HandleTypeMatch_(buffer_var, buffer_element_dtype)); + ICHECK_EQ(t, buffer_element_dtype); + + std::string index_str = PrintExpr_(index); + return vid + "[" + index_str + "]"; +} + +void CodeGenTileLangPY::RegisterHandleType_(const VarNode *buf_var, + DataType t) { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) { + handle_data_type_[buf_var] = t; + } else { + ICHECK(it->second == t) << "conflicting buf var type"; + } +} + +bool CodeGenTileLangPY::HandleTypeMatch_(const VarNode *buf_var, + DataType t) const { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) + return false; + return it->second == t; +} + +void CodeGenTileLangPY::EscapeStringLiteral_(const std::string &s, + std::ostream &os) { + os << '"'; + for (unsigned char c : s) { + switch (c) { + case '\\': + os << "\\\\"; + break; + case '"': + os << "\\\""; + break; + case '\n': + os << "\\n"; + break; + case '\r': + os << "\\r"; + break; + case '\t': + os << "\\t"; + break; + case '\f': + os << "\\f"; + break; + case '\b': + os << "\\b"; + break; + default: + // Handle non-printable and non-ASCII characters + if (c < 32 || c == 127) { + // Output as \xHH + os << "\\x"; + const char hex[] = "0123456789abcdef"; + os << hex[(c >> 4) & 0xF]; + os << hex[c & 0xF]; + } else { + os << c; + } + break; + } + } + os << '"'; +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/codegen_py.h b/src/target/codegen_py.h new file mode 100644 index 000000000..431fe933d --- /dev/null +++ b/src/target/codegen_py.h @@ -0,0 +1,255 @@ +/*! + * \file codegen_py.h + * \brief Common utilities to generate simple Python code. + */ +#ifndef TVM_TL_TARGET_CODEGEN_PY_H_ +#define TVM_TL_TARGET_CODEGEN_PY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// from tvm/src/ +#include "target/source/codegen_source_base.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace codegen { + +using namespace tir; +/*! + * \brief A base class to generate simple Python code. + */ +class CodeGenTileLangPY + : public ExprFunctor, + public StmtFunctor, + public CodeGenSourceBase { +public: + /*! + * \brief Add the function definition to the generated module. + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. + */ + virtual void AddFunction(const GlobalVar &gvar, const PrimFunc &func); + + /*! + * \brief Finalize the compilation and return the code. + * \return The code. + */ + virtual std::string Finish(); + +protected: + /*! + * \brief Get the name of a declared function + * \param gvar The GlobalVar of the function + * \returns The string name of the function + */ + ffi::String GetFunctionName_(const GlobalVar &gvar); + + /*! + * \brief Reserve the function name in the generated module. + * + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. + * \param whether to append return 0 in the end. + */ + virtual void RegisterFunction_(const GlobalVar &gvar, const PrimFunc &func); + + /*! + * \brief Initialize codegen state for generating f. + * \param f The function to be compiled. + */ + virtual void InitFuncState_(const PrimFunc &f); + + /*! \brief Print the function signature before ":" + * \param function_name The name of the function + * \param func The function whose signature should be printed + * \param os The output stream + */ + virtual void PrintFunctionSignature_(const ffi::String &function_name, + const PrimFunc &func, + std::ostream &os); // NOLINT(*) + + /*! + * \brief Print the function decorator + * \param os The output stream + */ + virtual void PrintFuncDecorator_(std::ostream &os) {} // NOLINT(*) + + /*! + * \brief Insert statement before function body. + * \param f The function to be compiled. + */ + virtual void PreFunctionBody_(const PrimFunc &f) {} + +protected: + /*! \brief reserves common Python keywords */ + void ReserveKeywordsAsUnique_(); + + void PrintSSAAssign(const std::string &target, const std::string &src, + DataType t) override; + +protected: + /*! + * \brief Print Type representation of type type. + * \param t The type representation. + * \param os The output stream + */ + void PrintType(DataType type, std::ostream &os) override; // NOLINT(*) + + /*! + * \brief Print the Stmt n to CodeGenTileLangPY->stream + * \param n The statement to be printed. + */ + void PrintStmt_(const Stmt &n) { VisitStmt(n); } + /*! + * \brief Print the expression n into os + * \param n The expression to be printed. + * \param os The output stream + */ + void PrintExpr_(const PrimExpr &n, std::ostream &os) { // NOLINT(*) + VisitExpr(n, os); + } + /*! + * \brief Same as PrintExpr_, but simply returns result string + * \param n The expression to be printed. + */ + std::string PrintExpr_(const PrimExpr &n) { + std::ostringstream os; + PrintExpr_(n, os); + return os.str(); + } + + // expression + void VisitExpr_(const VarNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const StringImmNode *op, + std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const AddNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const SubNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MulNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const ModNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const EQNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const NENode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const LTNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const LENode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const GTNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const GENode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const AndNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const OrNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const NotNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const RampNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*) + void VisitExpr_(const BufferLoadNode *op, + std::ostream &os) override; // NOLINT(*) + + // statment + void VisitStmt_(const BufferStoreNode *op) override; + void VisitStmt_(const DeclBufferNode *op) override; + void VisitStmt_(const LetStmtNode *op) override; + void VisitStmt_(const AllocateNode *op) override; + void VisitStmt_(const AttrStmtNode *op) override; + void VisitStmt_(const ForNode *op) override; + void VisitStmt_(const WhileNode *op) override; + void VisitStmt_(const IfThenElseNode *op) override; + void VisitStmt_(const SeqStmtNode *op) override; + void VisitStmt_(const EvaluateNode *op) override; + void VisitStmt_(const AssertStmtNode *op) override; + +protected: + // Get a string of type casting + virtual std::string CastFromTo_(const std::string &value, DataType from, + DataType target); + + virtual void PrintBinaryExpr_(const std::string &opstr, DataType dtype, + PrimExpr lhs, PrimExpr rhs, + std::ostream &os); // NOLINT(*) + virtual void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr, + std::ostream &os); // NOLINT(*) + + /*! + * \brief Print external function call. + * \param ret_type The return type. + * \param global_symbol The symbolc of the target function. + * \param args The arguments to the function. + * \param skip_first_arg Whether to skip the first arguments. + * \param os The output stream. + */ + virtual void PrintCallExtern_(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, + bool skip_first_arg, + std::ostream &os); // NOLINT(*) + + // Print reference to a buffer as type t in index. + virtual std::string GetBufferRef_(DataType t, const BufferNode *buffer, + PrimExpr index); + + /*! + * \brief Register the data type of buf_var + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + void RegisterHandleType_(const VarNode *buf_var, DataType t); + + /*! + * \brief If buffer is allocated as type t. + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + bool HandleTypeMatch_(const VarNode *buf_var, DataType t) const; + +protected: + /*! \brief the storage scope of allocation */ + std::unordered_map alloc_storage_scope_; + + /*! \brief Record of ops that have pre-defined global symbol. */ + OpAttrMap op_attr_global_symbol_ = + Op::GetAttrMap("TGlobalSymbol"); + + // cache commonly used ops + const Op &builtin_call_extern_ = builtin::call_extern(); + const Op &builtin_call_pure_extern_ = builtin::call_pure_extern(); + +private: + /*! \brief the data type of allocated buffers */ + std::unordered_map handle_data_type_; + + /* \brief Map of GlobalVar to their symbol. + * + * For externally-exposed functions, this is given by the + * tvm::attr::kTarget attribute of the PrimFunc. For internal + * functions, this is the name of the function's GlobalVar, possibly + * altered to prevent duplicate names. + */ + std::unordered_map internal_functions_; + + /* \brief Name supply to generate unique function names */ + NameSupply func_name_supply_; + + /*! + * \brief Escape a string to be a valid Python double-quoted string literal. + * \param s The input string to escape. + * \param os The output stream to write the escaped string to. + */ + void EscapeStringLiteral_(const std::string &s, std::ostream &os); +}; + +} // namespace codegen +} // namespace tvm +#endif // TVM_TL_TARGET_CODEGEN_PY_H_ diff --git a/src/target/codegen_utils.cc b/src/target/codegen_utils.cc new file mode 100644 index 000000000..75d038d3a --- /dev/null +++ b/src/target/codegen_utils.cc @@ -0,0 +1,41 @@ +/*! + * \file target/codegen_utils.cc + * \brief Shared utility functions for code generation + */ + +#include "codegen_utils.h" + +namespace tvm { +namespace codegen { + +bool CheckOutermostParenthesesMatch(const std::string &s) { + if (!s.empty() && s.front() == '(' && s.back() == ')') { + size_t len = s.size(); + int n_unmatched = 0; + for (size_t i = 0; i < len; ++i) { + if (s[i] == '(') { + n_unmatched++; + } else if (s[i] == ')') { + n_unmatched--; + } + if (n_unmatched < 0) { + return false; + } + if (n_unmatched == 0) { + return i == len - 1; + } + } + } + return false; +} + +std::string RemoveOutermostParentheses(const std::string &s) { + if (CheckOutermostParenthesesMatch(s)) { + return s.substr(1, s.size() - 2); + } else { + return s; + } +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/codegen_utils.h b/src/target/codegen_utils.h new file mode 100644 index 000000000..1ef52d4b1 --- /dev/null +++ b/src/target/codegen_utils.h @@ -0,0 +1,33 @@ +/*! + * \file target/codegen_utils.h + * \brief Shared utility functions for code generation + */ + +#ifndef TVM_TARGET_CODEGEN_UTILS_H_ +#define TVM_TARGET_CODEGEN_UTILS_H_ + +#include + +namespace tvm { +namespace codegen { + +/*! + * \brief Check if the outermost parentheses match + * \param s The input string + * \return true if the first character is '(' and the last character is ')' + * and they form a matching pair + */ +bool CheckOutermostParenthesesMatch(const std::string &s); + +/*! + * \brief Remove outermost parentheses if they match + * \param s The input string + * \return The string with outermost parentheses removed if they match, + * otherwise return the original string + */ +std::string RemoveOutermostParentheses(const std::string &s); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_CODEGEN_UTILS_H_ diff --git a/src/target/rt_mod_cutedsl.cc b/src/target/rt_mod_cutedsl.cc new file mode 100644 index 000000000..a2b6d05d1 --- /dev/null +++ b/src/target/rt_mod_cutedsl.cc @@ -0,0 +1,69 @@ +#include "codegen_cutedsl.h" +#include "runtime/cuda/cuda_module.h" +#include "runtime/pack_args.h" +#include + +namespace tvm { +namespace codegen { + +static std::unordered_map +ExtractFuncInfo(const IRModule &mod) { + std::unordered_map fmap; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); + + runtime::FunctionInfo info; + for (size_t i = 0; i < f->params.size(); ++i) { + if (f->params[i]->dtype.is_handle()) { + auto ptr = f->params[i]->type_annotation.as(); + if (ptr && ptr->storage_scope == "grid_constant") { + info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1)); + continue; + } + } + info.arg_types.push_back(f->params[i].dtype()); + } + if (auto opt = f->GetAttr>( + tir::attr::kKernelLaunchParams)) { + for (const auto &tag : opt.value()) { + info.launch_param_tags.push_back(tag); + } + } + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + fmap[static_cast(global_symbol.value())] = info; + } + return fmap; +} + +ffi::Module BuildTileLangCuTeDSLWithoutCompile(IRModule mod, Target target) { + CodeGenTileLangCuTeDSL cg; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangCuTeDSL: Can only take PrimFunc"; + auto gvar = Downcast(kv.first); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); + cg.AddFunction(gvar, f); + } + + std::string code = cg.Finish(); + if (const auto f = + ffi::Function::GetGlobal("tilelang_callback_cutedsl_postproc")) { + code = (*f)(code, target).cast(); + } + return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_cutedsl_without_compile", + BuildTileLangCuTeDSLWithoutCompile); +} + +} // namespace codegen +} // namespace tvm diff --git a/src/tl_templates/cuda/nvrtc_std.h b/src/tl_templates/cuda/nvrtc_std.h index 1e6800e51..34cd58bb2 100644 --- a/src/tl_templates/cuda/nvrtc_std.h +++ b/src/tl_templates/cuda/nvrtc_std.h @@ -173,4 +173,4 @@ template inline constexpr size_t extent_v = extent::value; } // namespace std -#endif \ No newline at end of file +#endif // __CUDACC_RTC__ diff --git a/testing/python/jit/test_tilelang_jit_cutedsl.py b/testing/python/jit/test_tilelang_jit_cutedsl.py new file mode 100644 index 000000000..7c613c4d1 --- /dev/null +++ b/testing/python/jit/test_tilelang_jit_cutedsl.py @@ -0,0 +1,381 @@ +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +import tilelang +import torch +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + stramp = "&*(XS)" + + @tvm.register_global_func("tilelang_callback_cutedsl_postproc", override=True) + def tilelang_callback_cutedsl_postproc(code, _): + code = f"# {stramp}\n" + code + return code + + matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl") + + kernel_source = matmul_kernel.get_kernel_source() + + assert stramp in kernel_source, f"Expected {stramp} in the kernel source" + + +def test_gemm_f16f16f16_nn(): + run_gemm( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def matmul_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl") + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + A = torch.randn(M, K, dtype=in_dtype).cuda() + B = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(out_dtype) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + +def run_cutedsl_kernel_do_bench( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, target="cutedsl") + + profiler = matmul_kernel.get_profiler() + + cutedsl_latency = profiler.do_bench(func=matmul_kernel) + print(f"CuTeDSL Latency: {cutedsl_latency} ms") + + assert cutedsl_latency is not None + + tvm_latency = profiler.do_bench() + print(f"TVM Latency: {tvm_latency} ms") + + assert tvm_latency is not None + + +def test_cutedsl_kernel_do_bench(): + run_cutedsl_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cutedsl_kernel_multi_stream( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, target="cutedsl") + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + num_streams = 4 + for _ in range(num_streams): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + matmul_kernel(tensor_a, tensor_b, tensor_c) + + +def test_cutedsl_kernel_multi_stream(): + run_cutedsl_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def run_cutedsl_dynamic_shape( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128 +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.compile(program, target="cutedsl") + if isinstance(M, T.Var): + M = 1024 + if isinstance(N, T.Var): + N = 1024 + if isinstance(K, T.Var): + K = 768 + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + + tensor_a = torch.randn(M, K, dtype=in_dtype).cuda() + tensor_b = torch.randn(K, N, dtype=in_dtype).cuda() + + if trans_A: + tensor_a = tensor_a.T + if trans_B: + tensor_b = tensor_b.T + tensor_c = torch.randn(M, N, dtype=out_dtype).cuda() + + matmul_kernel(tensor_a, tensor_b, tensor_c) + + tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) + tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_cutedsl_dynamic_shape(): + run_cutedsl_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cutedsl_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + run_cutedsl_dynamic_shape( + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2 + ) + + +def check_hopper(): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 58295406e..cf6a5591b 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -29,6 +29,11 @@ KERNEL_PY_PATH = "kernel.py" PARAMS_PATH = "params.pkl" +# CuTeDSL C++ launcher specific +LAUNCHER_LIB_PATH = "launcher_lib.so" +LAUNCHER_CPP_PATH = "launcher.cpp" +CUTEDSL_CUBIN_PATH = "kernel.cubin" + class KernelCache: """ @@ -43,7 +48,7 @@ class KernelCache: _instance = None # For implementing singleton pattern _lock = threading.Lock() # For thread safety _memory_cache = {} # In-memory cache dictionary - execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi" + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi" def __new__(cls): """ @@ -72,7 +77,7 @@ def _generate_key( self, func: Callable, out_idx: list[int], - execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", args=None, target: str | Target = "auto", target_host: str | Target = None, @@ -85,7 +90,7 @@ def _generate_key( Args: func (Callable): The function to be compiled. out_idx (List[int]): Indices specifying which outputs to return. - execution_backend (Literal): Backend type for execution. Defaults to "cython". + execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi". args: Arguments passed to the function. target (Union[str, Target]): Compilation target platform. Defaults to "auto". target_host (Union[str, Target], optional): Host target platform. @@ -118,7 +123,7 @@ def cached( *args, target: str | Target = "auto", target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto", verbose: bool = False, pass_configs: dict = None, compile_flags: list[str] | str | None = None, @@ -217,7 +222,11 @@ def cached( ) with self._lock: if env.is_cache_enabled(): + cache_path = self._get_cache_path(key) self._save_kernel_to_disk(key, kernel, func, verbose) + # Set cache path on adapter so it can save cubin after first execution + if hasattr(kernel, "adapter") and execution_backend == "cutedsl": + kernel.adapter._cache_path = cache_path # Store in memory cache after compilation self._memory_cache[key] = kernel @@ -287,59 +296,83 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non # Save kernel source code try: - device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) - if verbose: - self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") - if kernel.kernel_source is not None: - KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) - except Exception as e: - self.logger.error(f"Error saving kernel source code to disk: {e}") + if self.execution_backend != "cutedsl": + device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) + if verbose: + self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") + if kernel.kernel_source is not None: + KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) + except Exception: + self.logger.exception("Error saving kernel source code to disk") # Save wrapped kernel source code try: - host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) + host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH if self.execution_backend != "cutedsl" else KERNEL_PY_PATH) if verbose: self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") if self.execution_backend == "tvm_ffi": KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source())) else: KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source())) - except Exception as e: - self.logger.error(f"Error saving host kernel source code to disk: {e}") + except Exception: + self.logger.exception("Error saving host kernel source code to disk") # Save the kernel library try: # Save CUBIN or SO file - if self.execution_backend == "nvrtc": - kernel_lib_path = KERNEL_CUBIN_PATH - elif self.execution_backend == "tvm_ffi": - kernel_lib_path = EXECUTABLE_PATH - else: - kernel_lib_path = KERNEL_LIB_PATH - - kernel_lib_path = os.path.join(cache_path, kernel_lib_path) + if self.execution_backend == "cutedsl": + # For CuTeDSL, kernel_lib_path is the Python module + kernel_lib_path = os.path.join(cache_path, KERNEL_PY_PATH) + + # Save C++ launcher library if it exists + lib_gen = getattr(kernel.adapter, "lib_generator", None) + if lib_gen and hasattr(lib_gen, "launcher_libpath") and lib_gen.launcher_libpath: + launcher_lib_path = os.path.join(cache_path, LAUNCHER_LIB_PATH) + src_launcher_path = lib_gen.launcher_libpath + if verbose: + self.logger.debug(f"Saving C++ launcher library to cache: {src_launcher_path}") + KernelCache._safe_write_file( + launcher_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_launcher_path)) + ) + + # Optionally save launcher C++ source for debugging + if hasattr(kernel.adapter, "launcher_cpp_code") and kernel.adapter.launcher_cpp_code: + launcher_cpp_path = os.path.join(cache_path, LAUNCHER_CPP_PATH) + if verbose: + self.logger.debug(f"Saving C++ launcher source to: {launcher_cpp_path}") + KernelCache._safe_write_file(launcher_cpp_path, "w", lambda file: file.write(kernel.adapter.launcher_cpp_code)) - # Save an extra Python file for NVRTC - if self.execution_backend == "nvrtc": - src_lib_path = kernel.adapter.libpath - kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) - src_lib_path = src_lib_path.replace(".cubin", ".py") - if verbose: - self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") - KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) - elif self.execution_backend == "tvm_ffi": - executable = kernel.adapter.executable - if verbose: - self.logger.debug(f"Saving kernel executable to file: {executable}") - KernelCache._safe_write_executable(executable, kernel_lib_path) else: - src_lib_path = kernel.adapter.libpath - if verbose: - self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") - KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) - - except Exception as e: - self.logger.error(f"Error saving kernel library to disk: {e}") + if self.execution_backend == "nvrtc": + kernel_lib_path = KERNEL_CUBIN_PATH + elif self.execution_backend == "tvm_ffi": + kernel_lib_path = EXECUTABLE_PATH + else: + kernel_lib_path = KERNEL_LIB_PATH + kernel_lib_path = os.path.join(cache_path, kernel_lib_path) + + # Save an extra Python file for NVRTC + if self.execution_backend == "nvrtc": + src_lib_path = kernel.adapter.libpath + kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH) + src_lib_path = src_lib_path.replace(".cubin", ".py") + if verbose: + self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") + KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) + + if self.execution_backend == "tvm_ffi": + executable = kernel.adapter.executable + if verbose: + self.logger.debug(f"Saving kernel executable to file: {executable}") + KernelCache._safe_write_executable(executable, kernel_lib_path) + else: + src_lib_path = kernel.adapter.libpath + if verbose: + self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") + KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) + + except Exception: + self.logger.exception("Error saving kernel library to disk") # Save kernel parameters try: @@ -347,19 +380,19 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non if verbose: self.logger.debug(f"Saving kernel parameters to disk: {params_path}") KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) - except Exception as e: - self.logger.error(f"Error saving kernel parameters to disk: {e}") + except Exception: + self.logger.exception("Error saving kernel parameters to disk") def _load_kernel_from_disk( self, key: str, target: str | Target = "auto", - target_host: str | Target = None, - out_idx: list[int] = None, - execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", - pass_configs: dict = None, + target_host: str | Target | None = None, + out_idx: list[int] | None = None, + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", + pass_configs: dict | None = None, compile_flags: list[str] | str | None = None, - func: Callable = None, + func: Callable | None = None, verbose: bool = False, ) -> JITKernel | None: """ @@ -370,7 +403,7 @@ def _load_kernel_from_disk( target (Union[str, Target]): Compilation target platform. Defaults to "auto". target_host (Union[str, Target], optional): Host target platform. out_idx (List[int], optional): Indices specifying which outputs to return. - execution_backend (Literal): Backend type for execution. Defaults to "cython". + execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi". pass_configs (dict, optional): Configuration for compiler passes. func (Callable, optional): The original function. verbose (bool): Enable verbose log messages. @@ -385,11 +418,21 @@ def _load_kernel_from_disk( kernel_lib_path = KERNEL_CUBIN_PATH elif self.execution_backend == "tvm_ffi": kernel_lib_path = EXECUTABLE_PATH + elif self.execution_backend == "cutedsl": + kernel_lib_path = KERNEL_PY_PATH else: kernel_lib_path = KERNEL_LIB_PATH kernel_lib_path = os.path.join(cache_path, kernel_lib_path) params_path = os.path.join(cache_path, PARAMS_PATH) - if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): + + # Check required files exist + required_files = [kernel_lib_path, params_path] + + # For CuTeDSL, also check launcher library + if self.execution_backend == "cutedsl": + required_files.append(os.path.join(cache_path, LAUNCHER_LIB_PATH)) + + if not all([os.path.exists(file) for file in required_files]): return None device_kernel_source: str | None = None @@ -397,20 +440,25 @@ def _load_kernel_from_disk( kernel_params: list[KernelParam] | None = None # Load the kernel source file (optional) - try: - if verbose: - self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}") - with open(device_kernel_path) as f: - device_kernel_source = f.read() - except Exception as e: - self.logger.error(f"Error loading kernel source code from disk: {e}") - try: - if verbose: - self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") - with open(host_kernel_path) as f: - host_kernel_source = f.read() - except Exception as e: - self.logger.error(f"Error loading host kernel source code from disk: {e}") + if self.execution_backend != "cutedsl": + try: + if verbose: + self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}") + with open(device_kernel_path) as f: + device_kernel_source = f.read() + except Exception: + self.logger.exception("Error loading kernel source code from disk") + try: + if verbose: + self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") + with open(host_kernel_path) as f: + host_kernel_source = f.read() + except Exception: + self.logger.exception("Error loading host kernel source code from disk") + else: + # For CuTeDSL, set empty strings since sources aren't loaded from cache + device_kernel_source = "" + host_kernel_source = "" # Load kernel parameters try: @@ -418,10 +466,10 @@ def _load_kernel_from_disk( self.logger.debug(f"Loading kernel parameters from file: {params_path}") with open(params_path, "rb") as f: kernel_params = cloudpickle.load(f) - except Exception as e: - self.logger.error(f"Error loading kernel parameters from disk: {e}") + except Exception: + self.logger.exception("Error loading kernel parameters from disk") - if host_kernel_source and device_kernel_source and kernel_params: + if ((host_kernel_source and device_kernel_source) or self.execution_backend == "cutedsl") and kernel_params: return JITKernel.from_database( func=func, host_kernel_source=host_kernel_source, @@ -453,5 +501,5 @@ def _clear_disk_cache(self): # Re-create the cache directory KernelCache._create_dirs() - except Exception as e: - self.logger.error(f"Error clearing disk cache: {e}") + except Exception: + self.logger.exception("Error clearing disk cache") diff --git a/tilelang/contrib/cutedsl/__init__.py b/tilelang/contrib/cutedsl/__init__.py new file mode 100644 index 000000000..1028badea --- /dev/null +++ b/tilelang/contrib/cutedsl/__init__.py @@ -0,0 +1,128 @@ +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import nvvm +from cutlass.cutlass_dsl import T + +# re-export cutlass.cute.arch functions first +from cutlass.cute.arch import sync_threads # noqa: F401 +from cutlass.cute.arch import alloc_smem, get_dyn_smem # noqa: F401 +from cutlass.cute.arch import warpgroup_reg_alloc, warpgroup_reg_dealloc # noqa: F401 + +from cutlass.cute import make_tensor, make_rmem_tensor, recast_ptr # noqa: F401 +from cutlass.cute.typing import Numeric + +from cutlass.base_dsl.typing import as_numeric, Int32, Uint16, Uint32 # noqa: F401 +from cutlass._mlir.dialects import llvm, arith # noqa: F401 +from cutlass._mlir import ir as mlir_ir +from cutlass.cutlass_dsl import dsl_user_op + +# Import our custom implementations (will override if names conflict) +from .mbar import * +from .cpasync import * +from .gemm_V1 import * +from .reduce import * +from .ldsm import * +from .math import * +from .threadblock_swizzle import * + +# Forward nvvm enums +from cutlass._mlir.dialects.nvvm import ( + MemOrderKind, + MemScopeKind, + AtomicOpKind, +) + +BYTES_PER_TENSORMAP = 128 +BYTES_PER_POINTER = 8 + + +def make_filled_tensor(shape, value): + t = cute.make_rmem_tensor(shape, type(value)) + t.fill(value) + return t + + +def make_tensor_at_offset(ptr: cute.Pointer, offset, shape, div_by=1): + if div_by != 1: + offset = cute.assume(cutlass.as_numeric(offset), divby=div_by) + return cute.make_tensor(ptr + offset, shape) + + +def shuffle_elect(thread_extent): + # thread_extent is the number of threads of a warpgroup + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + if thread_extent == 0: + return warp_idx == 0 + else: + return (warp_idx % (thread_extent // 32)) == 0 + + +def sync_thread_partial(barrier_id=None, thread_count=None): + bar_sync_ptx(barrier_id, thread_count) + + +# Packing functions +def pack_half2(x, y): + """ + Pack two half-precision (fp16) values into a single 32-bit value. + Corresponds to CUDA's __pack_half2 intrinsic. + + This packs two fp16 values into a single int32 by treating the fp16 bits + as raw data and concatenating them. + """ + + @dsl_user_op + def pack_half2_impl(x_val, y_val, *, loc=None, ip=None): + # Cast fp16 to uint16 (bitcast) + x_ir = x_val.ir_value(loc=loc, ip=ip) if hasattr(x_val, "ir_value") else x_val + y_ir = y_val.ir_value(loc=loc, ip=ip) if hasattr(y_val, "ir_value") else y_val + + # Bitcast fp16 to i16 + i16_type = mlir_ir.IntegerType.get_signless(16) + x_i16 = llvm.bitcast(i16_type, x_ir, loc=loc, ip=ip) + y_i16 = llvm.bitcast(i16_type, y_ir, loc=loc, ip=ip) + + packed_xy = llvm.inline_asm( + Int32.mlir_type, + [x_i16, y_i16], + "mov.b32 $0, {$1, $2};", + "=r,h,h", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Int32(packed_xy) + + return pack_half2_impl(x, y) + + +def AtomicAdd(ptr: cute.Pointer, value: Numeric, *, loc=None, ip=None): + if ptr.dtype == cutlass.Float32: + ret = nvvm.atomicrmw( + T.f32(), + AtomicOpKind.FADD, + ptr.llvm_ptr, + ptr.dtype(value).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU, + loc=loc, + ip=ip, + ) + elif ptr.dtype == cutlass.Int32: + ret = nvvm.atomicrmw( + T.i32(), + AtomicOpKind.ADD, + ptr.llvm_ptr, + ptr.dtype(value).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU, + loc=loc, + ip=ip, + ) + else: + raise ValueError(f"Unsupported dtype: {ptr.dtype}") + return ptr.dtype(ret) diff --git a/tilelang/contrib/cutedsl/cpasync.py b/tilelang/contrib/cutedsl/cpasync.py new file mode 100644 index 000000000..6ddeb8933 --- /dev/null +++ b/tilelang/contrib/cutedsl/cpasync.py @@ -0,0 +1,215 @@ +from __future__ import annotations +from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op # noqa: F401 + +from cutlass._mlir.dialects import nvvm, cute_nvgpu # noqa: F401 +from cutlass._mlir import ir + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +import cutlass.cute as cute +from cutlass.cute.typing import Int, Boolean, Int32, Int16, Uint64, Union # noqa: F401 +from cutlass.impl_utils import check_value_in + +from cutlass.cute.arch import cp_async_commit_group as cp_async_commit # noqa: F401 +from cutlass.cute.arch import cp_async_wait_group as cp_async_wait # noqa: F401 + +BYTES_PER_TENSORMAP = 128 +BYTES_PER_POINTER = 8 + + +def cp_async_gs(size, dst, dst_offset, src, src_offset): + assert size in [16, 8, 4] + # use CG (cache global) to by pass L1 when loading contiguous 128B. + mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA + if isinstance(src, cute.Tensor): + src_ptr = src.iterator + elif isinstance(src, cute.Pointer): + src_ptr = src + else: + raise ValueError(f"Invalid source type: {type(src)}") + if isinstance(dst, cute.Tensor): + dst_ptr = dst.iterator + elif isinstance(dst, cute.Pointer): + dst_ptr = dst + else: + raise ValueError(f"Invalid destination type: {type(dst)}") + cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode) + + +@cute.jit +def cp_async_gs_conditional(size, dst, dst_offset, src, src_offset, cond): + if cond: + cp_async_gs(size, dst, dst_offset, src, src_offset) + + +@dsl_user_op +def extract_tensormap_ptr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer: + """ + extract the tensormap pointer from a TMA Copy Atom. + :param tma_atom: The TMA Copy Atom + :type tma_atom: CopyAtom + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + ptr_type = _cute_ir.PtrType.get(Uint64.mlir_type, _cute_ir.AddressSpace.generic, 64) + tensormap_ptr = _cute_nvgpu_ir.get_tma_desc_addr(ptr_type, exec_value, loc=loc, ip=ip) + return tensormap_ptr + + +@dsl_user_op +def tma_load(tma_desc, mbar: cute.Pointer, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None: + """ + Load data from global memory to shared memory using TMA (Tensor Memory Access). + + :param tma_desc: TMA descriptor for the tensor + :type tma_desc: CopyAtom or tensormap_ptr or Tensor of tensormap_ptr + :param mbar: Mbarrier pointer in shared memory + :type mbar: Pointer + :param smem_ptr: Destination pointer in shared memory + :type smem_ptr: Pointer + :param crd: Coordinates tuple for the tensor access + :type crd: tuple[Int, ...] + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + + if not isinstance(crd, tuple) and isinstance(tma_desc, cute.Pointer): + # Legacy signature: tma_load(smem_ptr, gmem_ptr, mbar, size) + _smem_ptr = tma_desc + _gmem_ptr = mbar + _mbar = smem_ptr + nvvm.cp_async_bulk_shared_cluster_global( + dst_mem=_smem_ptr.llvm_ptr, + src_mem=_gmem_ptr.llvm_ptr, + mbar=_mbar.llvm_ptr, + size=Int32(crd).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + else: + if isinstance(tma_desc, cute.CopyAtom): + tma_desc_ptr = extract_tensormap_ptr(tma_desc) + elif isinstance(tma_desc, cute.Tensor): + tma_desc_ptr = tma_desc.iterator + else: + tma_desc_ptr = tma_desc + nvvm.cp_async_bulk_tensor_shared_cluster_global( + dst_mem=smem_ptr.llvm_ptr, + tma_descriptor=tma_desc_ptr.llvm_ptr, + coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd], + mbar=mbar.llvm_ptr, + im2col_offsets=[], + load_mode=nvvm.CpAsyncBulkTensorLoadMode.TILE, + group=nvvm.Tcgen05GroupKind.CTA_1, + use_intrinsic=False, # set to True would lead to compile error + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_store(tma_desc, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None: + """ + Store data from shared memory to global memory using TMA (Tensor Memory Access). + + :param tma_desc: TMA descriptor for the tensor + :type tma_desc: TMA descriptor + :param smem_ptr: Source pointer in shared memory + :type smem_ptr: Pointer + :param crd: Coordinates tuple for the tensor access + :type crd: tuple[Int, ...] + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch") + if not isinstance(crd, tuple): + if arch not in ("sm_90", "sm_90a"): + raise NotImplementedError("tma_store(size) path is only implemented for sm_90/sm_90a") + gmem_ptr = tma_desc.align(smem_ptr.alignment) + _cute_nvgpu_ir.arch_copy_SM90_bulk_copy_s2g( + dsmem_data_addr=smem_ptr.value, + gmem_data_addr=gmem_ptr.value, + size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), crd), + loc=loc, + ip=ip, + ) + else: + if isinstance(tma_desc, cute.CopyAtom): + tma_desc_ptr = extract_tensormap_ptr(tma_desc) + elif isinstance(tma_desc, cute.Tensor): + tma_desc_ptr = tma_desc.iterator + else: + tma_desc_ptr = tma_desc + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_descriptor=tma_desc_ptr.llvm_ptr, + src_mem=smem_ptr.llvm_ptr, + coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd], + predicate=None, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def tma_store_arrive(*, loc=None, ip=None) -> None: + """ + Indicate arrival of warp issuing TMA_STORE. + Corresponds to PTX instruction: cp.async.bulk.commit_group; + """ + nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip) + + +@dsl_user_op +def tma_store_wait(count: int, *, read=None, loc=None, ip=None) -> None: + """ + Wait for TMA_STORE operations to complete. + Corresponds to PTX instruction: cp.async.bulk.wait_group.read ; + + :param count: The number of outstanding bulk async groups to wait for + :type count: Int + """ + nvvm.cp_async_bulk_wait_group(group=count, read=read, loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_shared_global( + dst: cute.Pointer, src: cute.Pointer, cp_size: Int, modifier: nvvm.LoadCacheModifierKind, *, src_size: Int = None, loc=None, ip=None +) -> None: + """ + Asynchronously copy data from global memory to shared memory. + + :param dst: Destination pointer in shared memory + :type dst: Pointer + :param src: Source pointer in global memory + :type src: Pointer + :param size: Size of the copy in bytes + :type size: Int + :param modifier: Cache modifier + :type modifier: Int + :param cp_size: Optional copy size override + :type cp_size: Int + """ + size = src_size if src_size else cp_size + nvvm.cp_async_shared_global( + dst=dst.llvm_ptr, + src=src.llvm_ptr, + size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), size), + modifier=modifier, + cp_size=Int32(cp_size).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_descriptor(tma_desc, *, loc=None, ip=None) -> None: + """ + Prefetch a TMA descriptor. + Corresponds to PTX instruction: prefetch.tensormap; + """ + if isinstance(tma_desc, cute.CopyAtom): + tma_desc_ptr = extract_tensormap_ptr(tma_desc) + elif isinstance(tma_desc, cute.Tensor): + tma_desc_ptr = tma_desc.iterator + else: + tma_desc_ptr = tma_desc + nvvm.prefetch_tensormap(tma_desc_ptr.llvm_ptr, loc=loc, ip=ip) diff --git a/tilelang/contrib/cutedsl/gemm_V1.py b/tilelang/contrib/cutedsl/gemm_V1.py new file mode 100644 index 000000000..0f6cc71e9 --- /dev/null +++ b/tilelang/contrib/cutedsl/gemm_V1.py @@ -0,0 +1,569 @@ +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils # noqa: F401 +import math +import cutlass.utils.hopper_helpers as hopper_utils +from cutlass.utils import LayoutEnum +from cutlass.cute.nvgpu.warpgroup import OperandMajorMode, OperandSource, make_smem_layout_atom + + +def make_aligned_tensor(ptr: cute.Pointer, layout: cute.Layout, align_bytes: int, swizzle=False): + ptr = ptr.align(align_bytes) + if swizzle and isinstance(layout, cute.ComposedLayout): + ptr = cute.recast_ptr(ptr=ptr, swizzle_=layout.inner, dtype=ptr.dtype) + return cute.make_tensor(ptr, layout.outer) + return cute.make_tensor(ptr, layout) + + +def gemm_ss( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with both A and B from shared memory""" + if use_wgmma: + gemm = Gemm_SM90( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum) + else: + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm(A_ptr, B_ptr, C_ptr) + + +def gemm_rs( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with A from register/fragment and B from shared memory""" + if use_wgmma: + gemm = Gemm_SM90( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm.body_rs(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum) + else: + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm.body_rs(A_ptr, B_ptr, C_ptr) + + +def gemm_sr( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with A from shared memory and B from register/fragment""" + # wgmma doesn't support gemm_sr, only use SM80 + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + gemm.body_sr(A_ptr, B_ptr, C_ptr) + + +def gemm_rr( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + use_wgmma=None, + wg_wait=0, + A_ptr: cute.Pointer = None, + B_ptr: cute.Pointer = None, + C_ptr: cute.Pointer = None, +): + """GEMM with both A and B from register/fragment""" + # Both operands in register, no copy needed + gemm = Gemm_SM80( + M, + N, + K, + warp_m, + warp_n, + trans_A, + trans_B, + clear_accum, + stride_A, + stride_B, + offset_A, + offset_B, + A_ptr.dtype, + B_ptr.dtype, + C_ptr.dtype, + ) + # For gemm_rr, directly call _body_impl with copy_A=False, copy_B=False + gemm._body_impl(A_ptr, B_ptr, C_ptr, copy_A=False, copy_B=False) + + +class Gemm_SM80: + _instances = {} # cache instances for the same arguments + + def __new__(cls, *args): + key = args + if key not in cls._instances: + cls._instances[key] = super().__new__(cls) + return cls._instances[key] + + # in Tilelang, trans_A == 0 or trans_B == 1 means K major + # in Cute, trans == 0 means K major + def __init__( + self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type + ): + if not hasattr(self, "initialized"): + self.cta_tiler = (M, N, K) + self.mma_inst_shape = (16, 8, 16) + self.trans_A = trans_A != 0 # same with Tilelang + self.trans_B = trans_B == 0 # inverse with Tilelang + A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR + B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR + self.A_layout = self._make_smem_layout_AB(A_type, A_major_mode, 128, (M, K)) + self.B_layout = self._make_smem_layout_AB(B_type, B_major_mode, 128, (N, K)) + self.ab_dtype = A_type + self.acc_dtype = C_type + self.tiled_mma = self._make_tiled_mma(warp_m, warp_n) + self.clear_accum = clear_accum + + def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): + is_row_major = major_mode == LayoutEnum.ROW_MAJOR + major_mode_size = smem_tiler[1] if is_row_major else smem_tiler[0] + major_mode_size = 64 if major_mode_size >= 64 else major_mode_size + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if is_row_major + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + layout_atom_outer, + ) + layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1) if is_row_major else (1, 0)) + return layout + + def _make_tiled_mma(self, warp_m, warp_n): + atom_layout_mnk = (warp_m, warp_n, 1) + op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_inst_shape) + permutation_mnk = ( + atom_layout_mnk[0] * self.mma_inst_shape[0], + atom_layout_mnk[1] * self.mma_inst_shape[1] * 2, + atom_layout_mnk[2] * self.mma_inst_shape[2], + ) + tiled_mma = cute.make_tiled_mma(op, atom_layout_mnk, permutation_mnk) + return tiled_mma + + @cute.jit + def __call__( + self, + sA_ptr: cute.Pointer, + sB_ptr: cute.Pointer, + rC_ptr: cute.Pointer, + ): + """GEMM body: both A and B from shared memory""" + self._body_impl(sA_ptr, sB_ptr, rC_ptr, copy_A=True, copy_B=True) + + @cute.jit + def body_rs( + self, + rA_ptr: cute.Pointer, # A already in register + sB_ptr: cute.Pointer, # B from shared memory + rC_ptr: cute.Pointer, + ): + """GEMM body_rs: A from register, B from shared memory""" + self._body_impl(rA_ptr, sB_ptr, rC_ptr, copy_A=False, copy_B=True) + + @cute.jit + def body_sr( + self, + sA_ptr: cute.Pointer, # A from shared memory + rB_ptr: cute.Pointer, # B already in register + rC_ptr: cute.Pointer, + ): + """GEMM body_sr: A from shared memory, B from register""" + self._body_impl(sA_ptr, rB_ptr, rC_ptr, copy_A=True, copy_B=False) + + @cute.jit + def _body_impl( + self, + A_ptr: cute.Pointer, + B_ptr: cute.Pointer, + rC_ptr: cute.Pointer, + copy_A: cutlass.Constexpr = True, + copy_B: cutlass.Constexpr = True, + ): + """Internal implementation with configurable copy operations""" + tidx, _, _ = cute.arch.thread_idx() + thr_mma = self.tiled_mma.get_slice(tidx) + + tCrA = None + tCrB = None + tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1]))) + + # Create copy operations only for operands that need copying + if cutlass.const_expr(copy_A): + sA = make_aligned_tensor(A_ptr, self.A_layout, 16) + tCsA = thr_mma.partition_A(sA) + tCrA = self.tiled_mma.make_fragment_A(tCsA) + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_A, 4), + sA.element_type, + ) + tiled_copy_s2r_A = cute.make_tiled_copy( + atom_copy_s2r_A, + layout_tv=self.tiled_mma.tv_layout_A_tiled, + tiler_mn=(self.tiled_mma.get_tile_size(0), self.tiled_mma.get_tile_size(2)), + ) + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + else: + # A already in register + tCrA = cute.make_tensor(A_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2]))) + + if cutlass.const_expr(copy_B): + sB = make_aligned_tensor(B_ptr, self.B_layout, 16) + tCsB = thr_mma.partition_B(sB) + tCrB = self.tiled_mma.make_fragment_B(tCsB) + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_B, 4), + sB.element_type, + ) + tiled_copy_s2r_B = cute.make_tiled_copy( + atom_copy_s2r_B, + layout_tv=self.tiled_mma.tv_layout_B_tiled, + tiler_mn=(self.tiled_mma.get_tile_size(1), self.tiled_mma.get_tile_size(2)), + ) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + else: + # B already in register + tCrB = cute.make_tensor(B_ptr, self.tiled_mma.partition_shape_B((self.cta_tiler[1], self.cta_tiler[2]))) + + if self.clear_accum: + tCrC.fill(0) + + for k in cutlass.range(cute.size(tCrA, mode=[2])): + if cutlass.const_expr(copy_A): + cute.copy(tiled_copy_s2r_A, tCsA_copy_view[None, None, k], tCrA_copy_view[None, None, k]) + if cutlass.const_expr(copy_B): + cute.copy(tiled_copy_s2r_B, tCsB_copy_view[None, None, k], tCrB_copy_view[None, None, k]) + cute.gemm(self.tiled_mma, tCrC, tCrA[None, None, k], tCrB[None, None, k], tCrC) + + +class Gemm_SM90: + _instances = {} # cache instances for the same arguments + + def __new__(cls, *args): + key = args + if key not in cls._instances: + cls._instances[key] = super().__new__(cls) + return cls._instances[key] + + # in Tilelang, trans_A == 0 or trans_B == 1 means K major + # in Cute, trans == 0 means K major + def __init__( + self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type + ): + if not hasattr(self, "initialized"): + self.cta_tiler = (M, N, K) + self.tiler_mn = (M, N) + self.atom_layout_mnk = (warp_m // 4, warp_n, 1) + self.trans_A = trans_A != 0 # same with Tilelang + self.trans_B = trans_B == 0 # inverse with Tilelang + self.a_leading_mode = OperandMajorMode.MN if self.trans_A else OperandMajorMode.K + self.b_leading_mode = OperandMajorMode.MN if self.trans_B else OperandMajorMode.K + A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR + B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR + self.A_layout = self.make_smem_layout_AB(A_type, A_major_mode, (M, K)) + self.B_layout = self.make_smem_layout_AB(B_type, B_major_mode, (N, K)) + self.a_dtype = A_type + self.b_dtype = B_type + self.acc_dtype = C_type + self.tiled_mma = None + self.A_source = None + self.clear_accum = clear_accum + + @staticmethod + def make_tma_atom( + tensor, + smem_layout_staged, + smem_tile, + mcast_dim, + ): + op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() if mcast_dim == 1 else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + + return tma_atom + + @staticmethod + def get_tma_atom(tensor, tiler_mk, stages=1): + smem_layout_staged = Gemm_SM90.make_smem_layout_AB(tensor.element_type, LayoutEnum.from_tensor(tensor), tiler_mk, stages) + tma_atom = Gemm_SM90.make_tma_atom(tensor, smem_layout_staged, tiler_mk, 1) + return tma_atom + + @staticmethod + def make_smem_layout_AB(dtype, major_mode: LayoutEnum, tiler_mk, stages=1): + smem_shape = tiler_mk + # Determine if K is the major mode and get the major mode size + is_k_major = major_mode.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + major_mode_size = tiler_mk[1] if is_k_major else tiler_mk[0] + + # Create SMEM layout atom for A tensor based on major mode and data type + smem_layout_atom = make_smem_layout_atom( + hopper_utils.get_smem_layout_atom(major_mode, dtype, major_mode_size), + dtype, + ) + # Tile the SMEM layout atom to the A tensor shape and add staging dimension + smem_layout = cute.tile_to_shape(smem_layout_atom, cute.append(smem_shape, stages), order=(0, 1, 2) if is_k_major else (1, 0, 2)) + return smem_layout + + def _make_tiled_mma(self, is_rsMode=False): + tiled_mma = hopper_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_leading_mode, + self.b_leading_mode, + self.acc_dtype, + self.atom_layout_mnk, + (64, self.tiler_mn[1] // self.atom_layout_mnk[1]), + OperandSource.SMEM if not is_rsMode else OperandSource.RMEM, + ) + return tiled_mma + + @cute.jit + def __call__( + self, + sA_ptr: cute.Pointer, + sB_ptr: cute.Pointer, + rC_ptr: cute.Pointer, + wg_wait: cutlass.Constexpr = 0, + clear_accum: cutlass.Constexpr = False, + ): + tidx, _, _ = cute.arch.thread_idx() + self.tiled_mma = self._make_tiled_mma() + thr_mma = self.tiled_mma.get_slice(tidx) + + sA_ptr = cute.recast_ptr(sA_ptr, self.A_layout.inner, dtype=sA_ptr.dtype) + sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype) + sA = cute.make_tensor(sA_ptr, self.A_layout.outer) + sB = cute.make_tensor(sB_ptr, self.B_layout.outer) + + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + + tCrA = self.tiled_mma.make_fragment_A(tCsA) + tCrB = self.tiled_mma.make_fragment_B(tCsB) + tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1]))) + + cute.nvgpu.warpgroup.fence() + if cutlass.const_expr(clear_accum): + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + else: + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k in cutlass.range(num_k_blocks): + tCrA_1phase = tCrA[None, None, k, 0] + tCrB_1phase = tCrB[None, None, k, 0] + cute.gemm(self.tiled_mma, tCrC, tCrA_1phase, tCrB_1phase, tCrC) + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(wg_wait >= 0): + cute.nvgpu.warpgroup.wait_group(wg_wait) + + @cute.jit + def body_rs( + self, + rA_ptr: cute.Pointer, # A already in register (Fragment) + sB_ptr: cute.Pointer, # B from shared memory + rC_ptr: cute.Pointer, + wg_wait: cutlass.Constexpr = 0, + clear_accum: cutlass.Constexpr = False, + ): + """ + GEMM body_rs for SM90/Hopper: A from register, B from shared memory. + Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h + """ + tidx, _, _ = cute.arch.thread_idx() + self.tiled_mma = self._make_tiled_mma(is_rsMode=True) + # if self.A_source != OperandSource.RMEM or self.tiled_mma is None: + # self.tiled_mma = self._make_tiled_mma(is_rsMode = True) + # self.A_source = OperandSource.RMEM + # B from shared memory (with swizzle) + sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype) + sB = cute.make_tensor(sB_ptr, self.B_layout.outer) + + # Use the existing tiled_mma + thr_mma = self.tiled_mma.get_slice(tidx) + + # Partition B from shared memory - standard path + tCsB = thr_mma.partition_B(sB) + tCrB = self.tiled_mma.make_fragment_B(tCsB) + + # A already in register + # For body_rs, A is NOT partitioned through thr_mma (it's already partitioned) + # We create the tensor directly with the full shape + # This matches C++: make_tensor(make_rmem_ptr(pA), partition_shape_A(...)) + tCrA = cute.make_tensor(rA_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2]))) + + # C accumulator + tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1]))) + + # Fence operands (prepare for wgmma) + cute.nvgpu.warpgroup.fence() + # Note: warpgroup_arrive() is called internally by wgmma + # Set accumulation mode + if cutlass.const_expr(clear_accum): + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + else: + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + # GEMM loop + num_k_blocks = cute.size(tCrB, mode=[2]) + for k_block in cutlass.range(num_k_blocks): + # Match the indexing pattern from __call__ + # If tCrB has 4 dimensions (with pipeline), use [None, None, k, 0] + # Otherwise use [None, None, k] + tCrB_k = tCrB[None, None, k_block, 0] if cute.rank(tCrB) >= 4 else tCrB[None, None, k_block] + tCrA_k = tCrA[None, None, k_block, 0] if cute.rank(tCrA) >= 4 else tCrA[None, None, k_block] + cute.gemm(self.tiled_mma, tCrC, tCrA_k, tCrB_k, tCrC) + self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + if cutlass.const_expr(wg_wait >= 0): + cute.nvgpu.warpgroup.wait_group(wg_wait) diff --git a/tilelang/contrib/cutedsl/ldsm.py b/tilelang/contrib/cutedsl/ldsm.py new file mode 100644 index 000000000..4f3602697 --- /dev/null +++ b/tilelang/contrib/cutedsl/ldsm.py @@ -0,0 +1,127 @@ +""" +LDMATRIX and STMATRIX operations for CuTeDSL backend. +Based on tl_templates/cuda/ldsm.h + +These functions provide wrappers around PTX ldmatrix/stmatrix instructions +for loading/storing 8x8 matrix fragments between shared memory and registers. +""" + +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass._mlir import ir # noqa: F401 +from cutlass.cute.typing import Pointer, Int32 # noqa: F401 +import cutlass.cute as cute + + +def _to_ir_value(v, loc=None, ip=None): + """Convert value to MLIR IR, handling both cutlass types and raw MLIR Values""" + if hasattr(v, "ir_value"): + return v.ir_value(loc=loc, ip=ip) + else: + # Already an MLIR Value + return v + + +def _ldmatrix(smem_ptr, local_ptr, num, transpose, loc=None, ip=None): + """Internal helper for ldmatrix operations""" + layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row + assert num in [2, 4] + ret_type = llvm.StructType.get_literal([T.i32()] * num) + out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), num) + for i in range(num): + out[i] = cute.Int32(llvm.extractvalue(T.i32(), out_i32, [i], loc=loc, ip=ip)) + + +def _stmatrix(smem_ptr, values, transpose, loc=None, ip=None): + """Internal helper for stmatrix operations""" + layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row + ir_values = [_to_ir_value(v, loc, ip) for v in values] + nvvm.stmatrix(smem_ptr.llvm_ptr, ir_values, layout=layout, loc=loc, ip=ip) + + +# ============================================================================ +# LDMATRIX operations (load from shared memory to registers) +# ============================================================================ + + +@dsl_user_op +def ptx_ldmatrix_x1(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 1 matrix (8x8) from shared memory""" + # _ldmatrix(smem_ptr, local_ptr, 1, False, loc, ip) + out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.row, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1) + out[0] = cute.Int32(out_i32) + + +@dsl_user_op +def ptx_ldmatrix_x2(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 2 matrices (8x8 each) from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 2, False, loc, ip) + + +@dsl_user_op +def ptx_ldmatrix_x4(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 4 matrices (8x8 each) from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 4, False, loc, ip) + + +@dsl_user_op +def ptx_ldmatrix_x1_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 1 matrix (8x8) with transpose from shared memory""" + out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.col, loc=loc, ip=ip) + out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1) + out[0] = cute.Int32(out_i32) + + +@dsl_user_op +def ptx_ldmatrix_x2_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 2 matrices (8x8 each) with transpose from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 2, True, loc, ip) + + +@dsl_user_op +def ptx_ldmatrix_x4_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None: + """Load 4 matrices (8x8 each) with transpose from shared memory""" + _ldmatrix(smem_ptr, local_ptr, 4, True, loc, ip) + + +# ============================================================================ +# STMATRIX operations (store from registers to shared memory) +# ============================================================================ + + +@dsl_user_op +def ptx_stmatrix_x1(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None: + """Store 1 matrix (8x8) to shared memory""" + _stmatrix(smem_ptr, [value0], False, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x2(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None: + """Store 2 matrices (8x8 each) to shared memory""" + _stmatrix(smem_ptr, [value0, value1], False, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x4(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None: + """Store 4 matrices (8x8 each) to shared memory""" + _stmatrix(smem_ptr, [value0, value1, value2, value3], False, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x1_trans(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None: + """Store 1 matrix (8x8) with transpose to shared memory""" + _stmatrix(smem_ptr, [value0], True, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x2_trans(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None: + """Store 2 matrices (8x8 each) with transpose to shared memory""" + _stmatrix(smem_ptr, [value0, value1], True, loc, ip) + + +@dsl_user_op +def ptx_stmatrix_x4_trans(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None: + """Store 4 matrices (8x8 each) with transpose to shared memory""" + _stmatrix(smem_ptr, [value0, value1, value2, value3], True, loc, ip) diff --git a/tilelang/contrib/cutedsl/math.py b/tilelang/contrib/cutedsl/math.py new file mode 100644 index 000000000..3f775091b --- /dev/null +++ b/tilelang/contrib/cutedsl/math.py @@ -0,0 +1,9 @@ +import cutlass.cute as cute +from cutlass.cute.typing import Union, Numeric +from cutlass.cute.tensor import TensorSSA +from cutlass._mlir.dialects import arith +from cutlass.cute.math import exp, exp2, log, log2, log10, tan, cos, sin, sqrt # noqa: F401 + + +def divf(x: Union[TensorSSA, Numeric], y: Union[TensorSSA, Numeric], fastmath: bool = False) -> Union[TensorSSA, Numeric]: + return cute.math._math_op(arith.divf, fastmath, x, y) diff --git a/tilelang/contrib/cutedsl/mbar.py b/tilelang/contrib/cutedsl/mbar.py new file mode 100644 index 000000000..ca956e2f4 --- /dev/null +++ b/tilelang/contrib/cutedsl/mbar.py @@ -0,0 +1,45 @@ +""" +Simple wrappers that delegate to cutlass.cute.arch implementations. +We use the existing implementations from cutlass rather than reinventing the wheel. +""" + +from cutlass.cute.typing import Pointer, Int, Int32, Boolean # noqa: F401 +from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op # noqa: F401 +from cutlass._mlir.dialects import nvvm + +from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive # noqa: F401 +from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx # noqa: F401 +from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc # noqa: F401 + +import cutlass.cute.arch as arch + + +@dsl_user_op +def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000, *, loc=None, ip=None) -> None: + """Waits on a mbarrier with a specified phase.""" + nvvm.mbarrier_try_wait_parity_shared( + mbar_ptr.llvm_ptr, + Int32(phase).ir_value(loc=loc, ip=ip), + Int32(timeout_ns).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_cp_async_arrive(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: + mbar_llvm_ptr = mbar_ptr.llvm_ptr + nvvm.cp_async_mbarrier_arrive_shared( + mbar_llvm_ptr, + noinc=False, + loc=loc, + ip=ip, + ) + + +def fence_proxy_async(): + arch.fence_proxy(arch.ProxyKind.async_shared, space=arch.SharedSpace.shared_cta) + + +def fence_barrier_init(): + arch.mbarrier_init_fence() diff --git a/tilelang/contrib/cutedsl/reduce.py b/tilelang/contrib/cutedsl/reduce.py new file mode 100644 index 000000000..f835b149b --- /dev/null +++ b/tilelang/contrib/cutedsl/reduce.py @@ -0,0 +1,186 @@ +""" +Reduce operations for CuTeDSL backend. +Based on tl_templates/cuda/reduce.h +""" + +from __future__ import annotations + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32, Float32 +from cutlass.cutlass_dsl import dsl_user_op, T +from cutlass._mlir.dialects import nvvm +from cutlass.cute.arch.nvvm_wrappers import shuffle_sync_op + + +@dsl_user_op +def min(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32: + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def max(a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +class SumOp: + """Sum reduction operator""" + + @staticmethod + def __call__(x, y): + return x + y + + +class MaxOp: + """Max reduction operator""" + + @staticmethod + def __call__(x, y): + return max(x, y) + + +class MinOp: + """Min reduction operator""" + + @staticmethod + def __call__(x, y): + # Use cutlass.min which is JIT-friendly + return min(x, y) + + +class BitAndOp: + """Bitwise AND reduction operator""" + + @staticmethod + def __call__(x, y): + return x & y + + +class BitOrOp: + """Bitwise OR reduction operator""" + + @staticmethod + def __call__(x, y): + return x | y + + +class BitXorOp: + """Bitwise XOR reduction operator""" + + @staticmethod + def __call__(x, y): + return x ^ y + + +def bar_sync(barrier_id, number_of_threads): + cute.arch.barrier(barrier_id=barrier_id, number_of_threads=number_of_threads) + + +def bar_sync_ptx(barrier_id, number_of_threads): + from cutlass._mlir.dialects import llvm + + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(number_of_threads).ir_value()], + "bar.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def AllReduce(reducer, threads, scale, thread_offset, all_threads=None): + """ + AllReduce operation implementing warp/block-level reduction. + Based on tl::AllReduce from reduce.h + + Args: + reducer: Reducer operator class (SumOp, MaxOp, etc.) + threads: Number of threads participating in reduction + scale: Reduction scale factor + thread_offset: Thread ID offset + all_threads: Total number of threads in block + + Returns: + A callable object with run() and run_hopper() methods + """ + + class AllReduceInstance: + def __init__(self, reducer, threads, scale, thread_offset: cutlass.Constexpr[int], all_threads: cutlass.Constexpr[int]): + self.reducer = reducer + self.threads = threads + self.scale = scale + self.thread_offset = thread_offset + self.all_threads = all_threads if all_threads is not None else threads + + def run(self, x, red_buf: cute.Pointer = None): + """ + Perform all-reduce across threads. + Based on tl::AllReduce<...>::run from reduce.h + """ + offset = self.threads // 2 + + if offset >= 32: + # Use shared memory for large thread counts + cute.arch.sync_threads() + tidx, _, _ = cute.arch.thread_idx() + cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x + cute.arch.sync_threads() + x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0]) + else: + # Use warp shuffle for small thread counts + # Use the pre-existing shuffle_sync_op with butterfly (XOR) mode + other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly) + x = self.reducer()(x, other) + + return ( + x + if offset == self.scale + else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run(x, red_buf) + ) + + def run_hopper(self, x, red_buf: cute.Pointer = None): + """ + Perform all-reduce on Hopper architecture using bar.sync. + Based on tl::AllReduce<...>::run_hopper from reduce.h + """ + offset = self.threads // 2 + tidx, _, _ = cute.arch.thread_idx() + if offset >= 32: + # Use inlined asm for bar.sync to avoid instruction reordering + bar_sync_ptx(1, self.all_threads) + cute.make_tensor(red_buf + tidx - self.thread_offset, (1,))[0] = x + bar_sync_ptx(2, self.all_threads) + x = self.reducer()(x, cute.make_tensor(red_buf + ((tidx - self.thread_offset) ^ offset), (1,))[0]) + else: + # Use warp shuffle for small thread counts + # Use the pre-existing shuffle_sync_op with butterfly (XOR) mode + other = shuffle_sync_op(x, offset, mask=0xFFFFFFFF, mask_and_clamp=0x1F, kind=nvvm.ShflKind.bfly) + x = self.reducer()(x, other) + + return ( + x + if offset == self.scale + else AllReduce(self.reducer, offset, self.scale, self.thread_offset, self.all_threads).run_hopper(x, red_buf) + ) + + return AllReduceInstance(reducer, threads, scale, thread_offset, all_threads) diff --git a/tilelang/contrib/cutedsl/threadblock_swizzle.py b/tilelang/contrib/cutedsl/threadblock_swizzle.py new file mode 100644 index 000000000..1ce78eb86 --- /dev/null +++ b/tilelang/contrib/cutedsl/threadblock_swizzle.py @@ -0,0 +1,54 @@ +import cutlass.cute as cute +from cutlass.cute.typing import Constexpr +from dataclasses import dataclass + + +@dataclass(frozen=True) +class dim3: + x: int + y: int + z: int + + +def ThreadIdx() -> dim3: + return dim3(*cute.arch.thread_idx()) + + +def BlockIdx() -> dim3: + return dim3(*cute.arch.block_idx()) + + +def GridDim() -> dim3: + return dim3(*cute.arch.grid_dim()) + + +@cute.jit +def rasterization2DRow(panel_width: Constexpr[int]) -> dim3: + blockIdx = BlockIdx() + gridDim = GridDim() + block_idx = blockIdx.x + blockIdx.y * gridDim.x + grid_size = gridDim.x * gridDim.y + panel_size = panel_width * gridDim.x + panel_offset = block_idx % panel_size + panel_idx = block_idx // panel_size + total_panel = cute.ceil_div(grid_size, panel_size) + stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.x + col_idx = (gridDim.x - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride) + row_idx = panel_offset % stride + panel_idx * panel_width + return dim3(col_idx, row_idx, blockIdx.z) + + +@cute.jit +def rasterization2DColumn(panel_width: Constexpr[int]) -> dim3: + blockIdx = BlockIdx() + gridDim = GridDim() + block_idx = blockIdx.x + blockIdx.y * gridDim.x + grid_size = gridDim.x * gridDim.y + panel_size = panel_width * gridDim.y + panel_offset = block_idx % panel_size + panel_idx = block_idx // panel_size + total_panel = cute.ceil_div(grid_size, panel_size) + stride = panel_width if panel_idx + 1 < total_panel else (grid_size - panel_idx * panel_size) // gridDim.y + row_idx = (gridDim.y - 1 - panel_offset // stride) if (panel_idx & 1 != 0) else (panel_offset // stride) + col_idx = panel_offset % stride + panel_idx * panel_width + return dim3(col_idx, row_idx, blockIdx.z) diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 8b70f6d40..fda7f7509 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -197,7 +197,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: device_mod = tir.transform.Simplify()(device_mod) if target.kind.name == "cuda": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target) + global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target) elif target.kind.name == "hip": device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) else: @@ -211,7 +212,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> device_mod = tilelang.transform.LowerIntrin()(device_mod) device_mod = tir.transform.Simplify()(device_mod) if target.kind.name == "cuda": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda_without_compile")(device_mod, target) + global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile" + device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target) elif target.kind.name == "hip": device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip_without_compile")(device_mod, target) elif target.kind.name == "c": diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index a61c91d12..eac206f72 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -49,7 +49,7 @@ def compile( func: PrimFunc[_KP, _T] = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto", target: str | Target = "auto", target_host: str | Target | None = None, verbose: bool = False, @@ -64,7 +64,7 @@ def compile( The TileLang TIR function to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional Execution backend to use for kernel execution. Use "auto" to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython). target : Union[str, Target], optional @@ -118,7 +118,7 @@ def compile( def par_compile( funcs: Iterable[PrimFunc[_KP, _T]], out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto", target: str | Target = "auto", target_host: str | Target | None = None, verbose: bool = False, @@ -135,7 +135,7 @@ def par_compile( The TileLang TIR functions to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional Execution backend to use for kernel execution. Use "auto" to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython). target : Union[str, Target], optional @@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): """ out_idx: list[int] | int | None - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] target: str | Target target_host: str | Target verbose: bool @@ -424,7 +424,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: return kernel -ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] +ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] @overload @@ -473,7 +473,7 @@ def jit( # This is the new public interface Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". target_host : Union[str, Target], optional Target host for cross-compilation. Defaults to None. - execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional Backend for kernel execution and argument passing. Use "auto" to pick a sensible default per target (cuda->tvm_ffi, metal->torch, others->cython). verbose : bool, optional diff --git a/tilelang/jit/adapter/__init__.py b/tilelang/jit/adapter/__init__.py index dcfdaf5bf..f511608fc 100644 --- a/tilelang/jit/adapter/__init__.py +++ b/tilelang/jit/adapter/__init__.py @@ -4,3 +4,4 @@ from .cython import CythonKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401 from .torch import MetalKernelAdapter # noqa: F401 +from .cutedsl import CuTeDSLKernelAdapter # noqa: F401 diff --git a/tilelang/jit/adapter/cutedsl/__init__.py b/tilelang/jit/adapter/cutedsl/__init__.py new file mode 100644 index 000000000..e25899a1d --- /dev/null +++ b/tilelang/jit/adapter/cutedsl/__init__.py @@ -0,0 +1,16 @@ +"""CuTeDSL Backend for TileLang. + +This module provides runtime compilation support using NVIDIA's CuTeDSL API. +""" + +__all__ = [ + "CuTeDSLKernelAdapter", + "TLCuTeDSLSourceWrapper", + "CuTeDSLLibraryGenerator", + "check_cutedsl_available", +] + +from .checks import check_cutedsl_available # noqa: F401 +from .adapter import CuTeDSLKernelAdapter # noqa: F401 +from .wrapper import TLCuTeDSLSourceWrapper # noqa: F401 +from .libgen import CuTeDSLLibraryGenerator # noqa: F401 diff --git a/tilelang/jit/adapter/cutedsl/adapter.py b/tilelang/jit/adapter/cutedsl/adapter.py new file mode 100644 index 000000000..a0ab5db4d --- /dev/null +++ b/tilelang/jit/adapter/cutedsl/adapter.py @@ -0,0 +1,368 @@ +from __future__ import annotations +import logging +from typing import Any, Callable + +import torch +from tvm import tir +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.engine.param import KernelParam +from tilelang.jit.adapter.wrapper import TLPyWrapper +from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available +from tilelang.jit.adapter.cutedsl.libgen import CuTeDSLLibraryGenerator +from tilelang.utils.language import retrieve_func_from_module +from tilelang.utils.target import determine_target +from tilelang.jit.adapter.base import BaseKernelAdapter + +logger = logging.getLogger(__name__) + + +class CuTeDSLKernelAdapter(BaseKernelAdapter): + pymodule = None + + def __init__( + self, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + host_kernel_source: str | None = None, + device_kernel_source: str | None = None, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + check_cutedsl_available() + + self.params = params + self.result_idx = self._legalize_result_idx(result_idx) + self.host_kernel_source = host_kernel_source + self.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + gsym = func_or_mod.attrs.get("global_symbol") + if gsym is None: + raise ValueError("PrimFunc is missing required attr 'global_symbol'") + self.ir_module = tvm.IRModule({gsym: func_or_mod}) + else: + self.ir_module = func_or_mod + + # Cache parameter information during initialization + self.param_dtypes = [param.torch_dtype() for param in params] + self.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + # Keep tir.Var for dynamic dimensions + native_shape.append(dim) + else: + native_shape.append(dim) + self.param_shapes.append(native_shape) + + self.dynamic_symbolic_map, self.dynamic_symbolic_order = self._process_dynamic_symbolic() + + self.target = Target.canon_target(determine_target(target)) + self.verbose = verbose + self.wrapper = TLPyWrapper(self.target) + self.wrapper.assign_optimized_module(self.ir_module) + self.wrapper.assign_pass_configs(pass_configs) + self.wrapper.assign_host_module(host_mod) + self.wrapper.assign_device_module(device_mod) + wrapper_result = self.wrapper.wrap(device_kernel_source) + self.host_func = wrapper_result["host_func"] + self.function_names = wrapper_result["function_names"] + self.tma_cpp_init_code = wrapper_result["tma_cpp_init_code"] + self.tma_lib_name = wrapper_result["tma_lib_name"] + self.launcher_cpp_code = wrapper_result.get("launcher_cpp_code", None) + self.launcher_lib_name = wrapper_result.get("launcher_lib_name", None) + + self.lib_generator = CuTeDSLLibraryGenerator(self.target, self.verbose) + self.lib_generator.update_lib_code(self.device_kernel_source) + self.lib_generator.update_host_func(self.host_func) + self.lib_generator.update_tma_cpp_init_code(self.tma_cpp_init_code) + self.lib_generator.update_tma_lib_name(self.tma_lib_name) + self.lib_generator.update_launcher_cpp_code(self.launcher_cpp_code) + self.lib_generator.update_launcher_lib_name(self.launcher_lib_name) + self.lib_generator.assign_compile_flags(compile_flags) + self.lib_generator.compile_lib() + self.lib_generator.load_lib() + self.libpath = self.lib_generator.libpath + self.device_kernel_source = open(self.libpath).read() + self.pymodule = self.lib_generator.pymodule + + self._post_init() + + @classmethod + def from_database( + cls, + params: list[KernelParam], + result_idx: list[int], + target: str, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_kernel_source: str, + device_kernel_source: str, + kernel_lib_path: str, + verbose: bool = False, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, + ): + adapter = cls.__new__(cls) + adapter.params = params + adapter.result_idx = adapter._legalize_result_idx(result_idx) + adapter.host_kernel_source = host_kernel_source + adapter.device_kernel_source = device_kernel_source + + if isinstance(func_or_mod, tir.PrimFunc): + gsym = func_or_mod.attrs.get("global_symbol") + if gsym is None: + raise ValueError("PrimFunc is missing required attr 'global_symbol'") + adapter.ir_module = tvm.IRModule({gsym: func_or_mod}) + else: + adapter.ir_module = func_or_mod + + # Cache parameter information during initialization + adapter.param_dtypes = [param.torch_dtype() for param in params] + adapter.param_shapes = [] + for param in params: + native_shape = [] + for dim in param.shape: + if isinstance(dim, tir.IntImm): + native_shape.append(int(dim)) + elif isinstance(dim, tir.Var): + # Keep tir.Var for dynamic dimensions + native_shape.append(dim) + else: + native_shape.append(dim) + adapter.param_shapes.append(native_shape) + + adapter.dynamic_symbolic_map, adapter.dynamic_symbolic_order = adapter._process_dynamic_symbolic() + + adapter.target = Target.canon_target(determine_target(target)) + adapter.verbose = verbose + adapter.lib_generator = CuTeDSLLibraryGenerator(adapter.target, adapter.verbose) + adapter.lib_generator.assign_compile_flags(compile_flags) + adapter.lib_generator.load_lib(lib_path=kernel_lib_path) + adapter.libpath = kernel_lib_path + adapter.kernel_global_source = open(adapter.libpath).read() + adapter.pymodule = adapter.lib_generator.pymodule + + adapter._post_init() + return adapter + + def _process_dynamic_symbolic(self) -> tuple[dict[tir.Var, tuple[int, int, int]], list[tir.Var]]: + """Extract information about dynamic symbols from the TIR function. + + We follow the same ordering semantics as `TLCUDASourceWrapper.get_dynamic_symbolic_set()`: + 1) dynamic symbols in buffer shapes (in prim_func param order) + 2) then dynamic symbols in buffer strides + + The mapping encodes: + - id=0: shape var -> (0, buffer_param_index, dim_index) + - id=1: stride var -> (1, buffer_param_index, stride_index) + + Returns: + (dynamic_symbolic_map, dynamic_symbolic_order) + """ + func = self.prim_func + params = func.params + buffer_map = func.buffer_map + dynamic_symbolic_map: dict[tir.Var, tuple[int, int, int]] = {} + dynamic_symbolic_order: list[tir.Var] = [] + + def unique_push_back(v: tir.Var, entry: tuple[int, int, int]): + if v in dynamic_symbolic_map: + return + dynamic_symbolic_map[v] = entry + dynamic_symbolic_order.append(v) + + # 1) Shapes + for i, param in enumerate(params): + if param not in buffer_map: + continue + buffer = buffer_map[param] + for j, shape in enumerate(buffer.shape): + if isinstance(shape, tir.Var): + unique_push_back(shape, (0, i, j)) + + # 2) Strides + for i, param in enumerate(params): + if param not in buffer_map: + continue + buffer = buffer_map[param] + if buffer.strides is None: + continue + for j, stride in enumerate(buffer.strides): + if isinstance(stride, tir.Var): + unique_push_back(stride, (1, i, j)) + + return dynamic_symbolic_map, dynamic_symbolic_order + + def get_kernel_source(self, kernel_only: bool = True) -> str | None: + """Get the CUDA kernel source code. + + Returns + ------- + str | None + The kernel source code, or None if not available + """ + return self.device_kernel_source + + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): + """Low-level function to call the compiled CUDA kernel.""" + result = self.pymodule.call(*args, stream=stream) + + # After first call, save cubin to cache if needed + self._save_cubin_to_cache_if_needed() + + return result + + def _save_cubin_to_cache_if_needed(self): + """Save cubin to cache directory after first execution. + + This is called after the first kernel execution to ensure the generated + cubin file is copied to the cache directory for future reuse. + """ + if getattr(self, "_cubin_saved_to_cache", False): + return + self._cubin_saved_to_cache = True + + # Check if we have a cache path (set by kernel_cache) + cache_path = getattr(self, "_cache_path", None) + if cache_path is None: + return + + import os + import shutil + + # Source cubin path (in temp directory) + src_py_path = self.libpath + src_py_stem = os.path.splitext(os.path.basename(src_py_path))[0] + src_dir = os.path.dirname(src_py_path) + src_cubin_path = os.path.join(src_dir, f"{src_py_stem}.cubin") + + if not os.path.exists(src_cubin_path): + return + + # Destination cubin path (in cache directory) + dst_cubin_path = os.path.join(cache_path, "kernel.cubin") + + if os.path.exists(dst_cubin_path): + return + + # Copy cubin to cache + try: + shutil.copy2(src_cubin_path, dst_cubin_path) + logger.debug(f"Saved CuTeDSL cubin to cache: {dst_cubin_path}") + except Exception as e: + logger.warning(f"Failed to save cubin to cache: {e}", exc_info=True) + + def _wrap_forward_from_prebuild_lib(self, *ins: Any, stream: int | None = None): + """High-level wrapper for kernel execution. + + Handles: + 1. Input validation + 2. Output tensor allocation + 3. Dynamic shape resolution + 4. CUDA stream management + + Args: + ins: Input arguments (may include scalars and tensors) + stream: Optional CUDA stream for asynchronous execution + + Returns: + Single tensor or list of tensors containing the kernel results + """ + if len(ins) + len(self.result_idx) != len(self.params): + raise ValueError( + f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" + ) + + # Materialize args in PrimFunc param order (inputs + allocated outputs) + ins_idx = 0 + param_values: list[Any] = [None] * len(self.params) + for i in range(len(self.params)): + if i in self.result_idx: + continue + param_values[i] = ins[ins_idx] + ins_idx += 1 + + first_tensor = next((v for v in param_values if isinstance(v, torch.Tensor)), None) + if first_tensor is None: + raise ValueError("Expected at least one torch.Tensor argument to infer CUDA device") + + args: list[Any] = [] + + # tensor pointers + for i in range(len(self.params)): + if i in self.result_idx: + dtype = self.param_dtypes[i] + shape = [] + # Now working with native Python list, no FFI calls needed + for s in self.param_shapes[i]: + if isinstance(s, tir.Var): + ref_id, ref_param_idx, ref_dim_idx = self.dynamic_symbolic_map[s] + ref_val = param_values[ref_param_idx] + if not isinstance(ref_val, torch.Tensor): + raise TypeError(f"Dynamic shape/stride var {s} refers to a non-tensor param at index {ref_param_idx}") + if ref_id == 0: + shape.append(ref_val.shape[ref_dim_idx]) + elif ref_id == 1: + # Stride vars are not expected in output shapes, but handle defensively. + shape.append(ref_val.stride()[ref_dim_idx]) + else: + raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}") + else: # Already converted to Python int during initialization + shape.append(s) + tensor = torch.empty(*shape, dtype=dtype, device=first_tensor.device) + param_values[i] = tensor + else: + tensor = param_values[i] + args.append(tensor) + + # dynamic symbolics + for sym in self.dynamic_symbolic_order: + ref_id, buffer_idx, dim_idx = self.dynamic_symbolic_map[sym] + ref_val = param_values[buffer_idx] + if not isinstance(ref_val, torch.Tensor): + raise TypeError(f"Dynamic symbolic var {sym} refers to a non-tensor param at index {buffer_idx}") + if ref_id == 0: + args.append(ref_val.shape[dim_idx]) + elif ref_id == 1: + args.append(ref_val.stride()[dim_idx]) + else: + raise ValueError(f"Unknown dynamic symbol ref id: {ref_id}") + + # if stream is not None, we need to pass the stream to the library + if stream is None: + if str(self.target).startswith("cuda") and torch.cuda.is_available(): + stream = torch.cuda.current_stream().cuda_stream + else: + stream = 0 + + self._forward_from_prebuild_lib(*args, stream=stream) + + if len(self.result_idx) == 1: + return args[self.result_idx[0]] + else: + return [args[i] for i in self.result_idx] + + def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]: + """Convert to a PyTorch-compatible function. + + Returns + ------- + Callable[..., torch.Tensor | list[torch.Tensor]] + A callable function that takes tensors and returns tensor(s) + """ + return self._wrap_forward_from_prebuild_lib + + @property + def prim_func(self) -> tir.PrimFunc: + """Returns the primary TIR function from the IR module.""" + return retrieve_func_from_module(self.ir_module) diff --git a/tilelang/jit/adapter/cutedsl/checks.py b/tilelang/jit/adapter/cutedsl/checks.py new file mode 100644 index 000000000..ced8ea7c3 --- /dev/null +++ b/tilelang/jit/adapter/cutedsl/checks.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +from importlib import metadata as _importlib_metadata +from importlib.util import find_spec as _find_spec +import os + +_CUTEDSL_PUBLIC_DIST = "nvidia-cutlass-dsl" +_CUTEDSL_MIN_VERSION = (4, 3, 1) +_VERSION_TRIPLE_RE = re.compile(r"(\d+)\.(\d+)\.(\d+)") + + +def _parse_version_triple(version_str: str) -> tuple[int, int, int] | None: + """Parse a best-effort (major, minor, patch) triple from a version string. + + We intentionally avoid importing heavy/optional version parsers. For our + minimum requirement (>= 4.3.1), a numeric triple comparison is sufficient. + """ + m = _VERSION_TRIPLE_RE.search(version_str) + if not m: + return None + return int(m.group(1)), int(m.group(2)), int(m.group(3)) + + +def _min_version_str() -> str: + return ".".join(map(str, _CUTEDSL_MIN_VERSION)) + + +def _requirement_spec() -> str: + return f"{_CUTEDSL_PUBLIC_DIST}>={_min_version_str()}" + + +def check_cutedsl_available() -> None: + """Fail fast if the CuTeDSL backend cannot be used in this Python environment. + + Policy: + - If the public distribution `nvidia-cutlass-dsl` is installed, require version >= a minimum supported version. + - Regardless of distribution metadata, require that `cutlass.cute` is importable. + + This intentionally does not mention or special-case any internal distributions. + """ + # 1) Version gate (only when the public dist metadata is present) + try: + dist_version = _importlib_metadata.version(_CUTEDSL_PUBLIC_DIST) + except _importlib_metadata.PackageNotFoundError: + dist_version = None + except Exception: + # Metadata is best-effort; don't block internal/nonstandard installs here. + dist_version = None + + if dist_version is not None: + parsed = _parse_version_triple(dist_version) + if parsed is None or parsed < _CUTEDSL_MIN_VERSION: + req = _requirement_spec() + raise ImportError( + f"CuTeDSL backend requires `{req}`, but found version `{dist_version}`. Please run: `pip install -U '{req}'`." + ) + + # 2) Capability probe: keep it cheap. + # Importing cutlass/cute can be expensive and defeats our lazy-import design, + # especially on cache hits. We only require that the module is importable. + cutlass_spec = _find_spec("cutlass") + if cutlass_spec is None: + req = _requirement_spec() + raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).") + + # Avoid find_spec("cutlass.cute") which can be surprisingly expensive. + # Instead, check for a 'cute' submodule/package under cutlass's search locations. + locs = getattr(cutlass_spec, "submodule_search_locations", None) + has_cute = False + if locs: + for base in locs: + if os.path.isdir(os.path.join(base, "cute")) or os.path.isfile(os.path.join(base, "cute.py")): + has_cute = True + break + + if not has_cute: + req = _requirement_spec() + raise ImportError(f"CuTeDSL backend requires the CUTLASS Python DSL with CuTe support (install via `pip install -U '{req}'`).") diff --git a/tilelang/jit/adapter/cutedsl/libgen.py b/tilelang/jit/adapter/cutedsl/libgen.py new file mode 100644 index 000000000..3dac6b141 --- /dev/null +++ b/tilelang/jit/adapter/cutedsl/libgen.py @@ -0,0 +1,124 @@ +"""CuTeDSL Library Generator for TileLang. + +This module provides library generation functionality for the CuTeDSL backend. +""" + +from __future__ import annotations +import importlib.util +import os +import tempfile +import subprocess + +from tvm.target import Target + +from tilelang.jit.adapter.libgen import LibraryGenerator +from tilelang.jit.adapter.utils import is_cutedsl_target + + +class CuTeDSLLibraryGenerator(LibraryGenerator): + host_func: str | None = None + tma_cpp_init_code: str | None = None + tma_lib_name: str | None = None + launcher_cpp_code: str | None = None + launcher_lib_name: str | None = None + pymodule = None + + def __init__(self, target: Target, verbose: bool = False): + super().__init__(target, verbose) + + @staticmethod + def import_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def update_host_func(self, host_func: str): + self.host_func = host_func + + def update_tma_cpp_init_code(self, tma_cpp_init_code: str): + self.tma_cpp_init_code = tma_cpp_init_code + + def update_tma_lib_name(self, tma_lib_name: str): + self.tma_lib_name = tma_lib_name + + def update_launcher_cpp_code(self, launcher_cpp_code: str): + self.launcher_cpp_code = launcher_cpp_code + + def update_launcher_lib_name(self, launcher_lib_name: str): + self.launcher_lib_name = launcher_lib_name + + def load_lib(self, lib_path: str | None = None): + if lib_path is None: + if self.libpath is None: + raise RuntimeError("CuTeDSLLibraryGenerator.libpath is not set; call compile_lib() first or pass lib_path explicitly.") + lib_path = self.libpath + + self.pymodule = self.import_from_file("kernel", lib_path) + + def compile_lib(self, timeout: float = None): + if self.host_func is None: + raise RuntimeError("CuTeDSLLibraryGenerator.host_func is not set; call update_host_func() before compile_lib().") + target = self.target + if is_cutedsl_target(target): + # Use a dedicated temp directory per kernel so CuTeDSL artifacts (e.g. kept .cubin) + # never pollute user CWD, and are easy to locate alongside the generated module. + work_dir = tempfile.mkdtemp(prefix="tilelang_cutedsl_") + src_path = os.path.join(work_dir, "kernel.py") + with open(src_path, "w") as f: + # Note: lib_code (containing @cute.kernel definitions) is embedded + # inside host_func's _generate_cubin_if_needed function, so we only + # write host_func here. This ensures cute imports are lazy-loaded. + f.write(self.host_func) + + # Compile C++ launcher library if needed + if self.launcher_cpp_code is not None: + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".cpp", + delete=False, + ) as launcher_src: + launcher_src.write(self.launcher_cpp_code) + launcher_src_path = launcher_src.name + + # Generate launcher lib under the same directory as the source file + launcher_lib_path = os.path.join(os.path.dirname(src_path), self.launcher_lib_name) + + # Get TVM FFI compiler flags using tvm_ffi.libinfo API + try: + import tvm_ffi.libinfo + + include_paths = tvm_ffi.libinfo.include_paths() + tvm_cxxflags = [f"-I{path}" for path in include_paths] + lib_path = tvm_ffi.libinfo.find_libtvm_ffi() + lib_dir = os.path.dirname(lib_path) + tvm_ldflags = [f"-L{lib_dir}", "-ltvm_ffi"] + except (ImportError, RuntimeError): + # tvm_ffi unavailable or libinfo functions failed + tvm_cxxflags = [] + tvm_ldflags = [] + + # Compile with nvcc (need CUDA driver API) + compile_cmd = [ + "nvcc", + "-shared", + "-Xcompiler=-fPIC", + "-lcuda", + *tvm_cxxflags, + *tvm_ldflags, + "-o", + launcher_lib_path, + launcher_src_path, + ] + + result = subprocess.run(compile_cmd, check=False, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0: + raise RuntimeError(f"Failed to compile C++ launcher: {result.stderr}") + + self.launcher_libpath = launcher_lib_path + self.launcher_libname = self.launcher_lib_name + + self.srcpath = src_path + self.libpath = src_path + else: + raise ValueError(f"Unsupported target: {target}") diff --git a/tilelang/jit/adapter/cutedsl/wrapper.py b/tilelang/jit/adapter/cutedsl/wrapper.py new file mode 100644 index 000000000..c20d2ec67 --- /dev/null +++ b/tilelang/jit/adapter/cutedsl/wrapper.py @@ -0,0 +1,1354 @@ +"""CuTeDSL Source Wrapper for TileLang. + +This module provides C++ kernel launcher generation for the CuTeDSL backend. + +Key features: +- Automatic C++ launcher generation with CUDA Driver API +- TMA descriptors on HOST memory, passed via __grid_constant__ (no device copy needed) +- cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space +- Support for single and multiple kernel launches +- Complete cache system integration +""" + +from __future__ import annotations +from typing import Any, ClassVar + +from tvm import IRModule +from tvm.target import Target +from tvm.tir.stmt_functor import post_order_visit + +from tilelang import tvm as tvm +from tilelang.jit.adapter.wrapper import TLCUDASourceWrapper +from tilelang.jit.adapter.utils import ( + extract_python_func_declaration, + pythonic_expr, + parse_tma_descriptor_args, +) + +# ============================================================================= +# C++ LAUNCHER TEMPLATES (using named parameters for clarity) +# ============================================================================= + +# TMA single descriptor initialization template (writes to caller-provided host array) +# No device copy needed - cuLaunchKernel handles __grid_constant__ params automatically +CPP_TMA_DESC_INIT_TEMPLATE = """\ + // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) + {{ + uint64_t globalDim[{rank}] = {{{global_dim_values}}}; + uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}}; + uint32_t boxDim[{rank}] = {{{box_dim_values}}}; + uint32_t elemStrides[{rank}] = {{{elem_stride_values}}}; + + result = cuTensorMapEncodeTiled( + &tma_descs[{desc_idx}], + static_cast({dtype}), + {rank}, + reinterpret_cast({tensor_name}_ptr), + globalDim, + globalStrides, + boxDim, + elemStrides, + static_cast({interleave}), + static_cast({swizzle}), + static_cast({l2_promotion}), + static_cast({oob_fill}) + ); + + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to encode TMA descriptor {desc_idx}: " << result << "\\n"; + return result; + }} + }} +""" + +# TMA single im2col descriptor initialization template (writes to caller-provided host array) +# Align field ordering with NVRTC wrapper (cuTensorMapEncodeIm2col signature). +CPP_TMA_IM2COL_DESC_INIT_TEMPLATE = """\ + // Descriptor {desc_idx}: {desc_name} (tensor: {tensor_name}) [im2col] + {{ + uint64_t globalDim[{rank}] = {{{global_dim_values}}}; + uint64_t globalStrides[{stride_rank}] = {{{global_stride_values}}}; + uint32_t elemStrides[{rank}] = {{{elem_stride_values}}}; + int32_t lowerCorner[{rank_minus_two}] = {{{lower_corner_values}}}; + int32_t upperCorner[{rank_minus_two}] = {{{upper_corner_values}}}; + + result = cuTensorMapEncodeIm2col( + &tma_descs[{desc_idx}], + static_cast({dtype}), + {rank}, + reinterpret_cast({tensor_name}_ptr), + globalDim, + globalStrides, + lowerCorner, + upperCorner, + static_cast({channels_per_pixel}), + static_cast({pixels_per_column}), + elemStrides, + static_cast({interleave}), + static_cast({swizzle}), + static_cast({l2_promotion}), + static_cast({oob_fill}) + ); + + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to encode TMA im2col descriptor {desc_idx}: " << result << "\\n"; + return result; + }} + }} +""" + +# TMA initialization function template (writes to caller-provided host array) +# __grid_constant__ allows kernel to receive TMA descriptor by value via param space +CPP_TMA_INIT_FUNC_TEMPLATE = """\ +CUresult tma_init(CUtensorMap* tma_descs, {func_args}) {{ + // Initialize {num_descs} TMA descriptor(s) in caller-provided host array + // cuLaunchKernel will copy 128-byte CUtensorMap to kernel param space automatically + CUresult result; + +{desc_init_code} + + return CUDA_SUCCESS; +}} +""" + +# Kernel initialization template +CPP_KERNEL_INIT_TEMPLATE = """\ + // Find and configure kernel {kernel_idx}: {kernel_name} + result = find_kernel_by_pattern(g_module, "{kernel_name}", &g_kernels[{kernel_idx}]); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to find kernel {kernel_name}: " << result << "\\n"; + return result; + }} + + if ({smem_size} > 0) {{ + result = cuFuncSetAttribute(g_kernels[{kernel_idx}], + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + {smem_size}); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to set smem for {kernel_name}: " << result << "\\n"; + return result; + }} + }} +""" + +# TMA launch initialization template (host memory mode - uses __grid_constant__) +# Kernel receives TMA descriptor by value: .param .align 128 .b8 xxx_param[128] +CPP_TMA_LAUNCH_INIT_TEMPLATE = """\ + // Declare stack-local TMA descriptor array (eliminates concurrency race) + CUtensorMap tma_descs[{num_tma_descs}]; + + // Initialize TMA descriptors (HOST memory - passed via __grid_constant__) + // NOTE: We intentionally do NOT reuse/cached descriptors across launches. + // Pointer-only reuse is a correctness trap (shape/stride may change with same ptr), + // and correctness beats micro-optimizations. + result = tma_init(tma_descs, {tma_tensor_args}); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to initialize TMA descriptors: " << result << "\\n"; + return result; + }} +""" + +# Kernel launch template +CPP_KERNEL_LAUNCH_TEMPLATE = """\ + // Launch kernel {kernel_idx}: {kernel_name} + {{ + void* args[] = {{{kernel_args}}}; + result = cuLaunchKernel( + g_kernels[{kernel_idx}], + {grid_x}, {grid_y}, {grid_z}, + {block_x}, {block_y}, {block_z}, + {smem_size}, + stream, + args, + nullptr + ); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to launch kernel {kernel_name}: " << result << "\\n"; + return result; + }} + }} +""" + +# Complete C++ launcher template +CPP_LAUNCHER_TEMPLATE = """\ +#include +#include +#include +#include +#include +#include +#include + +// TVM Headers +#include +#include +#include + +// Cached module handle +static CUmodule g_module = nullptr; +static bool g_module_initialized = false; + +// Cached kernel functions +static CUfunction g_kernels[{num_kernels}] = {{nullptr}}; +static bool g_kernels_initialized = false; + +// Find kernel by pattern (substring match, prefer base name over _N variants) +CUresult find_kernel_by_pattern(CUmodule module, const char* pattern, CUfunction* out_func) {{ + CUresult result; + unsigned int num_funcs = 0; + + result = cuModuleGetFunctionCount(&num_funcs, module); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to get function count: " << result << "\\n"; + return result; + }} + + std::vector func_list(num_funcs); + result = cuModuleEnumerateFunctions(func_list.data(), num_funcs, module); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to enumerate functions: " << result << "\\n"; + return result; + }} + + // Collect substring matches, separating base name from _N variants + std::vector> base_matches; // pattern not followed by _digit + std::vector> variant_matches; // pattern followed by _digit + + size_t pattern_len = std::strlen(pattern); + + for (unsigned int i = 0; i < num_funcs; i++) {{ + const char* func_name = nullptr; + result = cuFuncGetName(&func_name, func_list[i]); + if (result != CUDA_SUCCESS || func_name == nullptr) {{ + std::cerr << "Failed to get function name: " << result << "\\n"; + return result; + }} + + std::string name_str(func_name); + size_t pos = name_str.find(pattern); + + if (pos != std::string::npos) {{ + // Found substring match + size_t after_pattern = pos + pattern_len; + + // Check what follows the pattern + if (after_pattern < name_str.length() && + name_str[after_pattern] == '_' && + after_pattern + 1 < name_str.length() && + std::isdigit(name_str[after_pattern + 1])) {{ + // Pattern followed by _digit (e.g., "main_kernel_1") + variant_matches.push_back({{name_str, func_list[i]}}); + }} else {{ + // Pattern not followed by _digit (e.g., "main_kernel" itself) + base_matches.push_back({{name_str, func_list[i]}}); + }} + }} + }} + + // Decision logic: prefer base matches over variant matches + if (!base_matches.empty()) {{ + if (base_matches.size() == 1) {{ + *out_func = base_matches[0].second; + return CUDA_SUCCESS; + }} + + // Multiple base matches - ambiguous + std::cerr << "Error: Pattern '" << pattern << "' matched " << base_matches.size() + << " base kernels (ambiguous). Matches found:\\n"; + for (const auto& match : base_matches) {{ + std::cerr << " - " << match.first << "\\n"; + }} + std::cerr << "Please use a more specific pattern.\\n"; + return CUDA_ERROR_NOT_FOUND; + }} + + // No base matches, try variant matches + if (!variant_matches.empty()) {{ + if (variant_matches.size() == 1) {{ + *out_func = variant_matches[0].second; + return CUDA_SUCCESS; + }} + + // Multiple variant matches - ambiguous + std::cerr << "Error: Pattern '" << pattern << "' matched " << variant_matches.size() + << " variant kernels (ambiguous). Matches found:\\n"; + for (const auto& match : variant_matches) {{ + std::cerr << " - " << match.first << "\\n"; + }} + std::cerr << "Please use a more specific pattern (e.g., '" << pattern << "_1').\\n"; + return CUDA_ERROR_NOT_FOUND; + }} + + // No matches at all + std::cerr << "Failed to find kernel matching pattern '" << pattern << "'\\n"; + return CUDA_ERROR_NOT_FOUND; +}} + + +// Initialize CUDA module (called once on first launch) +static CUresult tilelang_init_cuda_module(const std::string& cubin_path) {{ + if (g_module_initialized) return CUDA_SUCCESS; + + CUresult result; + result = cuInit(0); + if (result != CUDA_SUCCESS) return result; + + std::ifstream cubin_file(cubin_path.c_str(), std::ios::binary); + if (!cubin_file) {{ + std::cerr << "Failed to open cubin file: " << cubin_path << "\\n"; + return CUDA_ERROR_FILE_NOT_FOUND; + }} + + std::vector cubin_data((std::istreambuf_iterator(cubin_file)), + std::istreambuf_iterator()); + cubin_file.close(); + + if (cubin_data.empty()) {{ + std::cerr << "Empty cubin file: " << cubin_path << "\\n"; + return CUDA_ERROR_INVALID_IMAGE; + }} + + result = cuModuleLoadData(&g_module, cubin_data.data()); + if (result != CUDA_SUCCESS) {{ + std::cerr << "Failed to load CUDA module: " << result << "\\n"; + return result; + }} + + g_module_initialized = true; + return CUDA_SUCCESS; +}} + +// Initialize all kernel functions (called once after module load) +static CUresult tilelang_init_kernels() {{ + if (g_kernels_initialized) return CUDA_SUCCESS; + CUresult result; + +{kernel_inits} + + g_kernels_initialized = true; + return CUDA_SUCCESS; +}} + +// TMA descriptor initialization (host-side) +{tma_init_func} + +// Main kernel launcher +extern "C" CUresult launch_kernel({launch_func_sig}, uint64_t _stream, tvm::ffi::Bytes cubin_path) {{ + CUresult result; + + std::string cubin_path_str(reinterpret_cast(cubin_path.data()), cubin_path.size()); + result = tilelang_init_cuda_module(cubin_path_str); + if (result != CUDA_SUCCESS) return result; + + result = tilelang_init_kernels(); + if (result != CUDA_SUCCESS) return result; + +{get_ptr_code} + CUstream stream = (CUstream)_stream; + +{tma_init_in_launch} + +{kernel_launches} + + return CUDA_SUCCESS; +}} + +// Cleanup function +extern "C" CUresult cleanup_module() {{ + if (g_module_initialized && g_module != nullptr) {{ + cuModuleUnload(g_module); + g_module = nullptr; + g_module_initialized = false; + }} + + g_kernels_initialized = false; + + return CUDA_SUCCESS; +}} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_kernel, launch_kernel); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(cleanup_module, cleanup_module); +""" + +# ============================================================================= +# PYTHON CUBIN GENERATION TEMPLATES +# ============================================================================= + +# TMA descriptor atom initialization template +CUBIN_TMA_ATOM_INIT_TEMPLATE = """\ + {desc_name} = tl.Gemm_SM90.get_tma_atom(__fake_tensor__, (32, 32))""" + +# Kernel launch call template +CUBIN_KERNEL_LAUNCH_TEMPLATE = """\ + {function_name}({call_args}).launch( + grid=[{grid_x}, {grid_y}, {grid_z}], + block=[{block_x}, {block_y}, {block_z}], + smem={smem_size}, + stream=stream, + )""" + +# Fake tensor creation template +CUBIN_FAKE_TENSOR_TEMPLATE = """\ + __fake_{arg_name}__ = make_fake_compact_tensor(_DTYPE_MAP[str({arg_name}.dtype)], {arg_name}.shape, stride_order={arg_name}.dim_order()[::-1], assumed_align=16)""" + +# Complete cubin generation code template +# {lib_code} contains the @cute.kernel definitions and is embedded here +CUBIN_GEN_CODE_TEMPLATE = """\ +{lib_code} + + @cute.jit + def kernel_wrapper({wrapper_args}): +{tma_init_code}{kernel_launches} + + # Compile kernels to generate cubin +{fake_tensor_code}{fake_tma_tensor_code} __fake_stream__ = make_fake_stream() + # Always generate cubin under a unique staging directory to avoid concurrent + # processes clobbering each other's intermediate artifacts. + _staging_dir = Path(tempfile.mkdtemp( + prefix=Path(__file__).stem + ".cubin.staging.", + dir=_module_dir, + )) + try: + _kernel_wrapper = cute.compile( + kernel_wrapper, + {compile_args}, + options=f"--enable-tvm-ffi --keep-cubin --dump-dir={{_staging_dir.as_posix()}}", + ) + + # CuTeDSL generates a long, mangled cubin filename that includes argument/type info, + # e.g. "cutlass_kernel_wrapper_FakeTensor...sm_90a.cubin". We expect exactly one cubin. + _cubin_files = sorted(_staging_dir.glob("*.cubin"), key=lambda p: p.stat().st_mtime) + if len(_cubin_files) != 1: + raise RuntimeError( + f"Expected exactly one .cubin under {{_staging_dir}}, got {{len(_cubin_files)}}: {{_cubin_files}}" + ) + os.replace(_cubin_files[0], _cubin_path) + finally: + shutil.rmtree(_staging_dir, ignore_errors=True)""" + +# ============================================================================= +# PYTHON HOST FUNCTION TEMPLATE +# ============================================================================= + +PYTHON_HOST_FUNC_TEMPLATE = """\ +import os +from pathlib import Path + +# Minimal imports for runtime (no cutlass/cute - only needed for cubin generation) +import tvm.runtime as runtime + +_cpp_launcher = None +_cpp_launcher_lib = None +_cubin_generated = False + +# Pre-compute paths - cubin is stored alongside the launcher .so +# Use module basename to avoid conflicts when multiple kernels run concurrently +# e.g., "/tmp/tmp8liu__ho.py" -> "/tmp/tmp8liu__ho.cubin" +# "kernel.py" (in cache) -> "kernel.cubin" +_module_dir = Path(os.path.dirname(__file__)) +_cubin_path = _module_dir / (Path(__file__).stem + ".cubin") +_cubin_path_bytes = _cubin_path.as_posix().encode('utf-8') +_cubin_needs_generation = not _cubin_path.exists() + +def _generate_cubin_if_needed({cubin_gen_params}): + \"\"\"Generate cubin file on first call. + + All CuTeDSL imports are inside this function to avoid slow + module-level initialization when loading from cache. + \"\"\" + global _cubin_generated, _cubin_path + + # Lazy import CuTeDSL only when cubin generation is needed + from cuda.bindings.driver import CUstream + import cutlass + import cutlass.cute as cute + from cutlass.cute.runtime import make_fake_stream, make_fake_compact_tensor + import tilelang.contrib.cutedsl as tl + # We rely on CuTeDSL's keep-cubin artifact rather than custom extraction. + import tempfile + import shutil + + _DTYPE_MAP = {{ + "torch.float32": cutlass.Float32, + "torch.float16": cutlass.Float16, + "torch.bfloat16": cutlass.BFloat16, + "torch.float8_e4m3fnuz": cutlass.Float8E4M3FN, + "torch.float8_e4m3fn": cutlass.Float8E4M3FN, + "torch.float8_e5m2": cutlass.Float8E5M2, + "torch.float64": cutlass.Float64, + "torch.int64": cutlass.Int64, + "torch.int32": cutlass.Int32, + "torch.uint32": cutlass.Uint32, + "torch.bool": cutlass.Boolean, + "torch.int8": cutlass.Int8, + "torch.uint8": cutlass.Uint8, + "torch.int16": cutlass.Int16, + "torch.uint16": cutlass.Uint16, + "torch.uchar": cutlass.Uint8, + }} + +{cubin_gen_code} + + _cubin_generated = True + +def _load_cpp_launcher(): + \"\"\"Load C++ kernel launcher.\"\"\" + global _cpp_launcher, _cpp_launcher_lib + if _cpp_launcher is not None: + return _cpp_launcher + + lib_path = os.path.join(os.path.dirname(__file__), "{launcher_lib_name}") + if not os.path.exists(lib_path): + raise FileNotFoundError(f"Launcher not found: {{lib_path}}") + + _cpp_launcher_lib = runtime.load_module(lib_path) + _cpp_launcher = _cpp_launcher_lib["launch_kernel"] + return _cpp_launcher + +def call({call_func_params}, stream): + \"\"\"Kernel dispatch function.\"\"\" + global _cubin_path_bytes, _cubin_needs_generation + + if _cubin_needs_generation: + _generate_cubin_if_needed({cubin_gen_call_args}) + _cubin_needs_generation = False + +{arg_prep_code} + + launcher = _load_cpp_launcher() + result = launcher({launcher_call_args}, stream, _cubin_path_bytes) + + if result != 0: + raise RuntimeError(f"Kernel launch failed with CUDA error {{result}}") +""" + +# ============================================================================= +# WRAPPER CLASS +# ============================================================================= + + +class TLCuTeDSLSourceWrapper(TLCUDASourceWrapper): + """Wrapper class for TileLang CuTe DSL backend with C++ launcher. + + Generates optimized C++ launcher code that: + - Loads cubin via CUDA Driver API + - Passes TMA descriptors by value (host-side, no device copy) + - Launches kernels with minimal Python overhead + - Supports both single and multiple kernel scenarios + """ + + _TYPE_MAP: ClassVar[dict[str, str]] = { + "float32": "cutlass.Float32", + "float16": "cutlass.Float16", + "bfloat16": "cutlass.BFloat16", + "float8_e4m3": "cutlass.Float8E4M3", + "float8_e5m2": "cutlass.Float8E5M2", + "float64": "cutlass.Float64", + "int64": "cutlass.Int64", + "int32": "cutlass.Int32", + "uint32": "cutlass.Uint32", + "bool": "cutlass.Boolean", + "int8": "cutlass.Int8", + "uint8": "cutlass.Uint8", + "int16": "cutlass.Int16", + "uint16": "cutlass.Uint16", + "uchar": "cutlass.Uint8", + } + + # C++ launcher code must not depend on cutlass Python types. + # Use plain C/C++ types for expression rendering inside generated .cpp. + _CXX_TYPE_MAP: ClassVar[dict[str, str]] = { + "float32": "float", + "float64": "double", + "int64": "int64_t", + "int32": "int32_t", + "uint32": "uint32_t", + "bool": "bool", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + } + + _CTYPES_MAP: ClassVar[dict[str, str]] = { + "buffer": "ctypes.c_uint64", + "cutlass.Float32": "ctypes.c_float", + "cutlass.Float16": "ctypes.c_uint16", + "cutlass.Float64": "ctypes.c_double", + "cutlass.Int64": "ctypes.c_int64", + "cutlass.Int32": "ctypes.c_int32", + "cutlass.Uint32": "ctypes.c_uint32", + "cutlass.Int8": "ctypes.c_int8", + "cutlass.Uint8": "ctypes.c_uint8", + "cutlass.Int16": "ctypes.c_int16", + "cutlass.Uint16": "ctypes.c_uint16", + "int": "ctypes.c_int32", + } + + _generated_host_func: str | None = None + _launcher_lib_name: str | None = None + + def __init__( + self, + scheduled_ir_module: IRModule, + source: str, + target: Target, + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None, + ): + super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def host_func(self): + """Override parent's host_func to return generated Python code.""" + if self._generated_host_func is not None: + return self._generated_host_func + return super().host_func + + @host_func.setter + def host_func(self, value): + """Allow setting generated host function code.""" + self._generated_host_func = value + + # ========================================================================= + # Utility Methods + # ========================================================================= + + def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + """Convert TVM expression to Python string.""" + return pythonic_expr(expr, self._TYPE_MAP, floor_div_op="//") + + def _cxx_expr(self, expr: tvm.tir.PrimExpr) -> str: + """Convert TVM expression to C++ string for generated launcher code.""" + return pythonic_expr(expr, self._CXX_TYPE_MAP) + + @staticmethod + def _cxx_cast(ctype: str, expr_str: str) -> str: + return f"static_cast<{ctype}>({expr_str})" + + def _collect_function_args(self) -> tuple[list[dict], list[str]]: + """Collect all function arguments from primary function. + + Returns: + Tuple of (function_args, buffer_args) + """ + function_args = [] + buffer_args = [] + + for param in self.prim_func.params: + if param in self.prim_func.buffer_map: + buffer = self.prim_func.buffer_map[param] + function_args.append({"name": buffer.data.name, "type": "buffer"}) + buffer_args.append(buffer.data.name) + elif isinstance(param, tvm.tir.Var): + function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]}) + else: + raise ValueError(f"Parameter {param} not in buffer map") + + existing_names = {arg["name"] for arg in function_args} + for dyn_sym in self.get_dynamic_symbolic_set(self.prim_func): + dyn_sym_name, dyn_sym_dtype = dyn_sym if isinstance(dyn_sym, tuple) else (dyn_sym, "int32") + if dyn_sym_name in existing_names: + continue + existing_names.add(dyn_sym_name) + function_args.append({"name": dyn_sym_name, "type": self._TYPE_MAP.get(dyn_sym_dtype, "int")}) + + return function_args, buffer_args + + @staticmethod + def _extract_func_call_args( + declaration: str, + function_args: list[dict], + function_params: list, + desc_name_map: dict[str, str] | None = None, + desc_name_var_map: dict[str, tvm.tir.Var] | None = None, + ) -> list[tuple[str, str]]: + """Extract function call arguments from Python function declaration.""" + + def maybe_desc(name: str | tuple[str, str], param_names: list[str], i: int): + name_str = name if isinstance(name, str) else name[0] + param = param_names[i] + if not (param == name_str + "_desc" or param.startswith(name_str + "_desc_")): + return False + if desc_name_map is not None: + desc_name_map[param] = name_str + return True + + def extract_param_names_ast(decl: str) -> list[str] | None: + """Extract parameter names using AST parsing.""" + import ast + import warnings + + try: + # Build a syntactically valid function by adding a body + func_stub = decl.rstrip() + if not func_stub.endswith(":"): + func_stub += ":" + func_stub += "\n pass" + + # Parse and locate the FunctionDef + tree = ast.parse(func_stub) + func_def = None + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_def = node + break + + if func_def is None: + return None + + # Extract parameter names, skipping 'self' + param_names = [] + for arg in func_def.args.args: + if arg.arg != "self": + param_names.append(arg.arg) + + return param_names + except Exception as e: + warnings.warn(f"AST parsing failed for function declaration, falling back to split-based parsing: {e}", stacklevel=2) + return None + + def extract_param_names_split(decl: str) -> list[str]: + """Fallback: extract parameter names using naive split-based parsing.""" + paren_start = decl.find("(") + paren_end = decl.rfind(")") + if paren_start == -1 or paren_end == -1: + return [] + + params_str = decl[paren_start + 1 : paren_end].strip() + if not params_str: + return [] + + param_parts = params_str.split(",") + param_names = [] + for param in param_parts: + param = param.strip() + if not param or param == "self": + continue + if ":" in param: + param_name = param.split(":")[0].strip() + else: + param_name = param.strip() + param_names.append(param_name) + + return param_names + + # Try AST-based extraction first, fallback to split-based + param_names = extract_param_names_ast(declaration) + if param_names is None: + param_names = extract_param_names_split(declaration) + + call_args = [] + for i, param_name in enumerate(param_names): + for arg in function_args: + if arg["name"] == param_name: + call_args.append((param_name, arg["type"])) + elif maybe_desc(arg["name"], param_names, i): + call_args.append((param_name, "None")) + if desc_name_var_map is not None and function_params is not None: + assert len(call_args) <= len(function_params) + desc_name_var_map[param_name] = function_params[len(call_args) - 1] + return call_args + + @staticmethod + def _filter_non_descriptor_args( + call_args: list[tuple[str, str]], desc_names: list[str], tma_tensors: list[str] + ) -> list[tuple[str, str]]: + """Filter out descriptor arguments.""" + filtered = [] + for arg_name, arg_type in call_args: + if "desc" in arg_name and arg_name in desc_names: + continue + if arg_name in tma_tensors: + continue + filtered.append((arg_name, arg_type)) + return filtered + + # ========================================================================= + # TMA Descriptor Code Generation + # ========================================================================= + + def _generate_tma_desc_init(self, desc_name: str, desc_idx: int, tensor_name: str, info: dict) -> str: + """Generate single TMA descriptor initialization code.""" + if info.get("is_img2col", False): + rank = info["tensor_rank"] + return CPP_TMA_IM2COL_DESC_INIT_TEMPLATE.format( + desc_idx=desc_idx, + desc_name=desc_name, + tensor_name=tensor_name, + rank=rank, + stride_rank=rank - 1, + rank_minus_two=rank - 2, + global_dim_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_dim"]), + global_stride_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_stride"][1:]), + elem_stride_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["element_strides"]), + lower_corner_values=", ".join(self._cxx_cast("int32_t", self._cxx_expr(x)) for x in info["lower_corner"]), + upper_corner_values=", ".join(self._cxx_cast("int32_t", self._cxx_expr(x)) for x in info["upper_corner"]), + # Match NVRTC wrapper naming: channelsPerPixel then pixelsPerColumn + channels_per_pixel=info["smem_box_channel"], + pixels_per_column=info["smem_box_pixel"], + dtype=info["dtype"], + interleave=info["interleave"], + swizzle=info["swizzle"], + l2_promotion=info["l2Promotion"], + oob_fill=info["oobFill"], + ) + + return CPP_TMA_DESC_INIT_TEMPLATE.format( + desc_idx=desc_idx, + desc_name=desc_name, + tensor_name=tensor_name, + rank=info["tensor_rank"], + global_dim_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_dim"]), + stride_rank=info["tensor_rank"] - 1, + global_stride_values=", ".join(self._cxx_cast("uint64_t", self._cxx_expr(x)) for x in info["global_stride"][1:]), + box_dim_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["box_dim"]), + elem_stride_values=", ".join(self._cxx_cast("uint32_t", self._cxx_expr(x)) for x in info["element_strides"]), + dtype=info["dtype"], + interleave=info["interleave"], + swizzle=info["swizzle"], + l2_promotion=info["l2Promotion"], + oob_fill=info["oobFill"], + ) + + def _generate_tma_init_func( + self, + desc_names: list[str], + tensor_args: list[str], + tensor_arg_map: dict[str, tuple[str, int]], + scalar_args: list[dict[str, str]], + ) -> str: + """Generate TMA init function code (creates descriptors in caller-provided host array). + + TMA descriptors are stored in stack-local tma_descs[] array in launch_kernel. + cuLaunchKernel automatically handles __grid_constant__ params. + """ + if not desc_names: + return "" + + func_args_parts = [f"uint64_t {arg}_ptr" for arg in tensor_args] + for arg in scalar_args: + if arg["type"] in ["int", "cutlass.Int32"]: + func_args_parts.append(f"int32_t {arg['name']}") + elif arg["type"] in ["float", "cutlass.Float32"]: + func_args_parts.append(f"float {arg['name']}") + else: + # Default to int32_t for scalars used in shape/stride math + func_args_parts.append(f"int32_t {arg['name']}") + func_args = ", ".join(func_args_parts) + num_descs = len(desc_names) + + desc_inits = [] + for idx, desc_name in enumerate(desc_names): + info = self.tma_desc_info[desc_name] + tensor_name, _ = tensor_arg_map[desc_name] + desc_inits.append(self._generate_tma_desc_init(desc_name, idx, tensor_name, info)) + + return CPP_TMA_INIT_FUNC_TEMPLATE.format( + func_args=func_args, + num_descs=num_descs, + desc_init_code="\n".join(desc_inits), + ) + + def _generate_tma_launch_init( + self, desc_names: list[str], tma_tensors: list[str], scalar_args: list[dict[str, str]], num_tma_descs: int + ) -> str: + """Generate TMA initialization code for launch function (host memory mode). + + TMA descriptors stay on host. cuLaunchKernel copies them to param space + when kernel uses __grid_constant__ CUtensorMap parameter. + """ + if not desc_names: + return "" + + # Generate tma_init call args (no device_ptr needed) + call_args_parts = [f"{arg}_ptr" for arg in tma_tensors] + [arg["name"] for arg in scalar_args] + tma_tensor_args = ", ".join(call_args_parts) + + return CPP_TMA_LAUNCH_INIT_TEMPLATE.format( + num_tma_descs=num_tma_descs, + tma_tensor_args=tma_tensor_args, + ) + + # ========================================================================= + # Kernel Code Generation + # ========================================================================= + + def _generate_kernel_init(self, kernel_idx: int, kernel_name: str, smem_size: int) -> str: + """Generate kernel initialization code.""" + return CPP_KERNEL_INIT_TEMPLATE.format( + kernel_idx=kernel_idx, + kernel_name=kernel_name, + smem_size=smem_size, + ) + + def _generate_kernel_launch(self, kernel_meta: dict, kernel_idx: int, all_desc_names: list[str]) -> str: + """Generate single kernel launch code. + + For __grid_constant__ CUtensorMap params: + - Pass CUtensorMap* directly (not CUtensorMap**) + - cuLaunchKernel copies 128 bytes to kernel param space + """ + call_args = kernel_meta["call_args"] + desc_names = kernel_meta["desc_names"] + function_info = kernel_meta["function_info"] + + # Build kernel args + kernel_args = [] + for arg_name, arg_type in call_args: + if "desc" in arg_name and arg_name in desc_names: + # For __grid_constant__ CUtensorMap: pass host pointer directly + # cuLaunchKernel will copy 128-byte CUtensorMap to param space + desc_idx = all_desc_names.index(arg_name) + kernel_args.append(f"&tma_descs[{desc_idx}]") + elif arg_type == "buffer": + kernel_args.append(f"&{arg_name}_ptr") + else: + kernel_args.append(f"&{arg_name}") + + grid = function_info["grid_info"] + block = function_info["block_info"] + smem_size = function_info["dynamic_smem_buf"] or 0 + + return CPP_KERNEL_LAUNCH_TEMPLATE.format( + kernel_idx=kernel_idx, + kernel_name=kernel_meta["function_name"], + kernel_args=", ".join(kernel_args), + grid_x=self._cxx_expr(grid[0]), + grid_y=self._cxx_expr(grid[1]), + grid_z=self._cxx_expr(grid[2]), + block_x=self._cxx_expr(block[0]), + block_y=self._cxx_expr(block[1]), + block_z=self._cxx_expr(block[2]), + smem_size=smem_size, + ) + + # ========================================================================= + # C++ Launcher Generation + # ========================================================================= + + def _generate_cpp_launcher( + self, + kernel_metadata_list: list[dict], + function_args: list[dict], + all_tma_tensors: list[str], + all_desc_names: list[str], + tensor_arg_map: dict[str, tuple[str, int]], + ) -> str: + """Generate complete C++ launcher code using templates. + + TMA descriptors are stored on HOST memory in stack-local tma_descs[] array. + cuLaunchKernel automatically copies 128-byte CUtensorMap to kernel param space + when kernel uses __grid_constant__ parameter. + """ + num_kernels = len(kernel_metadata_list) + num_tma_descs = max(len(all_desc_names), 1) # At least 1 to avoid zero-size array + + # Generate kernel inits + kernel_inits = "\n".join( + self._generate_kernel_init(idx, km["function_name"], km["function_info"]["dynamic_smem_buf"] or 0) + for idx, km in enumerate(kernel_metadata_list) + ) + + # Generate TMA init function + scalar_args = [arg for arg in function_args if arg["type"] != "buffer"] + tma_init_func = self._generate_tma_init_func(all_desc_names, all_tma_tensors, tensor_arg_map, scalar_args) + + # Generate launch function signature and get_ptr code + func_sig_parts = [] + get_ptr_code = "" + for arg in function_args: + if arg["type"] == "buffer": + func_sig_parts.append(f"tvm::ffi::TensorView {arg['name']}") + get_ptr_code += f" uint64_t {arg['name']}_ptr = reinterpret_cast({arg['name']}.data_ptr());\n" + elif arg["type"] in ["int", "cutlass.Int32"]: + func_sig_parts.append(f"int32_t {arg['name']}") + elif arg["type"] in ["float", "cutlass.Float32"]: + func_sig_parts.append(f"float {arg['name']}") + else: + func_sig_parts.append(f"int32_t {arg['name']}") + + # Generate TMA init in launch + tma_init_in_launch = self._generate_tma_launch_init(all_desc_names, all_tma_tensors, scalar_args, num_tma_descs) + + # Generate kernel launches + kernel_launches = "\n".join(self._generate_kernel_launch(km, idx, all_desc_names) for idx, km in enumerate(kernel_metadata_list)) + + return CPP_LAUNCHER_TEMPLATE.format( + num_kernels=num_kernels, + num_tma_descs=num_tma_descs, + kernel_inits=kernel_inits, + tma_init_func=tma_init_func, + launch_func_sig=", ".join(func_sig_parts), + get_ptr_code=get_ptr_code, + tma_init_in_launch=tma_init_in_launch, + kernel_launches=kernel_launches, + ) + + # ========================================================================= + # Python Wrapper Generation + # ========================================================================= + + def _generate_cubin_gen_code( + self, + kernel_metadata_list: list[dict], + buffer_args: list[str], + all_desc_names: list[str], + lib_code: str = "", + ) -> str: + """Generate cubin generation code for Python wrapper using templates. + + Args: + lib_code: The CuTeDSL kernel definitions (@cute.kernel decorated functions). + This will be embedded inside _generate_cubin_if_needed to enable + lazy loading of cutlass/cute modules. + """ + # Build unified wrapper parameters + wrapper_params_union = [] + for kernel_meta in kernel_metadata_list: + for arg_name, _ in kernel_meta["call_args"]: + if arg_name not in wrapper_params_union: + wrapper_params_union.append(arg_name) + + # Build inner args for cute.compile + inner_args = [] + fake_inner_args = [] + for arg_name in wrapper_params_union: + if arg_name in buffer_args: + inner_args.append(f"{arg_name}_") + fake_inner_args.append(f"__fake_{arg_name}__") + elif arg_name in all_desc_names: + continue + else: + inner_args.append(arg_name) + fake_inner_args.append(arg_name) + if all_desc_names: + inner_args.append("__fake_tensor__") + fake_inner_args.append("__fake_tensor__") + fake_inner_args.append("__fake_stream__") + + # Generate TMA init code + tma_init_code = "" + if all_desc_names: + tma_init_lines = [" # Create dummy TMA atoms for compilation"] + tma_init_lines.extend(CUBIN_TMA_ATOM_INIT_TEMPLATE.format(desc_name=desc_name) for desc_name in all_desc_names) + tma_init_code = "\n".join(tma_init_lines) + "\n" + + # Generate kernel launch calls + kernel_launches = "\n".join( + CUBIN_KERNEL_LAUNCH_TEMPLATE.format( + function_name=km["function_name"], + call_args=", ".join(arg[0] if arg[0] not in buffer_args else f"{arg[0]}_" for arg in km["call_args"]), + grid_x=self._pythonic_expr(km["function_info"]["grid_info"][0]), + grid_y=self._pythonic_expr(km["function_info"]["grid_info"][1]), + grid_z=self._pythonic_expr(km["function_info"]["grid_info"][2]), + block_x=self._pythonic_expr(km["function_info"]["block_info"][0]), + block_y=self._pythonic_expr(km["function_info"]["block_info"][1]), + block_z=self._pythonic_expr(km["function_info"]["block_info"][2]), + smem_size=km["function_info"]["dynamic_smem_buf"] or 0, + ) + for km in kernel_metadata_list + ) + + # Generate fake tensor creation code + # IMPORTANT: Generate fake tensors based on the *union* of parameters actually + # passed to cute.compile (wrapper_params_union). + # + # In multi-kernel cases, a tensor may appear both as a TMA descriptor + # (e.g. Output_partial_desc) for one kernel and as a plain tensor argument + # (e.g. Output_partial_) for another kernel. Skipping fake tensor creation + # just because a matching "{arg}_desc" exists is a correctness bug and + # results in undefined names like "__fake_Output_partial__". + fake_tensor_code = "\n".join( + CUBIN_FAKE_TENSOR_TEMPLATE.format(arg_name=arg_name) for arg_name in wrapper_params_union if arg_name in buffer_args + ) + if fake_tensor_code: + fake_tensor_code += "\n" + + # Generate fake TMA tensor code + fake_tma_tensor_code = "" + if all_desc_names: + fake_tma_tensor_code = ( + " __fake_tensor__ = make_fake_compact_tensor(cutlass.Int32, (32, 32), stride_order=(1, 0), assumed_align=16)\n" + ) + + # Indent lib_code to be inside the function + indented_lib_code = "\n".join(" " + line if line.strip() else line for line in lib_code.split("\n")) if lib_code else "" + + return CUBIN_GEN_CODE_TEMPLATE.format( + lib_code=indented_lib_code, + wrapper_args=", ".join(inner_args + ["stream: CUstream"]), + tma_init_code=tma_init_code, + kernel_launches=kernel_launches, + fake_tensor_code=fake_tensor_code, + fake_tma_tensor_code=fake_tma_tensor_code, + compile_args=", ".join(fake_inner_args), + primary_name=kernel_metadata_list[0]["function_name"], + ) + + def _generate_python_wrapper( + self, + function_args: list[dict], + cubin_gen_code: str, + cubin_gen_params: str, + ) -> str: + """Generate Python wrapper code.""" + # Build function parameters + call_func_params = ", ".join(arg["name"] for arg in function_args) + launcher_call_args = ", ".join(arg["name"] for arg in function_args) + + return PYTHON_HOST_FUNC_TEMPLATE.format( + cubin_gen_params=cubin_gen_params, + cubin_gen_code=cubin_gen_code, + launcher_lib_name=self._launcher_lib_name, + call_func_params=call_func_params, + cubin_gen_call_args=cubin_gen_params, + arg_prep_code="", + launcher_call_args=launcher_call_args, + ) + + # ========================================================================= + # TMA Descriptor Processing + # ========================================================================= + + def _process_tma_descriptors(self, desc_names: list[str]) -> tuple[list[str], dict[str, tuple[str, int]]]: + """Process TMA descriptors and return tensor args and mapping. + + Returns: + Tuple of (tensor_args, tensor_arg_map) + """ + if not hasattr(self, "tma_desc_info") or not desc_names: + return [], {} + + tensor_args = [] + tensor_arg_map = {} + + for desc_name in desc_names: + info = self.tma_desc_info[desc_name] + # Extract the base buffer variable name (must be a Var, not arbitrary expression) + global_addr = info["globalAddress"] + if not isinstance(global_addr, tvm.tir.Var): + raise ValueError(f"TMA globalAddress must be a buffer Var, got {type(global_addr)}: {global_addr}") + tensor_name = global_addr.name + + if tensor_name not in tensor_args: + tensor_args.append(tensor_name) + tensor_arg_map[desc_name] = (tensor_name, len(tensor_args) - 1) + else: + tensor_arg_map[desc_name] = (tensor_name, tensor_args.index(tensor_name)) + + return tensor_args, tensor_arg_map + + def generate_tma_descriptor_args( + self, + desc_name_map: dict[str, str], + desc_name_var_map: dict[str, tvm.tir.Var], + tma_desc_code_map: dict[str, str], + ) -> list[str]: + """Generate TMA descriptor information for C++ code generation. + + Returns: + List of descriptor variable names in the order they were processed. + """ + if self.tma_descriptor_args is None: + return [] + + if not hasattr(self, "tma_desc_info"): + self.tma_desc_info = {} + + parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) + + desc_names_ordered = [] + + for params in parsed_params: + handle_name = params.handle_name + + if handle_name in tma_desc_code_map: + continue + + desc_var = desc_name_var_map[handle_name] + args = self.tma_descriptor_args[desc_var] + _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:] + tensor_rank = int(tensor_rank) + + global_dim = remaining_args[:tensor_rank] + global_stride = remaining_args[tensor_rank : 2 * tensor_rank] + + if not params.is_img2col: + box_dim = remaining_args[2 * tensor_rank : 3 * tensor_rank] + element_strides = remaining_args[3 * tensor_rank : 4 * tensor_rank] + + self.tma_desc_info[handle_name] = { + "desc_var": desc_var, + "is_img2col": False, + "dtype": params.dtype, + "tensor_rank": params.tensor_rank, + "globalAddress": params.global_address, + "global_dim": global_dim, + "global_stride": global_stride, + "box_dim": box_dim, + "element_strides": element_strides, + "interleave": params.interleave, + "swizzle": params.swizzle, + "l2Promotion": params.l2_promotion, + "oobFill": params.oob_fill, + } + else: + element_strides = remaining_args[2 * tensor_rank : 3 * tensor_rank] + + self.tma_desc_info[handle_name] = { + "desc_var": desc_var, + "is_img2col": True, + "dtype": params.dtype, + "tensor_rank": params.tensor_rank, + "globalAddress": params.global_address, + "global_dim": global_dim, + "global_stride": global_stride, + "element_strides": element_strides, + "lower_corner": params.lower_corner, + "upper_corner": params.upper_corner, + "smem_box_channel": params.smem_box_channel, + "smem_box_pixel": params.smem_box_pixel, + "interleave": params.interleave, + "swizzle": params.swizzle, + "l2Promotion": params.l2_promotion, + "oobFill": params.oob_fill, + } + + tma_desc_code_map[handle_name] = "" + desc_names_ordered.append(handle_name) + + return desc_names_ordered + + # ========================================================================= + # Main Entry Points + # ========================================================================= + + def create_dispatch_func(self, code, function_informations): + """Create dispatch function - always use C++ launcher.""" + return self.create_dispatch_func_cpp_launcher(code, function_informations) + + def create_dispatch_func_cpp_launcher(self, code, function_informations): + """Create dispatch function using C++ launcher.""" + function_args, buffer_args = self._collect_function_args() + + # Process each kernel and collect metadata + kernel_metadata = [] + all_desc_names_union = [] + all_tma_tensors_union = [] + + for function_name, function_info in function_informations.items(): + declaration = extract_python_func_declaration(code, function_name) + desc_name_map: dict[str, str] = {} + desc_name_var_map: dict[str, tvm.tir.Var] = {} + call_args = self._extract_func_call_args( + declaration, + function_args, + function_info["function_params"], + desc_name_map, + desc_name_var_map, + ) + + tma_desc_code_map = {} + desc_names = self.generate_tma_descriptor_args(desc_name_map, desc_name_var_map, tma_desc_code_map) + + tma_tensor_args, _ = self._process_tma_descriptors(desc_names) + + kernel_metadata.append( + { + "function_name": function_name, + "function_info": function_info, + "call_args": call_args, + "desc_names": desc_names, + "tma_tensor_args": tma_tensor_args, + "desc_name_map": desc_name_map, + } + ) + + for desc in desc_names: + if desc not in all_desc_names_union: + all_desc_names_union.append(desc) + for t in tma_tensor_args: + if t not in all_tma_tensors_union: + all_tma_tensors_union.append(t) + + # Process all TMA descriptors + all_tma_tensors, tensor_arg_map = self._process_tma_descriptors(all_desc_names_union) + + # Generate C++ launcher + launcher_cpp_code = self._generate_cpp_launcher( + kernel_metadata, function_args, all_tma_tensors, all_desc_names_union, tensor_arg_map + ) + + self.launcher_cpp_code = launcher_cpp_code + # Use a deterministic name so that: + # 1) the generated kernel.py can always locate the launcher in the same directory + # 2) KernelCache can store it under a stable filename + self._launcher_lib_name = "launcher_lib.so" + self.launcher_lib_name = self._launcher_lib_name + + # Generate cubin generation code (includes lib_code with @cute.kernel definitions) + cubin_gen_code = self._generate_cubin_gen_code( + kernel_metadata, buffer_args, all_desc_names_union, lib_code=getattr(self, "lib_code", "") + ) + + # Generate Python wrapper + buffer_names = [arg["name"] for arg in function_args if arg["type"] == "buffer"] + # Cubin generation may reference scalar args (e.g., dynamic symbols like m/n/k) + # inside `kernel_wrapper` and `cute.compile(...)`. They must be visible in + # `_generate_cubin_if_needed(...)` scope, so include them in its signature. + scalar_names = [arg["name"] for arg in function_args if arg["type"] != "buffer"] + cubin_gen_params = ", ".join(buffer_names + scalar_names) + + python_wrapper = self._generate_python_wrapper(function_args, cubin_gen_code, cubin_gen_params) + + return python_wrapper + + def get_launcher_cpp_code(self) -> str: + """Get the generated C++ launcher code.""" + return getattr(self, "launcher_cpp_code", "") + + def update_lib_code(self, code: str): + """Update the library code with the given code string.""" + self.lib_code = code + + function_informations = {} + for function_name in self.function_names: + if (function_name not in self.block_info) or (function_name not in self.grid_info): + continue + + assert function_name in self.device_mod, f"Function {function_name} not found in device module" + device_func = self.device_mod[function_name] + kernel_params_cnt = len(device_func.params) + function_params: list[str] = None + + def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): + nonlocal function_params + if isinstance(node, tvm.tir.Call): + if not (hasattr(node, "op") and node.op == tvm.ir.Op.get("tir.tvm_call_packed")): + return + args = node.args + if not args or args[0] != fn: + return + if len(args) < 1 + param_cnt: + raise AssertionError("tvm_call_packed should have at least 1 argument and match device function parameters") + function_params = args[1 : 1 + param_cnt] + + post_order_visit(self.host_func.body, visitor) + assert function_params is not None, "function_params should not be None" + + function_informations[function_name] = { + "function_name": function_name, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + "function_params": function_params, + } + + self.host_func = self.create_dispatch_func(code, function_informations) + return self.lib_code diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index b1b672997..083c8f215 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -76,7 +76,9 @@ def __init__( self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_host_module(host_mod) self.wrapper.assign_device_module(device_mod) - self.host_func, self.function_names = self.wrapper.wrap(device_kernel_source) + wrapper_result = self.wrapper.wrap(device_kernel_source) + self.host_func = wrapper_result["host_func"] + self.function_names = wrapper_result["function_names"] self.lib_generator = NVRTCLibraryGenerator(self.target, self.verbose) self.lib_generator.update_lib_code(self.device_kernel_source) diff --git a/tilelang/jit/adapter/nvrtc/wrapper.py b/tilelang/jit/adapter/nvrtc/wrapper.py index 3df2b3bfa..2316823ec 100644 --- a/tilelang/jit/adapter/nvrtc/wrapper.py +++ b/tilelang/jit/adapter/nvrtc/wrapper.py @@ -273,7 +273,7 @@ def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: Casts are noise in generated Python code - Python is dynamically typed. """ - return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True) + return pythonic_expr(expr, self._TYPE_MAP, ignore_cast=True, floor_div_op="//") def create_dispatch_func(self, code, function_informations): """Generate Python dispatch function that launches multiple CUDA kernels. diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index 15801ffa7..d43adf840 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -38,6 +38,53 @@ def match_declare_kernel(source: str, annotation: str = "__global__") -> int: raise ValueError("No global kernel found in the source code") +def match_declare_kernel_cutedsl(source: str, annotation: str = "@cute.kernel") -> int: + # Match decorator followed by function definition across lines + # \s+ allows any whitespace including newlines between decorator and def + pattern = r"@cute\.kernel\s+def\s+(\w+)" + matched = re.search(pattern, source, re.MULTILINE) + if matched: + # Find the position of the opening parenthesis after the function name + # matched.start(1) gives position of function name + func_name_pos = matched.start(1) + # Find the '(' after function name + paren_pos = source.find("(", func_name_pos) + if paren_pos != -1: + return paren_pos + raise ValueError("No global kernel found in the source code") + + +def extract_python_func_declaration(source: str, func_name: str) -> str: + """Extract the full Python function declaration from decorator to colon. + + Args: + source: Source code containing the function + func_name: Name of the function to extract (can include '(' suffix) + + Returns: + The function declaration from 'def' to ':', including parameters + + Example: + For code: + @cute.kernel + def kernel(arg1: cute.Tensor, arg2: int): + ... + Returns: "def kernel(arg1: cute.Tensor, arg2: int)" + """ + # Remove '(' suffix if present + if func_name.endswith("("): + func_name = func_name[:-1] + + # Match from def to the closing ) followed by : + # This handles multi-line function signatures + pattern = rf"def\s+{re.escape(func_name)}\s*\([^)]*\)" + matched = re.search(pattern, source, re.DOTALL) + if matched: + return matched.group(0) + + raise ValueError(f"No function declaration found for {func_name}") + + def match_declare_kernel_cpu(source: str, annotation: str = "int32_t") -> int: pattern = r"int32_t\s+\w+" for line in source.split("\n"): @@ -64,6 +111,10 @@ def is_metal_target(target: Target) -> bool: return target.kind.name == "metal" +def is_cutedsl_target(target: Target) -> bool: + return target.kind.name == "cuda" and "cutedsl" in target.keys + + def get_annotated_mod( func_or_mod: tir.PrimFunc | tvm.IRModule, target: str | Target = "auto", @@ -102,7 +153,9 @@ def get_annotated_mod( return dispatch[model_type](mod) -def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False) -> str: +def pythonic_expr( + expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None, ignore_cast: bool = False, floor_div_op: str = "/" +) -> str: """ Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. @@ -110,6 +163,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = Non expr: The TVM PrimExpr to convert. dtype_map: A dictionary mapping data types to their string representations. ignore_cast: Whether to ignore the cast operator and return the string representation of the value without the cast. + floor_div_op: Operator to use for tvm.tir.FloorDiv. Default '/' preserves prior + behavior (suitable for generating C/C++ expressions). For generating + Python code where integer division is required (e.g. grid/block), + pass '//' explicitly. Returns: A string representation of the expression. """ @@ -180,7 +237,7 @@ def _visitor(node): ): op_map = { tvm.tir.Mul: "*", - tvm.tir.FloorDiv: "/", + tvm.tir.FloorDiv: floor_div_op, tvm.tir.Add: "+", tvm.tir.Sub: "-", tvm.tir.FloorMod: "%", diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index c028a58ef..d83d0ccc0 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -4,8 +4,10 @@ from typing import Any from tvm import IRModule from tvm.target import Target + from .utils import ( is_metal_target, + is_cutedsl_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, @@ -198,7 +200,9 @@ def __init__( self.lib_code: str | None = self.update_lib_code(source) def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: - return pythonic_expr(expr, self._TYPE_MAP) + # This wrapper generates C/CUDA source. C/C++ integer division uses '/', + # and '//' is not a valid operator in C/C++. + return pythonic_expr(expr, self._TYPE_MAP, floor_div_op="/") def _lookup_type(self, dtype: str | Any) -> str: key = dtype if isinstance(dtype, str) else str(dtype) @@ -326,9 +330,9 @@ def generate_l2_persistent_map(self, function_name: str) -> str: return init_l2_persistent_map def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_var_map: dict[str, tvm.tir.Var]) -> str: - tma_descripter_init = "" + tma_descriptor_init = "" if self.tma_descriptor_args is None: - return tma_descripter_init + return tma_descriptor_init # Parse TMA descriptor arguments using the common utility parsed_params = parse_tma_descriptor_args(self.tma_descriptor_args, desc_name_map, desc_name_var_map, self._pythonic_expr) @@ -336,7 +340,7 @@ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_ # Generate C++ code from parsed parameters for params in parsed_params: if not params.is_img2col: - tma_descripter_init += TMA_DESC_INIT_FUNC.format( + tma_descriptor_init += TMA_DESC_INIT_FUNC.format( params.handle_name, params.dtype, params.tensor_rank, @@ -351,7 +355,7 @@ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_ params.oob_fill, ) else: - tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( + tma_descriptor_init += TMA_IM2COL_DESC_INIT_FUNC.format( params.handle_name, params.dtype, params.tensor_rank, @@ -369,7 +373,7 @@ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], desc_name_ params.oob_fill, ) - return tma_descripter_init + return tma_descriptor_init def parse_source_information(self): if self.device_mod is None or self.host_mod is None: @@ -817,6 +821,9 @@ def update_lib_code(self, code: str): return self.lib_code +# TLCuTeDSLSourceWrapper has been moved to tilelang.jit.adapter.cutedsl.wrapper + + class TLWrapper(BaseWrapper): """ A wrapper class for the TileLang backend. @@ -875,9 +882,13 @@ class TLPyWrapper(TLWrapper): def __init__(self, target: Target): super().__init__(target) - def wrap(self, c_source: str): + def wrap(self, py_source: str): # assert self.scheduled_ir_module is not None, "Please assign optimized module first." - if is_cuda_target(self.target): + if is_cutedsl_target(self.target): + from tilelang.jit.adapter.cutedsl import TLCuTeDSLSourceWrapper + + wrapper_class = TLCuTeDSLSourceWrapper + elif is_cuda_target(self.target): from tilelang.jit.adapter.nvrtc import TLNVRTCSourceWrapper wrapper_class = TLNVRTCSourceWrapper @@ -885,10 +896,17 @@ def wrap(self, c_source: str): raise ValueError(f"Unsupported target for NVRTC backend: {self.target}") wrapper = wrapper_class( scheduled_ir_module=self.scheduled_ir_module, - source=c_source, + source=py_source, target=self.target, device_mod=self.device_mod, host_mod=self.host_mod, pass_configs=self.pass_configs, ) - return wrapper.host_func, wrapper.function_names + return { + "host_func": getattr(wrapper, "host_func", None), + "function_names": getattr(wrapper, "function_names", None), + "tma_cpp_init_code": getattr(wrapper, "tma_cpp_init_code", None), + "tma_lib_name": getattr(wrapper, "tma_lib_name", None), + "launcher_cpp_code": getattr(wrapper, "launcher_cpp_code", None), + "launcher_lib_name": getattr(wrapper, "launcher_lib_name", None), + } diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py index 492e8cb0f..db5e4a8b4 100644 --- a/tilelang/jit/execution_backend.py +++ b/tilelang/jit/execution_backend.py @@ -3,6 +3,8 @@ from collections.abc import Iterable from tvm.target import Target +from tilelang.jit.adapter.utils import is_cutedsl_target +from tilelang.env import env as _env # Canonical names for execution backends used internally _CANONICAL_MAP = { @@ -30,7 +32,9 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T """ kind = _target_kind(target) - if kind == "cuda": + if is_cutedsl_target(target): + return ["cutedsl"] + elif kind == "cuda": allowed = ["tvm_ffi", "nvrtc", "cython", "ctypes"] elif kind == "hip": allowed = ["tvm_ffi", "cython", "ctypes"] @@ -72,8 +76,26 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: allowed_all = allowed_backends_for_target(target, include_unavailable=True) allowed_avail = allowed_backends_for_target(target, include_unavailable=False) + def _require_gemm_v1_for_cutedsl(): + if not _env.use_gemm_v1(): + raise ValueError( + "CuTeDSL backend requires GEMM v1. Please set environment variable TILELANG_USE_GEMM_V1=1 before importing tilelang." + ) + # Fail fast with a clear error if CuTeDSL dependencies are missing or incompatible. + try: + from tilelang.jit.adapter.cutedsl.checks import check_cutedsl_available # lazy + + check_cutedsl_available() + except ImportError as e: + # Keep resolve_execution_backend's error semantics (ValueError) while + # preserving the actionable ImportError message. + raise ValueError(str(e)) from e + # Default selection for auto/None if req in (None, "auto"): + if is_cutedsl_target(target): + _require_gemm_v1_for_cutedsl() + return "cutedsl" kind = _target_kind(target) if kind == "cuda": choice = "tvm_ffi" @@ -100,4 +122,8 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str: f"Try one of: {_format_options(allowed_avail)}." ) + # CuTeDSL requires GEMM v1 + if req == "cutedsl": + _require_gemm_v1_for_cutedsl() + return req diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 9a0dab89c..a788e76ba 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -7,7 +7,7 @@ except ImportError: # Python < 3.10 from typing_extensions import ParamSpec -from tilelang.jit.adapter.utils import is_metal_target, is_cuda_target +from tilelang.jit.adapter.utils import is_cutedsl_target, is_metal_target, is_cuda_target from tvm.target import Target from tvm.tir import PrimFunc @@ -15,7 +15,14 @@ from tilelang import tvm from tilelang import env from tilelang.engine.param import CompiledArtifact, KernelParam -from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, TVMFFIKernelAdapter, MetalKernelAdapter +from tilelang.jit.adapter import ( + BaseKernelAdapter, + CtypesKernelAdapter, + CythonKernelAdapter, + CuTeDSLKernelAdapter, + TVMFFIKernelAdapter, + MetalKernelAdapter, +) from tilelang.profiler import Profiler, TensorSupplyType from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc @@ -57,7 +64,7 @@ def __init__( self, func: PrimFunc = None, out_idx: list[int] | int = None, - execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", + execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", target: str | Target = "auto", target_host: str | Target = None, verbose: bool = False, @@ -74,7 +81,7 @@ def __init__( The TileLang TIR function to compile and wrap. out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). - execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional + execution_backend : Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional Execution backend to use for kernel execution. target : Union[str, Target], optional Compilation target, either as a string or a TVM Target object (default: "auto"). @@ -109,6 +116,7 @@ def __init__( "cython", "nvrtc", "torch", + "cutedsl", ], f"Invalid execution backend. {execution_backend}" if execution_backend == "cython": from tilelang.contrib.cc import get_cplus_compiler @@ -316,6 +324,20 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int # pass_configs=pass_configs, # compile_flags=compile_flags, ) + elif execution_backend == "cutedsl": + assert is_cutedsl_target(target) + adapter = CuTeDSLKernelAdapter( + params=artifact.params, + result_idx=out_idx, + target=target, + func_or_mod=tilelang_func, + host_mod=artifact.host_mod, + device_mod=artifact.device_mod, + device_kernel_source=artifact.kernel_source, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) else: # Handle invalid backend. raise ValueError(f"Invalid execution backend: {execution_backend}") @@ -387,6 +409,18 @@ def _create_adapter_from_database( pass_configs=pass_configs, compile_flags=compile_flags, ) + elif execution_backend == "cutedsl": + adapter = CuTeDSLKernelAdapter.from_database( + params=params, + result_idx=result_idx, + target=target, + func_or_mod=func_or_mod, + host_kernel_source=host_kernel_source, + device_kernel_source=device_kernel_source, + kernel_lib_path=kernel_lib_path, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) else: # Handle invalid backend. raise ValueError(f"Invalid execution backend: {execution_backend}") @@ -437,7 +471,7 @@ def get_kernel_source(self, kernel_only: bool = True) -> str: str The source code of the compiled kernel function. """ - if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}: + if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi", "cutedsl"}: return self.adapter.get_kernel_source(kernel_only=kernel_only) return self.artifact.kernel_source @@ -445,7 +479,7 @@ def get_host_source(self) -> str: """ Returns the source code of the host function. """ - if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi"}: + if self.execution_backend in {"ctypes", "cython", "nvrtc", "tvm_ffi", "cutedsl"}: return self.adapter.get_host_source() assert self.artifact.host_mod is not None, "host_mod is not available" return str(self.artifact.host_mod) diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 4ead7efd0..a2b88f5e8 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -15,6 +15,7 @@ "llvm": "LLVM CPU target (accepts standard TVM LLVM options).", "webgpu": "WebGPU target for browser/WebGPU runtimes.", "c": "C source backend.", + "cutedsl": "CuTe DSL GPU target.", } @@ -95,6 +96,14 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj return_var = "metal" else: raise ValueError("No CUDA or HIP or MPS available on this system.") + elif isinstance(target, str) and target.startswith("cutedsl"): + cuda_target_str = target.replace("cutedsl", "cuda", 1) + temp_target = Target(cuda_target_str) + + target_dict = dict(temp_target.export()) + target_dict["keys"] = list(target_dict["keys"]) + ["cutedsl"] + + return_var = Target(target_dict) else: # Validate the target if it's not "auto" if isinstance(target, Target): @@ -115,6 +124,8 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj else: raise AssertionError(f"Target {target} is not supported") + if isinstance(return_var, Target): + return return_var if return_object: if isinstance(return_var, Target): return return_var