diff --git a/.clang-tidy b/.clang-tidy index 2ddbefbf9..5c2a7aa65 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -4,7 +4,7 @@ ExtraArgs: ['-v'] FormatStyle: file UseColor: true WarningsAsErrors: '*' -ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' +HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' # NOTE: there must be no spaces before the '-', so put the comma last. Checks: >- diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc new file mode 100644 index 000000000..2caef2239 --- /dev/null +++ b/src/transform/arg_binder.cc @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file arg_binder.cc + * \brief Helper utility to match and bind arguments. + */ +#include "arg_binder.h" + +#include +#include +#include +#include + +#include + +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, + const std::string &arg_name, std::vector *asserts) { + PrimExpr scond = ana->Simplify(cond); + if (is_zero(scond)) { + LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " + << " on argument " << arg_name; + } + if (!is_one(scond)) { + std::ostringstream os; + os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond; + asserts->emplace_back(AssertStmt(scond, StringImm(os.str()), Evaluate(0))); + } +} + +bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets) { + ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; + if (const VarNode *v = arg.as()) { + auto it = def_map_->find(v); + if (it == def_map_->end()) { + Var v_arg = Downcast(arg); + defs_.emplace_back(v_arg); + if (with_lets) { + (*def_map_)[v] = arg; + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); + } else { + (*def_map_)[v] = value; + } + return true; + } else { + BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_); + } + } else { + BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_); + } + return false; +} + +void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_let) { + Bind_(arg, value, arg_name, with_let); +} + +void ArgBinder::BindArray(const Array &arg, + const Array &value, + const std::string &arg_name) { + ICHECK_EQ(arg.size(), value.size()) + << "Argument " << arg_name << " array size mismatch"; + for (size_t i = 0; i < arg.size(); ++i) { + std::ostringstream os; + os << arg_name << "[" << i << "]"; + this->Bind(arg[i], value[i], os.str()); + } +} + +void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value, + const std::string &arg_name, bool fuzzy_match) { + ICHECK_EQ(arg.scope(), value.scope()) + << "Argument " << arg_name << " Buffer bind scope mismatch"; + ICHECK_EQ(arg->dtype, value->dtype) + << "Argument " << arg_name << " Buffer bind data type mismatch"; + if (value->data_alignment % arg->data_alignment != 0) { + LOG(WARNING) << "Trying to bind buffer to another one with lower alignment " + "requirement " + << " required_alignment=" << arg->data_alignment + << ", provided_alignment=" << value->data_alignment; + } + + if (value->elem_offset.defined()) { + // bind pointer and offset. + if (is_zero(arg->elem_offset)) { + ICHECK(is_zero(value->elem_offset)) + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << arg->elem_offset + << ", provided elem_offset=" << value->elem_offset; + } + + this->Bind(arg->data, value->data, arg_name + ".data"); + if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", + false)) { + if (arg->offset_factor > 1) { + PrimExpr offset = value->elem_offset; + PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, + arg_name + ".elem_offset", &asserts_); + } + } + } + + if (arg->shape.size() < value->shape.size()) { + ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; + size_t diff = value->shape.size() - arg->shape.size(); + for (size_t i = 0; i < diff; ++i) { + ICHECK(is_one(analyzer_.Simplify(value->shape[i]))) + << "Argument " << arg_name << " shape mismatch" << arg->shape + << " vs " << value->shape; + } + for (size_t i = 0; i < arg->shape.size(); ++i) { + std::ostringstream os; + os << arg_name << ".shape[" << i << "]"; + this->Bind(arg->shape[i], value->shape[i + diff], os.str()); + } + if (!value->strides.empty()) { + ICHECK_EQ(arg->strides.size(), arg->shape.size()); + ICHECK_EQ(value->strides.size(), value->shape.size()); + for (size_t i = 0; i < arg->strides.size(); ++i) { + std::ostringstream os; + os << arg_name << ".strides[" << i << "]"; + this->Bind(arg->strides[i], value->strides[i + diff], os.str()); + } + } + } else { + this->BindArray(arg->shape, value->shape, arg_name + ".shape"); + this->BindArray(arg->strides, value->strides, arg_name + ".strides"); + } +} + +inline PrimExpr TVMArrayGet(DataType t, Var arr, + builtin::TVMStructFieldKind kind) { + return TVMStructGet(t, arr, 0, kind); +} + +void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, + const PrimExpr &device_id, const Var &handle, + const std::string &arg_name) { + const DataType tvm_shape_type = DataType::ShapeIndex(); + const DataType tvm_ndim_type = DataType::Int(32); + const Stmt nop = Evaluate(0); + + init_nest_.emplace_back(AssertStmt( + !Call(DataType::Bool(), builtin::isnullptr(), {handle}), + StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"), + nop)); + + // dimension checks + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); + + // Helper functions for shape/stride name formatting + auto shape_handle_name = [&]() { return arg_name + ".shape"; }; + auto stride_handle_name = [&]() { return arg_name + ".strides"; }; + auto array_element_name = [&](const std::string &arr_name, size_t k) { + std::stringstream ss; + ss << arr_name << '[' << k << ']'; + return ss.str(); + }; + auto shape_element_name = [&](size_t k) { + return array_element_name(shape_handle_name(), k); + }; + auto stride_element_name = [&](size_t k) { + return array_element_name(stride_handle_name(), k); + }; + + PrimExpr a_ndim = + make_const(tvm_ndim_type, static_cast(buffer->shape.size())); + std::ostringstream ndim_err_msg; + ndim_err_msg << arg_name << ".ndim is expected to equal " + << buffer->shape.size(); + auto msg = StringImm(ndim_err_msg.str()); + init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); + // type checks + std::ostringstream type_err_msg; + type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype; + PrimExpr cond = + (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == + IntImm(DataType::UInt(8), buffer->dtype.code()) && + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == + IntImm(DataType::UInt(8), buffer->dtype.bits()) && + TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == + IntImm(DataType::UInt(16), buffer->dtype.lanes())); + if (!(buffer->dtype == DataType::Int(1) || + buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4))) { + auto type_msg = StringImm(type_err_msg.str()); + asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); + } + + // shape field + Buffer buf_shape = + decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())}, + tvm_shape_type, shape_handle_name()); + Var v_shape(shape_handle_name(), DataType::Handle()); + def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); + init_nest_.emplace_back(LetStmt( + buf_shape->data, + TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); + init_nest_.emplace_back(DeclBuffer(buf_shape, nop)); + for (size_t k = 0; k < buffer->shape.size(); ++k) { + if (buffer->dtype == DataType::Int(4) || + buffer->dtype == DataType::UInt(4) || + buffer->dtype == DataType::Int(1)) { + break; + } + Bind_(buffer->shape[k], + cast(buffer->shape[k].dtype(), + BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})), + shape_element_name(k), true); + } + // strides field + Buffer buf_strides = + decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, + tvm_shape_type, arg_name + ".strides"); + def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type)); + init_nest_.emplace_back(LetStmt( + buf_strides->data, + TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); + init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); + PrimExpr v_strides_is_null = + Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); + if (buffer->strides.empty()) { + // Assert the buffer is compact + DataType stype = buffer->DefaultIndexType(); + PrimExpr expect_stride = make_const(stype, 1); + Array conds; + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + PrimExpr svalue = + cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue); + expect_stride = expect_stride * buffer->shape[k]; + } + std::ostringstream stride_err_msg; + stride_err_msg << stride_handle_name() << ": expected to be compact array"; + if (!conds.empty()) { + auto stride_msg = StringImm(stride_err_msg.str()); + Stmt check = + AssertStmt(foldl([](PrimExpr a, PrimExpr b, + Span span) { return logical_and(a, b, span); }, + const_true(1), conds), + stride_msg, Evaluate(0)); + check = IfThenElse(Not(v_strides_is_null), check); + asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); + } + } else if (buffer->buffer_type == kAutoBroadcast) { + PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1); + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + DataType stride_dtype = buffer->strides[k].dtype(); + PrimExpr explicit_stride = + cast(stride_dtype, + BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape); + PrimExpr value = tvm::if_then_else( + v_strides_is_null, stride_from_shape_cast, explicit_stride); + value = tvm::if_then_else(buffer->shape[k] == 1, make_zero(stride_dtype), + value); + Bind_(buffer->strides[k], value, stride_element_name(k), true); + PrimExpr shape_extent = cast(stride_dtype, buffer->shape[k]); + stride_from_shape = + analyzer_.Simplify(stride_from_shape_cast * shape_extent); + } + } else { + PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1); + + for (int k = buffer->strides.size() - 1; k >= 0; k--) { + DataType stride_dtype = buffer->strides[k].dtype(); + PrimExpr explicit_stride = + cast(stride_dtype, + BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); + PrimExpr shape_stride = cast( + stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})); + PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape); + + Bind_(buffer->strides[k], + tvm::if_then_else(v_strides_is_null, stride_from_shape_cast, + explicit_stride), + stride_element_name(k), true); + + stride_from_shape = + analyzer_.Simplify(stride_from_shape_cast * shape_stride); + } + } + // Byte_offset field. + int data_bytes = GetVectorBytes(buffer->dtype); + + if (const auto *const_offset = buffer->elem_offset.as()) { + Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + arg_name + ".byte_offset", true); + } else { + if (Bind_(buffer->elem_offset, + cast(buffer->elem_offset.dtype(), + (TVMArrayGet(DataType::UInt(64), handle, + builtin::kArrByteOffset) / + make_const(DataType::UInt(64), data_bytes))), + arg_name + ".elem_offset", true)) { + if (buffer->offset_factor > 1) { + PrimExpr offset = buffer->elem_offset; + PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); + PrimExpr zero = make_zero(offset.dtype()); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, + arg_name + ".elem_offset", &asserts_); + } + } + } + // device info. + Bind_(device_type, + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), + arg_name + ".device_type", true); + Bind_(device_id, + TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), + arg_name + ".device_id", true); + + // Data field. Because the validation of the data field may depend + // on a dynamic size defined by the other DLTensor* parameters, this + // field must be generated last. + if (Bind_(buffer->data, + TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), + arg_name + ".data", true)) { + Var vptr(buffer->data); + + // Check if the data pointer is NULL. This check is skipped for + // size-0 arrays, since CUDA provides a NULL pointer for size-zero + // allocations. + auto alloc_size = [&]() -> PrimExpr { + PrimExpr product = IntImm(buffer->DefaultIndexType(), 1); + for (const auto &dim : buffer->shape) { + product *= dim; + } + return product; + }(); + asserts_.emplace_back(AssertStmt( + alloc_size == 0 || + !Call(DataType::Bool(), builtin::isnullptr(), {vptr}), + StringImm(arg_name + " is expected to have non-NULL data pointer"), + nop)); + + def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); + // mark alignment of external bufs + init_nest_.emplace_back( + AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); + } +} + +} // namespace tl +} // namespace tvm diff --git a/src/transform/arg_binder.h b/src/transform/arg_binder.h new file mode 100644 index 000000000..d2dcc06aa --- /dev/null +++ b/src/transform/arg_binder.h @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file arg_binder.h + * \brief Helper utility to match and bind arguments. + */ +#ifndef TVM_TL_TRANSFORM_ARG_BINDER_H_ +#define TVM_TL_TRANSFORM_ARG_BINDER_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Helper utility to generate match and bind of arguments. + * + * \note There is many places in TVM IR where we need argument bindings. + * + * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)). + * Here n is a undefined variable that is decided by the outside, tB imposes + * a constraint such that it can only take tensor with shape 3, tC imposes + * another constraint that it's shape must equals n + 2. + * So if we call it with f(bufferA, bufferB, bufferC), we need to generate + * the following binding sequence: + * - define n = bufferA.shape[0] + * - assert bufferB.shape[0] == 3 + * - assert bufferB.shape[1] == n + 3 + * + * In general, this is a constraint solving problem. We have simplified + * assumption over the binding declaration, such that we require the variable + * occurred in constraint must be declared in argument list. So it is illegal to + * have signature f(tA(shape=(n+3))) without any argument variable corresponds + * to n, even though it is already enough to derive n from the input argument. + */ +class ArgBinder { +public: + /*! + * \brief Constructor + * \param def_map A definition map that contains definition of known + * variables. ArgBinder will update this def_map when adding new definitions. + */ + explicit ArgBinder(std::unordered_map *def_map) + : def_map_(def_map) {} + /*! + * \brief Try to bind arg to value, generate constraint if necessary. + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + * \param with_let Whether add lets during bind + */ + void Bind(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_let = false); + /*! + * \brief Bind array to array + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + */ + void BindArray(const Array &arg, const Array &value, + const std::string &arg_name); + /*! + * \brief Bind symbolic buffer to another symbolic buffer + * \param arg The argument to be binded. + * \param value The target expression value + * \param arg_name argument name. + * \param fuzzy_match If enabled, we allow value's dimension to be smaller + * than arg, as long as arg's higher dimensions are of 1. + */ + void BindBuffer(const Buffer &arg, const Buffer &value, + const std::string &arg_name, bool fuzzy_match); + /*! + * \brief Bind symbolic buffer to a DLTensor handle. + * \param buffer The argument buffer to be binded. + * \param device_type The device id to be binded. + * \param device_id The device id to be binded. + * \param handle The DLTensor handle. + * \param arg_name argument name. + */ + void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, + const PrimExpr &device_id, const Var &handle, + const std::string &arg_name); + + /*! \return The defs generated in binding. */ + const std::vector &defs() const { return defs_; } + + /*! \return The asserts generated in binding + * + * This contains statements that assert the correct value has been + * bound. For example, `binder.Bind(var, expr_1)` will produce an + * entry mapping `var` to `expr_1` in the `binder.defs()`. If + * `binder.Bind(var, expr_2)` is called later, then this will + * produce an assert statemtn that `expr_1 == expr_2`. + * + * Note: Some assert statements produced by BindDLTensor are located + * in `binder.init_nest()`, not within `binder.asserts()`. This is + * deliberate, as some values may require checks prior to + * initialization. (e.g. Intializing `m = dl_tensor->shape[3]` + * requires first asserting that `3 < dl_tensor->ndim`.) + */ + const std::vector &asserts() const { return asserts_; } + + /*! + * \brief Initialization nest generated + * + * This contains both variable bindings and any assert statements + * that are required in order to safely produce those variable + * bindings. + * + * \note Variable bindings may be implemented either as a `LetStmt` + * that defines the variable, or as a variable replacement. Any + * bindings implemented as a `LetStmt` will be in the + * initialization list. Any bindings implemented as a variable + * replacement will be stored in the `var_def` map. + * + * A `tir::LetStmt` is usually generated when binding to a + * `DLTensor`. This requires loading values from memory, which + * should only be performed once. If the binding to a + * `DLTensor` were implemented as a variable replacement, it + * would load values from memory once for each usage of the + * variable. + * + * \return The initialization nest generated during binding. + */ + const std::vector &init_nest() const { return init_nest_; } + /*! \return Handle data type of the data */ + const Map &def_handle_dtype() const { + return def_handle_dtype_; + } + +private: + // Internal bind function + bool Bind_(const PrimExpr &arg, const PrimExpr &value, + const std::string &arg_name, bool with_lets); + /*! \brief The definition map, can be uses to substitute */ + std::unordered_map *def_map_; + /*! \brief defs generated in the current binder */ + std::vector defs_; + /*! \brief Initialize nest */ + std::vector init_nest_; + /*! \brief handle data type in the defintiions */ + Map def_handle_dtype_; + /*! \brief asserts generated */ + std::vector asserts_; + /*! \brief internal analyzer. */ + arith::Analyzer analyzer_; +}; +} // namespace tl +} // namespace tvm +#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_ diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index cda4ad2e1..4550af8e4 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -262,24 +262,32 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, return true; // Extent must be divisible - if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), + PrimExpr target_size_for_iter = + make_const(iter_var_size.dtype(), target_vectorized_size); + PrimExpr target_size_for_expr = + make_const(expr.dtype(), target_vectorized_size); + PrimExpr target_size_for_var = + make_const(var.dtype(), target_vectorized_size); + PrimExpr zero = make_const(var.dtype(), 0); + + if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), 0)) return false; // The base offset must be divisible if (!analyzer->CanProveEqual( - FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) { + FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) { return false; } // Bind thread range - Var v0("v0"), v1("v1"); - analyzer->Bind(v0, Range(0, target_vectorized_size)); - analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv( - iter_var_size, target_vectorized_size)))); + Var v0("v0", var.dtype()), v1("v1", var.dtype()); + analyzer->Bind(v0, Range(zero, target_size_for_var)); + analyzer->Bind(v1, Range(zero, analyzer->Simplify(FloorDiv( + iter_var_size, target_size_for_iter)))); PrimExpr expr_transformed = analyzer->Simplify( - Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); - Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); + Substitute(expr, {{var, v0 + v1 * target_size_for_var}})); + Vectorizer vectorizer(v0, target_size_for_var); PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); // This simplify is necessary for thread region specified diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index a124027ce..b03193c8c 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -36,7 +36,7 @@ #include #include "../op/builtin.h" -#include "tir/transforms/arg_binder.h" +#include "arg_binder.h" #include "tir/transforms/ir_utils.h" namespace tvm { @@ -496,7 +496,6 @@ tvm::transform::Pass MakePackedAPI() { func->body)) { func.CopyOnWrite()->body = body.value(); } - func = MakePackedAPI(std::move(func)); if (!func.same_as(orig_func)) {