diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp index 369dbcb0d981..233934567ff3 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp @@ -17,6 +17,11 @@ #include namespace bb::group_elements { + +// MSB of the top 64-bit limb in a uint256_t (bit 255). Used in point compression to encode the +// y-coordinate parity bit, and cleared when recovering the x-coordinate. +static constexpr uint64_t UINT256_TOP_LIMB_MSB = 0x8000000000000000ULL; + template concept SupportsHashToCurve = T::can_hash_to_curve; template class alignas(64) affine_element { @@ -80,10 +85,6 @@ template class alignas(64) affine constexpr affine_element operator*(const Fr& exponent) const noexcept; - template > 255) == uint256_t(0), void>> - [[nodiscard]] constexpr uint256_t compress() const noexcept; - static constexpr affine_element infinity(); constexpr affine_element set_infinity() const noexcept; constexpr void self_set_infinity() noexcept; diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp index 420ad908be68..3dd3b3f048cf 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp @@ -139,7 +139,10 @@ template class TestAffineElement : public testing::Test { { for (size_t i = 0; i < 10; i++) { affine_element P = affine_element(element::random_element()); - uint256_t compressed = P.compress(); + uint256_t compressed = uint256_t(P.x); + if (uint256_t(P.y).get_bit(0)) { + compressed.data[3] |= group_elements::UINT256_TOP_LIMB_MSB; + } affine_element Q = affine_element::from_compressed(compressed); EXPECT_EQ(P, Q); } @@ -168,8 +171,6 @@ template class TestAffineElement : public testing::Test { affine_element R(0, P.y); ASSERT_FALSE(P == R); } - // Regression test to ensure that the point at infinity is not equal to its coordinate-wise reduction, which may lie - // on the curve, depending on the y-coordinate. static void test_infinity_ordering_regression() { affine_element P(0, 1); diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp index ba51cb33093a..294f4889125e 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp @@ -21,7 +21,7 @@ template constexpr affine_element affine_element::from_compressed(const uint256_t& compressed) noexcept { uint256_t x_coordinate = compressed; - x_coordinate.data[3] = x_coordinate.data[3] & (~0x8000000000000000ULL); + x_coordinate.data[3] = x_coordinate.data[3] & (~UINT256_TOP_LIMB_MSB); bool y_bit = compressed.get_bit(255); Fq x = Fq(x_coordinate); @@ -80,18 +80,6 @@ constexpr affine_element affine_element::operator*(const F return bb::group_elements::element(*this) * exponent; } -template -template - -constexpr uint256_t affine_element::compress() const noexcept -{ - uint256_t out(x); - if (uint256_t(y).get_bit(0)) { - out.data[3] = out.data[3] | 0x8000000000000000ULL; - } - return out; -} - template constexpr affine_element affine_element::infinity() { affine_element e{}; @@ -157,15 +145,9 @@ constexpr bool affine_element::operator==(const affine_element& other return !only_one_is_infinity && (both_infinity || ((x == other.x) && (y == other.y))); } -/** - * Comparison operators (for std::sort) - * - * @details CAUTION!! Don't use this operator. It has no meaning other than for use by std::sort. - **/ template constexpr bool affine_element::operator>(const affine_element& other) const noexcept { - // We are setting point at infinity to always be the lowest element if (is_point_at_infinity()) { return false; } diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp index fc2afad7028d..b6ae9683eca4 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp @@ -59,7 +59,6 @@ template class alignas(32) element { constexpr element dbl() const noexcept; constexpr void self_dbl() noexcept; - constexpr void self_mixed_add_or_sub(const affine_element& other, uint64_t predicate) noexcept; constexpr element operator+(const element& other) const noexcept; constexpr element operator+(const affine_element& other) const noexcept; @@ -128,27 +127,6 @@ template class alignas(32) element { template > static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept; - // { - // bool found_one = false; - // Fq yy; - // Fq x; - // Fq y; - // Fq t0; - // while (!found_one) { - // x = Fq::random_element(engine); - // yy = x.sqr() * x + Params::b; - // if constexpr (Params::has_a) { - // yy += (x * Params::a); - // } - // y = yy.sqrt(); - // t0 = y.sqr(); - // found_one = (yy == t0); - // } - // return { x, y, Fq::one() }; - // } - static void conditional_negate_affine(const affine_element& in, - affine_element& out, - uint64_t predicate) noexcept; friend std::ostream& operator<<(std::ostream& os, const element& a) { @@ -162,10 +140,6 @@ template std::ostream& operator<<(std::ostrea return os << "x:" << e.x << " y:" << e.y << " z:" << e.z; } -// constexpr element::one = element{ Params::one_x, Params::one_y, Fq::one() }; -// constexpr element::point_at_infinity = one.set_infinity(); -// constexpr element::curve_b = Params::b; - } // namespace bb::group_elements #include "./element_impl.hpp" diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp index 3e45799aaade..54ce94e5feb7 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp @@ -155,99 +155,6 @@ template constexpr element element -constexpr void element::self_mixed_add_or_sub(const affine_element& other, - const uint64_t predicate) noexcept -{ - if constexpr (Fq::modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) { - if (is_point_at_infinity()) { - conditional_negate_affine(other, *(affine_element*)this, predicate); // NOLINT - z = Fq::one(); - return; - } - } else { - const bool edge_case_trigger = x.is_msb_set() || other.x.is_msb_set(); - if (edge_case_trigger) { - if (x.is_msb_set()) { - conditional_negate_affine(other, *(affine_element*)this, predicate); // NOLINT - z = Fq::one(); - } - return; - } - } - - // T0 = z1.z1 - Fq T0 = z.sqr(); - - // T1 = x2.t0 - x1 = x2.z1.z1 - x1 - Fq T1 = other.x * T0; - T1 -= x; - - // T2 = T0.z1 = z1.z1.z1 - // T2 = T2.y2 - y1 = y2.z1.z1.z1 - y1 - Fq T2 = z * T0; - T2 *= other.y; - T2.self_conditional_negate(predicate); - T2 -= y; - - if (__builtin_expect(T1.is_zero(), 0)) { - if (T2.is_zero()) { - // y2 equals y1, x2 equals x1, double x1 - self_dbl(); - return; - } - self_set_infinity(); - return; - } - - // T2 = 2T2 = 2(y2.z1.z1.z1 - y1) = R - // z3 = z1 + H - T2 += T2; - z += T1; - - // T3 = T1*T1 = HH - Fq T3 = T1.sqr(); - - // z3 = z3 - z1z1 - HH - T0 += T3; - - // z3 = (z1 + H)*(z1 + H) - z.self_sqr(); - z -= T0; - - // T3 = 4HH - T3 += T3; - T3 += T3; - - // T1 = T1*T3 = 4HHH - T1 *= T3; - - // T3 = T3 * x1 = 4HH*x1 - T3 *= x; - - // T0 = 2T3 - T0 = T3 + T3; - - // T0 = T0 + T1 = 2(4HH*x1) + 4HHH - T0 += T1; - x = T2.sqr(); - - // x3 = x3 - T0 = R*R - 8HH*x1 -4HHH - x -= T0; - - // T3 = T3 - x3 = 4HH*x1 - x3 - T3 -= x; - - T1 *= y; - T1 += T1; - - // T3 = T2 * T3 = R*(4HH*x1 - x3) - T3 *= T2; - - // y3 = T3 - T1 - y = T3 - T1; -} - template constexpr element element::operator+=(const affine_element& other) noexcept { @@ -1057,14 +964,6 @@ std::vector> element::batch_mul_with_endomo return work_elements; } -template -void element::conditional_negate_affine(const affine_element& in, - affine_element& out, - const uint64_t predicate) noexcept -{ - out = { in.x, predicate ? -in.y : in.y }; -} - template void element::batch_normalize(element* elements, const size_t num_elements) noexcept { diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp index e60de391c108..4c02e1c39983 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp @@ -130,16 +130,6 @@ template class group { } return derive_generators(domain_bytes, num_generators, starting_index); } - - BB_INLINE static void conditional_negate_affine(const affine_element* src, - affine_element* dest, - uint64_t predicate); }; } // namespace bb - -#ifdef DISABLE_ASM -#include "group_impl_int128.tcc" -#else -#include "group_impl_asm.tcc" -#endif diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_asm.tcc b/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_asm.tcc deleted file mode 100644 index 2177ba1ad37a..000000000000 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_asm.tcc +++ /dev/null @@ -1,162 +0,0 @@ -#pragma once - -#ifndef DISABLE_ASM - -#include "barretenberg/ecc/groups/group.hpp" -#include - -namespace bb { -// copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const affine_element* src, affine_element* -// dest) -// { -// if constexpr (Params::small_elements) { -// #if defined __AVX__ && defined USE_AVX -// ASSERT((((uintptr_t)src & 0x1f) == 0)); -// ASSERT((((uintptr_t)dest & 0x1f) == 0)); -// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t" -// "vmovdqa 32(%0), %%ymm1 \n\t" -// "vmovdqa %%ymm0, 0(%1) \n\t" -// "vmovdqa %%ymm1, 32(%1) \n\t" -// : -// : "r"(src), "r"(dest) -// : "%ymm0", "%ymm1", "memory"); -// #else -// *dest = *src; -// #endif -// } else { -// *dest = *src; -// } -// } - -// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const element* src, element* dest) -// { -// if constexpr (Params::small_elements) { -// #if defined __AVX__ && defined USE_AVX -// ASSERT((((uintptr_t)src & 0x1f) == 0)); -// ASSERT((((uintptr_t)dest & 0x1f) == 0)); -// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t" -// "vmovdqa 32(%0), %%ymm1 \n\t" -// "vmovdqa 64(%0), %%ymm2 \n\t" -// "vmovdqa %%ymm0, 0(%1) \n\t" -// "vmovdqa %%ymm1, 32(%1) \n\t" -// "vmovdqa %%ymm2, 64(%1) \n\t" -// : -// : "r"(src), "r"(dest) -// : "%ymm0", "%ymm1", "%ymm2", "memory"); -// #else -// *dest = *src; -// #endif -// } else { -// *dest = src; -// } -// } - -// copies src into dest, inverting y-coordinate if 'predicate' is true -// n.b. requires src and dest to be aligned on 32 byte boundary -template -inline void group::conditional_negate_affine(const affine_element* src, - affine_element* dest, - uint64_t predicate) -{ - constexpr uint256_t twice_modulus = Fq::modulus + Fq::modulus; - - constexpr uint64_t twice_modulus_0 = twice_modulus.data[0]; - constexpr uint64_t twice_modulus_1 = twice_modulus.data[1]; - constexpr uint64_t twice_modulus_2 = twice_modulus.data[2]; - constexpr uint64_t twice_modulus_3 = twice_modulus.data[3]; - - if constexpr (Params::small_elements) { -#if defined __AVX__ && defined USE_AVX - BB_ASSERT_EQ(((uintptr_t)src & 0x1f, 0)); - BB_ASSERT_EQ(((uintptr_t)dest & 0x1f, 0)); - __asm__ __volatile__("xorq %%r8, %%r8 \n\t" - "movq 32(%0), %%r8 \n\t" - "movq 40(%0), %%r9 \n\t" - "movq 48(%0), %%r10 \n\t" - "movq 56(%0), %%r11 \n\t" - "movq %[modulus_0], %%r12 \n\t" - "movq %[modulus_1], %%r13 \n\t" - "movq %[modulus_2], %%r14 \n\t" - "movq %[modulus_3], %%r15 \n\t" - "subq %%r8, %%r12 \n\t" - "sbbq %%r9, %%r13 \n\t" - "sbbq %%r10, %%r14 \n\t" - "sbbq %%r11, %%r15 \n\t" - "testq %2, %2 \n\t" - "cmovnzq %%r12, %%r8 \n\t" - "cmovnzq %%r13, %%r9 \n\t" - "cmovnzq %%r14, %%r10 \n\t" - "cmovnzq %%r15, %%r11 \n\t" - "vmovdqa 0(%0), %%ymm0 \n\t" - "vmovdqa %%ymm0, 0(%1) \n\t" - "movq %%r8, 32(%1) \n\t" - "movq %%r9, 40(%1) \n\t" - "movq %%r10, 48(%1) \n\t" - "movq %%r11, 56(%1) \n\t" - : - : "r"(src), - "r"(dest), - "r"(predicate), - [modulus_0] "i"(twice_modulus_0), - [modulus_1] "i"(twice_modulus_1), - [modulus_2] "i"(twice_modulus_2), - [modulus_3] "i"(twice_modulus_3) - : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "%ymm0", "memory", "cc"); -#else - __asm__ __volatile__("xorq %%r8, %%r8 \n\t" - "movq 32(%0), %%r8 \n\t" - "movq 40(%0), %%r9 \n\t" - "movq 48(%0), %%r10 \n\t" - "movq 56(%0), %%r11 \n\t" - "movq %[modulus_0], %%r12 \n\t" - "movq %[modulus_1], %%r13 \n\t" - "movq %[modulus_2], %%r14 \n\t" - "movq %[modulus_3], %%r15 \n\t" - "subq %%r8, %%r12 \n\t" - "sbbq %%r9, %%r13 \n\t" - "sbbq %%r10, %%r14 \n\t" - "sbbq %%r11, %%r15 \n\t" - "testq %2, %2 \n\t" - "cmovnzq %%r12, %%r8 \n\t" - "cmovnzq %%r13, %%r9 \n\t" - "cmovnzq %%r14, %%r10 \n\t" - "cmovnzq %%r15, %%r11 \n\t" - "movq 0(%0), %%r12 \n\t" - "movq 8(%0), %%r13 \n\t" - "movq 16(%0), %%r14 \n\t" - "movq 24(%0), %%r15 \n\t" - "movq %%r8, 32(%1) \n\t" - "movq %%r9, 40(%1) \n\t" - "movq %%r10, 48(%1) \n\t" - "movq %%r11, 56(%1) \n\t" - "movq %%r12, 0(%1) \n\t" - "movq %%r13, 8(%1) \n\t" - "movq %%r14, 16(%1) \n\t" - "movq %%r15, 24(%1) \n\t" - : - : "r"(src), - "r"(dest), - "r"(predicate), - [modulus_0] "i"(twice_modulus_0), - [modulus_1] "i"(twice_modulus_1), - [modulus_2] "i"(twice_modulus_2), - [modulus_3] "i"(twice_modulus_3) - : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "memory", "cc"); -#endif - } else { - if (predicate) { // NOLINT - Fq::__copy(src->x, dest->x); - dest->y = -src->y; - } else { - copy_affine(*src, *dest); - } - } -} - -} // namespace bb - -#endif diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_int128.tcc b/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_int128.tcc deleted file mode 100644 index 761cbe7d1334..000000000000 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_int128.tcc +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#ifdef DISABLE_ASM - -#include "barretenberg/ecc/groups/group.hpp" -#include - -namespace bb { - -// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const affine_element* src, affine_element* -// dest) -// { -// *dest = *src; -// } - -// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const element* src, element* dest) -// { -// *dest = *src; -// } - -template -inline void group::conditional_negate_affine(const affine_element* src, - affine_element* dest, - uint64_t predicate) -{ - *dest = predicate ? -(*src) : (*src); -} -} // namespace bb - -#endif diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp index e6c32a388f8a..9ec8606b9788 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp @@ -7,93 +7,60 @@ #pragma once #include "barretenberg/numeric/bitop/get_msb.hpp" #include -#include // NOLINTBEGIN(readability-implicit-bool-conversion) + +/** + * @brief Fixed-window non-adjacent form (WNAF) scalar decomposition for elliptic curve scalar multiplication. + * + * @details WNAF decomposes a scalar into a sequence of odd signed digits in the range [-(2^w - 1), 2^w - 1], + * where w = wnaf_bits. Each digit is packed into a uint64_t entry with the following bit layout: + * + * Bit 63 32 31 30 0 + * ┌────────────────────────┬────┬──────────────────────────┐ + * │ point_index │sign│ table_index │ + * └────────────────────────┴────┴──────────────────────────┘ + * + * - table_index (bits 0-30): abs(digit) >> 1. Since all digits are odd, the absolute value is always + * 2*k + 1 for some k, so table_index = k. This directly indexes a precomputed + * lookup table of odd multiples [1·P, 3·P, 5·P, ...]. + * In the Pippenger MSM path, this is the bucket index that determines which + * bucket the point is accumulated into. + * - sign (bit 31): 0 = positive digit, 1 = negative digit (negate the point's y-coordinate). + * - point_index (bits 32-63): identifies which input point this entry refers to. In single-scalar + * multiplication this is 0. In multi-scalar multiplication (Pippenger), + * this records which of the N input points the entry belongs to, since the + * schedule is later sorted by bucket and the original point ordering is lost. + * + * The template `wnaf_round` / `fixed_wnaf` variants shift point_index into bits 32+ internally. + * The runtime `fixed_wnaf` variant expects the caller to pass point_index pre-shifted. + */ namespace bb::wnaf { constexpr size_t SCALAR_BITS = 127; #define WNAF_SIZE(x) ((bb::wnaf::SCALAR_BITS + (x) - 1) / (x)) // NOLINT(cppcoreguidelines-macro-usage) -constexpr size_t get_optimal_bucket_width(const size_t num_points) -{ - if (num_points >= 14617149) { - return 21; - } - if (num_points >= 1139094) { - return 18; - } - // if (num_points >= 100000) - if (num_points >= 155975) { - return 15; - } - if (num_points >= 144834) - // if (num_points >= 100000) - { - return 14; - } - if (num_points >= 25067) { - return 12; - } - if (num_points >= 13926) { - return 11; - } - if (num_points >= 7659) { - return 10; - } - if (num_points >= 2436) { - return 9; - } - if (num_points >= 376) { - return 7; - } - if (num_points >= 231) { - return 6; - } - if (num_points >= 97) { - return 5; - } - if (num_points >= 35) { - return 4; - } - if (num_points >= 10) { - return 3; - } - if (num_points >= 2) { - return 2; - } - return 1; -} -constexpr size_t get_num_buckets(const size_t num_points) -{ - const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2); - return 1UL << bits_per_bucket; -} - -constexpr size_t get_num_rounds(const size_t num_points) -{ - const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2); - return WNAF_SIZE(bits_per_bucket + 1); -} +/** + * @brief Extract a window of `bits` consecutive bits starting at `bit_position` from a 128-bit scalar. + * + * @tparam bits The number of bits in the window (0 returns 0). + * @tparam bit_position The starting bit index within the 128-bit scalar. + * @param scalar Pointer to a 128-bit scalar stored as two consecutive uint64_t limbs (little-endian word order). + * @return The integer value of the extracted bit window. + * + * @details We determine which 64-bit limb(s) the window touches by computing + * lo_limb_idx = bit_position / 64 and hi_limb_idx = (bit_position + bits - 1) / 64. + * For the low limb, we right-shift by (bit_position % 64) to align the desired bits to position 0. + * If the window fits entirely within one limb (lo_limb_idx == hi_limb_idx), we simply mask off `bits` bits. + * Otherwise, the window straddles two limbs: we left-shift the high limb by (64 - bit_position % 64) to place + * its contributing bits adjacent to the low limb's bits, OR them together, and then mask to `bits` bits. + */ template inline uint64_t get_wnaf_bits_const(const uint64_t* scalar) noexcept { if constexpr (bits == 0) { return 0ULL; } else { - /** - * we want to take a 128 bit scalar and shift it down by (bit_position). - * We then wish to mask out `bits` number of bits. - * Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also - * (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual - * bit position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - - * (bit_position & 63) - * - * So, step 1: - * get low limb and shift right by (bit_position & 63) - * get high limb and shift left by (64 - (bit_position & 63)) - * - */ constexpr size_t lo_limb_idx = bit_position / 64; constexpr size_t hi_limb_idx = (bit_position + bits - 1) / 64; constexpr uint64_t lo_shift = bit_position & 63UL; @@ -110,21 +77,17 @@ template inline uint64_t get_wnaf_bits_const( } } +/** + * @brief A variant of the previous function that the bit position and number of bits are provided at runtime. + * + * @param scalar Pointer to a 128-bit scalar stored as two consecutive uint64_t limbs (little-endian word order). + * @param bits The number of bits in the window (0 returns 0). + * @param bit_position The starting bit index within the 128-bit scalar. + * @return The integer value of the extracted bit window. + */ inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const uint64_t bit_position) noexcept { - /** - * we want to take a 128 bit scalar and shift it down by (bit_position). - * We then wish to mask out `bits` number of bits. - * Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also - * (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual bit - * position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - (bit_position - * & 63) - * - * So, step 1: - * get low limb and shift right by (bit_position & 63) - * get high limb and shift left by (64 - (bit_position & 63)) - * - */ + const auto lo_limb_idx = static_cast(bit_position >> 6); const auto hi_limb_idx = static_cast((bit_position + bits - 1) >> 6); const uint64_t lo_shift = bit_position & 63UL; @@ -138,35 +101,11 @@ inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const return (lo & bit_mask) | (hi & hi_mask); } -inline void fixed_wnaf_packed( - const uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const uint64_t point_index, const size_t wnaf_bits) noexcept -{ - skew_map = ((scalar[0] & 1) == 0); - uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); - const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; - - for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { - uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[(wnaf_entries - round_i)] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - previous = slice + predicate; - } - size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1)); - uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - uint64_t predicate = ((slice & 1UL) == 0UL); - - wnaf[1] = ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); -} - /** * @brief Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication. * - * WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn optimizes - * the number of point doublings in scalar multiplication, in turn aiding efficiency. + * @details WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn + * optimizes the number of point doublings in scalar multiplication, in turn aiding efficiency. * * @param scalar Pointer to 128-bit scalar for which WNAF is to be computed. * @param wnaf Pointer to num_points+1 size array where the computed WNAF will be stored. @@ -182,16 +121,29 @@ inline void fixed_wnaf(const uint64_t* scalar, const uint64_t num_points, const size_t wnaf_bits) noexcept { + // If the scalar is even, we set the skew map to true. The skew is used to subtract a base point from the msm result + // in case scalar is even. skew_map = ((scalar[0] & 1) == 0); + // The first slice is the least significant slice of the scalar. uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + // For the rest we start a rolling window of wnaf_bits bits, and compute the wnaf slice. for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + // Check if the slice is even. This will be used to borrow from the previous slice. uint64_t predicate = ((slice & 1UL) == 0UL); + // If the current slice is odd (predicate=0), the WNAF digit is simply `previous`. + // If even (predicate=1), we borrow: subtract 2^wnaf_bits from `previous` to get a + // negative value, then negate via XOR with all-ones (two's complement identity: + // -x = ~x + 1, but we immediately shift right by 1, absorbing the +1 since the + // result is guaranteed odd). The >> 1 converts from the raw odd value to a bucket + // index (e.g., value 5 → bucket 2, value 7 → bucket 3). Bit 31 stores the sign + // (1 = negative), and the upper bits carry point_index for multi-scalar indexing. wnaf[(wnaf_entries - round_i) * num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index); + // Carry the borrow into the next window: if we borrowed, add 1 to the current slice. previous = slice + predicate; } size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1)); @@ -199,151 +151,19 @@ inline void fixed_wnaf(const uint64_t* scalar, uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); + ((((previous - (predicate << (wnaf_bits))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index); wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); } /** - * Current flow... - * - * If a wnaf entry is even, we add +1 to it, and subtract 32 from the previous entry. - * This works if the previous entry is odd. If we recursively apply this process, starting at the least significant - *window, this will always be the case. - * - * However, we want to skip over windows that are 0, which poses a problem. - * - * Scenario 1: even window followed by 0 window followed by any window 'x' - * - * We can't add 1 to the even window and subtract 32 from the 0 window, as we don't have a bucket that maps to -32 - * This means that we have to identify whether we are going to borrow 32 from 'x', requiring us to look at least 2 - *steps ahead - * - * Scenario 2: <0> <0> - * - * This problem proceeds indefinitely - if we have adjacent 0 windows, we do not know whether we need to track a - *borrow flag until we identify the next non-zero window - * - * Scenario 3: <0> - * - * This one works... - * - * Ok, so we should be a bit more limited with when we don't include window entries. - * The goal here is to identify short scalars, so we want to identify the most significant non-zero window - **/ -inline uint64_t get_num_scalar_bits(const uint64_t* scalar) -{ - const uint64_t msb_1 = numeric::get_msb(scalar[1]); - const uint64_t msb_0 = numeric::get_msb(scalar[0]); - - const uint64_t scalar_1_mask = (0ULL - (scalar[1] > 0)); - const uint64_t scalar_0_mask = (0ULL - (scalar[0] > 0)) & ~scalar_1_mask; - - const uint64_t msb = (scalar_1_mask & (msb_1 + 64)) | (scalar_0_mask & (msb_0)); - return msb; -} - -/** - * How to compute an x-bit wnaf slice? - * - * Iterate over number of slices in scalar. - * For each slice, if slice is even, ADD +1 to current slice and SUBTRACT 2^x from previous slice. - * (for 1st slice we instead add +1 and set the scalar's 'skew' value to 'true' (i.e. need to subtract 1 from it at the - * end of our scalar mul algo)) - * - * In *wnaf we store the following: - * 1. bits 0-30: ABSOLUTE value of wnaf (i.e. -3 goes to 3) - * 2. bit 31: 'predicate' bool (i.e. does the wnaf value need to be negated?) - * 3. bits 32-63: position in a point array that describes the elliptic curve point this wnaf slice is referencing - * - * N.B. IN OUR STDLIB ALGORITHMS THE SKEW VALUE REPRESENTS AN ADDITION NOT A SUBTRACTION (i.e. we add +1 at the end of - * the scalar mul algo we don't sub 1) (this is to eliminate situations which could produce the point at infinity as an - * output as our circuit logic cannot accommodate this edge case). - * - * Credits: Zac W. - * - * @param scalar Pointer to the 128-bit non-montgomery scalar that is supposed to be transformed into wnaf - * @param wnaf Pointer to output array that needs to accommodate enough 64-bit WNAF entries - * @param skew_map Reference to output skew value, which if true shows that the point should be added once at the end of - * computation - * @param wnaf_round_counts Pointer to output array specifying the number of points participating in each round - * @param point_index The index of the point that should be multiplied by this scalar in the point array - * @param num_points Total points in the MSM (2*num_initial_points) + * @brief Recursive WNAF round for a fixed 127-bit scalar (SCALAR_BITS). * + * @details Processes one window per recursive call, using compile-time unrolling via `round_i`. + * Uses the runtime `get_wnaf_bits` for bit extraction. The WNAF output array is interleaved: + * entry for round `r` is stored at index `(wnaf_entries - r) << log2(num_points)`, so that + * entries for the same round across different points are contiguous for cache locality. + * Each entry packs: bits [0..30] = lookup table index, bit 31 = sign, bits [32..63] = point_index. */ -inline void fixed_wnaf_with_counts(const uint64_t* scalar, - uint64_t* wnaf, - bool& skew_map, - uint64_t* wnaf_round_counts, - const uint64_t point_index, - const uint64_t num_points, - const size_t wnaf_bits) noexcept -{ - const size_t max_wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; - if ((scalar[0] | scalar[1]) == 0ULL) { - skew_map = false; - for (size_t round_i = 0; round_i < max_wnaf_entries; ++round_i) { - wnaf[(round_i)*num_points] = 0xffffffffffffffffULL; - } - return; - } - const auto current_scalar_bits = static_cast(get_num_scalar_bits(scalar) + 1); - skew_map = ((scalar[0] & 1) == 0); - uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); - const auto wnaf_entries = static_cast((current_scalar_bits + wnaf_bits - 1) / wnaf_bits); - - if (wnaf_entries == 1) { - wnaf[(max_wnaf_entries - 1) * num_points] = (previous >> 1UL) | (point_index); - ++wnaf_round_counts[max_wnaf_entries - 1]; - for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) { - wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL; - } - return; - } - - // If there are several windows - for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { - - // Get a bit slice - uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); - - // Get the predicate (last bit is zero) - uint64_t predicate = ((slice & 1UL) == 0UL); - - // Update round count - ++wnaf_round_counts[max_wnaf_entries - round_i]; - - // Calculate entry value - // If the last bit of current slice is 1, we simply put the previous value with the point index - // If the last bit of the current slice is 0, we negate everything, so that we subtract from the WNAF form and - // make it 0 - wnaf[(max_wnaf_entries - round_i) * num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - - // Update the previous value to the next windows - previous = slice + predicate; - } - // The final iteration for top bits - auto final_bits = static_cast(current_scalar_bits - (wnaf_bits * (wnaf_entries - 1))); - uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - uint64_t predicate = ((slice & 1UL) == 0UL); - - ++wnaf_round_counts[(max_wnaf_entries - wnaf_entries + 1)]; - wnaf[((max_wnaf_entries - wnaf_entries + 1) * num_points)] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - - // Saving top bits - ++wnaf_round_counts[max_wnaf_entries - wnaf_entries]; - wnaf[(max_wnaf_entries - wnaf_entries) * num_points] = ((slice + predicate) >> 1UL) | (point_index); - - // Fill all unused slots with -1 - for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) { - wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL; - } -} - template inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept { @@ -354,21 +174,29 @@ inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_in uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf_round(scalar, wnaf, point_index, slice + predicate); } else { constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits; uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - // uint64_t slice = get_wnaf_bits_const(scalar); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); } } +/** + * @brief Recursive WNAF round for an arbitrary-width scalar. + * + * @details Same algorithm as the SCALAR_BITS overload above, but parametrized by `scalar_bits` so it can + * handle scalars of any bit width (e.g., after an endomorphism split produces shorter scalars). + * Uses the compile-time `get_wnaf_bits_const` for bit extraction since all parameters are template constants. + * Correctly handles the edge case where `scalar_bits` is an exact multiple of `wnaf_bits` (the final + * window is a full `wnaf_bits` wide rather than the remainder). + */ template inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept { @@ -379,7 +207,7 @@ inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_in uint64_t slice = get_wnaf_bits_const(scalar); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf_round(scalar, wnaf, point_index, slice + predicate); } else { @@ -389,41 +217,12 @@ inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_in uint64_t slice = get_wnaf_bits_const(scalar); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); } } -template -inline void wnaf_round_packed(const uint64_t* scalar, - uint64_t* wnaf, - const uint64_t point_index, - const uint64_t previous) noexcept -{ - constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; - - if constexpr (round_i < wnaf_entries - 1) { - uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); - // uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[(wnaf_entries - round_i)] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - wnaf_round_packed(scalar, wnaf, point_index, slice + predicate); - } else { - constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits; - uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - // uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[1] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - - wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); - } -} - template inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept { @@ -440,80 +239,6 @@ inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const s wnaf_round(scalar, wnaf, point_index, previous); } -template -inline void wnaf_round_with_restricted_first_slice(uint64_t* scalar, - uint64_t* wnaf, - const uint64_t point_index, - const uint64_t previous) noexcept -{ - constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits; - constexpr auto log2_num_points = static_cast(numeric::get_msb(static_cast(num_points))); - constexpr size_t bits_in_first_slice = scalar_bits % wnaf_bits; - if constexpr (round_i == 1) { - uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - - wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (bits_in_first_slice /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | - (predicate << 31UL)) | - (point_index << 32UL); - if (round_i == 1) { - std::cerr << "writing value " << std::hex << wnaf[(wnaf_entries - round_i) << log2_num_points] << std::dec - << " at index " << ((wnaf_entries - round_i) << log2_num_points) << std::endl; - } - wnaf_round_with_restricted_first_slice( - scalar, wnaf, point_index, slice + predicate); - - } else if constexpr (round_i < wnaf_entries - 1) { - uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index << 32UL); - wnaf_round_with_restricted_first_slice( - scalar, wnaf, point_index, slice + predicate); - } else { - uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index << 32UL); - wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); - } -} - -template -inline void fixed_wnaf_with_restricted_first_slice(uint64_t* scalar, - uint64_t* wnaf, - bool& skew_map, - const size_t point_index) noexcept -{ - constexpr size_t bits_in_first_slice = num_bits % wnaf_bits; - std::cerr << "bits in first slice = " << bits_in_first_slice << std::endl; - skew_map = ((scalar[0] & 1) == 0); - uint64_t previous = get_wnaf_bits_const(scalar) + static_cast(skew_map); - std::cerr << "previous = " << previous << std::endl; - wnaf_round_with_restricted_first_slice(scalar, wnaf, point_index, previous); -} - -// template -// inline void fixed_wnaf_packed(const uint64_t* scalar, -// uint64_t* wnaf, -// bool& skew_map, -// const uint64_t point_index) noexcept -// { -// skew_map = ((scalar[0] & 1) == 0); -// uint64_t previous = get_wnaf_bits_const(scalar) + (uint64_t)skew_map; -// wnaf_round_packed(scalar, wnaf, point_index, previous); -// } - -// template -// inline constexpr std::array fixed_wnaf(const uint64_t *scalar) const noexcept -// { -// bool skew_map = ((scalar[0] * 1) == 0); -// uint64_t previous = get_wnaf_bits_const(scalar) + (uint64_t)skew_map; -// std::array result; -// } } // namespace bb::wnaf // NOLINTEND(readability-implicit-bool-conversion) diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp index 7890577d64c1..b91504a99abb 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp @@ -14,26 +14,90 @@ namespace { void recover_fixed_wnaf(const uint64_t* wnaf, bool skew, uint64_t& hi, uint64_t& lo, size_t wnaf_bits) { - size_t wnaf_entries = (127 + wnaf_bits - 1) / wnaf_bits; - uint128_t scalar = 0; // (uint128_t)(skew); - for (int i = 0; i < static_cast(wnaf_entries); ++i) { - uint64_t entry_formatted = wnaf[static_cast(i)]; - bool negative = (entry_formatted >> 31) != 0U; - uint64_t entry = ((entry_formatted & 0x0fffffffU) << 1) + 1; + const size_t wnaf_entries = (127 + wnaf_bits - 1) / wnaf_bits; + const uint64_t max_table_index = (1UL << (wnaf_bits - 1)) - 1; + + for (size_t i = 0; i < wnaf_entries; ++i) { + uint64_t entry = wnaf[i]; + uint64_t table_index = entry & 0x7fffffffUL; + bool sign = ((entry >> 31) & 1) != 0U; + uint64_t point_index_bits = entry >> 32; + + EXPECT_LE(table_index, max_table_index) + << "entry " << i << ": table_index " << table_index << " exceeds max " << max_table_index; + + // The most significant digit is always positive by construction (no sign bit is OR'd in). + if (i == 0) { + EXPECT_FALSE(sign) << "entry 0 (most significant digit) must be positive"; + } + + // All current callers use point_index=0, so bits 32-63 should be clear. + EXPECT_EQ(point_index_bits, 0UL) << "entry " << i << ": unexpected non-zero point_index bits"; + } + + // Recover the scalar: sum signed odd digits at their positional weights, then subtract skew. + uint128_t scalar = 0; + for (size_t i = 0; i < wnaf_entries; ++i) { + uint64_t entry_formatted = wnaf[i]; + bool negative = ((entry_formatted >> 31) & 1) != 0U; + uint64_t digit = ((entry_formatted & 0x7fffffffUL) << 1) + 1; + auto shift = static_cast(wnaf_bits * (wnaf_entries - 1 - i)); if (negative) { - scalar -= (static_cast(entry)) - << static_cast(wnaf_bits * (wnaf_entries - 1 - static_cast(i))); + scalar -= static_cast(digit) << shift; } else { - scalar += (static_cast(entry)) - << static_cast(wnaf_bits * (wnaf_entries - 1 - static_cast(i))); + scalar += static_cast(digit) << shift; } } scalar -= static_cast(skew); hi = static_cast(scalar >> static_cast(64)); - lo = static_cast(static_cast(scalar & static_cast(0xffff'ffff'ffff'ffff))); + lo = static_cast(scalar & static_cast(0xffff'ffff'ffff'ffff)); } } // namespace +TEST(wnaf, GetWnafBitsConstLimbBoundary) +{ + // scalar[0] bits 59-63 = 1,0,1,0,1 and scalar[1] bits 0-4 = 1,0,1,0,1 + // Full bit pattern around the boundary (bit 63 | bit 64): + // bit: ...59 60 61 62 63 | 64 65 66 67 68 69... + // val: ... 1 0 1 0 1 | 1 0 1 0 1 0... + const uint64_t scalar[2] = { 0xA800000000000000ULL, 0x0000000000000015ULL }; + + // Window starts at bit 63 — straddles the limb boundary (2 bits from lo, 3 from hi) + // bits 63,64,65,66,67 = 1,1,0,1,0 → 1 + 2 + 0 + 8 + 0 = 11 + EXPECT_EQ((wnaf::get_wnaf_bits_const<5, 63>(scalar)), 11ULL); + + // Window starts at bit 64 — exactly at the hi limb start + // bits 64,65,66,67,68 = 1,0,1,0,1 → 1 + 0 + 4 + 0 + 16 = 21 + EXPECT_EQ((wnaf::get_wnaf_bits_const<5, 64>(scalar)), 21ULL); + + // Window starts at bit 65 — one past the boundary, entirely in hi limb + // bits 65,66,67,68,69 = 0,1,0,1,0 → 0 + 2 + 0 + 8 + 0 = 10 + EXPECT_EQ((wnaf::get_wnaf_bits_const<5, 65>(scalar)), 10ULL); +} + +TEST(wnaf, WnafPowerOfTwo) +{ + // Powers of 2 are all even (skew = true) and have a single 1-bit with all lower bits zero, + // so every window below the leading bit is even, forcing borrows to cascade through all rounds. + auto test_power_of_two_scalar = [](uint64_t lo, uint64_t hi) { + uint64_t buffer[2] = { lo, hi }; + uint64_t wnaf_out[WNAF_SIZE(5)] = { 0 }; + bool skew = false; + wnaf::fixed_wnaf<1, 5>(buffer, wnaf_out, skew, 0); + EXPECT_TRUE(skew); // all powers of 2 are even + uint64_t recovered_hi = 0; + uint64_t recovered_lo = 0; + recover_fixed_wnaf(wnaf_out, skew, recovered_hi, recovered_lo, 5); + EXPECT_EQ(lo, recovered_lo); + EXPECT_EQ(hi, recovered_hi); + }; + + test_power_of_two_scalar(2ULL, 0ULL); // 2^1: smallest even, borrows cascade through all 26 windows + test_power_of_two_scalar(1ULL << 32, 0ULL); // 2^32: mid-lo-limb + test_power_of_two_scalar(0ULL, 1ULL); // 2^64: exactly at the limb boundary + test_power_of_two_scalar(0ULL, 1ULL << 62); // 2^126: near the 127-bit maximum +} + TEST(wnaf, WnafZero) { uint64_t buffer[2]{ 0, 0 };