From 2f9d02ea523817ace03a74cbc36e82fc7e627981 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 27 Nov 2025 19:18:54 +0800 Subject: [PATCH 01/10] [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase. --- src/target/codegen_c_host.cc | 81 ++----------- src/transform/arg_binder.cc | 194 +++++++++++++------------------ src/transform/make_packed_api.cc | 3 +- src/transform/merge_if_stmt.cc | 45 +++++-- src/transform/merge_if_stmt.h | 52 +++++++++ 5 files changed, 181 insertions(+), 194 deletions(-) create mode 100644 src/transform/merge_if_stmt.h diff --git a/src/target/codegen_c_host.cc b/src/target/codegen_c_host.cc index b5e74b0a3..fedf8a1d6 100644 --- a/src/target/codegen_c_host.cc +++ b/src/target/codegen_c_host.cc @@ -348,7 +348,6 @@ void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op, } void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*) - using namespace tvm::tir; if (emit_asserts_) { std::string cond = PrintExpr(op->condition); PrintIndent(); @@ -356,88 +355,28 @@ void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*) int assert_if_scope = this->BeginScope(); { // Prepare the base error message - const auto *msg_node = op->message.as(); + const auto *msg_node = op->message.as(); ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm"; const std::string &raw_msg = msg_node->value; const std::string esc_msg = tvm::support::StrEscape( raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, /*escape_whitespace_special_chars=*/true); - // If the assertion condition contains any equality checks anywhere - // in a composite boolean expression, append the actual LHS/RHS values - // Collect all EQ nodes within the condition (including inside And/Or/Not) - std::vector eq_nodes; - { - std::vector stk; - stk.push_back(op->condition); - while (!stk.empty()) { - PrimExpr cur = stk.back(); - stk.pop_back(); - if (const auto *eq = cur.as()) { - eq_nodes.push_back(eq); - continue; - } - if (const auto *an = cur.as()) { - stk.push_back(an->a); - stk.push_back(an->b); - continue; - } - if (const auto *on = cur.as()) { - stk.push_back(on->a); - stk.push_back(on->b); - continue; - } - if (const auto *nn = cur.as()) { - stk.push_back(nn->a); - continue; - } - } - } - - if (!eq_nodes.empty()) { - // Build a single detailed message that includes all LHS/RHS pairs + // If the assertion is an equality check, append the actual LHS/RHS values + if (const auto *eq = op->condition.as()) { + std::string lhs = PrintExpr(eq->a); + std::string rhs = PrintExpr(eq->b); PrintIndent(); - stream << "char __tvm_assert_msg_buf[1024];\n"; + stream << "char __tvm_assert_msg_buf[512];\n"; PrintIndent(); - stream << "int __tvm_assert_msg_len = snprintf(__tvm_assert_msg_buf, " - "sizeof(__tvm_assert_msg_buf), \"%s\", \"" - << esc_msg << "\");\n"; - - auto escape_for_printf_literal = [&](const std::string &s) { - std::string out; - out.reserve(s.size()); - for (char c : s) { - if (c == '%') { - out += "%%"; - } else if (c == '"') { - out += "\\\""; - } else if (c == '\\') { - out += "\\\\"; - } else { - out.push_back(c); - } - } - return out; - }; - - for (const auto *eq : eq_nodes) { - std::string lhs = PrintExpr(eq->a); - std::string rhs = PrintExpr(eq->b); - std::string lhs_disp = escape_for_printf_literal(lhs); - std::string rhs_disp = escape_for_printf_literal(rhs); - PrintIndent(); - stream << "__tvm_assert_msg_len += snprintf(__tvm_assert_msg_buf + " - "__tvm_assert_msg_len, " - "sizeof(__tvm_assert_msg_buf) - __tvm_assert_msg_len, \"; (" - << lhs_disp << " == " << rhs_disp - << ") got: %lld, expected: %lld\", (long long)(" << lhs - << "), (long long)(" << rhs << "));\n"; - } + stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, " + "got: %lld\", \"" + << esc_msg << "\", (long long)(" << lhs << "), (long long)(" + << rhs << "));\n"; PrintIndent(); stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " "__tvm_assert_msg_buf);\n"; } else { - // Fallback: just emit the base message PrintIndent(); stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg << "\");\n"; diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 361cfe909..20b8047e1 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - /*! * \file arg_binder.cc * \brief Helper utility to match and bind arguments. @@ -24,6 +5,7 @@ #include "arg_binder.h" #include +#include #include #include #include @@ -44,16 +26,32 @@ namespace tl { using namespace tir; void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, - const std::string &arg_name, std::vector *asserts) { + const std::string &arg_name, std::vector *asserts, + PrimExpr nullable_guard = PrimExpr()) { PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " << " on argument " << arg_name; } + if (!is_one(scond)) { std::ostringstream os; os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond; - asserts->emplace_back(AssertStmt(scond, StringImm(os.str()), Evaluate(0))); + + // Check if the condition is of the form "is_null || actual_cond" + // If so, generate "if !is_null: assert actual_cond" instead of "assert + // is_null || actual_cond" + if (nullable_guard.defined()) { + // Pattern: nullable_guard || actual_condition + // We want to transform this into: if !nullable_guard: assert + // actual_condition + Stmt check = AssertStmt(scond, StringImm(os.str()), Evaluate(0)); + check = IfThenElse(Not(nullable_guard), check); + asserts->emplace_back(SeqStmt({check, Evaluate(0)})); + } else { + asserts->emplace_back( + AssertStmt(scond, StringImm(os.str()), Evaluate(0))); + } } } @@ -106,8 +104,8 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, return true; } else { // Second or later binding: add is_null short-circuit - PrimExpr cond = MakeGuarded(it->second == value); - BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); + PrimExpr cond = value == it->second; + BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard); } } else { // 2. complex binding expr = value @@ -129,7 +127,7 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, auto value_opt = sol->src_to_dst.Get(v); ICHECK(value_opt->defined()) << "Unable to solve variable `" << v << "` from expression `" - << (arg == value) << "`"; + << (value == arg) << "`"; auto value = ffi::GetRef(sol->src_to_dst.Get(v)->get()); BindVar(v.as(), value); } @@ -138,9 +136,10 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, // because the solved expression may contain floordiv (e.g. 3 * m == n // ==> m = n // 3) we re-compute the constraint to verify the solution // is correct - PrimExpr cond = MakeGuarded(arg == value); - BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); + PrimExpr cond = value == arg; + BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard); } + // ICHECK(false); return false; } @@ -160,10 +159,10 @@ bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value, } return true; } else { - BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, value == it->second, arg_name, &asserts_); } } else { - BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, value == arg, arg_name, &asserts_); } return false; } @@ -236,7 +235,7 @@ void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, + BinderAddAssert(&analyzer_, zero == truncmod(offset, factor), arg_name + ".elem_offset", &asserts_); } } @@ -318,9 +317,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size() << ", but got mismatched ndim"; auto msg = StringImm(ndim_err_msg.str()); - // Only check ndim when handle is non-NULL (using short-circuit OR) - v_ndim = tvm::if_then_else(Not(is_null), v_ndim, make_zero(tvm_ndim_type)); - init_nest_.emplace_back(AssertStmt(Or(is_null, a_ndim == v_ndim), msg, nop)); + // Only check ndim when handle is non-NULL (using if statement) + Stmt ndim_check = AssertStmt(a_ndim == v_ndim, msg, nop); + ndim_check = IfThenElse(Not(is_null), ndim_check); + init_nest_.emplace_back(SeqStmt({ndim_check, nop})); // type checks std::ostringstream type_err_msg; // Avoid dumping TIR expressions in error text; just state mismatch. @@ -396,8 +396,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::UInt(4))) { auto type_msg = StringImm(type_err_msg.str()); - // Only check dtype when handle is non-NULL (short-circuit) - asserts_.emplace_back(AssertStmt(Or(is_null, cond), type_msg, nop)); + // Only check dtype when handle is non-NULL (using if statement) + Stmt dtype_check = AssertStmt(cond, type_msg, nop); + dtype_check = IfThenElse(Not(is_null), dtype_check); + asserts_.emplace_back(SeqStmt({dtype_check, nop})); } // shape field @@ -427,31 +429,16 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, } // The "real" runtime shape value read from DLTensor - PrimExpr raw_shape_val = + PrimExpr shape_val = cast(buffer->shape[k].dtype(), BufferLoad(buf_shape, {IntImm(DataType::Int(32), static_cast(k))})); - // Bind to the value of the symbolic dimension (e.g., m) in TIR, with an - // is_null guard: - // handle is NULL → use 0, placeholder but no dereference - // handle non-NULL → actually read from DLTensor's shape array - PrimExpr bound_shape_val = tvm::if_then_else( - is_null, make_zero(buffer->shape[k].dtype()), raw_shape_val); - // When first encountering a Var (e.g., m), this will generate: // Let(m, bound_shape_val, ...) // Constant dimensions will only generate consistency assertions. - BindNullable(buffer->shape[k], bound_shape_val, shape_element_name(k), true, + BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true, is_null); - - // Keep an explicit "consistency check": when non-NULL, the symbolic - // dimension must equal the DLTensor's shape. - Stmt shape_check = AssertStmt( - Or(is_null, buffer->shape[k] == raw_shape_val), - StringImm(shape_element_name(k) + " mismatch with DLTensor shape"), - Evaluate(0)); - asserts_.emplace_back(shape_check); } // strides field @@ -499,7 +486,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { - PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1); + PrimExpr stride_from_shape = 1; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; DataType stride_dtype = buffer->strides[k].dtype(); @@ -507,31 +494,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, cast(stride_dtype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), static_cast(k))})); - PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape); - - PrimExpr core_value = tvm::if_then_else( - v_strides_is_null, stride_from_shape_cast, explicit_stride); - core_value = tvm::if_then_else(buffer->shape[k] == 1, - make_zero(stride_dtype), core_value); - - // Bind like shape: define var when needed, and only assert when non-NULL - PrimExpr bound_stride_val = - tvm::if_then_else(is_null, make_zero(stride_dtype), core_value); - BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k), - true, is_null); - Stmt stride_check = AssertStmt( - Or(is_null, buffer->strides[k] == core_value), - StringImm(stride_element_name(k) + " mismatch with DLTensor strides"), - Evaluate(0)); - asserts_.emplace_back(stride_check); + PrimExpr stride_val = tvm::if_then_else( + v_strides_is_null, stride_from_shape, explicit_stride); - PrimExpr shape_extent = cast(stride_dtype, buffer->shape[k]); - stride_from_shape = - analyzer_.Simplify(stride_from_shape_cast * shape_extent); + BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true, + is_null); } } else { - PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1); + PrimExpr stride_from_shape = 1; for (int k = static_cast(buffer->strides.size()) - 1; k >= 0; --k) { DataType stride_dtype = buffer->strides[k].dtype(); @@ -540,24 +511,12 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); PrimExpr shape_stride = cast( stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); - PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape); - PrimExpr core_value = tvm::if_then_else( - v_strides_is_null, stride_from_shape_cast, explicit_stride); + PrimExpr stride_val = tvm::if_then_else( + v_strides_is_null, stride_from_shape, explicit_stride); - PrimExpr bound_stride_val = - tvm::if_then_else(is_null, make_zero(stride_dtype), core_value); - BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k), - true, is_null); - - Stmt stride_check = AssertStmt( - Or(is_null, buffer->strides[k] == core_value), - StringImm(stride_element_name(k) + " mismatch with DLTensor strides"), - Evaluate(0)); - asserts_.emplace_back(stride_check); - - stride_from_shape = - analyzer_.Simplify(stride_from_shape_cast * shape_stride); + BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true, + is_null); } } @@ -574,9 +533,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, PrimExpr expect_byte_offset = make_const(DataType::UInt(64), const_offset->value * data_bytes); Stmt byte_off_check = - AssertStmt(Or(is_null, expect_byte_offset == actual_byte_offset), + AssertStmt(expect_byte_offset == actual_byte_offset, StringImm(arg_name + ".byte_offset mismatch"), nop); - asserts_.emplace_back(byte_off_check); + byte_off_check = IfThenElse(Not(is_null), byte_off_check); + asserts_.emplace_back(SeqStmt({byte_off_check, nop})); } else { PrimExpr actual_byte_offset = tvm::if_then_else( Not(is_null), @@ -586,28 +546,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, cast(buffer->elem_offset.dtype(), (actual_byte_offset / make_const(DataType::UInt(64), data_bytes))); - // Like shape/stride, do NULL-safe binding for elem_offset: - // handle is NULL → 0 - // handle non-NULL → actual_byte_offset / data_bytes - PrimExpr bound_elem_off = tvm::if_then_else( - is_null, make_zero(buffer->elem_offset.dtype()), expect_elem_off); - BindNullable(buffer->elem_offset, bound_elem_off, arg_name + ".elem_offset", - true, is_null); - - // Strict consistency check for non-NULL case - Stmt elem_off_check = - AssertStmt(Or(is_null, buffer->elem_offset == expect_elem_off), - StringImm(arg_name + ".elem_offset mismatch"), nop); - asserts_.emplace_back(elem_off_check); + BindNullable(buffer->elem_offset, expect_elem_off, + arg_name + ".elem_offset", true, is_null); if (buffer->offset_factor > 1) { PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - Stmt off_factor_check = - AssertStmt(Or(is_null, truncmod(offset, factor) == zero), - StringImm(arg_name + ".elem_offset factor mismatch"), nop); - asserts_.emplace_back(off_factor_check); + BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset", + true, is_null); } } @@ -621,14 +568,29 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, Not(is_null), TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), make_zero(DataType::Int(32))); + // Bind device_id to a safe expression (0 when NULL handle) BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true, is_null); // Check device_type consistency (device_id equality is implicitly ensured by // binding above) - init_nest_.emplace_back( - AssertStmt(Or(is_null, device_type == actual_dev_type), - StringImm(arg_name + ".device_type mismatch"), nop)); + { + std::ostringstream dev_msg; + dev_msg << arg_name << ".device_type mismatch"; + if (const auto *imm = device_type.as()) { + dev_msg << " [expected: " << imm->value << " (" + << tvm::runtime::DLDeviceType2Str(static_cast(imm->value)) + << ")]"; + } + // Give a short legend so users can interpret numeric codes in the + // appended "got/expected" part printed by the runtime. + dev_msg << "; DLPack codes: 1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, " + "14=OneAPI, 15=WebGPU"; + auto device_type_check = + IfThenElse(Not(is_null), AssertStmt(device_type == actual_dev_type, + StringImm(dev_msg.str()), nop)); + asserts_.emplace_back(SeqStmt({device_type_check, Evaluate(0)})); + } // Data field. Because the validation of the data field may depend // on a dynamic size defined by the other DLTensor* parameters, this @@ -650,12 +612,14 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, product *= dim; return product; }(); - asserts_.emplace_back(AssertStmt( - Or(is_null, (alloc_size == 0) || - !Call(DataType::Bool(), builtin::isnullptr(), {vptr})), + Stmt data_null_check = AssertStmt( + (alloc_size == 0) || + !Call(DataType::Bool(), builtin::isnullptr(), {vptr}), StringImm(arg_name + " is expected to have non-NULL data pointer, but got NULL"), - nop)); + nop); + data_null_check = IfThenElse(Not(is_null), data_null_check); + asserts_.emplace_back(SeqStmt({data_null_check, nop})); // mark alignment of external bufs init_nest_.emplace_back( diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index 187a75dc3..be57f071d 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -39,6 +39,7 @@ #include "../op/builtin.h" #include "arg_binder.h" +#include "merge_if_stmt.h" #include "tir/transforms/ir_utils.h" namespace tvm { @@ -436,7 +437,6 @@ PrimFunc MakePackedAPI(PrimFunc func) { func_ptr->buffer_map = ffi::Map(); func_ptr->ret_type = PrimType(DataType::Int(32)); - // return the function. return func; } @@ -467,6 +467,7 @@ tvm::transform::Pass MakePackedAPI() { func.CopyOnWrite()->body = body.value(); } func = MakePackedAPI(std::move(func)); + func = MergeIfStmtSubstitute(func); if (!func.same_as(orig_func)) { updates->Add(gvar, func); diff --git a/src/transform/merge_if_stmt.cc b/src/transform/merge_if_stmt.cc index 39ea3b0b7..98d9d3ac2 100644 --- a/src/transform/merge_if_stmt.cc +++ b/src/transform/merge_if_stmt.cc @@ -3,6 +3,8 @@ * \brief Merge the If Stmt in SeqStmt */ +#include "merge_if_stmt.h" + #include #include #include @@ -20,23 +22,46 @@ using namespace tir; class MergeIfStmtRewriter : public StmtExprMutator { public: static PrimFunc Substitute(PrimFunc &f) { - auto rewriter = MergeIfStmtRewriter(); - f.CopyOnWrite()->body = rewriter(f->body); + f.CopyOnWrite()->body = MergeIfStmtRewriter::Apply(f->body); return f; } + static Stmt Apply(Stmt stmt) { + auto rewriter = MergeIfStmtRewriter(); + return rewriter(stmt); + } + private: MergeIfStmtRewriter() = default; + void FlattenAppend(const Stmt &s, Array *out) { + if (const auto *seq = s.as()) { + for (const Stmt &e : seq->seq) { + FlattenAppend(e, out); + } + } else { + out->push_back(s); + } + } + Stmt VisitStmt_(const SeqStmtNode *op) final { - Array new_seq; + // First, recursively flatten nested SeqStmt so that + // SeqStmt{ if, SeqStmt{ if, SeqStmt{ if } } } + // becomes a single-level sequence of [if, if, if]. + Array flat_seq; + for (const Stmt &stmt : op->seq) { + Stmt new_stmt = this->VisitStmt(stmt); + FlattenAppend(new_stmt, &flat_seq); + } + // Then, merge consecutive IfThenElse (without else) that share the same + // condition. + Array new_seq; PrimExpr current_condition; Array current_if_bodies; - for (const Stmt &stmt : op->seq) { - Stmt new_stmt = this->VisitStmt(stmt); - if (const IfThenElseNode *if_node = new_stmt.as()) { + for (const Stmt &stmt : flat_seq) { + if (const auto *if_node = stmt.as()) { if (!if_node->else_case.defined()) { if (current_condition.defined() && ExprDeepEqual()(current_condition, if_node->condition)) { @@ -73,7 +98,7 @@ class MergeIfStmtRewriter : public StmtExprMutator { current_if_bodies.clear(); } - new_seq.push_back(new_stmt); + new_seq.push_back(stmt); } if (!current_if_bodies.empty()) { @@ -90,6 +115,12 @@ class MergeIfStmtRewriter : public StmtExprMutator { } }; +PrimFunc MergeIfStmtSubstitute(PrimFunc &f) { + return MergeIfStmtRewriter::Substitute(f); +} + +Stmt ApplyMergeIfStmt(Stmt stmt) { return MergeIfStmtRewriter::Apply(stmt); } + using namespace tir::transform; tvm::transform::Pass MergeIfStmt() { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { diff --git a/src/transform/merge_if_stmt.h b/src/transform/merge_if_stmt.h new file mode 100644 index 000000000..5d7a282d1 --- /dev/null +++ b/src/transform/merge_if_stmt.h @@ -0,0 +1,52 @@ +/*! + * \file merge_if_stmt.h + * \brief Merge consecutive If statements with the same condition + */ +#ifndef TVM_TL_TRANSFORM_MERGE_IF_STMT_H_ +#define TVM_TL_TRANSFORM_MERGE_IF_STMT_H_ + +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// Forward declaration +class MergeIfStmtRewriter; + +/*! + * \brief Apply MergeIfStmt transformation to a PrimFunc + * + * This function merges consecutive IfThenElse statements that have the same + * condition into a single if statement with a SeqStmt body. + * + * Example: + * if (cond) { stmt1 } + * if (cond) { stmt2 } + * if (cond) { stmt3 } + * + * Becomes: + * if (cond) { + * stmt1 + * stmt2 + * stmt3 + * } + * + * \param f The PrimFunc to transform + * \return Transformed PrimFunc with merged if statements + */ +PrimFunc MergeIfStmtSubstitute(PrimFunc &f); + +/*! + * \brief Apply MergeIfStmt transformation to a statement + * \param stmt The statement to transform + * \return Transformed statement with merged if statements + */ +Stmt ApplyMergeIfStmt(Stmt stmt); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_MERGE_IF_STMT_H_ From ad67d196e3f2792eec34a447500278ceec43112d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 27 Nov 2025 20:18:21 +0800 Subject: [PATCH 02/10] [Enhancement] Update matmul kernel and optimize argument binding This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code. --- src/transform/arg_binder.cc | 5 ++- src/transform/arg_binder.h | 2 +- src/transform/make_packed_api.cc | 76 +++++++++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 20b8047e1..25e94344e 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -276,7 +276,7 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, const PrimExpr &device_id, const Var &handle, - const std::string &arg_name) { + const std::string &arg_name, bool is_used) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); @@ -285,11 +285,12 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, // avoid dereferencing it by using expression-level conditionals and // short-circuiting guards in asserts. Cache the null check in a Let-bound // boolean so codegen does not repeat `(handle == NULL)` everywhere. + Var is_null_var(arg_name + "_is_null", DataType::Bool()); init_nest_.emplace_back( LetStmt(is_null_var, Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); - const PrimExpr &is_null = is_null_var; + const PrimExpr &is_null = is_used ? const_false():is_null_var; // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); diff --git a/src/transform/arg_binder.h b/src/transform/arg_binder.h index 793ada111..6a580636f 100644 --- a/src/transform/arg_binder.h +++ b/src/transform/arg_binder.h @@ -105,7 +105,7 @@ class ArgBinder { */ void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, const PrimExpr &device_id, const Var &handle, - const std::string &arg_name); + const std::string &arg_name, bool is_used); /*! \return The defs generated in binding. */ const std::vector &defs() const { return defs_; } diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index be57f071d..e6fbe93c6 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -298,6 +298,80 @@ PrimFunc MakePackedAPI(PrimFunc func) { std::vector> var_def; std::vector> buffer_def; + // First, collect a reverse map from Buffer->data var to parameter var so we + // can detect whether a buffer is actually used by the function body. In + // addition, collect variables that appear in the buffer's shape/stride so we + // can consider uses of those symbols as a use of the buffer itself. + std::unordered_map data_var2param; + std::unordered_map> + shape_var2params; + for (const auto &kv : func_ptr->buffer_map) { + const Var ¶m = kv.first; + const Buffer &buf = kv.second; + data_var2param[buf->data.get()] = param.get(); + auto record_shape_vars = [&](const PrimExpr &e) { + PostOrderVisit(e, [&](const ObjectRef &n) { + if (const auto *v = n.as()) { + shape_var2params[v].push_back(param.get()); + } + }); + }; + for (const PrimExpr &e : buf->shape) + record_shape_vars(e); + for (const PrimExpr &e : buf->strides) + record_shape_vars(e); + if (buf->elem_offset.defined()) + record_shape_vars(buf->elem_offset); + } + + // A visitor that marks a buffer as used when its underlying data var is + // referenced (e.g. BufferLoad/BufferStore or any direct var usage). + struct UsedBufferDetector : public StmtExprVisitor { + UsedBufferDetector( + const std::unordered_map &data2param, + const std::unordered_map> + &shape2params) + : data2param(data2param), shape2params(shape2params) {} + void VisitExpr_(const VarNode *op) override { + auto it = data2param.find(op); + if (it != data2param.end()) { + used_params.insert(it->second); + } + auto it2 = shape2params.find(op); + if (it2 != shape2params.end()) { + for (const VarNode *p : it2->second) used_params.insert(p); + } + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const BufferStoreNode *op) override { + auto it = data2param.find(op->buffer->data.get()); + if (it != data2param.end()) { + used_params.insert(it->second); + } + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode *op) override { + auto it = data2param.find(op->buffer->data.get()); + if (it != data2param.end()) { + used_params.insert(it->second); + } + StmtExprVisitor::VisitExpr_(op); + } + + const std::unordered_map &data2param; + const std::unordered_map> + &shape2params; + std::unordered_set used_params; + }; + + UsedBufferDetector detector(data_var2param, shape_var2params); + detector(func_ptr->body); + + // Build the packed argument handling. While doing so, keep track of whether + // each parameter buffer is actually used. Unused input buffers can be + // nullable and do not require DLTensor field dereferences. + std::unordered_set used_param_buffers = detector.used_params; + for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; PrimExpr arg_value; @@ -390,7 +464,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { for (const auto &[var, buffer] : buffer_def) { binder.BindDLTensor(buffer, device_type, device_id, var, - name_hint + "." + var->name_hint); + name_hint + "." + var->name_hint, used_param_buffers.count(var.get())); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } // reset global symbol to attach prefix From a00e6d40765f4fba779a091d7cc0d9965a4965ef Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 27 Nov 2025 20:19:22 +0800 Subject: [PATCH 03/10] lint fix --- src/transform/arg_binder.cc | 4 ++-- src/transform/make_packed_api.cc | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 25e94344e..4ffcdf6be 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -285,12 +285,12 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, // avoid dereferencing it by using expression-level conditionals and // short-circuiting guards in asserts. Cache the null check in a Let-bound // boolean so codegen does not repeat `(handle == NULL)` everywhere. - + Var is_null_var(arg_name + "_is_null", DataType::Bool()); init_nest_.emplace_back( LetStmt(is_null_var, Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); - const PrimExpr &is_null = is_used ? const_false():is_null_var; + const PrimExpr &is_null = is_used ? const_false() : is_null_var; // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index e6fbe93c6..cbbc1a34b 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -339,7 +339,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { } auto it2 = shape2params.find(op); if (it2 != shape2params.end()) { - for (const VarNode *p : it2->second) used_params.insert(p); + for (const VarNode *p : it2->second) + used_params.insert(p); } StmtExprVisitor::VisitExpr_(op); } @@ -464,7 +465,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { for (const auto &[var, buffer] : buffer_def) { binder.BindDLTensor(buffer, device_type, device_id, var, - name_hint + "." + var->name_hint, used_param_buffers.count(var.get())); + name_hint + "." + var->name_hint, + used_param_buffers.count(var.get())); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } // reset global symbol to attach prefix From 3070a89c5e4218b2c36cd3984e5a0e9fae2d7150 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 27 Nov 2025 22:01:38 +0800 Subject: [PATCH 04/10] [Enhancement] Add tensor checks documentation and improve argument binding assertions This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code. --- docs/compiler_internals/tensor_checks.md | 387 +++++++++++++++++++++++ docs/index.md | 1 + src/transform/arg_binder.cc | 6 + tilelang/engine/phase.py | 1 + 4 files changed, 395 insertions(+) create mode 100644 docs/compiler_internals/tensor_checks.md diff --git a/docs/compiler_internals/tensor_checks.md b/docs/compiler_internals/tensor_checks.md new file mode 100644 index 000000000..b4d2a0b3c --- /dev/null +++ b/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,387 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What you’ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesn’t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. + diff --git a/docs/index.md b/docs/index.md index 5d9a158f8..9f7947766 100644 --- a/docs/index.md +++ b/docs/index.md @@ -42,6 +42,7 @@ deeplearning_operators/deepseek_mla compiler_internals/letstmt_inline compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks ::: :::{toctree} diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 4ffcdf6be..6857f502c 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -291,6 +291,12 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, LetStmt(is_null_var, Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); const PrimExpr &is_null = is_used ? const_false() : is_null_var; + if (is_used) { + init_nest_.emplace_back(AssertStmt( + !is_null_var, + tvm::tir::StringImm(arg_name + " is expected to have non-NULL pointer"), + nop)); + } // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 17d6e4aa5..dfa8050a3 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -225,6 +225,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) mod = tilelang.transform.MakePackedAPI()(mod) + mod = tilelang.transform.Simplify()(mod) mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) # Transform threadblock to persistent threadblock From f52617651cd421967a2f6e2925f22c4ab589d238 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 28 Nov 2025 02:16:53 +0800 Subject: [PATCH 05/10] [Enhancement] Update .gitignore and refine matmul kernel for improved performance This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users. --- .gitignore | 3 + examples/quickstart.py | 42 +++++++---- maint/host_checks/01_num_args_mismatch.py | 22 ++++++ maint/host_checks/02_pointer_type_error.py | 22 ++++++ maint/host_checks/03_ndim_mismatch.py | 20 ++++++ maint/host_checks/04_dtype_mismatch.py | 21 ++++++ maint/host_checks/05_shape_mismatch.py | 20 ++++++ maint/host_checks/06_strides_mismatch.py | 20 ++++++ maint/host_checks/07_device_type_mismatch.py | 19 +++++ maint/host_checks/08_device_id_mismatch.py | 26 +++++++ maint/host_checks/09_null_data_pointer.py | 25 +++++++ maint/host_checks/10_scalar_type_mismatch.py | 16 +++++ maint/host_checks/README.md | 21 ++++++ maint/host_checks/common.py | 43 ++++++++++++ maint/host_checks/run_all.py | 73 ++++++++++++++++++++ src/runtime/error_helpers.cc | 49 +++++++++++++ src/transform/make_packed_api.cc | 28 ++++++-- tilelang/jit/adapter/tvm_ffi.py | 15 ---- 18 files changed, 449 insertions(+), 36 deletions(-) create mode 100644 maint/host_checks/01_num_args_mismatch.py create mode 100644 maint/host_checks/02_pointer_type_error.py create mode 100644 maint/host_checks/03_ndim_mismatch.py create mode 100644 maint/host_checks/04_dtype_mismatch.py create mode 100644 maint/host_checks/05_shape_mismatch.py create mode 100644 maint/host_checks/06_strides_mismatch.py create mode 100644 maint/host_checks/07_device_type_mismatch.py create mode 100644 maint/host_checks/08_device_id_mismatch.py create mode 100644 maint/host_checks/09_null_data_pointer.py create mode 100644 maint/host_checks/10_scalar_type_mismatch.py create mode 100644 maint/host_checks/README.md create mode 100644 maint/host_checks/common.py create mode 100644 maint/host_checks/run_all.py create mode 100644 src/runtime/error_helpers.cc diff --git a/.gitignore b/.gitignore index 752f6cb76..730398dfc 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ cmake-build-*/ # pre-commit cache .pre-commit-cache/* + +# host checks logs +maint/host_checks/logs/* diff --git a/examples/quickstart.py b/examples/quickstart.py index 46a39e0d9..568d74471 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -5,7 +5,11 @@ # @tilelang.jit(target="cuda") # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time -@tilelang.jit +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func @@ -13,6 +17,7 @@ def matmul_relu_kernel( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), + # D: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): @@ -26,7 +31,7 @@ def matmul_relu_kernel( # Clear local accumulation T.clear(C_local) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): # Copy tile of A # This is a sugar syntax for parallelized copy T.copy(A[by * block_M, ko * block_K], A_shared) @@ -36,11 +41,7 @@ def matmul_relu_kernel( # Perform a tile-level GEMM on the shared buffers # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs - T.gemm(A_shared, B_shared, C_local) - - # relu - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) + T.gemm_v1(A_shared, B_shared, C_local) # Copy result back to global memory T.copy(C_local, C[by * block_M, bx * block_N]) @@ -48,37 +49,48 @@ def matmul_relu_kernel( return matmul_relu_kernel -M = 1024 # M = T.dynamic("m") if you want to use dynamic shape -N = 1024 -K = 1024 +tilelang.disable_cache() +M = 16384 # M = T.dynamic("m") if you want to use dynamic shape +N = 16384 +K = 16384 block_M = 128 block_N = 128 block_K = 32 # Define the kernel (matmul) and compile/lower it into an executable module matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) +print(matmul_relu_kernel.get_host_source()) + # Test the kernel in Python with PyTorch data import torch # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) +a_1 = torch.randn(M, K + 1, device="cuda", dtype=torch.float16) +a_c = torch.randn(M, K, device="cpu", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) c = torch.empty(M, N, device="cuda", dtype=torch.float16) # Run the kernel through the Profiler -matmul_relu_kernel(a, b, c) +matmul_relu_kernel(a_1, b, c) + +# matmul_relu_kernel(None, b, c, None) + +# matmul_relu_kernel(a_c, b, c) + +# matmul_relu_kernel(a, b, c) print(c) # Reference multiplication using PyTorch ref_c = torch.relu(a @ b) # Validate correctness -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") +# torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +# print("Kernel output matches PyTorch reference.") # 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) +cuda_source = matmul_relu_kernel.get_kernel_source() +print("Generated CUDA kernel:\n", cuda_source) # 5.Profile latency with kernel profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py new file mode 100644 index 000000000..ab154c6c7 --- /dev/null +++ b/maint/host_checks/01_num_args_mismatch.py @@ -0,0 +1,22 @@ +"""Reproduce: Argument count mismatch. + +Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. +Calling with the wrong number of inputs raises a ValueError before host entry. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + # Missing b + # Expected: ValueError with message about expected vs. actual inputs + fn(a) + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py new file mode 100644 index 000000000..fd3585405 --- /dev/null +++ b/maint/host_checks/02_pointer_type_error.py @@ -0,0 +1,22 @@ +"""Reproduce: Pointer-type argument expected but scalar provided. + +We pass an integer for A; wrapper forwards it to the host where a pointer is expected. +Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # Wrong type for A (int instead of tensor) + a = 1 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py new file mode 100644 index 000000000..4818730ee --- /dev/null +++ b/maint/host_checks/03_ndim_mismatch.py @@ -0,0 +1,20 @@ +"""Reproduce: ndim (rank) mismatch for A. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A has rank 3 instead of 2 + a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py new file mode 100644 index 000000000..ebaa1f943 --- /dev/null +++ b/maint/host_checks/04_dtype_mismatch.py @@ -0,0 +1,21 @@ +"""Reproduce: dtype mismatch for A (float32 vs expected float16). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + print(fn.get_host_source()) + + a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py new file mode 100644 index 000000000..54d251de9 --- /dev/null +++ b/maint/host_checks/05_shape_mismatch.py @@ -0,0 +1,20 @@ +"""Reproduce: shape constant/symbol mismatch on A. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A's second dimension is wrong (K+1 instead of K) + a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py new file mode 100644 index 000000000..14af496f5 --- /dev/null +++ b/maint/host_checks/06_strides_mismatch.py @@ -0,0 +1,20 @@ +"""Reproduce: strides check failure (non-contiguous A via transpose). +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + a_nc = a.t() # non-contiguous after transpose + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a_nc, b) + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py new file mode 100644 index 000000000..8de7fd287 --- /dev/null +++ b/maint/host_checks/07_device_type_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cpu", dtype=torch.float16) + b = torch.empty((K, N), device="cpu", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py new file mode 100644 index 000000000..d3fc34041 --- /dev/null +++ b/maint/host_checks/08_device_id_mismatch.py @@ -0,0 +1,26 @@ +"""Reproduce: device_id mismatch (requires >=2 CUDA devices). +""" +import torch +from common import build_matmul_kernel + + +def main(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + if torch.cuda.device_count() < 2: + print("[SKIP] Need at least 2 CUDA devices to reproduce device_id mismatch.") + return + + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda:0", dtype=torch.float16) + b = torch.empty((K, N), device="cuda:1", dtype=torch.float16) + # Output device is derived by the adapter; mismatch occurs in host checks + + fn(a, b) + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py new file mode 100644 index 000000000..00bac67dd --- /dev/null +++ b/maint/host_checks/09_null_data_pointer.py @@ -0,0 +1,25 @@ +"""Reproduce: NULL data pointer (advanced). + +Passing None for a tensor argument will be forwarded through the adapter. Depending on +FFI handling, this commonly triggers a pointer-type assertion (e.g., "Expect buffer to be pointer or tensor") +or a host-side non-NULL pointer check. + +Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script +demonstrates passing None, which still reproduces the intended class of failure. +""" +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = None # attempt to pass a null-like pointer + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py new file mode 100644 index 000000000..c4275a902 --- /dev/null +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -0,0 +1,16 @@ +"""Reproduce: scalar parameter type mismatch (int/bool). +""" +from common import build_scalar_check_kernel + + +def main(): + fn = build_scalar_check_kernel(target="cuda") + + # Wrong types + fn(1.0, True) # x should be int -> Expect arg[0] to be int + fn(1, 2.5) # flag should be bool -> Expect arg[1] to be boolean + + +if __name__ == "__main__": + main() + diff --git a/maint/host_checks/README.md b/maint/host_checks/README.md new file mode 100644 index 000000000..ac23d6fd2 --- /dev/null +++ b/maint/host_checks/README.md @@ -0,0 +1,21 @@ +# Host-Side Check Repro Scripts + +This folder contains standalone scripts that deliberately trigger host-side (and adapter-side) validation errors described in `docs/compiler_internals/tensor_checks.md`. Each script can be run directly and will reproduce the corresponding error with a minimal example. + +Prerequisites +- CUDA-capable environment (most scripts compile a CUDA-targeted kernel) +- Python packages: torch, tilelang + +Usage +- Run any script, e.g.: + - `python 01_num_args_mismatch.py` + - `python 02_pointer_type_error.py` + - ... up to `10_scalar_type_mismatch.py` + +- Or run all at once with a summary: + - `python run_all.py` + - Logs per test are saved under `logs/` as `