diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h new file mode 100644 index 000000000000..986ee53a1258 --- /dev/null +++ b/include/tvm/node/object_path.h @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/node/object_path.h + * ObjectPath class that represents a path from a root object to one of its descendants + * via attribute access, array indexing etc. + */ + +#ifndef TVM_NODE_OBJECT_PATH_H_ +#define TVM_NODE_OBJECT_PATH_H_ + +#include +#include + +#include + +namespace tvm { + +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; + +class ObjectPath; + +/*! + * \brief Path to an object from some root object. + * + * Motivation: + * + * Same IR node object can be referenced in several different contexts inside a larger IR object. + * For example, a variable could be referenced in several statements within a block. + * + * This makes it impossible to use an object pointer to uniquely identify a "location" within + * the larger IR object for error reporting purposes. The ObjectPath class addresses this problem + * by serving as a unique "location" identifier. + */ +class ObjectPathNode : public Object { + public: + /*! \brief Get the parent path */ + ObjectPath GetParent() const; + + /*! \brief Extend this path with access to an object attribute. */ + ObjectPath Attr(const char* attr_key); + + /*! \brief Extend this path with access to an object attribute. */ + ObjectPath Attr(String attr_key); + + /*! \brief Extend this path with access to an array element. */ + ObjectPath ArrayIndex(size_t index); + + /*! \brief Extend this path with access to a missing array element. */ + ObjectPath MissingArrayElement(size_t index); + + /*! \brief Extend this path with access to a map value. */ + ObjectPath MapValue(ObjectRef key); + + /*! \brief Extend this path with access to a missing map entry. */ + ObjectPath MissingMapEntry(); + + static constexpr const char* _type_key = "ObjectPath"; + TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); + + protected: + explicit ObjectPathNode(ObjectPathNode* parent); + + friend class ObjectPath; + friend std::string GetObjectPathRepr(const ObjectPathNode* node); + + const ObjectPathNode* ParentNode() const; + + /*! Compares just the last node of the path, without comparing the whole path. */ + virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0; + + virtual std::string LastNodeString() const = 0; + + private: + ObjectRef parent_; + size_t length_; +}; + +class ObjectPath : public ObjectRef { + public: + size_t Length() const; + + ObjectPath GetPrefix(size_t length) const; + + bool IsPrefixOf(const ObjectPath& other) const; + + bool PathsEqual(const ObjectPath& other) const; + + /*! \brief Create a path that represents the root object itself. */ + static ObjectPath Root(); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); +}; + +struct ObjectPathPair { + ObjectPath lhs_path; + ObjectPath rhs_path; +}; + +//------------------------------------------------------------------------- +//----- Concrete object path nodes ------------------------------------ +//------------------------------------------------------------------------- + +// ----- Root ----- + +class RootPathNode final : public ObjectPathNode { + public: + explicit RootPathNode(); + + static constexpr const char* _type_key = "RootPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class RootPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode); +}; + +// ----- Attribute access ----- + +class AttributeAccessPathNode final : public ObjectPathNode { + public: + /*! \brief Name of the attribute being accessed. Must be a static string. */ + String attr_key; + + explicit AttributeAccessPathNode(ObjectPathNode* parent, String attr_key); + + static constexpr const char* _type_key = "AttributeAccessPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class AttributeAccessPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath, AttributeAccessPathNode); +}; + +// ----- Unknown attribute access ----- + +class UnknownAttributeAccessPathNode final : public ObjectPathNode { + public: + explicit UnknownAttributeAccessPathNode(ObjectPathNode* parent); + + static constexpr const char* _type_key = "UnknownAttributeAccessPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class UnknownAttributeAccessPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath, + UnknownAttributeAccessPathNode); +}; + +// ----- Array element access by index ----- + +class ArrayIndexPathNode : public ObjectPathNode { + public: + /*! \brief Index of the array element that is being accessed. */ + size_t index; + + explicit ArrayIndexPathNode(ObjectPathNode* parent, size_t index); + + static constexpr const char* _type_key = "ArrayIndexPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class ArrayIndexPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode); +}; + +// ----- Missing array element ----- + +class MissingArrayElementPathNode : public ObjectPathNode { + public: + /*! \brief Index of the array element that is missing. */ + size_t index; + + explicit MissingArrayElementPathNode(ObjectPathNode* parent, size_t index); + + static constexpr const char* _type_key = "MissingArrayElementPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MissingArrayElementPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath, MissingArrayElementPathNode); +}; + +// ----- Map value ----- + +class MapValuePathNode : public ObjectPathNode { + public: + /*! \brief Key of the map entry that is being accessed */ + ObjectRef key; + + explicit MapValuePathNode(ObjectPathNode* parent, ObjectRef key); + + static constexpr const char* _type_key = "MapValuePath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MapValuePath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode); +}; + +// ----- Missing map entry ----- + +class MissingMapEntryPathNode : public ObjectPathNode { + public: + explicit MissingMapEntryPathNode(ObjectPathNode* parent); + + static constexpr const char* _type_key = "MissingMapEntryPath"; + TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); + + protected: + bool LastNodeEqual(const ObjectPathNode* other) const final; + std::string LastNodeString() const final; +}; + +class MissingMapEntryPath : public ObjectPath { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath, MissingMapEntryPathNode); +}; + +} // namespace tvm + +#endif // TVM_NODE_OBJECT_PATH_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 5d514e24d8d9..7f279a010ac1 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -404,5 +404,11 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr } } +/*! + * \brief Given an object and an address of its attribute, return the key of the attribute. + * \return nullptr if no attribute with the given address exists. + */ +const char* GetAttrKeyByAddress(const Object* object, const void* attr_address); + } // namespace tvm #endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 6c25c3d2d21d..8d5fd4d2a034 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -56,6 +56,8 @@ class BaseValueEqual { } }; +struct ObjectPathPair; + /*! * \brief Content-aware structural equality comparator for objects. * @@ -99,7 +101,10 @@ class StructuralEqual : public BaseValueEqual { * equality checking. Instead, it can store the necessary equality conditions * and check later via an internally managed stack. */ -class SEqualReducer : public BaseValueEqual { +class SEqualReducer { + private: + struct PathTracingData; + public: /*! \brief Internal handler that defines custom behaviors.. */ class Handler { @@ -110,12 +115,24 @@ class SEqualReducer : public BaseValueEqual { * \param lhs The left operand. * \param rhs The right operand. * \param map_free_vars Whether do we allow remap variables if possible. + * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability. * * \return false if there is an immediate failure, true otherwise. * \note This function may save the equality condition of (lhs == rhs) in an internal * stack and try to resolve later. */ - virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0; + virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const ObjectPathPair& current_paths) = 0; + + /*! + * \brief Mark the comparison as failed, but don't fail immediately. + * + * This is useful for producing better error messages when comparing containers. + * For example, if two array sizes mismatch, it's better to mark the comparison as failed + * but compare array elements anyway, so that we could find the true first mismatch. + */ + virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0; + /*! * \brief Lookup the graph node equal map for vars that are already mapped. * @@ -129,9 +146,10 @@ class SEqualReducer : public BaseValueEqual { * \brief Mark current comparison as graph node equal comparison. */ virtual void MarkGraphNode() = 0; - }; - using BaseValueEqual::operator(); + protected: + using PathTracingData = SEqualReducer::PathTracingData; + }; /*! \brief default constructor */ SEqualReducer() = default; @@ -140,17 +158,58 @@ class SEqualReducer : public BaseValueEqual { * \param handler The equal handler for objects. * \param map_free_vars Whether or not to map free variables. */ - explicit SEqualReducer(Handler* handler, bool map_free_vars) - : handler_(handler), map_free_vars_(map_free_vars) {} + explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars) + : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {} + + /*! + * \brief Reduce condition to comparison of two attribute values. + * \param lhs The left operand. + * \param rhs The right operand. + * \return the immediate check result. + */ + bool operator()(const double& lhs, const double& rhs) const; + bool operator()(const int64_t& lhs, const int64_t& rhs) const; + bool operator()(const uint64_t& lhs, const uint64_t& rhs) const; + bool operator()(const int& lhs, const int& rhs) const; + bool operator()(const bool& lhs, const bool& rhs) const; + bool operator()(const std::string& lhs, const std::string& rhs) const; + bool operator()(const DataType& lhs, const DataType& rhs) const; + + template ::value>::type> + bool operator()(const ENum& lhs, const ENum& rhs) const { + using Underlying = typename std::underlying_type::type; + static_assert(std::is_same::value); + return EnumAttrsEqual(static_cast(lhs), static_cast(rhs), &lhs, &rhs); + } + /*! * \brief Reduce condition to comparison of two objects. * \param lhs The left operand. * \param rhs The right operand. * \return the immediate check result. */ - bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - return handler_->SEqualReduce(lhs, rhs, map_free_vars_); + bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; + + /*! + * \brief Reduce condition to comparison of two objects. + * + * Like `operator()`, but with an additional `paths` parameter that specifies explicit object + * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container + * objects like Array and Map, or other custom objects that store nested objects that are not + * simply attributes. + * + * Can only be called when `IsPathTracingEnabled()` is `true`. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param paths Object paths for `lhs` and `rhs`. + * \return the immediate check result. + */ + bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const { + ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function"; + return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths); } + /*! * \brief Reduce condition to comparison of two definitions, * where free vars can be mapped. @@ -162,9 +221,8 @@ class SEqualReducer : public BaseValueEqual { * \param rhs The right operand. * \return the immediate check result. */ - bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { - return handler_->SEqualReduce(lhs, rhs, true); - } + bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); + /*! * \brief Reduce condition to comparison of two arrays. * \param lhs The left operand. @@ -173,13 +231,20 @@ class SEqualReducer : public BaseValueEqual { */ template bool operator()(const Array& lhs, const Array& rhs) const { - // quick specialization for Array to reduce amount of recursion - // depth as array comparison is pretty common. - if (lhs.size() != rhs.size()) return false; - for (size_t i = 0; i < lhs.size(); ++i) { - if (!(operator()(lhs[i], rhs[i]))) return false; + if (tracing_data_ == nullptr) { + // quick specialization for Array to reduce amount of recursion + // depth as array comparison is pretty common. + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(operator()(lhs[i], rhs[i]))) return false; + } + return true; } - return true; + + // If tracing is enabled, fall back to the regular path + const ObjectRef& lhs_obj = lhs; + const ObjectRef& rhs_obj = rhs; + return (*this)(lhs_obj, rhs_obj); } /*! * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var). @@ -198,9 +263,39 @@ class SEqualReducer : public BaseValueEqual { /*! \return Get the internal handler. */ Handler* operator->() const { return handler_; } + /*! \brief Check if this reducer is tracing paths to the first mismatch. */ + bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; } + + /*! + * \brief Get the paths of the currently compared objects. + * + * Can only be called when `IsPathTracingEnabled()` is true. + */ + const ObjectPathPair& GetCurrentObjectPaths() const; + + /*! + * \brief Specify the object paths of a detected mismatch. + */ + void RecordMismatchPaths(const ObjectPathPair& paths) const; + private: + bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const; + + bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const ObjectPathPair* paths) const; + + static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address, + const void* rhs_address, + const PathTracingData* tracing_data); + + template + static bool CompareAttributeValues(const T& lhs, const T& rhs, + const PathTracingData* tracing_data); + /*! \brief Internal class pointer. */ Handler* handler_; + /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */ + const PathTracingData* tracing_data_; /*! \brief Whether or not to map free vars. */ bool map_free_vars_; }; diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 00514b472d67..f2a25a409b81 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -209,6 +209,17 @@ def structural_equal(lhs, rhs, map_free_vars=False): return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) +def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): + lhs = tvm.runtime.convert(lhs) + rhs = tvm.runtime.convert(rhs) + mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars) + if len(mismatch) == 0: + return None + else: + assert len(mismatch) == 2 + return tuple(mismatch) + + def assert_structural_equal(lhs, rhs, map_free_vars=False): """Assert lhs and rhs are structurally equal to each other. diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 114f01dd0e50..443315bfebb2 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -20,6 +20,7 @@ from .packed_func import PackedFunc from .object import Object from .object_generic import ObjectGeneric, ObjectTypes +from .object_path import ObjectPath from .ndarray import NDArray, DataType, DataTypeCode, Device from .module import Module, num_threads from .profiling import Report diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py new file mode 100644 index 000000000000..da4b08dad478 --- /dev/null +++ b/python/tvm/runtime/object_path.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm._ffi +from . import _ffi_node_api +from tvm.runtime import Object + + +@tvm._ffi.register_object("ObjectPath") +class ObjectPath(Object): + def __init__(self) -> None: + super().__init__() + raise ValueError( + "ObjectPath can't be initialized directly. " + "Use ObjectPath.root() to create a path to the root object" + ) + + @staticmethod + def root() -> "ObjectPath": + return _ffi_node_api.ObjectPathRoot() + + def __eq__(self, other): + return _ffi_node_api.ObjectPathEqual(self, other) + + def __ne__(self, other): + return not _ffi_node_api.ObjectPathEqual(self, other) + + def attr(self, attr_key) -> "ObjectPath": + return _ffi_node_api.ObjectPathAttr(self, attr_key) + + def array_index(self, index) -> "ObjectPath": + return _ffi_node_api.ObjectPathArrayIndex(self, index) + + def missing_array_element(self, index) -> "ObjectPath": + return _ffi_node_api.ObjectPathMissingArrayElement(self, index) + + def map_value(self, key) -> "ObjectPath": + return _ffi_node_api.ObjectPathMapValue(self, key) + + def missing_map_entry(self) -> "ObjectPath": + return _ffi_node_api.ObjectPathMissingMapEntry(self) + + +@tvm._ffi.register_object("RootPath") +class RootPath(ObjectPath): + pass + + +@tvm._ffi.register_object("AttributeAccessPath") +class AttributeAccessPath(ObjectPath): + pass + + +@tvm._ffi.register_object("UnknownAttributeAccessPath") +class UnknownAttributeAccessPath(ObjectPath): + pass + + +@tvm._ffi.register_object("ArrayIndexPath") +class ArrayIndexPath(ObjectPath): + pass + + +@tvm._ffi.register_object("MissingArrayElementPath") +class MissingArrayElementPath(ObjectPath): + pass + + +@tvm._ffi.register_object("MapValuePath") +class MapValuePath(ObjectPath): + pass + + +@tvm._ffi.register_object("MissingMapEntryPath") +class MissingMapEntryPath(ObjectPath): + pass diff --git a/src/node/object_path.cc b/src/node/object_path.cc new file mode 100644 index 000000000000..3716ec60807f --- /dev/null +++ b/src/node/object_path.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include + +using namespace tvm::runtime; + +namespace tvm { + +ObjectPathNode::ObjectPathNode(ObjectPathNode* parent) + : parent_(GetRef(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {} + +ObjectPath ObjectPathNode::GetParent() const { return Downcast(parent_); } + +ObjectPath ObjectPathNode::Attr(const char* attr_key) { + if (attr_key != nullptr) { + return ObjectPath(make_object(this, attr_key)); + } else { + return ObjectPath(make_object(this)); + } +} + +ObjectPath ObjectPathNode::Attr(String attr_key) { + if (attr_key.defined()) { + return ObjectPath(make_object(this, attr_key)); + } else { + return ObjectPath(make_object(this)); + } +} + +TVM_REGISTER_GLOBAL("node.ObjectPathAttr") + .set_body_typed([](const ObjectPath& path, String attr_key) { + return path->Attr(std::move(attr_key)); + }); + +ObjectPath ObjectPathNode::ArrayIndex(size_t index) { + return ObjectPath(make_object(this, index)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex") + .set_body_typed([](const ObjectPath& path, size_t index) { return path->ArrayIndex(index); }); + +ObjectPath ObjectPathNode::MissingArrayElement(size_t index) { + return ObjectPath(make_object(this, index)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") + .set_body_typed([](const ObjectPath& path, size_t index) { + return path->MissingArrayElement(index); + }); + +ObjectPath ObjectPathNode::MapValue(ObjectRef key) { + return ObjectPath(make_object(this, std::move(key))); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMapValue") + .set_body_typed([](const ObjectPath& path, const ObjectRef& key) { + return path->MapValue(key); + }); + +ObjectPath ObjectPathNode::MissingMapEntry() { + return ObjectPath(make_object(this)); +} + +TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry").set_body_typed([](const ObjectPath& path) { + return path->MissingMapEntry(); +}); + +const ObjectPathNode* ObjectPathNode::ParentNode() const { + return static_cast(parent_.get()); +} + +/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object()); } + +TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed([]() { return ObjectPath::Root(); }); + +size_t ObjectPath::Length() const { + if (defined()) { + return (*this)->length_; + } else { + return 0; + } +} + +ObjectPath ObjectPath::GetPrefix(size_t length) const { + ICHECK(length <= Length()); + + const ObjectPathNode* node = static_cast(get()); + for (size_t i = 0; i < Length() - length; ++i) { + node = node->ParentNode(); + } + + return GetRef(node); +} + +bool ObjectPath::IsPrefixOf(const ObjectPath& other) const { + size_t this_len = Length(); + if (this_len > other.Length()) { + return false; + } + return this->PathsEqual(other.GetPrefix(this_len)); +} + +bool ObjectPath::PathsEqual(const ObjectPath& other) const { + if (Length() != other.Length()) { + return false; + } + + const ObjectPathNode* lhs = static_cast(get()); + const ObjectPathNode* rhs = static_cast(other.get()); + + while (lhs != nullptr && rhs != nullptr) { + if (lhs->type_index() != rhs->type_index()) { + return false; + } + if (!lhs->LastNodeEqual(rhs)) { + return false; + } + lhs = lhs->ParentNode(); + rhs = rhs->ParentNode(); + } + + return lhs == nullptr && rhs == nullptr; +} + +TVM_REGISTER_GLOBAL("node.ObjectPathEqual") + .set_body_typed([](const ObjectPath& lhs, const ObjectPath& rhs) { + return lhs.PathsEqual(rhs); + }); + +std::string GetObjectPathRepr(const ObjectPathNode* node) { + std::string ret; + while (node != nullptr) { + std::string node_str = node->LastNodeString(); + ret.append(node_str.rbegin(), node_str.rend()); + node = static_cast(node->GetParent().get()); + } + std::reverse(ret.begin(), ret.end()); + return ret; +} + +static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) { + p->stream << GetObjectPathRepr(static_cast(node.get())); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- Root ----- + +RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {} + +bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + +std::string RootPathNode::LastNodeString() const { return ""; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- AttributeAccess ----- + +AttributeAccessPathNode::AttributeAccessPathNode(ObjectPathNode* parent, String attr_key) + : ObjectPathNode(parent), attr_key(std::move(attr_key)) {} + +bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherAttrAccess = static_cast(other); + return attr_key == otherAttrAccess->attr_key; +} + +std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- UnknownAttributeAccess ----- + +UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(ObjectPathNode* parent) + : ObjectPathNode(parent) {} + +bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { + // Consider any two unknown attribute accesses unequal + return false; +} + +std::string UnknownAttributeAccessPathNode::LastNodeString() const { + return "."; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- ArrayIndexPath ----- + +ArrayIndexPathNode::ArrayIndexPathNode(ObjectPathNode* parent, size_t index) + : ObjectPathNode(parent), index(index) {} + +bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherArrayIndex = static_cast(other); + return index == otherArrayIndex->index; +} + +std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]"; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- MissingArrayElement ----- + +MissingArrayElementPathNode::MissingArrayElementPathNode(ObjectPathNode* parent, size_t index) + : ObjectPathNode(parent), index(index) {} + +bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherMissingElement = static_cast(other); + return index == otherMissingElement->index; +} + +std::string MissingArrayElementPathNode::LastNodeString() const { + return "[]"; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +// ----- MapValue ----- + +MapValuePathNode::MapValuePathNode(ObjectPathNode* parent, ObjectRef key) + : ObjectPathNode(parent), key(std::move(key)) {} + +bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const { + const auto* otherMapValue = static_cast(other); + return ObjectEqual()(key, otherMapValue->key); +} + +std::string MapValuePathNode::LastNodeString() const { + std::ostringstream s; + s << "[" << key << "]"; + return s.str(); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); + +// ----- MissingMapEntry ----- + +MissingMapEntryPathNode::MissingMapEntryPathNode(ObjectPathNode* parent) : ObjectPathNode(parent) {} + +bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + +std::string MissingMapEntryPathNode::LastNodeString() const { return "[]"; } + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch(PrintObjectPathRepr); + +} // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index a7c3493e7feb..db5567a3de55 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -281,4 +281,43 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr); TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames); TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode); + +namespace { +// Attribute visitor class for finding the attribute key by its address +class GetAttrKeyByAddressVisitor : public AttrVisitor { + public: + explicit GetAttrKeyByAddressVisitor(const void* attr_address) + : attr_address_(attr_address), key_(nullptr) {} + + void Visit(const char* key, double* value) final { DoVisit(key, value); } + void Visit(const char* key, int64_t* value) final { DoVisit(key, value); } + void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); } + void Visit(const char* key, int* value) final { DoVisit(key, value); } + void Visit(const char* key, bool* value) final { DoVisit(key, value); } + void Visit(const char* key, std::string* value) final { DoVisit(key, value); } + void Visit(const char* key, void** value) final { DoVisit(key, value); } + void Visit(const char* key, DataType* value) final { DoVisit(key, value); } + void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); } + void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); } + + const char* GetKey() const { return key_; } + + private: + const void* attr_address_; + const char* key_; + + void DoVisit(const char* key, const void* candidate) { + if (attr_address_ == candidate) { + key_ = key; + } + } +}; +} // anonymous namespace + +const char* GetAttrKeyByAddress(const Object* object, const void* attr_address) { + GetAttrKeyByAddressVisitor visitor(attr_address); + ReflectionVTable::Global()->VisitAttrs(const_cast(object), &visitor); + return visitor.GetKey(); +} + } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 8e52af60d235..f016f5340e36 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,147 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, return fsequal_reduce_[tindex](self, other, equal); } +namespace { + +// Represents the first found mismatch, if any. +struct FirstMismatch { + ObjectPathPair paths; + bool found{false}; + + void MaybeStoreMismatch(const ObjectPathPair new_paths) { + if (!found) { + paths = new_paths; + found = true; + } + } +}; + +} // anonymous namespace + +struct SEqualReducer::PathTracingData { + ObjectPathPair current_paths; + ObjectRef lhs_object; + ObjectRef rhs_object; + FirstMismatch* first_mismatch; +}; + +bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { + if (tracing_data_ == nullptr) { + // Fast path: no tracing + return handler_->SEqualReduce(lhs, rhs, map_free_vars_, {}); + } + return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr); +} + +bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + if (tracing_data_ == nullptr) { + // Fast path: no tracing + return handler_->SEqualReduce(lhs, rhs, true, {}); + } + return ObjectAttrsEqual(lhs, rhs, true, nullptr); +} + +/* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( + const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { + if (tracing_data != nullptr) { + const char* lhs_attr_key = GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address); + const char* rhs_attr_key = GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address); + tracing_data->first_mismatch->MaybeStoreMismatch( + {tracing_data->current_paths.lhs_path->Attr(lhs_attr_key), + tracing_data->current_paths.rhs_path->Attr(rhs_attr_key)}); + } +} + +template +/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, + const PathTracingData* tracing_data) { + if (BaseValueEqual()(lhs, rhs)) { + return true; + } else { + GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); + return false; + } +} + +bool SEqualReducer::operator()(const double& lhs, const double& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::operator()(const int& lhs, const int& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const { + return CompareAttributeValues(lhs, rhs, tracing_data_); +} + +bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, + const void* rhs_address) const { + if (lhs == rhs) { + return true; + } else { + GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_); + return false; + } +} + +const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { + ICHECK(tracing_data_ != nullptr) + << "GetCurrentObjectPaths() can only be called when path tracing is enabled"; + return tracing_data_->current_paths; +} + +void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const { + ICHECK(tracing_data_ != nullptr) + << "RecordMismatchPaths() can only be called when path tracing is enabled"; + tracing_data_->first_mismatch->MaybeStoreMismatch(paths); +} + +bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const ObjectPathPair* paths) const { + if (tracing_data_ == nullptr) { + // Fast path: no tracing + return handler_->SEqualReduce(lhs, rhs, map_free_vars, {}); + } + + // Slow path: tracing object paths for better error reporting + + ObjectPathPair new_paths; + if (paths != nullptr) { + // If paths of `lhs` and `rhs` are explicitly given, use them + new_paths = *paths; + } else { + // Otherwise, assume that `lhs` and `rhs` are attributes and try to find their keys + const char* lhs_attr_key = GetAttrKeyByAddress(tracing_data_->lhs_object.get(), &lhs); + const char* rhs_attr_key = GetAttrKeyByAddress(tracing_data_->rhs_object.get(), &rhs); + new_paths.lhs_path = tracing_data_->current_paths.lhs_path->Attr(lhs_attr_key); + new_paths.rhs_path = tracing_data_->current_paths.rhs_path->Attr(rhs_attr_key); + } + + if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { + return true; + } else { + tracing_data_->first_mismatch->MaybeStoreMismatch(new_paths); + return false; + } +} + /*! * \brief A non recursive stack based SEqual handler that can remaps vars. * @@ -53,9 +195,11 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, */ class RemapVarSEqualHandler : public SEqualReducer::Handler { public: - explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {} + explicit RemapVarSEqualHandler(bool assert_mode, FirstMismatch* first_mismatch) + : assert_mode_(assert_mode), first_mismatch_(first_mismatch) {} - bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { + bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const ObjectPathPair& current_paths) final { // We cannot use check lhs.same_as(rhs) to check equality. // if we choose to enable var remapping. // @@ -82,11 +226,16 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { return it->second.same_as(rhs); } if (equal_map_rhs_.count(rhs)) return false; + // need to push to pending tasks in this case - pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars)); + pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths); return true; }; - return CheckResult(run(), lhs, rhs); + return CheckResult(run(), lhs, rhs, current_paths); + } + + void DeferFail(const ObjectPathPair& mismatch_paths) final { + pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths); } void MarkGraphNode() final { @@ -108,7 +257,15 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { pending_tasks_.clear(); equal_map_lhs_.clear(); equal_map_rhs_.clear(); - if (!SEqualReduce(lhs, rhs, map_free_vars)) return false; + + ObjectPathPair current_paths; + if (IsPathTracingEnabled()) { + current_paths.lhs_path = current_paths.rhs_path = ObjectPath::Root(); + } + if (!SEqualReduce(lhs, rhs, map_free_vars, current_paths)) { + return false; + } + ICHECK_EQ(pending_tasks_.size(), 1U); ICHECK(allow_push_to_stack_); task_stack_.emplace_back(std::move(pending_tasks_.back())); @@ -118,7 +275,11 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { protected: // Check the result. - bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { + bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs, + const ObjectPathPair& current_paths) { + if (IsPathTracingEnabled() && !result) { + first_mismatch_->MaybeStoreMismatch(current_paths); + } if (assert_mode_ && !result) { LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl << PrettyPrint(lhs) << std::endl @@ -137,6 +298,13 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { // Caution: entry becomes invalid when the stack changes auto& entry = task_stack_.back(); + if (entry.force_fail) { + if (IsPathTracingEnabled()) { + first_mismatch_->MaybeStoreMismatch(entry.current_paths); + } + return false; + } + if (entry.children_expanded) { // When all the children has expanded and visited. // This means all the condition checks for @@ -161,7 +329,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { // which populates the pending tasks. ICHECK_EQ(pending_tasks_.size(), 0U); allow_push_to_stack_ = false; - if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false; + if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars, entry.current_paths)) + return false; allow_push_to_stack_ = true; // Push pending tasks in reverse order, so earlier tasks get to // expand first in the stack @@ -175,7 +344,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { } // The default equal as registered in the structural equal vtable. - bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { + bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const ObjectPathPair& current_paths) { auto compute = [=]() { ICHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index()); // skip entries that already have equality maps. @@ -184,10 +354,18 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { return it->second.same_as(rhs); } if (equal_map_rhs_.count(rhs)) return false; + // Run reduce check for free nodes. - return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars)); + if (!IsPathTracingEnabled()) { + return vtable_->SEqualReduce(lhs.get(), rhs.get(), + SEqualReducer(this, nullptr, map_free_vars)); + } else { + PathTracingData tracing_data = {current_paths, lhs, rhs, first_mismatch_}; + return vtable_->SEqualReduce(lhs.get(), rhs.get(), + SEqualReducer(this, &tracing_data, map_free_vars)); + } }; - return CheckResult(compute(), lhs, rhs); + return CheckResult(compute(), lhs, rhs, current_paths); } private: @@ -197,17 +375,28 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { ObjectRef lhs; /*! \brief The rhs operand to be compared. */ ObjectRef rhs; + /*! \brief If path tracing is enabled, paths taken so far from the root to `lhs` and `rhs` + * objects. */ + ObjectPathPair current_paths; /*! \brief The map free var argument. */ bool map_free_vars; /*! \brief Whether the children has been expanded via SEqualReduce */ bool children_expanded{false}; /*! \brief whether the task is about graph equality(need remap). */ bool graph_equal{false}; + bool force_fail{false}; Task() = default; - Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars) - : lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {} + Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars, const ObjectPathPair& current_paths) + : lhs(lhs), rhs(rhs), current_paths(current_paths), map_free_vars(map_free_vars) {} + + struct ForceFailTag {}; + Task(ForceFailTag, const ObjectPathPair& current_paths) + : current_paths(current_paths), force_fail(true) {} }; + + bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; } + // list of pending tasks to be pushed to the stack. std::vector pending_tasks_; // Internal task stack to executed the task. @@ -216,6 +405,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { bool allow_push_to_stack_{true}; // If in assert mode, must return true, and will throw error otherwise. bool assert_mode_{false}; + // Location to store the paths to the first detected mismatch, or nullptr to disable path tracing. + FirstMismatch* first_mismatch_; // reflection vtable ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs @@ -227,11 +418,21 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, bool map_free_vars) { - return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); + return RemapVarSEqualHandler(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars); + }); + +TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") + .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { + FirstMismatch first_mismatch; + if (RemapVarSEqualHandler(false, &first_mismatch).Equal(lhs, rhs, map_free_vars)) { + return Array(); + } else { + return Array({first_mismatch.paths.lhs_path, first_mismatch.paths.rhs_path}); + } }); bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); + return RemapVarSEqualHandler(false, nullptr).Equal(lhs, rhs, false); } } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 23811e219078..524fdb4bd210 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -395,12 +396,73 @@ struct ArrayNodeTrait { } static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { + if (equal.IsPathTracingEnabled()) { + return SEqualReduceTraced(lhs, rhs, equal); + } + if (lhs->size() != rhs->size()) return false; for (size_t i = 0; i < lhs->size(); ++i) { if (!equal(lhs->at(i), rhs->at(i))) return false; } return true; } + + private: + static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs, + const SEqualReducer& equal) { + size_t min_size = std::min(lhs->size(), rhs->size()); + const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths(); + + for (size_t index = 0; index < min_size; ++index) { + ObjectPathPair element_paths = {array_paths.lhs_path->ArrayIndex(index), + array_paths.rhs_path->ArrayIndex(index)}; + if (!equal(lhs->at(index), rhs->at(index), element_paths)) { + return false; + } + } + + if (lhs->size() == rhs->size()) { + return true; + } + + // If the array length is mismatched, don't report it immediately. + // Instead, defer the failure until we visit all children. + // + // This is for human readability. For example, say we have two sequences + // + // (1) a b c d e f g h i j k l m + // (2) a b c d e g h i j k l m + // + // If we directly report a mismatch at the end of the array right now, + // the user will see that array (1) has an element `m` at index 12 but array (2) + // has no index 12 because it's too short: + // + // (1) a b c d e f g h i j k l m + // ^error here + // (2) a b c d e g h i j k l m + // ^ error here + // + // This is not very helpful. Instead, if we defer reporting this mismatch until all elements + // are fully visited, we can be much more helpful with pointing out the location: + // + // (1) a b c d e f g h i j k l m + // ^ + // error here + // + // (2) a b c d e g h i j k l m + // ^ + // error here + if (lhs->size() > min_size) { + equal->DeferFail({array_paths.lhs_path->ArrayIndex(min_size), + array_paths.rhs_path->MissingArrayElement(min_size)}); + } else { + equal->DeferFail({array_paths.lhs_path->MissingArrayElement(min_size), + array_paths.rhs_path->ArrayIndex(min_size)}); + } + + // Can return `true` pretending that everything is good since we have deferred the failure. + return true; + } }; TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) .set_creator([](const std::string&) -> ObjectPtr { @@ -501,13 +563,105 @@ struct MapNodeTrait { return true; } + static bool IsStringMap(const MapNode* map) { + return std::all_of(map->begin(), map->end(), + [](const auto& v) { return v.first->template IsInstance(); }); + } + + static bool SEqualReduceTracedForOMap(const MapNode* lhs, const MapNode* rhs, + const SEqualReducer& equal) { + const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); + + std::vector seen_rhs_keys; + + // First, check that every key from `lhs` is also in `rhs`, + // and their values are mapped to each other. + for (const auto& kv : *lhs) { + ObjectPath lhs_path = map_paths.lhs_path->MapValue(kv.first); + + ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); + if (!rhs_key.defined()) { + equal.RecordMismatchPaths({lhs_path, map_paths.rhs_path->MissingMapEntry()}); + return false; + } + + auto it = rhs->find(rhs_key); + if (it == rhs->end()) { + equal.RecordMismatchPaths({lhs_path, map_paths.rhs_path->MissingMapEntry()}); + return false; + } + + if (!equal(kv.second, it->second, {lhs_path, map_paths.rhs_path->MapValue(it->first)})) { + return false; + } + + seen_rhs_keys.push_back(it->first.get()); + } + + std::sort(seen_rhs_keys.begin(), seen_rhs_keys.end()); + + // Second, check that we have visited every `rhs` key when iterating over `lhs`. + for (const auto& kv : *rhs) { + if (!std::binary_search(seen_rhs_keys.begin(), seen_rhs_keys.end(), kv.first.get())) { + equal.RecordMismatchPaths( + {map_paths.lhs_path->MissingMapEntry(), map_paths.rhs_path->MapValue(kv.first)}); + return false; + } + } + + ICHECK(lhs->size() == rhs->size()); + return true; + } + + static bool SEqualReduceTracedForSMap(const MapNode* lhs, const MapNode* rhs, + const SEqualReducer& equal) { + const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); + + // First, check that every key from `lhs` is also in `rhs`, and their values are equal. + for (const auto& kv : *lhs) { + ObjectPath lhs_path = map_paths.lhs_path->MapValue(kv.first); + auto it = rhs->find(kv.first); + if (it == rhs->end()) { + equal.RecordMismatchPaths({lhs_path, map_paths.rhs_path->MissingMapEntry()}); + return false; + } + + if (!equal(kv.second, it->second, {lhs_path, map_paths.rhs_path->MapValue(it->first)})) { + return false; + } + } + + // Second, make sure every key from `rhs` is also in `lhs`. + for (const auto& kv : *rhs) { + ObjectPath rhs_path = map_paths.rhs_path->MapValue(kv.first); + if (!lhs->count(kv.first)) { + equal.RecordMismatchPaths({map_paths.lhs_path->MissingMapEntry(), rhs_path}); + return false; + } + } + + ICHECK(lhs->size() == rhs->size()); + return true; + } + + static bool SEqualReduceTraced(const MapNode* lhs, const MapNode* rhs, + const SEqualReducer& equal) { + if (IsStringMap(lhs)) { + return SEqualReduceTracedForSMap(lhs, rhs, equal); + } else { + return SEqualReduceTracedForOMap(lhs, rhs, equal); + } + } + static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + if (equal.IsPathTracingEnabled()) { + return SEqualReduceTraced(lhs, rhs, equal); + } + if (rhs->size() != lhs->size()) return false; if (rhs->size() == 0) return true; - bool ls = std::all_of(lhs->begin(), lhs->end(), - [](const auto& v) { return v.first->template IsInstance(); }); - bool rs = std::all_of(rhs->begin(), rhs->end(), - [](const auto& v) { return v.first->template IsInstance(); }); + bool ls = IsStringMap(lhs); + bool rs = IsStringMap(rhs); if (ls != rs) { return false; } diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 7f48cc439234..09f47c1e85de 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -21,6 +21,7 @@ * \file tir/analysis/deep_equal.cc * \brief Deep equality checking. */ +#include #include #include #include @@ -32,21 +33,25 @@ namespace tir { class DeepCmpSEqualHandler : public SEqualReducer::Handler { public: // use direct recursion. - bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { + bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const ObjectPathPair&) final { if (lhs.same_as(rhs)) return true; if (!lhs.defined() && rhs.defined()) return false; if (!rhs.defined() && lhs.defined()) return false; if (lhs->type_index() != rhs->type_index()) return false; - return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false)); + return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) && + !fail_; } - ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); } + void DeferFail(const ObjectPathPair&) final { fail_ = true; } + ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); } void MarkGraphNode() final {} private: // reflection vtable ReflectionVTable* vtable_ = ReflectionVTable::Global(); + bool fail_ = false; }; bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { @@ -62,7 +67,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { if (lhs.as()) { return false; } - return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); + return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, {}); } TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index ff02f1e369ea..56ec76b81fa3 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -29,7 +29,7 @@ def consistent_equal(x, y, map_free_vars=False): if struct_equal0 != struct_equal1: raise ValueError( - "Non-communicative {} vs {}, sequal0={}, sequal1={}".format( + "Non-commutative {} vs {}, sequal0={}, sequal1={}".format( x, y, struct_equal0, struct_equal1 ) ) @@ -45,6 +45,28 @@ def consistent_equal(x, y, map_free_vars=False): return struct_equal0 +def get_sequal_mismatch(x, y, map_free_vars=False): + mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars) + mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars) + + if mismatch_0 is None and mismatch_1 is None: + return None + + if ( + mismatch_0 is None + or mismatch_1 is None + or mismatch_0[0] != mismatch_1[1] + or mismatch_0[1] != mismatch_1[0] + ): + raise ValueError( + "Non-commutative {} vs {}, mismatch_0={}, mismatch_1={}".format( + x, y, mismatch_0, mismatch_1 + ) + ) + + return mismatch_0 + + def test_exprs(): # save load json x = tvm.tir.const(1, "int32") @@ -107,6 +129,47 @@ def test_prim_func(): tvm.ir.assert_structural_equal(mod0, mod1) +def test_prim_func_param_count_mismatch(): + x = te.var("x") + y = te.var("y") + z = te.var("z") + # counter example of same equality + func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x)) + func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x)) + lhs_path, rhs_path = get_sequal_mismatch(func0, func1) + expected_lhs_path = tvm.runtime.ObjectPath.root().attr("params").missing_array_element(2) + expected_rhs_path = tvm.runtime.ObjectPath.root().attr("params").array_index(2) + assert lhs_path == expected_lhs_path + assert rhs_path == expected_rhs_path + + +def test_prim_func_param_dtype_mismatch(): + x = te.var("x") + y_0 = te.var("y", dtype="int32") + y_1 = te.var("z", dtype="float32") + # counter example of same equality + func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x)) + func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x)) + lhs_path, rhs_path = get_sequal_mismatch(func0, func1) + expected_path = tvm.runtime.ObjectPath.root().attr("params").array_index(1).attr("dtype") + assert lhs_path == expected_path + assert rhs_path == expected_path + + +def test_prim_func_body_mismatch(): + x_0 = te.var("x") + y_0 = te.var("y") + x_1 = te.var("x") + y_1 = te.var("y") + # counter example of same equality + func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0)) + func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1)) + lhs_path, rhs_path = get_sequal_mismatch(func0, func1) + expected_path = tvm.runtime.ObjectPath.root().attr("body").attr("value").attr("b") + assert lhs_path == expected_path + assert rhs_path == expected_path + + def test_array(): x = np.arange(10) nx = tvm.nd.array(x) @@ -183,6 +246,49 @@ def test_buffer_storage_scope(): assert not consistent_equal(func0, func2) +def test_buffer_map_mismatch(): + x = te.var("x") + buffer_0 = tvm.tir.decl_buffer((10, 10)) + buffer_0_clone = tvm.tir.decl_buffer((10, 10)) + buffer_1 = tvm.tir.decl_buffer((10, 20)) + + func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0}) + func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone}) + func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1}) + + lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) + expected_path = ( + tvm.runtime.ObjectPath.root() + .attr("buffer_map") + .map_value(x) + .attr("shape") + .array_index(1) + .attr("value") + ) + assert lhs_path == expected_path + assert rhs_path == expected_path + + assert get_sequal_mismatch(func_0, func_0_clone) is None + + +def test_buffer_map_length_mismatch(): + x = te.var("x") + y = te.var("x") + + buffer_0 = tvm.tir.decl_buffer((10, 10)) + buffer_1 = tvm.tir.decl_buffer((10, 20)) + + func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0}) + func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1}) + + lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) + + expected_lhs_path = tvm.runtime.ObjectPath.root().attr("buffer_map").missing_map_entry() + assert lhs_path == expected_lhs_path + expected_rhs_path = tvm.runtime.ObjectPath.root().attr("buffer_map").map_value(y) + assert rhs_path == expected_rhs_path + + def test_buffer_load_store(): b = tvm.tir.decl_buffer((10, 10), "float32") x = tvm.tir.BufferLoad(b, [0, 1]) @@ -208,6 +314,100 @@ def test_while(): assert consistent_equal(wx, wy, map_free_vars=True) +def test_while_condition_mismatch(): + x = tvm.tir.Var("x", "int32") + w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) + w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x)) + lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) + expected_path = tvm.runtime.ObjectPath.root().attr("condition") + assert lhs_path == expected_path + assert rhs_path == expected_path + + +def test_while_body_mismatch(): + x = tvm.tir.Var("x", "int32") + w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) + w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1)) + lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) + expected_path = tvm.runtime.ObjectPath.root().attr("body").attr("value") + assert lhs_path == expected_path + assert rhs_path == expected_path + + +def test_seq_mismatch(): + x = tvm.tir.Var("x", "int32") + seq_0 = tvm.tir.SeqStmt( + [ + tvm.tir.Evaluate(x), + tvm.tir.Evaluate(x + 1), + tvm.tir.Evaluate(x + 2), + tvm.tir.Evaluate(x + 3), + ] + ) + seq_1 = tvm.tir.SeqStmt( + [ + tvm.tir.Evaluate(x), + tvm.tir.Evaluate(x + 1), + tvm.tir.Evaluate(x + 99), + tvm.tir.Evaluate(x + 3), + ] + ) + lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) + expected_path = ( + tvm.runtime.ObjectPath.root() + .attr("seq") + .array_index(2) + .attr("value") + .attr("b") + .attr("value") + ) + assert lhs_path == expected_path + assert rhs_path == expected_path + + +def test_seq_mismatch_different_lengths(): + # Make sure we report a difference inside the array first, rather than the difference in length + x = tvm.tir.Var("x", "int32") + seq_0 = tvm.tir.SeqStmt( + [ + tvm.tir.Evaluate(x), + tvm.tir.Evaluate(x + 1), + tvm.tir.Evaluate(x + 2), + tvm.tir.Evaluate(x + 3), + ] + ) + seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)]) + lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) + expected_path = ( + tvm.runtime.ObjectPath.root() + .attr("seq") + .array_index(2) + .attr("value") + .attr("b") + .attr("value") + ) + assert lhs_path == expected_path + assert rhs_path == expected_path + + +def test_seq_length_mismatch(): + x = tvm.tir.Var("x", "int32") + seq_0 = tvm.tir.SeqStmt( + [ + tvm.tir.Evaluate(x), + tvm.tir.Evaluate(x + 1), + tvm.tir.Evaluate(x + 2), + tvm.tir.Evaluate(x + 3), + ] + ) + seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)]) + lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) + expected_lhs_path = tvm.runtime.ObjectPath.root().attr("seq").array_index(3) + expected_rhs_path = tvm.runtime.ObjectPath.root().attr("seq").missing_array_element(3) + assert lhs_path == expected_lhs_path + assert rhs_path == expected_rhs_path + + if __name__ == "__main__": test_exprs() test_prim_func()