Skip to content

Commit 1ca9833

Browse files
authored
[IR] Handle NaN in StructuralEqual and StructuralHash (#17249)
* [IR] Handle NaN in StructuralEqual and StructuralHash Prior to this commit, `NaN` values did not have any special handling in either `StructuralEqual` or `StructuralHash`. `StructuralEqual` checked whether the LHS and RHS were within some tolerance of each other. If the LHS and RHS are both `NaN`, this would evaluate to false. The updated `StructuralEqual` now checks for this case, and returns true if both sides are `NaN`. `StructuralHash` used the bit-pattern of a floating-point number to compute the hash. A `NaN` value may have any non-zero value in its mantissa, and so this could produce distinct hashes for ASTs that differ only by the choice of non-zero value. The updated `StructuralHash` uses the same `std::numeric_limits<double::quiet_NaN()` value for all `NaN` values. With these changes, `StructuralEqual` and `StructuralHash` can now compare two IR functions, even if they contain `NaN`. Closes #17247 * lint fix
1 parent 6f4ac23 commit 1ca9833

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

include/tvm/node/structural_equal.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/runtime/container/array.h>
2929
#include <tvm/runtime/data_type.h>
3030

31+
#include <cmath>
3132
#include <string>
3233

3334
namespace tvm {
@@ -38,11 +39,21 @@ namespace tvm {
3839
class BaseValueEqual {
3940
public:
4041
bool operator()(const double& lhs, const double& rhs) const {
41-
// fuzzy float pt comparison
42-
constexpr double atol = 1e-9;
43-
if (lhs == rhs) return true;
44-
double diff = lhs - rhs;
45-
return diff > -atol && diff < atol;
42+
if (std::isnan(lhs) && std::isnan(rhs)) {
43+
// IEEE floats do not compare as equivalent to each other.
44+
// However, for the purpose of comparing IR representation, two
45+
// NaN values are equivalent.
46+
return true;
47+
} else if (std::isnan(lhs) || std::isnan(rhs)) {
48+
return false;
49+
} else if (lhs == rhs) {
50+
return true;
51+
} else {
52+
// fuzzy float pt comparison
53+
constexpr double atol = 1e-9;
54+
double diff = lhs - rhs;
55+
return diff > -atol && diff < atol;
56+
}
4657
}
4758

4859
bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }

include/tvm/node/structural_hash.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
#include <tvm/runtime/data_type.h>
2828
#include <tvm/runtime/ndarray.h>
2929

30+
#include <cmath>
3031
#include <functional>
32+
#include <limits>
3133
#include <string>
3234

3335
namespace tvm {
@@ -52,7 +54,16 @@ class BaseValueHash {
5254

5355
public:
5456
uint64_t operator()(const float& key) const { return Reinterpret<float, uint32_t>(key); }
55-
uint64_t operator()(const double& key) const { return Reinterpret<double, uint64_t>(key); }
57+
uint64_t operator()(const double& key) const {
58+
if (std::isnan(key)) {
59+
// The IEEE format defines more than one bit-pattern that
60+
// represents NaN. For the purpose of comparing IR
61+
// representations, all NaN values are considered equivalent.
62+
return Reinterpret<double, uint64_t>(std::numeric_limits<double>::quiet_NaN());
63+
} else {
64+
return Reinterpret<double, uint64_t>(key);
65+
}
66+
}
5667
uint64_t operator()(const int64_t& key) const { return Reinterpret<int64_t, uint64_t>(key); }
5768
uint64_t operator()(const uint64_t& key) const { return key; }
5869
uint64_t operator()(const int& key) const { return Reinterpret<int, uint32_t>(key); }

tests/python/tir-base/test_tir_structural_equal_hash.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,5 +419,48 @@ def func(A: T.Buffer(1, "int32")):
419419
assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0]
420420

421421

422+
def test_nan_values_are_equivalent():
423+
"""Structural equality treats two NaN values as equivalent.
424+
425+
By IEEE, a check of `NaN == NaN` returns false, as does
426+
`abs(NaN - NaN) < tolerance`. However, for the purpose of
427+
comparing IR representations, both NaN values are equivalent.
428+
429+
"""
430+
431+
@T.prim_func(private=True)
432+
def func_1():
433+
return T.float32("nan")
434+
435+
@T.prim_func(private=True)
436+
def func_2():
437+
return T.float32("nan")
438+
439+
tvm.ir.assert_structural_equal(func_1, func_2)
440+
assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2)
441+
442+
443+
def test_all_nan_values_are_equivalent():
444+
"""Structural equality treats two NaN values as equivalent.
445+
446+
IEEE defines NaN as any value that has all exponent bits set,
447+
and has a non-zero mantissa. For the purposes of comparing IR
448+
representations, all NaN values are considered equivalent.
449+
450+
"""
451+
452+
# A NaN with the first payload bit set.
453+
nan_all_zeros = np.int32(0x7FC00000).view("float32")
454+
455+
# A NaN with the last payload bit set.
456+
nan_with_payload = np.int32(0x7F800001).view("float32")
457+
458+
float_1 = T.float32(nan_all_zeros)
459+
float_2 = T.float32(nan_with_payload)
460+
461+
tvm.ir.assert_structural_equal(float_1, float_2)
462+
assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2)
463+
464+
422465
if __name__ == "__main__":
423466
tvm.testing.main()

0 commit comments

Comments
 (0)