Skip to content

Commit bcc0e1d

Browse files
committed
[TIR] Modify IntImmNode deep_equal to match regardless of type
This patch makes a small change to compare the values of IntImmNode to see if they're equal when performing a deep_equal of expressions. This is to try and align it with how the [`PEqualChecker<IntImm>`](https://github.com/apache/tvm/blob/b2204ae6988c7745ea9736340ccd900bc21ae821/src/arith/pattern_match.h#L166) works where we only compare the values if both are IntImm. This caused some simplifications to be inconsistent based on whether we used IntImmNode or PrimExpr to pass an integer between different passes, and it seemed to make more sense to say that if the values are equal, then we can conclude the immediates are equal.
1 parent b2204ae commit bcc0e1d

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/tir/analysis/deep_equal.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
6363
if (lhs->type_index() != rhs->type_index()) return false;
6464
if (auto* plhs = lhs.as<IntImmNode>()) {
6565
auto* prhs = rhs.as<IntImmNode>();
66-
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
66+
return plhs->value == prhs->value;
6767
}
6868
if (lhs.as<AnyNode>()) {
6969
return false;

tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def func2():
3232
assert tvm.tir.analysis.expr_deep_equal(func2(), func2())
3333
assert not tvm.tir.analysis.expr_deep_equal(func2(), func1())
3434

35+
def test_equal_ints():
36+
assert tvm.tir.analysis.expr_deep_equal(128, tvm.tir.IntImm(dtype="int64", value=128))
3537

3638
if __name__ == "__main__":
3739
test_equal_expr()

0 commit comments

Comments
 (0)