diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index acc362758a7c..f5439bbb290c 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -108,9 +108,11 @@ class StructuralEqual : public BaseValueEqual { * \brief Compare objects via strutural equal. * \param lhs The left operand. * \param rhs The right operand. + * \param map_free_params Whether or not to map free variables. * \return The comparison result. */ - TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, + const bool map_free_params = false) const; }; /*! diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index e0de514122b8..379a75f6109b 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -563,8 +563,9 @@ TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") return first_mismatch; }); -bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false); +bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, + bool map_free_params) const { + return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, map_free_params); } bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,