From 9c5164250788c00ae7346efea9439d567b5a9bfe Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 24 Jul 2025 15:31:18 +0800 Subject: [PATCH 1/2] fix issue 73692 --- paddle/phi/kernels/funcs/unique_functor.h | 69 +++++++++++++++++------ 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/funcs/unique_functor.h b/paddle/phi/kernels/funcs/unique_functor.h index 758b9160096d09..edb1d35d55ced8 100644 --- a/paddle/phi/kernels/funcs/unique_functor.h +++ b/paddle/phi/kernels/funcs/unique_functor.h @@ -23,6 +23,35 @@ namespace phi { namespace funcs { +template +static bool NaNSafeEqual(const T& a, const T& b) { + if constexpr (std::is_floating_point_v) { + if (std::isnan(a) && std::isnan(b)) { + return true; + } + if (std::isnan(a) || std::isnan(b)) { + return false; + } + } + return a == b; +} + +template +static bool NaNSafeLess(const T& a, const T& b) { + if constexpr (std::is_floating_point_v) { + if (std::isnan(a) && std::isnan(b)) { + return false; + } + if (std::isnan(a)) { + return false; + } + if (std::isnan(b)) { + return true; + } + } + return a < b; +} + template struct UniqueOpFunctor { const Context& dev_ctx_; @@ -122,7 +151,7 @@ static bool Equal(const DenseTensor& a, const DenseTensor& b) { return false; } for (int64_t i = 0; i < a.numel(); ++i) { - if (a.data()[i] != b.data()[i]) { + if (!NaNSafeEqual(a.data()[i], b.data()[i])) { return false; } } @@ -140,7 +169,15 @@ static void UniqueFlattenedTensor(const Context& dev_ctx, bool return_inverse, bool return_counts) { const InT* in_data = in.data(); - std::set unique(in_data, in_data + in.numel()); + + auto nan_safe_comp = [](const InT& a, const InT& b) { + return NaNSafeLess(a, b); + }; + std::set unique(nan_safe_comp); + for (int64_t i = 0; i < in.numel(); ++i) { + unique.insert(in_data[i]); + } + out->Resize(common::make_ddim({static_cast(unique.size())})); auto* out_data = dev_ctx.template Alloc(out); std::copy(unique.begin(), unique.end(), out_data); @@ -162,29 +199,27 @@ static void UniqueFlattenedTensor(const Context& dev_ctx, if (return_inverse) { index->Resize(common::make_ddim({in.numel()})); auto inverse_data = dev_ctx.template Alloc(index); - std::unordered_map inverse_map; - inverse_map.reserve(out->numel()); - for (int64_t i = 0; i < out->numel(); ++i) { - inverse_map[out_data[i]] = i; - } for (int64_t i = 0; i < in.numel(); ++i) { - inverse_data[i] = inverse_map[in_data[i]]; + for (int64_t j = 0; j < out->numel(); ++j) { + if (NaNSafeEqual(in_data[i], out_data[j])) { + inverse_data[i] = j; + break; + } + } } } if (return_counts) { count->Resize(common::make_ddim({out->numel()})); auto count_data = dev_ctx.template Alloc(count); - std::unordered_map counts_map; - counts_map.reserve(out->numel()); for (int64_t i = 0; i < out->numel(); ++i) { - counts_map[out_data[i]] = 0; - } - for (int64_t i = 0; i < in.numel(); i++) { - counts_map[in_data[i]] += 1; - } - for (int64_t i = 0; i < out->numel(); i++) { - count_data[i] = counts_map[out_data[i]]; + IndexT cnt = 0; + for (int64_t j = 0; j < in.numel(); ++j) { + if (NaNSafeEqual(out_data[i], in_data[j])) { + cnt++; + } + } + count_data[i] = cnt; } } } From a56060cc3d27d3f3f722737faae284488b7eaad0 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 6 Aug 2025 09:30:36 +0800 Subject: [PATCH 2/2] fix error --- paddle/phi/kernels/funcs/unique_functor.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/funcs/unique_functor.h b/paddle/phi/kernels/funcs/unique_functor.h index edb1d35d55ced8..fafb1b284c60a8 100644 --- a/paddle/phi/kernels/funcs/unique_functor.h +++ b/paddle/phi/kernels/funcs/unique_functor.h @@ -27,7 +27,7 @@ template static bool NaNSafeEqual(const T& a, const T& b) { if constexpr (std::is_floating_point_v) { if (std::isnan(a) && std::isnan(b)) { - return true; + return &a == &b; } if (std::isnan(a) || std::isnan(b)) { return false; @@ -39,15 +39,15 @@ static bool NaNSafeEqual(const T& a, const T& b) { template static bool NaNSafeLess(const T& a, const T& b) { if constexpr (std::is_floating_point_v) { - if (std::isnan(a) && std::isnan(b)) { - return false; - } - if (std::isnan(a)) { + if (std::isnan(a) && !std::isnan(b)) { return false; } - if (std::isnan(b)) { + if (!std::isnan(a) && std::isnan(b)) { return true; } + if (std::isnan(a) && std::isnan(b)) { + return &a < &b; + } } return a < b; }