Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ def apply_prim_func_arg_and_result_memory_constraints(
)


def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
def verify_well_formed(obj: Union[PrimFunc, IRModule], assert_mode: bool = True) -> bool:
"""Verify if the given TIR is well-formed. The verification includes:
- Check if expressions not contain vars that is defined outside the block.

Parameters
----------
func: tvm.tir.PrimFunc
The function to be verified.
obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule]
The function or module to be verified.

assert_mode: bool
The indicator if it raises an error when the function is not well-formed.
Expand All @@ -366,7 +366,7 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
result: bool
Whether it is a well-formed TIR function.
"""
return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member
return _ffi_api.VerifyWellFormed(obj, assert_mode) # type: ignore # pylint: disable=no-member


def OOBChecker():
Expand Down
25 changes: 24 additions & 1 deletion src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/stmt_functor.h>

#include "../ir/functor_common.h"
#include "tvm/ir/module.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -142,7 +143,29 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
return true;
}

TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);
bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
for (const auto& [gvar, base_func] : mod->functions) {
if (auto prim_func = base_func.as<PrimFunc>()) {
bool res = VerifyWellFormed(prim_func.value(), assert_mode);
if (!res) {
return false;
}
}
}
return true;
}

TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed")
.set_body_typed([](const ObjectRef& obj, bool assert_mode) {
if (auto opt = obj.as<PrimFunc>()) {
return VerifyWellFormed(opt.value(), assert_mode);
} else if (auto opt = obj.as<IRModule>()) {
return VerifyWellFormed(opt.value(), assert_mode);
} else {
LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found "
<< obj->GetTypeKey();
}
});

} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def element_wise(
C[i, j] = B[i, j] * 2.0

assert tvm.tir.analysis.verify_well_formed(element_wise)
assert tvm.tir.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise))


def test_fail_use_out_loop_var():
Expand Down