From 035cb82f98d56ffe44afcb67ed57ed079c43030b Mon Sep 17 00:00:00 2001 From: Junekey Jeon Date: Fri, 3 May 2024 17:36:09 -0700 Subject: [PATCH] Properly implement preferred integer types policies (match (default), prefer-32, minimal) --- include/dragonbox/dragonbox.h | 297 ++++++++++++++++++++++------------ 1 file changed, 193 insertions(+), 104 deletions(-) diff --git a/include/dragonbox/dragonbox.h b/include/dragonbox/dragonbox.h index 0b0bf94..1288803 100644 --- a/include/dragonbox/dragonbox.h +++ b/include/dragonbox/dragonbox.h @@ -218,10 +218,13 @@ namespace jkj { // We need assert() macro, but it is not namespaced anyway, so nothing to do here. // + using JKJ_STD_REPLACEMENT_NAMESPACE::int_least8_t; + using JKJ_STD_REPLACEMENT_NAMESPACE::int_least16_t; using JKJ_STD_REPLACEMENT_NAMESPACE::int_least32_t; using JKJ_STD_REPLACEMENT_NAMESPACE::int_fast8_t; using JKJ_STD_REPLACEMENT_NAMESPACE::int_fast16_t; using JKJ_STD_REPLACEMENT_NAMESPACE::int_fast32_t; + using JKJ_STD_REPLACEMENT_NAMESPACE::uint_least8_t; using JKJ_STD_REPLACEMENT_NAMESPACE::uint_least16_t; using JKJ_STD_REPLACEMENT_NAMESPACE::uint_least32_t; using JKJ_STD_REPLACEMENT_NAMESPACE::uint_least64_t; @@ -1146,50 +1149,65 @@ namespace jkj { // Returns true if and only if n is divisible by 10^N. // Precondition: n <= 10^(N+1) // !!It takes an in-out parameter!! - template + template struct divide_by_pow10_info; - template <> - struct divide_by_pow10_info<1> { + template + struct divide_by_pow10_info<1, UInt> { static constexpr stdr::uint_least32_t magic_number = 6554; static constexpr int shift_amount = 16; }; template <> - struct divide_by_pow10_info<2> { + struct divide_by_pow10_info<1, stdr::uint_least8_t> { + static constexpr stdr::uint_least16_t magic_number = 103; + static constexpr int shift_amount = 10; + }; + + template <> + struct divide_by_pow10_info<1, stdr::uint_least16_t> { + static constexpr stdr::uint_least16_t magic_number = 103; + static constexpr int shift_amount = 10; + }; + + template + struct divide_by_pow10_info<2, UInt> { static constexpr stdr::uint_least32_t magic_number = 656; static constexpr int shift_amount = 16; }; - template - JKJ_CONSTEXPR14 bool - check_divisibility_and_divide_by_pow10(stdr::uint_least32_t& n) noexcept { + template <> + struct divide_by_pow10_info<2, stdr::uint_least16_t> { + static constexpr stdr::uint_least32_t magic_number = 41; + static constexpr int shift_amount = 12; + }; + + template + JKJ_CONSTEXPR14 bool check_divisibility_and_divide_by_pow10(UInt& n) noexcept { // Make sure the computation for max_n does not overflow. - static_assert(N + 1 <= log::floor_log10_pow2(31), ""); - assert(n <= compute_power(stdr::uint_least32_t(10))); + static_assert(N + 1 <= log::floor_log10_pow2(int(value_bits::value)), ""); + assert(n <= compute_power(UInt(10))); - using info = divide_by_pow10_info; - n *= info::magic_number; + using info = divide_by_pow10_info; + auto const prod = n * info::magic_number; - constexpr auto mask = - stdr::uint_least32_t((stdr::uint_least32_t(1) << info::shift_amount) - 1); - bool const result = ((n & mask) < info::magic_number); + constexpr auto mask = decltype(prod)((decltype(prod)(1) << info::shift_amount) - 1); + bool const result = ((prod & mask) < info::magic_number); - n >>= info::shift_amount; + n = UInt(prod >> info::shift_amount); return result; } // Compute floor(n / 10^N) for small n and N. // Precondition: n <= 10^(N+1) - template - JKJ_CONSTEXPR14 stdr::uint_least32_t - small_division_by_pow10(stdr::uint_least32_t n) noexcept { + template + JKJ_CONSTEXPR14 UInt small_division_by_pow10(UInt n) noexcept { // Make sure the computation for max_n does not overflow. - static_assert(N + 1 <= log::floor_log10_pow2(31), ""); - assert(n <= compute_power(stdr::uint_least32_t(10))); + static_assert(N + 1 <= log::floor_log10_pow2(int(value_bits::value)), ""); + assert(n <= compute_power(UInt(10))); - return (n * divide_by_pow10_info::magic_number) >> - divide_by_pow10_info::shift_amount; + return UInt((n * divide_by_pow10_info::magic_number) >> + divide_by_pow10_info::shift_amount); } // Compute floor(n / 10^N) for small N. @@ -2868,9 +2886,74 @@ namespace jkj { JKJ_INLINE_VARIABLE struct match_t { using preferred_integer_types_policy = match_t; - template + template + using remainder_type = typename FormatTraits::carrier_uint; + + template using decimal_exponent_type = typename FormatTraits::exponent_int; } match; + + JKJ_INLINE_VARIABLE struct prefer_32_t { + using preferred_integer_types_policy = prefer_32_t; + + template + using remainder_type = typename detail::stdr::conditional< + upper_bound <= + detail::stdr::numeric_limits::max(), + detail::stdr::uint_least32_t, typename FormatTraits::carrier_uint>::type; + + template + using decimal_exponent_type = typename detail::stdr::conditional< + FormatTraits::format::exponent_bits <= + detail::value_bits::value, + detail::stdr::int_least32_t, typename FormatTraits::exponent_int>::type; + } prefer_32; + + JKJ_INLINE_VARIABLE struct minimal_t { + using preferred_integer_types_policy = minimal_t; + + template + using remainder_type = typename detail::stdr::conditional< + upper_bound <= detail::stdr::numeric_limits::max(), + detail::stdr::uint_least8_t, + typename detail::stdr::conditional< + upper_bound <= + detail::stdr::numeric_limits::max(), + detail::stdr::uint_least16_t, + typename detail::stdr::conditional< + upper_bound <= + detail::stdr::numeric_limits::max(), + detail::stdr::uint_least32_t, + typename detail::stdr::conditional< + upper_bound <= detail::stdr::numeric_limits< + detail::stdr::uint_least64_t>::max(), + detail::stdr::uint_least64_t, + typename FormatTraits::carrier_uint>::type>::type>::type>::type; + + template + using decimal_exponent_type = typename detail::stdr::conditional< + lower_bound >= + detail::stdr::numeric_limits::min() && + upper_bound <= + detail::stdr::numeric_limits::max(), + detail::stdr::int_least8_t, + typename detail::stdr::conditional< + lower_bound >= + detail::stdr::numeric_limits::min() && + upper_bound <= + detail::stdr::numeric_limits::max(), + detail::stdr::int_least16_t, + typename detail::stdr::conditional< + lower_bound >= detail::stdr::numeric_limits< + detail::stdr::int_least32_t>::min() && + upper_bound <= detail::stdr::numeric_limits< + detail::stdr::int_least32_t>::max(), + detail::stdr::int_least32_t, + typename FormatTraits::exponent_int>::type>::type>::type; + } minimal; } } @@ -2917,9 +3000,9 @@ namespace jkj { return {carrier_uint(r >> 32), carrier_uint(r) == 0}; } - static constexpr detail::stdr::uint_least32_t compute_delta(cache_entry_type const& cache, + static constexpr detail::stdr::uint_least64_t compute_delta(cache_entry_type const& cache, int beta) noexcept { - return detail::stdr::uint_least32_t(cache >> (cache_bits - 1 - beta)); + return detail::stdr::uint_least64_t(cache >> (cache_bits - 1 - beta)); } static JKJ_CONSTEXPR20 compute_mul_parity_result @@ -2966,9 +3049,9 @@ namespace jkj { return {r.high(), r.low() == 0}; } - static constexpr detail::stdr::uint_least32_t compute_delta(cache_entry_type const& cache, + static constexpr detail::stdr::uint_least64_t compute_delta(cache_entry_type const& cache, int beta) noexcept { - return detail::stdr::uint_least32_t(cache.high() >> (total_bits - 1 - beta)); + return detail::stdr::uint_least64_t(cache.high() >> (total_bits - 1 - beta)); } static JKJ_CONSTEXPR20 compute_mul_parity_result @@ -3061,24 +3144,36 @@ namespace jkj { static constexpr int shorter_interval_tie_upper_threshold = -log::floor_log5_pow2(significand_bits + 2) - 2 - significand_bits; + template + using remainder_type = typename PreferredIntegerTypesPolicy::template remainder_type< + FormatTraits, compute_power(detail::stdr::uint_least64_t(10))>; + + template + using decimal_exponent_type = + typename PreferredIntegerTypesPolicy::template decimal_exponent_type< + FormatTraits, detail::stdr::int_least32_t(min(-max_k, min_k)), + detail::stdr::int_least32_t(max(max_k, -min_k + kappa + 1))>; + + template + using return_type = + decimal_fp, + SignPolicy::return_has_sign, TrailingZeroPolicy::report_trailing_zeros>; + //// The main algorithm assumes the input is a normal/subnormal finite number. template - JKJ_SAFEBUFFERS static JKJ_CONSTEXPR20 decimal_fp< - carrier_uint, - typename PreferredIntegerTypesPolicy::template decimal_exponent_type, - SignPolicy::return_has_sign, TrailingZeroPolicy::report_trailing_zeros> - compute_nearest(signed_significand_bits s, - exponent_int exponent_bits) noexcept { + JKJ_SAFEBUFFERS static JKJ_CONSTEXPR20 + return_type + compute_nearest(signed_significand_bits s, + exponent_int exponent_bits) noexcept { using cache_holder_type = typename CachePolicy::template cache_holder_type; static_assert( min_k >= cache_holder_type::min_k && max_k <= cache_holder_type::max_k, ""); - using decimal_exponent_type = - typename PreferredIntegerTypesPolicy::template decimal_exponent_type< - FormatTraits>; + using remainder_type_ = remainder_type; + using decimal_exponent_type_ = decimal_exponent_type; using multiplication_traits_ = multiplication_traits( + max_exponent - format::significand_bits, decimal_exponent_type_>( binary_exponent); auto const beta = binary_exponent + log::floor_log2_pow10( - decimal_exponent_type(-minus_k)); + decimal_exponent_type_(-minus_k)); // Compute xi and zi. auto const cache = CachePolicy::template get_cache( - decimal_exponent_type(-minus_k)); + decimal_exponent_type_(-minus_k)); auto xi = multiplication_traits_::compute_left_endpoint_for_shorter_interval_case( @@ -3173,7 +3268,7 @@ namespace jkj { if (decimal_significand * 10 >= xi) { return SignPolicy::handle_sign( s, TrailingZeroPolicy::template on_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + 1))); + decimal_significand, decimal_exponent_type_(minus_k + 1))); } // Otherwise, compute the round-up of y. @@ -3192,7 +3287,7 @@ namespace jkj { } return SignPolicy::handle_sign( s, TrailingZeroPolicy::template no_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k))); + decimal_significand, decimal_exponent_type_(minus_k))); } // Normal interval case. @@ -3211,19 +3306,20 @@ namespace jkj { auto interval_type = IntervalTypeProvider::normal_interval(s); // Compute k and beta. - auto const minus_k = decimal_exponent_type( + auto const minus_k = decimal_exponent_type_( log::floor_log10_pow2(binary_exponent) - + decimal_exponent_type_>(binary_exponent) - kappa); auto const cache = - CachePolicy::template get_cache(decimal_exponent_type(-minus_k)); + CachePolicy::template get_cache(decimal_exponent_type_(-minus_k)); auto const beta = binary_exponent + log::floor_log2_pow10( - decimal_exponent_type(-minus_k)); + decimal_exponent_type_(-minus_k)); // Compute zi and deltai. // 10^kappa <= deltai < 10^(kappa + 1) - auto const deltai = multiplication_traits_::compute_delta(cache, beta); + auto const deltai = static_cast( + multiplication_traits_::compute_delta(cache, beta)); // For the case of binary32, the result of integer check is not correct for // 29711844 * 2^-82 // = 6.1442653300000000008655037797566933477355632930994033813476... * 10^-18 @@ -3242,8 +3338,8 @@ namespace jkj { // Step 2: Try larger divisor; remove trailing zeros if necessary. ////////////////////////////////////////////////////////////////////// - constexpr auto big_divisor = compute_power(stdr::uint_least32_t(10)); - constexpr auto small_divisor = compute_power(stdr::uint_least32_t(10)); + constexpr auto big_divisor = compute_power(remainder_type_(10)); + constexpr auto small_divisor = compute_power(remainder_type_(10)); // Using an upper bound on zi, we might be able to optimize the division // better than the compiler; we are computing zi / big_divisor here. @@ -3251,14 +3347,13 @@ namespace jkj { div::divide_by_pow10(z_result.integer_part); - auto r = - stdr::uint_least32_t(z_result.integer_part - big_divisor * decimal_significand); + auto r = remainder_type_(z_result.integer_part - big_divisor * decimal_significand); do { if (r < deltai) { // Exclude the right endpoint if necessary. - if ((r | stdr::uint_least32_t(!z_result.is_integer) | - stdr::uint_least32_t(interval_type.include_right_endpoint())) == 0) { + if ((r | remainder_type_(!z_result.is_integer) | + remainder_type_(interval_type.include_right_endpoint())) == 0) { JKJ_IF_CONSTEXPR( BinaryToDecimalRoundingPolicy::tag == policy::binary_to_decimal_rounding::tag_t::do_not_care) { @@ -3267,7 +3362,7 @@ namespace jkj { return SignPolicy::handle_sign( s, TrailingZeroPolicy::template no_trailing_zeros( decimal_significand, - decimal_exponent_type(minus_k + kappa))); + decimal_exponent_type_(minus_k + kappa))); } else { --decimal_significand; @@ -3293,7 +3388,7 @@ namespace jkj { // We may need to remove trailing zeros. return SignPolicy::handle_sign( s, TrailingZeroPolicy::template on_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + kappa + 1))); + decimal_significand, decimal_exponent_type_(minus_k + kappa + 1))); } while (false); @@ -3326,7 +3421,9 @@ namespace jkj { } } else { - auto dist = r - (deltai / 2) + (small_divisor / 2); + // delta is equal to 10^(kappa + elog10(2) - floor(elog10(2))), so dist cannot + // be larger than r. + auto dist = remainder_type_(r - (deltai / 2) + (small_divisor / 2)); bool const approx_y_parity = ((dist ^ (small_divisor / 2)) & 1) != 0; // Is dist divisible by 10^kappa? @@ -3362,24 +3459,21 @@ namespace jkj { } return SignPolicy::handle_sign( s, TrailingZeroPolicy::template no_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + kappa))); + decimal_significand, decimal_exponent_type_(minus_k + kappa))); } template - JKJ_FORCEINLINE JKJ_SAFEBUFFERS static JKJ_CONSTEXPR20 decimal_fp< - carrier_uint, - typename PreferredIntegerTypesPolicy::template decimal_exponent_type, - SignPolicy::return_has_sign, TrailingZeroPolicy::report_trailing_zeros> - compute_left_closed_directed(signed_significand_bits s, - exponent_int exponent_bits) noexcept { + JKJ_FORCEINLINE JKJ_SAFEBUFFERS static JKJ_CONSTEXPR20 + return_type + compute_left_closed_directed(signed_significand_bits s, + exponent_int exponent_bits) noexcept { using cache_holder_type = typename CachePolicy::template cache_holder_type; static_assert( min_k >= cache_holder_type::min_k && max_k <= cache_holder_type::max_k, ""); - using decimal_exponent_type = - typename PreferredIntegerTypesPolicy::template decimal_exponent_type< - FormatTraits>; + using remainder_type_ = remainder_type; + using decimal_exponent_type_ = decimal_exponent_type; using multiplication_traits_ = multiplication_traits(binary_exponent) - + decimal_exponent_type_>(binary_exponent) - kappa); auto const cache = - CachePolicy::template get_cache(decimal_exponent_type(-minus_k)); + CachePolicy::template get_cache(decimal_exponent_type_(-minus_k)); int const beta = binary_exponent + log::floor_log2_pow10( - decimal_exponent_type(-minus_k)); + decimal_exponent_type_(-minus_k)); // Compute xi and deltai. // 10^kappa <= deltai < 10^(kappa + 1) - auto const deltai = multiplication_traits_::compute_delta(cache, beta); + auto const deltai = static_cast( + multiplication_traits_::compute_delta(cache, beta)); auto x_result = multiplication_traits_::compute_mul(two_fc << beta, cache); // Deal with the unique exceptional cases @@ -3439,7 +3534,7 @@ namespace jkj { // Step 2: Try larger divisor; remove trailing zeros if necessary. ////////////////////////////////////////////////////////////////////// - constexpr auto big_divisor = compute_power(stdr::uint_least32_t(10)); + constexpr auto big_divisor = compute_power(remainder_type_(10)); // Using an upper bound on xi, we might be able to optimize the division // better than the compiler; we are computing xi / big_divisor here. @@ -3447,12 +3542,12 @@ namespace jkj { div::divide_by_pow10(x_result.integer_part); - auto r = - stdr::uint_least32_t(x_result.integer_part - big_divisor * decimal_significand); + auto r = static_cast(x_result.integer_part - + big_divisor * decimal_significand); if (r != 0) { ++decimal_significand; - r = big_divisor - r; + r = remainder_type_(big_divisor - r); } do { @@ -3489,7 +3584,7 @@ namespace jkj { // The ceiling is inside, so we are done. return SignPolicy::handle_sign( s, TrailingZeroPolicy::template on_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + kappa + 1))); + decimal_significand, decimal_exponent_type_(minus_k + kappa + 1))); } while (false); @@ -3501,24 +3596,21 @@ namespace jkj { decimal_significand -= div::small_division_by_pow10(r); return SignPolicy::handle_sign( s, TrailingZeroPolicy::template no_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + kappa))); + decimal_significand, decimal_exponent_type_(minus_k + kappa))); } template - JKJ_FORCEINLINE JKJ_SAFEBUFFERS static JKJ_CONSTEXPR20 decimal_fp< - carrier_uint, - typename PreferredIntegerTypesPolicy::template decimal_exponent_type, - SignPolicy::return_has_sign, TrailingZeroPolicy::report_trailing_zeros> - compute_right_closed_directed(signed_significand_bits s, - exponent_int exponent_bits) noexcept { + JKJ_FORCEINLINE JKJ_SAFEBUFFERS static JKJ_CONSTEXPR20 + return_type + compute_right_closed_directed(signed_significand_bits s, + exponent_int exponent_bits) noexcept { using cache_holder_type = typename CachePolicy::template cache_holder_type; static_assert( min_k >= cache_holder_type::min_k && max_k <= cache_holder_type::max_k, ""); - using decimal_exponent_type = - typename PreferredIntegerTypesPolicy::template decimal_exponent_type< - FormatTraits>; + using remainder_type_ = remainder_type; + using decimal_exponent_type_ = decimal_exponent_type; using multiplication_traits_ = multiplication_traits(binary_exponent - - (shorter_interval ? 1 : 0)) - + decimal_exponent_type_>(binary_exponent - + (shorter_interval ? 1 : 0)) - kappa); auto const cache = CachePolicy::template get_cache(deicmal_exponent_type(-minus_k)); @@ -3560,8 +3652,8 @@ namespace jkj { // Compute zi and deltai. // 10^kappa <= deltai < 10^(kappa + 1) - auto const deltai = - multiplication_traits_::compute_delta(cache, beta - shorter_interval ? 1 : 0); + auto const deltai = static_cast( + multiplication_traits_::compute_delta(cache, beta - shorter_interval ? 1 : 0)); carrier_uint const zi = multiplication_traits_::compute_mul(two_fc << beta, cache).integer_part; @@ -3570,7 +3662,7 @@ namespace jkj { // Step 2: Try larger divisor; remove trailing zeros if necessary. ////////////////////////////////////////////////////////////////////// - constexpr auto big_divisor = compute_power(stdr::uint_least32_t(10)); + constexpr auto big_divisor = compute_power(remainder_type_(10)); // Using an upper bound on zi, we might be able to optimize the division better // than the compiler; we are computing zi / big_divisor here. @@ -3578,7 +3670,7 @@ namespace jkj { div::divide_by_pow10(zi); - auto const r = stdr::uint_least32_t(zi - big_divisor * decimal_significand); + auto const r = remainder_type_(zi - big_divisor * decimal_significand); do { if (r > deltai) { @@ -3596,7 +3688,7 @@ namespace jkj { // The floor is inside, so we are done. return SignPolicy::handle_sign( s, TrailingZeroPolicy::template on_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + kappa + 1))); + decimal_significand, decimal_exponent_type_(minus_k + kappa + 1))); } while (false); @@ -3608,7 +3700,7 @@ namespace jkj { decimal_significand += div::small_division_by_pow10(r); return SignPolicy::handle_sign( s, TrailingZeroPolicy::template no_trailing_zeros( - decimal_significand, decimal_exponent_type(minus_k + kappa))); + decimal_significand, decimal_exponent_type_(minus_k + kappa))); } static constexpr bool @@ -3897,19 +3989,13 @@ namespace jkj { Policies...>; template - using to_decimal_return_type = - decimal_fp::template decimal_exponent_type, - to_decimal_policy_holder::return_has_sign, - to_decimal_policy_holder::report_trailing_zeros>; + using to_decimal_return_type = typename impl::template return_type< + typename to_decimal_policy_holder::sign_policy, + typename to_decimal_policy_holder::trailing_zero_policy, + typename to_decimal_policy_holder::preferred_integer_types_policy>; template struct to_decimal_dispatcher { - using return_type = - decimal_fp, - PolicyHolder::return_has_sign, PolicyHolder::report_trailing_zeros>; using sign_policy = typename PolicyHolder::sign_policy; using trailing_zero_policy = typename PolicyHolder::trailing_zero_policy; using binary_to_decimal_rounding_policy = @@ -3917,6 +4003,9 @@ namespace jkj { using cache_policy = typename PolicyHolder::cache_policy; using preferred_integer_types_policy = typename PolicyHolder::preferred_integer_types_policy; + using return_type = + typename impl::template return_type; template JKJ_FORCEINLINE JKJ_SAFEBUFFERS JKJ_CONSTEXPR20 return_type