diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index c3aebc864..294c9f6bc 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -311,446 +311,618 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, return TVMStructGet(t, arr, 0, kind); } -void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, - const PrimExpr &device_id, const Var &handle, - const std::string &arg_name, bool is_used) { - const DataType tvm_shape_type = DataType::ShapeIndex(); - const DataType tvm_ndim_type = DataType::Int(32); - const Stmt nop = Evaluate(0); +void ArgBinder::BindDLTensors( + const std::vector> &buffer_def, + const PrimExpr &device_type, const PrimExpr &device_id, + const std::string &func_name, + const std::unordered_set &used_param_buffers) { + ffi::Array buffers; + ffi::Array handles; + + // First pass: collect shape var -> list of (buffer_name, dim_idx, handle_ptr) + struct ShapeVarSource { + std::string buf_name; + size_t dim_idx; + const VarNode *handle_ptr; // Raw pointer to check used_param_buffers + }; + std::unordered_map> + shape_var_sources; + + for (const auto &[handle, buffer] : buffer_def) { + std::string arg_name = func_name + "." + buffer->data->name_hint; + + // Scan buffer shape for symbolic variables + for (size_t k = 0; k < buffer->shape.size(); ++k) { + if (buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { + break; + } - // Allow NULL DLTensor* for optional inputs. When the handle is NULL, - // avoid dereferencing it by using expression-level conditionals and - // short-circuiting guards in asserts. Cache the null check in a Let-bound - // boolean so codegen does not repeat `(handle == NULL)` everywhere. - - Var is_null_var(arg_name + "_is_null", DataType::Bool()); - init_nest_.emplace_back( - LetStmt(is_null_var, - Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); - const PrimExpr &is_null = is_used ? const_false() : is_null_var; - if (is_used) { - init_nest_.emplace_back(AssertStmt( - !is_null_var, - tvm::tir::StringImm(arg_name + " is expected to have non-NULL pointer"), - nop)); + if (const VarNode *v = buffer->shape[k].as()) { + // This dimension is a symbolic variable + shape_var_sources[v].push_back({arg_name, k, handle.get()}); + } + } } - // dimension checks - PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); + // Second pass: Create is_null vars and shape buffers for all buffers first + std::unordered_map is_null_map; + std::unordered_map shape_buffer_map; + std::unordered_map + is_null_expr_map; // arg_name -> is_null expression (const_false for used + // buffers) - // Helper functions for shape/stride name formatting - auto shape_handle_name = [&]() { return arg_name + ".shape"; }; - auto stride_handle_name = [&]() { return arg_name + ".strides"; }; - auto array_element_name = [&](const std::string &arr_name, size_t k) { - std::stringstream ss; - ss << arr_name << '[' << k << ']'; - return ss.str(); - }; - auto shape_element_name = [&](size_t k) { - return array_element_name(shape_handle_name(), k); - }; - auto stride_element_name = [&](size_t k) { - return array_element_name(stride_handle_name(), k); - }; + const DataType tvm_shape_type = DataType::ShapeIndex(); + const DataType tvm_ndim_type = DataType::Int(32); + const Stmt nop = Evaluate(0); - PrimExpr a_ndim = - make_const(tvm_ndim_type, static_cast(buffer->shape.size())); - // Build clearer ndim message with kernel/buffer names - std::string kernel_nm = arg_name; - std::string buf_nm = arg_name; - size_t dot_pos = arg_name.find('.'); - if (dot_pos != std::string::npos) { - kernel_nm = arg_name.substr(0, dot_pos); - buf_nm = arg_name.substr(dot_pos + 1); - } - // Only check ndim when handle is non-NULL: use packed error helper - PrimExpr ndim_ok = (a_ndim == v_ndim); - ffi::Array ndim_args; - ndim_args.push_back(StringImm(tvm_error_ndim_mismatch)); - ndim_args.push_back(StringImm(kernel_nm)); - ndim_args.push_back(StringImm(buf_nm)); - ndim_args.push_back(cast(DataType::Int(64), a_ndim)); - ndim_args.push_back(cast(DataType::Int(64), v_ndim)); - Stmt ndim_call = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args)); - init_nest_.emplace_back( - SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call), - Evaluate(0)), - nop})); - // type checks - // Guard all dtype field loads by `is_null` using if_then_else - PrimExpr v_type_code = tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode), - IntImm(DataType::UInt(8), buffer->dtype.code())); - PrimExpr v_type_bits = tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits), - IntImm(DataType::UInt(8), buffer->dtype.bits())); - PrimExpr v_type_lanes = tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes), - IntImm(DataType::UInt(16), buffer->dtype.lanes())); - PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code()); - PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits()); - PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes()); - - PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits && - v_type_lanes == expect_lanes); - - // Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime. - if (buffer->dtype.is_float8_e4m3()) { - PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3); - PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn); - PrimExpr code_e4m3fnuz = - IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz); - PrimExpr code_match = - (v_type_code == code_e4m3 || v_type_code == code_e4m3fn || - v_type_code == code_e4m3fnuz); - cond = cond || (code_match && v_type_bits == expect_bits && - v_type_lanes == expect_lanes); - } - // Allow float8_e5m2 to match float8_e5m2fnuz at runtime. - if (buffer->dtype.is_float8_e5m2()) { - PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2); - PrimExpr code_e5m2fnuz = - IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz); - PrimExpr code_match = - (v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz); - cond = cond || (code_match && v_type_bits == expect_bits && - v_type_lanes == expect_lanes); - } - // Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6). - if (buffer->dtype.is_bool()) { - PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); - PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt); - PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6); - PrimExpr bits8 = IntImm(DataType::UInt(8), 8); - PrimExpr bits1 = IntImm(DataType::UInt(8), 1); - PrimExpr lanes_ok = (v_type_lanes == expect_lanes); - PrimExpr int8_ok = - (v_type_code == code_int && v_type_bits == bits8 && lanes_ok); - PrimExpr uint8_ok = - (v_type_code == code_uint && v_type_bits == bits8 && lanes_ok); - // Some frontends may tag bool tensors as kDLBool(code=6), commonly with - // bits=8 or bits=1. - PrimExpr kdlbool8_ok = - (v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok); - PrimExpr kdlbool1_ok = - (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); - // Also accept any dtype whose bitwidth=1, regardless of code, to be - // defensive. - PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); - cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; + // Create all is_null vars and shape buffers first + for (const auto &[handle, buffer] : buffer_def) { + bool is_used = used_param_buffers.count(handle.get()); + std::string arg_name = func_name + "." + buffer->data->name_hint; + + Var is_null_var(arg_name + "_is_null", DataType::Bool()); + init_nest_.emplace_back( + LetStmt(is_null_var, + Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop)); + const PrimExpr &is_null = is_used ? const_false() : is_null_var; + + is_null_map[arg_name] = is_null_var; + is_null_expr_map[arg_name] = is_null; + + if (is_used) { + init_nest_.emplace_back( + AssertStmt(!is_null_var, + tvm::tir::StringImm( + arg_name + " is expected to have non-NULL pointer"), + nop)); + } } - // Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for - // FP4). - if (buffer->dtype.is_float4()) { - PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); - PrimExpr bits8 = IntImm(DataType::UInt(8), 8); - // For FP4, we pack 2 elements per byte, but we still use same lanes at - // storage level Accept int8 with same lanes as the fp4 type - PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes); - PrimExpr int8_ok = - (v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok); - cond = cond || int8_ok; + + // Create all shape buffers before binding any shapes + for (const auto &[handle, buffer] : buffer_def) { + std::string arg_name = func_name + "." + buffer->data->name_hint; + const PrimExpr &is_null = is_null_expr_map[arg_name]; + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + + // shape field + Buffer buf_shape = + decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, + tvm_shape_type, shape_handle_name()); + def_handle_dtype_.Set(buf_shape->data, make_const(tvm_shape_type, 0)); + // Use if_then_else for NULL guard on the shape pointer itself, avoiding + // dereferencing TVMStructGet(handle, kArrShape) when handle is NULL. + init_nest_.emplace_back( + LetStmt(buf_shape->data, + tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), + make_zero(DataType::Handle())), + nop)); + init_nest_.emplace_back(DeclBuffer(buf_shape, nop)); + + // Save for later use in shape binding + shape_buffer_map[arg_name] = buf_shape; } - if (!(buffer->dtype == DataType::Int(1) || - buffer->dtype == DataType::Int(4) || - buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) { - // Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs. - // Only issue the call when handle is non-NULL and cond is false. - ffi::Array packed_args; - packed_args.push_back(StringImm(tvm_error_dtype_mismatch)); - // Split arg_name of the form "." into parts for clearer - // diagnostics - std::string kernel_name = arg_name; - std::string buffer_name = arg_name; + + // Now process each buffer fully + for (const auto &[handle, buffer] : buffer_def) { + bool is_used = used_param_buffers.count(handle.get()); + std::string arg_name = func_name + "." + buffer->data->name_hint; + const PrimExpr &is_null = is_null_expr_map[arg_name]; + + // dimension checks + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + auto stride_handle_name = [&]() { return arg_name + ".strides"; }; + auto array_element_name = [&](const std::string &arr_name, size_t k) { + std::stringstream ss; + ss << arr_name << '[' << k << ']'; + return ss.str(); + }; + auto shape_element_name = [&](size_t k) { + return array_element_name(shape_handle_name(), k); + }; + auto stride_element_name = [&](size_t k) { + return array_element_name(stride_handle_name(), k); + }; + + PrimExpr a_ndim = + make_const(tvm_ndim_type, static_cast(buffer->shape.size())); + // Build clearer ndim message with kernel/buffer names + std::string kernel_nm = arg_name; + std::string buf_nm = arg_name; size_t dot_pos = arg_name.find('.'); if (dot_pos != std::string::npos) { - kernel_name = arg_name.substr(0, dot_pos); - buffer_name = arg_name.substr(dot_pos + 1); + kernel_nm = arg_name.substr(0, dot_pos); + buf_nm = arg_name.substr(dot_pos + 1); } - packed_args.push_back(StringImm(kernel_name)); - packed_args.push_back(StringImm(buffer_name)); - - auto i64 = DataType::Int(64); - // Cast to int64 for FFI function signature - packed_args.push_back(cast(i64, v_type_code)); // actual_code - packed_args.push_back(cast(i64, v_type_bits)); // actual_bits - packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes - packed_args.push_back(cast(i64, expect_code)); // expect_code - packed_args.push_back(cast(i64, expect_bits)); // expect_bits - packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes - - Stmt call_err = Evaluate( - Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args)); - // Guard the call: only when handle is not null and cond fails - Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err); - asserts_.emplace_back(SeqStmt({guarded, nop})); - } - - // shape field - Buffer buf_shape = - decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, - tvm_shape_type, shape_handle_name()); - Var v_shape(shape_handle_name(), DataType::Handle()); - def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); - // Use if_then_else for NULL guard on the shape pointer itself, avoiding - // dereferencing TVMStructGet(handle, kArrShape) when handle is NULL. - init_nest_.emplace_back( - LetStmt(buf_shape->data, - tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), - make_zero(DataType::Handle())), - nop)); - init_nest_.emplace_back(DeclBuffer(buf_shape, nop)); - - for (size_t k = 0; k < buffer->shape.size(); ++k) { - // These packed-bit dtype shapes were not bound in the original - // implementation, so we just use them as is. - if (buffer->dtype == DataType::Int(4) || - buffer->dtype == DataType::UInt(4) || - buffer->dtype == DataType::Int(1)) { - break; + // Only check ndim when handle is non-NULL: use packed error helper + PrimExpr ndim_ok = (a_ndim == v_ndim); + ffi::Array ndim_args; + ndim_args.push_back(StringImm(tvm_error_ndim_mismatch)); + ndim_args.push_back(StringImm(kernel_nm)); + ndim_args.push_back(StringImm(buf_nm)); + ndim_args.push_back(cast(DataType::Int(64), a_ndim)); + ndim_args.push_back(cast(DataType::Int(64), v_ndim)); + Stmt ndim_call = Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args)); + init_nest_.emplace_back( + SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call), + Evaluate(0)), + nop})); + // type checks + // Guard all dtype field loads by `is_null` using if_then_else + PrimExpr v_type_code = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode), + IntImm(DataType::UInt(8), buffer->dtype.code())); + PrimExpr v_type_bits = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits), + IntImm(DataType::UInt(8), buffer->dtype.bits())); + PrimExpr v_type_lanes = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes), + IntImm(DataType::UInt(16), buffer->dtype.lanes())); + PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code()); + PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits()); + PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes()); + + PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + + // Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime. + if (buffer->dtype.is_float8_e4m3()) { + PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3); + PrimExpr code_e4m3fn = + IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn); + PrimExpr code_e4m3fnuz = + IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz); + PrimExpr code_match = + (v_type_code == code_e4m3 || v_type_code == code_e4m3fn || + v_type_code == code_e4m3fnuz); + cond = cond || (code_match && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); } - - // The "real" runtime shape value read from DLTensor - PrimExpr shape_val = - cast(buffer->shape[k].dtype(), - BufferLoad(buf_shape, - {IntImm(DataType::Int(32), static_cast(k))})); - - // When first encountering a Var (e.g., m), this will generate: - // Let(m, bound_shape_val, ...) - // Constant dimensions will only generate consistency assertions. - BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true, - is_null); - } - - // strides field - Buffer buf_strides = - decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, - tvm_shape_type, arg_name + ".strides"); - def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); - init_nest_.emplace_back( - LetStmt(buf_strides->data, - tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), - make_zero(DataType::Handle())), - nop)); - init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); - PrimExpr v_strides_is_null = - Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); - - if (buffer->strides.empty()) { - // Assert the buffer is compact - DataType stype = buffer->DefaultIndexType(); - PrimExpr expect_stride = make_const(stype, 1); - ffi::Array conds; - for (size_t i = buffer->shape.size(); i != 0; --i) { - size_t k = i - 1; - PrimExpr svalue = cast( - stype, BufferLoad(buf_strides, - {IntImm(DataType::Int(32), static_cast(k))})); - conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue); - expect_stride = expect_stride * buffer->shape[k]; + // Allow float8_e5m2 to match float8_e5m2fnuz at runtime. + if (buffer->dtype.is_float8_e5m2()) { + PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2); + PrimExpr code_e5m2fnuz = + IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz); + PrimExpr code_match = + (v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz); + cond = cond || (code_match && v_type_bits == expect_bits && + v_type_lanes == expect_lanes); + } + // Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6). + if (buffer->dtype.is_bool()) { + PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); + PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt); + PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6); + PrimExpr bits8 = IntImm(DataType::UInt(8), 8); + PrimExpr bits1 = IntImm(DataType::UInt(8), 1); + PrimExpr lanes_ok = (v_type_lanes == expect_lanes); + PrimExpr int8_ok = + (v_type_code == code_int && v_type_bits == bits8 && lanes_ok); + PrimExpr uint8_ok = + (v_type_code == code_uint && v_type_bits == bits8 && lanes_ok); + // Some frontends may tag bool tensors as kDLBool(code=6), commonly with + // bits=8 or bits=1. + PrimExpr kdlbool8_ok = + (v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok); + PrimExpr kdlbool1_ok = + (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); + // Also accept any dtype whose bitwidth=1, regardless of code, to be + // defensive. + PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); + cond = + cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; + } + // Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for + // FP4). + if (buffer->dtype.is_float4()) { + PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); + PrimExpr bits8 = IntImm(DataType::UInt(8), 8); + // For FP4, we pack 2 elements per byte, but we still use same lanes at + // storage level Accept int8 with same lanes as the fp4 type + PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes); + PrimExpr int8_ok = + (v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok); + cond = cond || int8_ok; } - std::ostringstream stride_err_msg; - stride_err_msg - << stride_handle_name() - << ": expected to be compact array, but got non-compact strides"; - if (!conds.empty()) { - PrimExpr all_ok = foldl([](PrimExpr a, PrimExpr b, - Span span) { return logical_and(a, b, span); }, - const_true(1), conds); - // Packed generic violation for non-compact strides - std::string kernel_nm3 = arg_name; - std::string buf_nm3 = arg_name; - size_t dot_pos3 = arg_name.find('.'); - if (dot_pos3 != std::string::npos) { - kernel_nm3 = arg_name.substr(0, dot_pos3); - buf_nm3 = arg_name.substr(dot_pos3 + 1); + if (!(buffer->dtype == DataType::Int(1) || + buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) { + // Build FFI packed call to __tvm_error_dtype_mismatch when mismatch + // occurs. Only issue the call when handle is non-NULL and cond is false. + ffi::Array packed_args; + packed_args.push_back(StringImm(tvm_error_dtype_mismatch)); + // Split arg_name of the form "." into parts for clearer + // diagnostics + std::string kernel_name = arg_name; + std::string buffer_name = arg_name; + size_t dot_pos = arg_name.find('.'); + if (dot_pos != std::string::npos) { + kernel_name = arg_name.substr(0, dot_pos); + buffer_name = arg_name.substr(dot_pos + 1); } - ffi::Array pargs4; - pargs4.push_back(StringImm(tvm_error_constraint_violation)); - pargs4.push_back(StringImm(kernel_nm3)); - pargs4.push_back(StringImm(buf_nm3)); - pargs4.push_back(StringImm("strides")); - Stmt call_err4 = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4)); - // Only check when strides array is present and condition fails - Stmt check = IfThenElse(Not(v_strides_is_null), - IfThenElse(Not(all_ok), call_err4), Evaluate(0)); - asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); + packed_args.push_back(StringImm(kernel_name)); + packed_args.push_back(StringImm(buffer_name)); + + auto i64 = DataType::Int(64); + // Cast to int64 for FFI function signature + packed_args.push_back(cast(i64, v_type_code)); // actual_code + packed_args.push_back(cast(i64, v_type_bits)); // actual_bits + packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes + packed_args.push_back(cast(i64, expect_code)); // expect_code + packed_args.push_back(cast(i64, expect_bits)); // expect_bits + packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes + + Stmt call_err = Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args)); + // Guard the call: only when handle is not null and cond fails + Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err); + asserts_.emplace_back(SeqStmt({guarded, nop})); } - } else if (buffer->buffer_type == kAutoBroadcast) { - PrimExpr stride_from_shape = 1; - for (size_t i = buffer->shape.size(); i != 0; --i) { - size_t k = i - 1; - DataType stride_dtype = buffer->strides[k].dtype(); - PrimExpr explicit_stride = - cast(stride_dtype, - BufferLoad(buf_strides, - {IntImm(DataType::Int(32), static_cast(k))})); - PrimExpr stride_val = tvm::if_then_else( - v_strides_is_null, stride_from_shape, explicit_stride); + // Get the pre-created shape buffer + Buffer buf_shape = shape_buffer_map[arg_name]; + + // Bind symbolic variables from buffer shape + for (size_t k = 0; k < buffer->shape.size(); ++k) { + // These packed-bit dtype shapes were not bound in the original + // implementation, so we just use them as is. + if (buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { + break; + } - BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true, - is_null); + // The "real" runtime shape value read from DLTensor + PrimExpr shape_val = + cast(buffer->shape[k].dtype(), + BufferLoad(buf_shape, + {IntImm(DataType::Int(32), static_cast(k))})); + + // Check if this dimension is a symbolic variable + if (const VarNode *v = buffer->shape[k].as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + // First time binding this symbolic variable + auto sources_it = shape_var_sources.find(v); + if (sources_it != shape_var_sources.end() && + sources_it->second.size() > 1) { + // This variable appears in multiple buffers + // Assert that at least one buffer is non-null + PrimExpr any_nonnull = const_false(); + for (const auto &src : sources_it->second) { + bool buf_is_used = used_param_buffers.count(src.handle_ptr); + if (buf_is_used) { + any_nonnull = const_true(); + break; + } + Var src_is_null = is_null_map[src.buf_name]; + any_nonnull = Or(any_nonnull, Not(src_is_null)); + } + + std::ostringstream err_msg; + err_msg << "Symbolic shape variable " + << ffi::GetRef(v)->name_hint + << " requires at least one non-null buffer among: "; + bool first = true; + for (const auto &src : sources_it->second) { + if (!first) + err_msg << ", "; + err_msg << src.buf_name; + first = false; + } + + init_nest_.emplace_back(AssertStmt( + any_nonnull, tvm::tir::StringImm(err_msg.str()), nop)); + + // Build cascaded if_then_else: if !is_null_a then a.shape[k] else + // if !is_null_b then b.shape[k] ... We need to construct this in + // reverse order + PrimExpr cascaded_value; + bool is_first_source = true; + + for (auto rit = sources_it->second.rbegin(); + rit != sources_it->second.rend(); ++rit) { + const auto &src = *rit; + + // Get the shape buffer for this source + auto it_buf = shape_buffer_map.find(src.buf_name); + if (it_buf == shape_buffer_map.end()) { + LOG(FATAL) << "Shape buffer not found for " << src.buf_name; + } + Buffer src_shape_buf = it_buf->second; + + // Construct the shape load + PrimExpr src_shape_val = + cast(buffer->shape[k].dtype(), + BufferLoad(src_shape_buf, + {IntImm(DataType::Int(32), + static_cast(src.dim_idx))})); + + // Check if this buffer is used (non-nullable) + bool src_is_used = used_param_buffers.count(src.handle_ptr); + + if (is_first_source) { + // Base case: use this shape value directly (we know at least + // one is non-null from assert) + cascaded_value = src_shape_val; + is_first_source = false; + } else { + // if !is_null then use this shape, else use previous cascaded + // value But if buffer is used (non-nullable), always use its + // shape + if (src_is_used) { + cascaded_value = src_shape_val; + } else { + Var src_is_null = is_null_map[src.buf_name]; + cascaded_value = tvm::if_then_else( + Not(src_is_null), src_shape_val, cascaded_value); + } + } + } + + // Bind the variable to the cascaded expression + Var v_arg = ffi::GetRef(v); + defs_.emplace_back(v_arg); + (*def_map_)[v] = cascaded_value; + init_nest_.emplace_back( + LetStmt(v_arg, cascaded_value, Evaluate(0))); + } else { + // Single source or no special handling needed, use the original + // nullable binding + BindNullable(buffer->shape[k], shape_val, shape_element_name(k), + true, is_null); + } + } else { + // Variable already bound, add assertion with nullable guard + PrimExpr cond = (it->second == shape_val); + BinderAddAssert(&analyzer_, cond, shape_element_name(k), &asserts_, + is_null); + } + } else { + // Constant dimension, just add assertion + BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true, + is_null); + } } - } else { - PrimExpr stride_from_shape = 1; - for (int k = static_cast(buffer->strides.size()) - 1; k >= 0; --k) { - DataType stride_dtype = buffer->strides[k].dtype(); - PrimExpr explicit_stride = - cast(stride_dtype, - BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); - PrimExpr shape_stride = cast( - stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); + // strides field + Buffer buf_strides = + decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, + tvm_shape_type, arg_name + ".strides"); + def_handle_dtype_.Set(buf_strides->data, + tir::TypeAnnotation(tvm_shape_type)); + init_nest_.emplace_back( + LetStmt(buf_strides->data, + tvm::if_then_else(Not(is_null), + TVMArrayGet(DataType::Handle(), handle, + builtin::kArrStrides), + make_zero(DataType::Handle())), + nop)); + init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); + PrimExpr v_strides_is_null = + Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); + + if (buffer->strides.empty()) { + // Assert the buffer is compact + DataType stype = buffer->DefaultIndexType(); + PrimExpr expect_stride = make_const(stype, 1); + ffi::Array conds; + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + PrimExpr svalue = + cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), + static_cast(k))})); + conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue); + expect_stride = expect_stride * buffer->shape[k]; + } + std::ostringstream stride_err_msg; + stride_err_msg + << stride_handle_name() + << ": expected to be compact array, but got non-compact strides"; + if (!conds.empty()) { + PrimExpr all_ok = + foldl([](PrimExpr a, PrimExpr b, + Span span) { return logical_and(a, b, span); }, + const_true(1), conds); + // Packed generic violation for non-compact strides + std::string kernel_nm3 = arg_name; + std::string buf_nm3 = arg_name; + size_t dot_pos3 = arg_name.find('.'); + if (dot_pos3 != std::string::npos) { + kernel_nm3 = arg_name.substr(0, dot_pos3); + buf_nm3 = arg_name.substr(dot_pos3 + 1); + } + ffi::Array pargs4; + pargs4.push_back(StringImm(tvm_error_constraint_violation)); + pargs4.push_back(StringImm(kernel_nm3)); + pargs4.push_back(StringImm(buf_nm3)); + pargs4.push_back(StringImm("strides")); + Stmt call_err4 = Evaluate( + Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4)); + // Only check when strides array is present and condition fails + Stmt check = + IfThenElse(Not(v_strides_is_null), + IfThenElse(Not(all_ok), call_err4), Evaluate(0)); + asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); + } + } else if (buffer->buffer_type == kAutoBroadcast) { + PrimExpr stride_from_shape = 1; + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + DataType stride_dtype = buffer->strides[k].dtype(); + PrimExpr explicit_stride = + cast(stride_dtype, + BufferLoad(buf_strides, + {IntImm(DataType::Int(32), static_cast(k))})); - PrimExpr stride_val = tvm::if_then_else( - v_strides_is_null, stride_from_shape, explicit_stride); + PrimExpr stride_val = tvm::if_then_else( + v_strides_is_null, stride_from_shape, explicit_stride); - BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true, - is_null); + BindNullable(buffer->strides[k], stride_val, stride_element_name(k), + true, is_null); + } + } else { + PrimExpr stride_from_shape = 1; + + for (int k = static_cast(buffer->strides.size()) - 1; k >= 0; --k) { + DataType stride_dtype = buffer->strides[k].dtype(); + PrimExpr explicit_stride = + cast(stride_dtype, + BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr shape_stride = + cast(stride_dtype, + BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); + + PrimExpr stride_val = tvm::if_then_else( + v_strides_is_null, stride_from_shape, explicit_stride); + + BindNullable(buffer->strides[k], stride_val, stride_element_name(k), + true, is_null); + } } - } - // Byte_offset field. - int data_bytes = GetVectorBytes(buffer->dtype); + // Byte_offset field. + int data_bytes = GetVectorBytes(buffer->dtype); + + if (const auto *const_offset = buffer->elem_offset.as()) { + // Constant elem_offset: only need consistency check, no need for + // additional Var binding. + PrimExpr actual_byte_offset = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + make_const(DataType::UInt(64), 0)); + PrimExpr expect_byte_offset = + make_const(DataType::UInt(64), const_offset->value * data_bytes); + PrimExpr ok = (expect_byte_offset == actual_byte_offset); + ffi::Array pargs; + pargs.push_back(StringImm(tvm_error_byte_offset_mismatch)); + pargs.push_back(StringImm(kernel_nm)); + pargs.push_back(StringImm(buf_nm)); + pargs.push_back(cast(DataType::Int(64), expect_byte_offset)); + pargs.push_back(cast(DataType::Int(64), actual_byte_offset)); + Stmt call_err = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); + asserts_.emplace_back(SeqStmt( + {IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)), + nop})); + } else { + PrimExpr actual_byte_offset = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + make_const(DataType::UInt(64), 0)); + PrimExpr expect_elem_off = cast( + buffer->elem_offset.dtype(), + (actual_byte_offset / make_const(DataType::UInt(64), data_bytes))); + + BindNullable(buffer->elem_offset, expect_elem_off, + arg_name + ".elem_offset", true, is_null); + + if (buffer->offset_factor > 1) { + PrimExpr offset = buffer->elem_offset; + PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); + BindNullable(offset, truncmod(offset, factor), + arg_name + ".elem_offset", true, is_null); + } + } - if (const auto *const_offset = buffer->elem_offset.as()) { - // Constant elem_offset: only need consistency check, no need for additional - // Var binding. - PrimExpr actual_byte_offset = tvm::if_then_else( + // device info. + // Define device_id from handle when available (so later passes can use it) + PrimExpr actual_dev_type = tvm::if_then_else( Not(is_null), - TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), - make_const(DataType::UInt(64), 0)); - PrimExpr expect_byte_offset = - make_const(DataType::UInt(64), const_offset->value * data_bytes); - PrimExpr ok = (expect_byte_offset == actual_byte_offset); - ffi::Array pargs; - pargs.push_back(StringImm(tvm_error_byte_offset_mismatch)); - pargs.push_back(StringImm(kernel_nm)); - pargs.push_back(StringImm(buf_nm)); - pargs.push_back(cast(DataType::Int(64), expect_byte_offset)); - pargs.push_back(cast(DataType::Int(64), actual_byte_offset)); - Stmt call_err = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); - asserts_.emplace_back(SeqStmt( - {IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)), - nop})); - } else { - PrimExpr actual_byte_offset = tvm::if_then_else( + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), + make_zero(DataType::Int(32))); + PrimExpr actual_dev_id = tvm::if_then_else( Not(is_null), - TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), - make_const(DataType::UInt(64), 0)); - PrimExpr expect_elem_off = - cast(buffer->elem_offset.dtype(), - (actual_byte_offset / make_const(DataType::UInt(64), data_bytes))); - - BindNullable(buffer->elem_offset, expect_elem_off, - arg_name + ".elem_offset", true, is_null); - - if (buffer->offset_factor > 1) { - PrimExpr offset = buffer->elem_offset; - PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); - PrimExpr zero = make_zero(offset.dtype()); - BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset", - true, is_null); - } - } - - // device info. - // Define device_id from handle when available (so later passes can use it) - PrimExpr actual_dev_type = tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), - make_zero(DataType::Int(32))); - PrimExpr actual_dev_id = tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), - make_zero(DataType::Int(32))); - - // Bind device_id to a safe expression (0 when NULL handle) - BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true, - is_null); - // Check device_type consistency (device_id equality is implicitly ensured by - // binding above) - { - PrimExpr ok = (device_type == actual_dev_type); - ffi::Array pargs2; - pargs2.push_back(StringImm(tvm_error_device_type_mismatch)); - pargs2.push_back(StringImm(kernel_nm)); - pargs2.push_back(StringImm(buf_nm)); - pargs2.push_back(cast(DataType::Int(64), device_type)); - pargs2.push_back(cast(DataType::Int(64), actual_dev_type)); - Stmt call_err2 = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2)); - asserts_.emplace_back(SeqStmt( - {IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), Evaluate(0)), - Evaluate(0)})); - } + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), + make_zero(DataType::Int(32))); - // Data field. Because the validation of the data field may depend - // on a dynamic size defined by the other DLTensor* parameters, this - // field must be generated last. - // Bind data pointer using expression-level guard to avoid deref on NULL. - { - Var vptr(buffer->data); - PrimExpr data_ptr = tvm::if_then_else( - Not(is_null), - TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), - make_zero(DataType::Handle())); - BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null); - - // Check if the data pointer is NULL. This check is skipped for - // size-0 arrays and also skipped when handle itself is NULL. - auto alloc_size = [&]() -> PrimExpr { - PrimExpr product = IntImm(buffer->DefaultIndexType(), 1); - for (const auto &dim : buffer->shape) - product *= dim; - return product; - }(); - // Improve message: kernel/buffer naming for data pointer null check - std::string kernel_nm2 = arg_name; - std::string buf_nm2 = arg_name; - size_t dot_pos2 = arg_name.find('.'); - if (dot_pos2 != std::string::npos) { - kernel_nm2 = arg_name.substr(0, dot_pos2); - buf_nm2 = arg_name.substr(dot_pos2 + 1); + // Bind device_id to a safe expression (0 when NULL handle) + BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true, + is_null); + // Check device_type consistency (device_id equality is implicitly ensured + // by binding above) + { + PrimExpr ok = (device_type == actual_dev_type); + ffi::Array pargs2; + pargs2.push_back(StringImm(tvm_error_device_type_mismatch)); + pargs2.push_back(StringImm(kernel_nm)); + pargs2.push_back(StringImm(buf_nm)); + pargs2.push_back(cast(DataType::Int(64), device_type)); + pargs2.push_back(cast(DataType::Int(64), actual_dev_type)); + Stmt call_err2 = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2)); + asserts_.emplace_back( + SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), + Evaluate(0)), + Evaluate(0)})); } - // expand combined condition via nested IfThenElse for portability - ffi::Array pargs3; - pargs3.push_back(StringImm(tvm_error_null_ptr)); - pargs3.push_back(StringImm(kernel_nm2)); - pargs3.push_back(StringImm(buf_nm2)); - pargs3.push_back(StringImm("data pointer")); - Stmt call_err3 = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3)); - asserts_.emplace_back(SeqStmt( - {IfThenElse(Not(is_null), - IfThenElse(Not(alloc_size == 0), - IfThenElse(Call(DataType::Bool(), - builtin::isnullptr(), {vptr}), - call_err3), - Evaluate(0)), - Evaluate(0)), - nop})); - - // mark alignment of external bufs - init_nest_.emplace_back( - AttrStmt(vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), nop)); - def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); + // Data field. Because the validation of the data field may depend + // on a dynamic size defined by the other DLTensor* parameters, this + // field must be generated last. + // Bind data pointer using expression-level guard to avoid deref on NULL. + { + Var vptr(buffer->data); + PrimExpr data_ptr = tvm::if_then_else( + Not(is_null), + TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), + make_zero(DataType::Handle())); + BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null); + + // Check if the data pointer is NULL. This check is skipped for + // size-0 arrays and also skipped when handle itself is NULL. + PrimExpr alloc_size = IntImm(buffer->DefaultIndexType(), 1); + for (const auto &dim : buffer->shape) { + alloc_size = alloc_size * dim; + } + // Improve message: kernel/buffer naming for data pointer null check + std::string kernel_nm2 = arg_name; + std::string buf_nm2 = arg_name; + size_t dot_pos2 = arg_name.find('.'); + if (dot_pos2 != std::string::npos) { + kernel_nm2 = arg_name.substr(0, dot_pos2); + buf_nm2 = arg_name.substr(dot_pos2 + 1); + } + // expand combined condition via nested IfThenElse for portability + ffi::Array pargs3; + pargs3.push_back(StringImm(tvm_error_null_ptr)); + pargs3.push_back(StringImm(kernel_nm2)); + pargs3.push_back(StringImm(buf_nm2)); + pargs3.push_back(StringImm("data pointer")); + Stmt call_err3 = + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3)); + asserts_.emplace_back(SeqStmt( + {IfThenElse(Not(is_null), + IfThenElse(Not(alloc_size == 0), + IfThenElse(Call(DataType::Bool(), + builtin::isnullptr(), {vptr}), + call_err3), + Evaluate(0)), + Evaluate(0)), + nop})); + + // mark alignment of external bufs + init_nest_.emplace_back( + AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); + + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); + } } } } // namespace tl -} // namespace tvm +} // namespace tvm \ No newline at end of file diff --git a/src/transform/arg_binder.h b/src/transform/arg_binder.h index 6a580636f..bb7a0f46f 100644 --- a/src/transform/arg_binder.h +++ b/src/transform/arg_binder.h @@ -95,17 +95,21 @@ class ArgBinder { */ void BindBuffer(const Buffer &arg, const Buffer &value, const std::string &arg_name, bool fuzzy_match); + /*! * \brief Bind symbolic buffer to a DLTensor handle. * \param buffer The argument buffer to be binded. - * \param device_type The device id to be binded. + * \param device_type The device type to be binded. * \param device_id The device id to be binded. - * \param handle The DLTensor handle. - * \param arg_name argument name. + * \param buffer_def The buffer definition. + * \param func_name The function name. + * \param used_param_buffers The used param buffers. */ - void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, - const PrimExpr &device_id, const Var &handle, - const std::string &arg_name, bool is_used); + void + BindDLTensors(const std::vector> &buffer_def, + const PrimExpr &device_type, const PrimExpr &device_id, + const std::string &func_name, + const std::unordered_set &used_param_buffers); /*! \return The defs generated in binding. */ const std::vector &defs() const { return defs_; } @@ -178,4 +182,4 @@ class ArgBinder { }; } // namespace tl } // namespace tvm -#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_ +#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_ \ No newline at end of file diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index 942c652fd..e9e8f76e6 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -393,10 +393,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { break; } } - if (!has_used_carrier && !carriers.empty()) { - // Choose the first carrier to anchor this symbol. - used_param_buffers.insert(carriers.front()); - } + // NOTE: With the new nullable shape binding logic in + // ArgBinder::BindDLTensors, we no longer need to force one carrier to be + // non-NULL. The binder will: + // 1. Assert that at least one carrier is non-NULL at runtime + // 2. Use cascaded if_then_else to read from the first non-NULL carrier + // So we can allow all carriers to be nullable. + // if (!has_used_carrier && !carriers.empty()) { + // used_param_buffers.insert(carriers.front()); + // } } for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { @@ -508,14 +513,14 @@ PrimFunc MakePackedAPI(PrimFunc func) { binder.Bind(param, expr, name_hint + "." + param->name_hint, true); } + binder.BindDLTensors(buffer_def, device_type, device_id, name_hint, + used_param_buffers); for (const auto &[var, buffer] : buffer_def) { // Prefer buffer data var name in diagnostics to avoid exposing low-level // handle vars - std::string display = name_hint + "." + buffer->data->name_hint; - binder.BindDLTensor(buffer, device_type, device_id, var, display, - used_param_buffers.count(var.get())); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } + // reset global symbol to attach prefix func = WithAttrs( std::move(func), @@ -614,4 +619,4 @@ TVM_FFI_STATIC_INIT_BLOCK() { } } // namespace tl -} // namespace tvm +} // namespace tvm \ No newline at end of file diff --git a/testing/python/transform/test_nullable_buffer_params.py b/testing/python/transform/test_nullable_buffer_params.py new file mode 100644 index 000000000..5bbde254b --- /dev/null +++ b/testing/python/transform/test_nullable_buffer_params.py @@ -0,0 +1,73 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +def test_nullable_shared_shape(): + """Test that buffers sharing a shape variable can be nullable.""" + + @tilelang.jit + def get_kernel(): + m = T.dynamic("m") + + @T.prim_func + def test_kernel( + a: T.Tensor[(m,), T.int32], + b: T.Tensor[(m,), T.int32], + c: T.Tensor[(m,), T.int32], + ): + with T.Kernel(1, threads=64): + tx = T.get_thread_binding() + if tx == 0: + T.print(m) + + return test_kernel + + m = 200 + kernel = get_kernel() + + # Create test tensors + tensor_a = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + tensor_b = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + tensor_c = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32) + + print("Test 1: All tensors provided") + kernel(tensor_a, tensor_b, tensor_c) + print("✓ PASS: All tensors provided") + + print("\nTest 2: Only first tensor provided") + kernel(tensor_a, None, None) + print("✓ PASS: Only first tensor provided") + + print("\nTest 3: Only middle tensor provided") + kernel(None, tensor_b, None) + print("✓ PASS: Only middle tensor provided") + + print("\nTest 4: Only last tensor provided") + kernel(None, None, tensor_c) + print("✓ PASS: Only last tensor provided") + + print("\nTest 5: First and last tensors provided") + kernel(tensor_a, None, tensor_c) + print("✓ PASS: First and last tensors provided") + + print("\nTest 6: All tensors are None (should fail)") + try: + kernel(None, None, None) + print("✗ FAIL: Should have raised an error") + return False + except RuntimeError as e: + if "at least one non-null buffer" in str(e): + print(f"✓ PASS: Correctly rejected with error: {e}") + else: + print(f"✗ FAIL: Wrong error message: {e}") + return False + + print("\n" + "=" * 60) + print("All tests passed!") + return True + + +if __name__ == "__main__": + tilelang.testing.main()