diff --git a/barretenberg/cpp/src/barretenberg/stdlib/honk_verifier/ultra_recursive_verifier.cpp b/barretenberg/cpp/src/barretenberg/stdlib/honk_verifier/ultra_recursive_verifier.cpp index a6967859ae15..f013197028f1 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/honk_verifier/ultra_recursive_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/honk_verifier/ultra_recursive_verifier.cpp @@ -100,7 +100,7 @@ UltraRecursiveVerifier_::Output UltraRecursiveVerifier_::verify_ // TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate this challenge properly. typename Curve::ScalarField recursion_separator = Curve::ScalarField::from_witness_index(builder, builder->add_variable(42)); - agg_obj.aggregate(nested_agg_obj, recursion_separator); + agg_obj.template aggregate(nested_agg_obj, recursion_separator); // Execute Sumcheck Verifier and extract multivariate opening point u = (u_0, ..., u_{d-1}) and purported // multivariate evaluations at u @@ -143,11 +143,11 @@ UltraRecursiveVerifier_::Output UltraRecursiveVerifier_::verify_ pairing_points[0] = pairing_points[0].normalize(); pairing_points[1] = pairing_points[1].normalize(); // TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate recursion separator challenge properly. - agg_obj.aggregate(pairing_points, recursion_separator); + agg_obj.template aggregate(pairing_points, recursion_separator); output.agg_obj = std::move(agg_obj); // Extract the IPA claim from the public inputs - // Parse out the nested IPA claim using key->ipa_claim_public_input_indices and runs the native IPA verifier. + // Parse out the nested IPA claim using key->ipa_claim_public_input_indices and run the native IPA verifier. if constexpr (HasIPAAccumulator) { const auto recover_fq_from_public_inputs = [](std::array& limbs) { for (size_t k = 0; k < Curve::BaseField::NUM_LIMBS; k++) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/plonk_recursion/aggregation_state/aggregation_state.hpp b/barretenberg/cpp/src/barretenberg/stdlib/plonk_recursion/aggregation_state/aggregation_state.hpp index 4eba0fa247f3..af1b30725b7a 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/plonk_recursion/aggregation_state/aggregation_state.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/plonk_recursion/aggregation_state/aggregation_state.hpp @@ -20,17 +20,36 @@ template struct aggregation_state { { return P0 == other.P0 && P1 == other.P1; }; - + template void aggregate(aggregation_state const& other, typename Curve::ScalarField recursion_separator) { - P0 += other.P0 * recursion_separator; - P1 += other.P1 * recursion_separator; + if constexpr (std::is_same_v) { + P0 += other.P0 * recursion_separator; + P1 += other.P1 * recursion_separator; + } else { + // Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1, + // recursion_separator} directly to avoid edge cases. + typename Curve::Group point_to_aggregate = other.P0.scalar_mul(recursion_separator, 128); + P0 += point_to_aggregate; + point_to_aggregate = other.P1.scalar_mul(recursion_separator, 128); + P1 += point_to_aggregate; + } } + template void aggregate(std::array const& other, typename Curve::ScalarField recursion_separator) { - P0 += other[0] * recursion_separator; - P1 += other[1] * recursion_separator; + if constexpr (std::is_same_v) { + P0 += other[0] * recursion_separator; + P1 += other[1] * recursion_separator; + } else { + // Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1, + // recursion_separator} directly to avoid edge cases. + typename Curve::Group point_to_aggregate = other[0].scalar_mul(recursion_separator, 128); + P0 += point_to_aggregate; + point_to_aggregate = other[1].scalar_mul(recursion_separator, 128); + P1 += point_to_aggregate; + } } PairingPointAccumulatorIndices get_witness_indices() diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp index eb27853ba67f..ee11b3b75e61 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp @@ -143,6 +143,7 @@ template class element { result.y.assert_is_in_field(); return result; } + element scalar_mul(const Fr& scalar, const size_t max_num_bits = 0) const; element reduce() const { @@ -525,7 +526,10 @@ template class element { num_fives = num_points / 5; num_sixes = 0; // size-6 table is expensive and only benefits us if creating them reduces the number of total tables - if (num_fives * 5 == (num_points - 1)) { + if (num_points == 1) { + num_fives = 0; + num_sixes = 0; + } else if (num_fives * 5 == (num_points - 1)) { num_fives -= 1; num_sixes = 1; } else if (num_fives * 5 == (num_points - 2) && num_fives >= 2) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp index 037a470649b3..6a96f54f7598 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp @@ -414,6 +414,142 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + // Test short scalar mul with variable even bit length. For efficiency, it's split into two tests. + static void test_short_scalar_mul_2_126() + { + Builder builder; + const size_t max_num_bits = 128; + + // We only test even bit lengths, because `bn254_endo_batch_mul` used in 'scalar_mul' can't handle odd lengths. + for (size_t i = 2; i < max_num_bits; i += 2) { + affine_element input(element::random_element()); + // Get a random 256 integer + uint256_t scalar_raw = engine.get_random_uint256(); + // Produce a length =< i scalar. + scalar_raw = scalar_raw >> (256 - i); + fr scalar = fr(scalar_raw); + + // Avoid multiplication by 0 that may occur when `i` is small + if (scalar == fr(0)) { + scalar += 1; + }; + + element_ct P = element_ct::from_witness(&builder, input); + scalar_ct x = scalar_ct::from_witness(&builder, scalar); + + // Set input tags + x.set_origin_tag(challenge_origin_tag); + P.set_origin_tag(submitted_value_origin_tag); + + std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl; + // Multiply using specified scalar length + element_ct c = P.scalar_mul(x, i); + std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl; + affine_element c_expected(element(input) * scalar); + + // Check the result of the multiplication has a tag that's the union of inputs' tags + EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag); + fq c_x_result(c.x.get_value().lo); + fq c_y_result(c.y.get_value().lo); + + EXPECT_EQ(c_x_result, c_expected.x); + + EXPECT_EQ(c_y_result, c_expected.y); + } + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + + static void test_short_scalar_mul_128_252() + { + Builder builder; + const size_t max_num_bits = 254; + + // We only test even bit lengths, because `bn254_endo_batch_mul` used in 'scalar_mul' can't handle odd lengths. + for (size_t i = 128; i < max_num_bits; i += 2) { + affine_element input(element::random_element()); + // Get a random 256-bit integer + uint256_t scalar_raw = engine.get_random_uint256(); + // Produce a length =< i scalar. + scalar_raw = scalar_raw >> (256 - i); + fr scalar = fr(scalar_raw); + + element_ct P = element_ct::from_witness(&builder, input); + scalar_ct x = scalar_ct::from_witness(&builder, scalar); + + // Set input tags + x.set_origin_tag(challenge_origin_tag); + P.set_origin_tag(submitted_value_origin_tag); + + std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl; + // Multiply using specified scalar length + element_ct c = P.scalar_mul(x, i); + std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl; + affine_element c_expected(element(input) * scalar); + + // Check the result of the multiplication has a tag that's the union of inputs' tags + EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag); + fq c_x_result(c.x.get_value().lo); + fq c_y_result(c.y.get_value().lo); + + EXPECT_EQ(c_x_result, c_expected.x); + + EXPECT_EQ(c_y_result, c_expected.y); + } + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + + static void test_short_scalar_mul_infinity() + { + // We check that a point at infinity preserves `is_point_at_infinity()` flag after being multiplied against a + // short scalar and also check that the number of gates in this case is equal to the number of gates spent on a + // finite point. + + // Populate test points. + std::vector points(2); + + points[0] = element::infinity(); + points[1] = element::random_element(); + // Containter for gate counts. + std::vector gates(2); + + // We initialize this flag as `true`, because the first result is expected to be the point at infinity. + bool expect_infinity = true; + + for (auto [point, num_gates] : zip_view(points, gates)) { + Builder builder; + + const size_t max_num_bits = 128; + // Get a random 256-bit integer + uint256_t scalar_raw = engine.get_random_uint256(); + // Produce a length =< max_num_bits scalar. + scalar_raw = scalar_raw >> (256 - max_num_bits); + fr scalar = fr(scalar_raw); + + element_ct P = element_ct::from_witness(&builder, point); + scalar_ct x = scalar_ct::from_witness(&builder, scalar); + + // Set input tags + x.set_origin_tag(challenge_origin_tag); + P.set_origin_tag(submitted_value_origin_tag); + + std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl; + element_ct c = P.scalar_mul(x, max_num_bits); + std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl; + num_gates = builder.get_estimated_num_finalized_gates(); + // Check the result of the multiplication has a tag that's the union of inputs' tags + EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag); + + EXPECT_EQ(c.is_point_at_infinity().get_value(), expect_infinity); + EXPECT_CIRCUIT_CORRECTNESS(builder); + // The second point is finite, hence we flip the flag + expect_infinity = false; + } + // Check that the numbers of gates are equal in both cases. + EXPECT_EQ(gates[0], gates[1]); + } + static void test_twin_mul() { Builder builder; @@ -950,13 +1086,25 @@ template class stdlib_biggroup : public testing::Test { static void test_compute_naf() { Builder builder = Builder(); - size_t num_repetitions(32); - for (size_t i = 0; i < num_repetitions; i++) { - fr scalar_val = fr::random_element(); + size_t max_num_bits = 254; + // Our design of NAF and the way it is used assumes the even length of scalars. + for (size_t length = 2; length < max_num_bits; length += 2) { + + fr scalar_val; + + uint256_t scalar_raw = engine.get_random_uint256(); + scalar_raw = scalar_raw >> (256 - length); + + scalar_val = fr(scalar_raw); + + // NAF with short scalars doesn't handle 0 + if (scalar_val == fr(0)) { + scalar_val += 1; + }; scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val); // Set tag for scalar scalar.set_origin_tag(submitted_value_origin_tag); - auto naf = element_ct::compute_naf(scalar); + auto naf = element_ct::compute_naf(scalar, length); for (const auto& bit : naf) { // Check that the tag is propagated to bits @@ -964,12 +1112,13 @@ template class stdlib_biggroup : public testing::Test { } // scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i} fr reconstructed_val(0); - for (size_t i = 0; i < 254; i++) { - reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (253 - i)); + for (size_t i = 0; i < length; i++) { + reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (length - 1 - i)); }; - reconstructed_val -= fr(naf[254].witness_bool); + reconstructed_val -= fr(naf[length].witness_bool); EXPECT_EQ(scalar_val, reconstructed_val); } + EXPECT_CIRCUIT_CORRECTNESS(builder); } @@ -1614,6 +1763,33 @@ HEAVY_TYPED_TEST(stdlib_biggroup, mul) { TestFixture::test_mul(); } + +HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_2_126_bits) +{ + if constexpr (HasGoblinBuilder) { + GTEST_SKIP(); + } else { + TestFixture::test_short_scalar_mul_2_126(); + } +} +HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_128_252_bits) +{ + if constexpr (HasGoblinBuilder) { + GTEST_SKIP(); + } else { + TestFixture::test_short_scalar_mul_128_252(); + } +} + +HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_infinity) +{ + if constexpr (HasGoblinBuilder) { + GTEST_SKIP(); + } else { + TestFixture::test_short_scalar_mul_infinity(); + } +} + HEAVY_TYPED_TEST(stdlib_biggroup, twin_mul) { if constexpr (HasGoblinBuilder) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp index eb3798830ed8..9f397c45230d 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp @@ -226,7 +226,9 @@ element element::bn254_endo_batch_mul(const std::vec const std::vector& small_scalars, const size_t max_num_small_bits) { - ASSERT(max_num_small_bits >= 128); + + ASSERT(max_num_small_bits % 2 == 0); + const size_t num_big_points = big_points.size(); const size_t num_small_points = small_points.size(); C* ctx = nullptr; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp index b133b9e0f9fd..2545b0a2034c 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp @@ -836,15 +836,28 @@ element element::batch_mul(const std::vector element element::operator*(const Fr& scalar) const { + // Use `scalar_mul` method without specifying the length of `scalar`. + return scalar_mul(scalar); +} + +template +/** + * @brief Implements scalar multiplication that supports short scalars. + * For multiple scalar multiplication use one of the `batch_mul` methods to save gates. + * @param scalar A field element. If `max_num_bits`>0, the length of the scalar must not exceed `max_num_bits`. + * @param max_num_bits Even integer < 254. Default value 0 corresponds to scalar multiplication by scalars of + * unspecified length. + * @return element + */ +element element::scalar_mul(const Fr& scalar, const size_t max_num_bits) const +{ + ASSERT(max_num_bits % 2 == 0); /** * * Let's say we have some curve E defined over a field Fq. The order of E is p, which is prime. @@ -868,27 +881,31 @@ element element::operator*(const Fr& scalar) const * specifics. * **/ + OriginTag tag{}; + tag = OriginTag(tag, OriginTag(this->get_origin_tag(), scalar.get_origin_tag())); - constexpr uint64_t num_rounds = Fr::modulus.get_msb() + 1; - - std::vector naf_entries = compute_naf(scalar); + bool_ct is_point_at_infinity = this->is_point_at_infinity(); - const auto offset_generators = compute_offset_generators(num_rounds); + const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits; - element accumulator = *this + offset_generators.first; + element result; + if (max_num_bits != 0) { + // The case of short scalars + result = element::bn254_endo_batch_mul({}, {}, { *this }, { scalar }, num_rounds); + } else { + // The case of arbitrary length scalars + result = element::bn254_endo_batch_mul({ *this }, { scalar }, {}, {}, num_rounds); + }; - for (size_t i = 1; i < num_rounds; ++i) { - bool_ct predicate = naf_entries[i]; - bigfield y_test = y.conditional_negate(predicate); - element to_add(x, y_test); - accumulator = accumulator.montgomery_ladder(to_add); - } + // Handle point at infinity + result.x = Fq::conditional_assign(is_point_at_infinity, x, result.x); + result.y = Fq::conditional_assign(is_point_at_infinity, y, result.y); - element skew_output = accumulator - (*this); + result.set_point_at_infinity(is_point_at_infinity); - Fq out_x = accumulator.x.conditional_select(skew_output.x, naf_entries[num_rounds]); - Fq out_y = accumulator.y.conditional_select(skew_output.y, naf_entries[num_rounds]); + // Propagate the origin tag + result.set_origin_tag(tag); - return element(out_x, out_y) - element(offset_generators.second); + return result; } } // namespace bb::stdlib::element_default diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp index a29e0c3716e0..352eb14d6c55 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp @@ -441,7 +441,7 @@ std::vector> element::compute_wnaf(const Fr& scalar) // updates multiplicative constants without computing new witnesses. This ensures the low accumulator will not // underflow // - // Once we hvae reconstructed an Fr element out of our accumulators, + // Once we have reconstructed an Fr element out of our accumulators, // we ALSO construct an Fr element from the constant offset terms we left out // We then subtract off the constant term and call `Fr::assert_is_in_field` to reduce the value modulo // Fr::modulus @@ -488,6 +488,9 @@ std::vector> element::compute_wnaf(const Fr& scalar) template std::vector> element::compute_naf(const Fr& scalar, const size_t max_num_bits) { + // We are not handling the case of odd bit lengths here. + ASSERT(max_num_bits % 2 == 0); + C* ctx = scalar.context; uint512_t scalar_multiplier_512 = uint512_t(uint256_t(scalar.get_value()) % Fr::modulus); uint256_t scalar_multiplier = scalar_multiplier_512.lo; @@ -576,9 +579,23 @@ std::vector> element::compute_naf(const Fr& scalar, cons } return std::make_pair(positive_accumulator, negative_accumulator); }; - const size_t midpoint = num_rounds - Fr::NUM_LIMB_BITS * 2; - auto hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint); - auto lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint); + const size_t midpoint = + (num_rounds > Fr::NUM_LIMB_BITS * 2) ? num_rounds - Fr::NUM_LIMB_BITS * 2 : num_rounds / 2; + + std::pair, field_t> hi_accumulators; + std::pair, field_t> lo_accumulators; + + if (num_rounds > Fr::NUM_LIMB_BITS * 2) { + hi_accumulators = reconstruct_half_naf(&naf_entries[0], midpoint); + lo_accumulators = reconstruct_half_naf(&naf_entries[midpoint], num_rounds - midpoint); + + } else { + // If the number of rounds is smaller than Fr::NUM_LIMB_BITS, the high bits of the resulting Fr element are + // 0. + const field_t zero = field_t::from_witness_index(ctx, 0); + lo_accumulators = reconstruct_half_naf(&naf_entries[0], num_rounds); + hi_accumulators = std::make_pair(zero, zero); + } lo_accumulators.second = lo_accumulators.second + field_t(naf_entries[num_rounds]);