diff --git a/paddle/phi/kernels/funcs/unique_functor.h b/paddle/phi/kernels/funcs/unique_functor.h index 758b9160096d09..fafb1b284c60a8 100644 --- a/paddle/phi/kernels/funcs/unique_functor.h +++ b/paddle/phi/kernels/funcs/unique_functor.h @@ -23,6 +23,35 @@ namespace phi { namespace funcs { +template +static bool NaNSafeEqual(const T& a, const T& b) { + if constexpr (std::is_floating_point_v) { + if (std::isnan(a) && std::isnan(b)) { + return &a == &b; + } + if (std::isnan(a) || std::isnan(b)) { + return false; + } + } + return a == b; +} + +template +static bool NaNSafeLess(const T& a, const T& b) { + if constexpr (std::is_floating_point_v) { + if (std::isnan(a) && !std::isnan(b)) { + return false; + } + if (!std::isnan(a) && std::isnan(b)) { + return true; + } + if (std::isnan(a) && std::isnan(b)) { + return &a < &b; + } + } + return a < b; +} + template struct UniqueOpFunctor { const Context& dev_ctx_; @@ -122,7 +151,7 @@ static bool Equal(const DenseTensor& a, const DenseTensor& b) { return false; } for (int64_t i = 0; i < a.numel(); ++i) { - if (a.data()[i] != b.data()[i]) { + if (!NaNSafeEqual(a.data()[i], b.data()[i])) { return false; } } @@ -140,7 +169,15 @@ static void UniqueFlattenedTensor(const Context& dev_ctx, bool return_inverse, bool return_counts) { const InT* in_data = in.data(); - std::set unique(in_data, in_data + in.numel()); + + auto nan_safe_comp = [](const InT& a, const InT& b) { + return NaNSafeLess(a, b); + }; + std::set unique(nan_safe_comp); + for (int64_t i = 0; i < in.numel(); ++i) { + unique.insert(in_data[i]); + } + out->Resize(common::make_ddim({static_cast(unique.size())})); auto* out_data = dev_ctx.template Alloc(out); std::copy(unique.begin(), unique.end(), out_data); @@ -162,29 +199,27 @@ static void UniqueFlattenedTensor(const Context& dev_ctx, if (return_inverse) { index->Resize(common::make_ddim({in.numel()})); auto inverse_data = dev_ctx.template Alloc(index); - std::unordered_map inverse_map; - inverse_map.reserve(out->numel()); - for (int64_t i = 0; i < out->numel(); ++i) { - inverse_map[out_data[i]] = i; - } for (int64_t i = 0; i < in.numel(); ++i) { - inverse_data[i] = inverse_map[in_data[i]]; + for (int64_t j = 0; j < out->numel(); ++j) { + if (NaNSafeEqual(in_data[i], out_data[j])) { + inverse_data[i] = j; + break; + } + } } } if (return_counts) { count->Resize(common::make_ddim({out->numel()})); auto count_data = dev_ctx.template Alloc(count); - std::unordered_map counts_map; - counts_map.reserve(out->numel()); for (int64_t i = 0; i < out->numel(); ++i) { - counts_map[out_data[i]] = 0; - } - for (int64_t i = 0; i < in.numel(); i++) { - counts_map[in_data[i]] += 1; - } - for (int64_t i = 0; i < out->numel(); i++) { - count_data[i] = counts_map[out_data[i]]; + IndexT cnt = 0; + for (int64_t j = 0; j < in.numel(); ++j) { + if (NaNSafeEqual(out_data[i], in_data[j])) { + cnt++; + } + } + count_data[i] = cnt; } } }