Skip to content

Commit

Permalink
Add a specialization of IsEqual for bfloat16 based on the specializat…
Browse files Browse the repository at this point in the history
…ion of Eigen::half. This allows bfloat16 tensors with NaNs to be compared in unit tests.

PiperOrigin-RevId: 674907803
  • Loading branch information
LarryLansing authored and tensorflower-gardener committed Sep 15, 2024
1 parent 779c7e5 commit 6537419
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tensorflow/core/framework/tensor_testutil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,35 @@ static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y,
}
return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(tsl::bfloat16 x, tsl::bfloat16 y,
Tolerance t) {
// We consider NaNs equal for testing.
if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
return ::testing::AssertionSuccess();

// Below is a reimplementation of CmpHelperFloatingPointEQ<tsl::bfloat16>,
// which we cannot use because tsl::bfloat16 is not default-constructible.

if (Eigen::numext::isnan(x) || Eigen::numext::isnan(y))
return EqualFailure(x, y);

auto sign_and_magnitude_to_biased = [](uint16_t sam) {
const uint16_t kSignBitMask = 0x8000;
if (kSignBitMask & sam) return ~sam + 1; // negative number.
return kSignBitMask | sam; // positive number.
};

auto xb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(x));
auto yb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(y));
if (t == Tolerance::kNone) {
if (xb == yb) return ::testing::AssertionSuccess();
} else {
auto distance = xb >= yb ? xb - yb : yb - xb;
const uint16_t kMaxUlps = 4;
if (distance <= kMaxUlps) return ::testing::AssertionSuccess();
}
return EqualFailure(x, y);
}
template <typename T>
static ::testing::AssertionResult IsEqual(const T& x, const T& y, Tolerance t) {
if (::testing::internal::CmpHelperEQ<T>("", "", x, y))
Expand Down

0 comments on commit 6537419

Please sign in to comment.