diff --git a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.cpp b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.cpp index 1af3d5bc6c82..0fb97f41e8cb 100644 --- a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.cpp +++ b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.cpp @@ -1,6 +1,7 @@ #include "barretenberg/bbapi/bbapi_chonk.hpp" #include "barretenberg/chonk/chonk_verifier.hpp" #include "barretenberg/chonk/mock_circuit_producer.hpp" +#include "barretenberg/chonk/proof_compression.hpp" #include "barretenberg/common/log.hpp" #include "barretenberg/common/serialize.hpp" #include "barretenberg/common/throw_or_abort.hpp" @@ -253,4 +254,17 @@ ChonkStats::Response ChonkStats::execute([[maybe_unused]] BBApiRequest& request) return response; } +ChonkCompressProof::Response ChonkCompressProof::execute(const BBApiRequest& /*request*/) && +{ + BB_BENCH_NAME(MSGPACK_SCHEMA_NAME); + return { .compressed_proof = ProofCompressor::compress_chonk_proof(proof) }; +} + +ChonkDecompressProof::Response ChonkDecompressProof::execute(const BBApiRequest& /*request*/) && +{ + BB_BENCH_NAME(MSGPACK_SCHEMA_NAME); + size_t mega_num_pub = ProofCompressor::compressed_mega_num_public_inputs(compressed_proof.size()); + return { .proof = ProofCompressor::decompress_chonk_proof(compressed_proof, mega_num_pub) }; +} + } // namespace bb::bbapi diff --git a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.hpp b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.hpp index 7d003c9096fa..d156a5a0c6c3 100644 --- a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.hpp +++ b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_chonk.hpp @@ -239,4 +239,48 @@ struct ChonkStats { bool operator==(const ChonkStats&) const = default; }; +/** + * @struct ChonkCompressProof + * @brief Compress a Chonk proof to a compact byte representation + * + * @details Uses point compression and uniform 32-byte encoding to reduce proof size (~1.72x). + */ +struct ChonkCompressProof { + static constexpr const char MSGPACK_SCHEMA_NAME[] = "ChonkCompressProof"; + + struct Response { + static constexpr const char MSGPACK_SCHEMA_NAME[] = "ChonkCompressProofResponse"; + std::vector compressed_proof; + MSGPACK_FIELDS(compressed_proof); + bool operator==(const Response&) const = default; + }; + + ChonkProof proof; + Response execute(const BBApiRequest& request = {}) &&; + MSGPACK_FIELDS(proof); + bool operator==(const ChonkCompressProof&) const = default; +}; + +/** + * @struct ChonkDecompressProof + * @brief Decompress a compressed Chonk proof back to field elements + * + * @details Derives mega_num_public_inputs from the compressed size automatically. + */ +struct ChonkDecompressProof { + static constexpr const char MSGPACK_SCHEMA_NAME[] = "ChonkDecompressProof"; + + struct Response { + static constexpr const char MSGPACK_SCHEMA_NAME[] = "ChonkDecompressProofResponse"; + ChonkProof proof; + MSGPACK_FIELDS(proof); + bool operator==(const Response&) const = default; + }; + + std::vector compressed_proof; + Response execute(const BBApiRequest& request = {}) &&; + MSGPACK_FIELDS(compressed_proof); + bool operator==(const ChonkDecompressProof&) const = default; +}; + } // namespace bb::bbapi diff --git a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_execute.hpp b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_execute.hpp index f459e600bcf2..ab16da99a508 100644 --- a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_execute.hpp +++ b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_execute.hpp @@ -28,6 +28,8 @@ using Command = NamedUnion #include "barretenberg/chonk/chonk_verifier.hpp" +#include "barretenberg/chonk/proof_compression.hpp" #include "barretenberg/chonk/test_bench_shared.hpp" #include "barretenberg/common/google_bb_bench.hpp" @@ -58,10 +59,43 @@ BENCHMARK_DEFINE_F(ChonkBench, Full)(benchmark::State& state) } } +/** + * @brief Benchmark proof compression (prover-side cost) + */ +BENCHMARK_DEFINE_F(ChonkBench, ProofCompress)(benchmark::State& state) +{ + size_t NUM_APP_CIRCUITS = 1; + auto precomputed_vks = precompute_vks(NUM_APP_CIRCUITS); + auto [proof, vk_and_hash] = accumulate_and_prove_with_precomputed_vks(NUM_APP_CIRCUITS, precomputed_vks); + + for (auto _ : state) { + benchmark::DoNotOptimize(ProofCompressor::compress_chonk_proof(proof)); + } +} + +/** + * @brief Benchmark proof decompression (verifier-side cost) + */ +BENCHMARK_DEFINE_F(ChonkBench, ProofDecompress)(benchmark::State& state) +{ + size_t NUM_APP_CIRCUITS = 1; + auto precomputed_vks = precompute_vks(NUM_APP_CIRCUITS); + auto [proof, vk_and_hash] = accumulate_and_prove_with_precomputed_vks(NUM_APP_CIRCUITS, precomputed_vks); + + auto compressed = ProofCompressor::compress_chonk_proof(proof); + size_t mega_num_pub_inputs = proof.mega_proof.size() - ChonkProof::HIDING_KERNEL_PROOF_LENGTH_WITHOUT_PUBLIC_INPUTS; + + for (auto _ : state) { + benchmark::DoNotOptimize(ProofCompressor::decompress_chonk_proof(compressed, mega_num_pub_inputs)); + } +} + #define ARGS Arg(ChonkBench::NUM_ITERATIONS_MEDIUM_COMPLEXITY)->Arg(2) BENCHMARK_REGISTER_F(ChonkBench, Full)->Unit(benchmark::kMillisecond)->ARGS; BENCHMARK_REGISTER_F(ChonkBench, VerificationOnly)->Unit(benchmark::kMillisecond); +BENCHMARK_REGISTER_F(ChonkBench, ProofCompress)->Unit(benchmark::kMillisecond); +BENCHMARK_REGISTER_F(ChonkBench, ProofDecompress)->Unit(benchmark::kMillisecond); } // namespace diff --git a/barretenberg/cpp/src/barretenberg/chonk/chonk.test.cpp b/barretenberg/cpp/src/barretenberg/chonk/chonk.test.cpp index d88d5d12a698..ee677801dc41 100644 --- a/barretenberg/cpp/src/barretenberg/chonk/chonk.test.cpp +++ b/barretenberg/cpp/src/barretenberg/chonk/chonk.test.cpp @@ -3,8 +3,10 @@ #include "barretenberg/chonk/chonk.hpp" #include "barretenberg/chonk/chonk_verifier.hpp" #include "barretenberg/chonk/mock_circuit_producer.hpp" +#include "barretenberg/chonk/proof_compression.hpp" #include "barretenberg/chonk/test_bench_shared.hpp" #include "barretenberg/common/assert.hpp" +#include "barretenberg/common/log.hpp" #include "barretenberg/common/mem.hpp" #include "barretenberg/common/test.hpp" #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" @@ -563,3 +565,33 @@ TEST_F(ChonkTests, MTailPropagationConsistency) { ChonkTests::test_hiding_kernel_io_propagation(HidingKernelIOField::ECC_OP_TABLES); } + +TEST_F(ChonkTests, ProofCompressionRoundtrip) +{ + TestSettings settings{ .log2_num_gates = SMALL_LOG_2_NUM_GATES }; + auto [proof, vk_and_hash] = accumulate_and_prove_ivc(/*num_app_circuits=*/1, settings); + + auto original_flat = proof.to_field_elements(); + info("Original proof size: ", original_flat.size(), " Fr elements (", original_flat.size() * 32, " bytes)"); + + auto compressed = ProofCompressor::compress_chonk_proof(proof); + double ratio = static_cast(original_flat.size() * 32) / static_cast(compressed.size()); + info("Compressed proof size: ", compressed.size(), " bytes"); + info("Compression ratio: ", ratio, "x"); + + // Compression should achieve at least 1.5x (commitments 4 Fr → 32 bytes, scalars 1:1) + EXPECT_GE(ratio, 1.5) << "Compression ratio " << ratio << "x is below the expected minimum of 1.5x"; + + size_t mega_num_pub_inputs = proof.mega_proof.size() - ChonkProof::HIDING_KERNEL_PROOF_LENGTH_WITHOUT_PUBLIC_INPUTS; + ChonkProof decompressed = ProofCompressor::decompress_chonk_proof(compressed, mega_num_pub_inputs); + + // Verify element-by-element roundtrip + auto decompressed_flat = decompressed.to_field_elements(); + ASSERT_EQ(decompressed_flat.size(), original_flat.size()); + for (size_t i = 0; i < original_flat.size(); i++) { + ASSERT_EQ(decompressed_flat[i], original_flat[i]) << "Mismatch at element " << i; + } + + // Verify the decompressed proof + EXPECT_TRUE(verify_chonk(decompressed, vk_and_hash)); +} diff --git a/barretenberg/cpp/src/barretenberg/chonk/proof_compression.hpp b/barretenberg/cpp/src/barretenberg/chonk/proof_compression.hpp new file mode 100644 index 000000000000..435cebeb8dd2 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/chonk/proof_compression.hpp @@ -0,0 +1,570 @@ +#pragma once + +#include "barretenberg/chonk/chonk_proof.hpp" +#include "barretenberg/common/assert.hpp" +#include "barretenberg/constants.hpp" +#include "barretenberg/ecc/curves/bn254/bn254.hpp" +#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" +#include "barretenberg/eccvm/eccvm_flavor.hpp" +#include "barretenberg/flavor/mega_zk_flavor.hpp" +#include "barretenberg/honk/proof_system/types/proof.hpp" +#include "barretenberg/translator_vm/translator_flavor.hpp" +#include +#include +#include + +namespace bb { + +/** + * @brief Compresses Chonk proofs from vector to compact byte representations. + * + * Compression techniques: + * 1. Point compression: store only x-coordinate + sign bit (instead of x and y) + * 2. Fq-as-u256: store each Fq coordinate as 32 bytes (instead of 2 Fr for lo/hi split) + * 3. Fr-as-u256: store each Fr scalar as 32 bytes (uniform encoding) + * + * Every element compresses to exactly 32 bytes regardless of type: + * - BN254 commitment (4 Fr → 32 bytes): point compression on Fq coordinates + * - BN254 scalar (1 Fr → 32 bytes): direct u256 encoding + * - Grumpkin commitment (2 Fr → 32 bytes): point compression on Fr coordinates + * - Grumpkin scalar (2 Fr → 32 bytes): reconstruct Fq, write as u256 + */ +class ProofCompressor { + using Fr = curve::BN254::ScalarField; + using Fq = curve::BN254::BaseField; + + static constexpr uint256_t SIGN_BIT_MASK = uint256_t(1) << 255; + + // Fq values are stored as (lo, hi) Fr pairs split at 2*NUM_LIMB_BITS = 136 bits. + static constexpr uint64_t NUM_LIMB_BITS = 68; + static constexpr uint64_t FQ_SPLIT_BITS = NUM_LIMB_BITS * 2; // 136 + + /** @brief True if y is in the "upper half" of its field, used for point compression sign bit. */ + template static bool y_is_negative(const Field& y) + { + return uint256_t(y) > (uint256_t(Field::modulus) - 1) / 2; + } + + // ========================================================================= + // Serialization helpers + // ========================================================================= + + static void write_u256(std::vector& out, const uint256_t& val) + { + for (int i = 31; i >= 0; --i) { + out.push_back(static_cast(val.data[i / 8] >> (8 * (i % 8)))); + } + } + + static uint256_t read_u256(const std::vector& data, size_t& pos) + { + uint256_t val{ 0, 0, 0, 0 }; + for (int i = 31; i >= 0; --i) { + val.data[i / 8] |= static_cast(data[pos++]) << (8 * (i % 8)); + } + return val; + } + + static Fq reconstruct_fq(const Fr& lo, const Fr& hi) + { + return Fq(uint256_t(lo) + (uint256_t(hi) << FQ_SPLIT_BITS)); + } + + static std::pair split_fq(const Fq& val) + { + constexpr uint256_t LOWER_MASK = (uint256_t(1) << FQ_SPLIT_BITS) - 1; + const uint256_t v = uint256_t(val); + return { Fr(v & LOWER_MASK), Fr(v >> FQ_SPLIT_BITS) }; + } + + // ========================================================================= + // Walk functions — define proof layouts once for compress/decompress + // ========================================================================= + + /** + * @brief Walk a MegaZK proof (BN254, ZK sumcheck). + * @details Layout from MegaZKStructuredProofBase and sumcheck prover code. + */ + template + static void walk_mega_zk_proof(ScalarFn&& process_scalar, + CommitmentFn&& process_commitment, + size_t num_public_inputs) + { + constexpr size_t log_n = MegaZKFlavor::VIRTUAL_LOG_N; + + // Public inputs + for (size_t i = 0; i < num_public_inputs; i++) { + process_scalar(); + } + // Witness commitments (hiding poly + 24 mega witness = NUM_WITNESS_ENTITIES total) + for (size_t i = 0; i < MegaZKFlavor::NUM_WITNESS_ENTITIES; i++) { + process_commitment(); + } + // Libra concatenation commitment + process_commitment(); + // Libra sum + process_scalar(); + // Sumcheck round univariates + for (size_t i = 0; i < log_n * MegaZKFlavor::BATCHED_RELATION_PARTIAL_LENGTH; i++) { + process_scalar(); + } + // Sumcheck evaluations + for (size_t i = 0; i < MegaZKFlavor::NUM_ALL_ENTITIES; i++) { + process_scalar(); + } + // Libra claimed evaluation + process_scalar(); + // Libra grand sum + quotient commitments + process_commitment(); + process_commitment(); + // Gemini fold commitments + for (size_t i = 0; i < log_n - 1; i++) { + process_commitment(); + } + // Gemini fold evaluations + for (size_t i = 0; i < log_n; i++) { + process_scalar(); + } + // Small IPA evaluations (for ZK) + for (size_t i = 0; i < NUM_SMALL_IPA_EVALUATIONS; i++) { + process_scalar(); + } + // Shplonk Q + KZG W + process_commitment(); + process_commitment(); + } + + /** + * @brief Walk a Merge proof (42 Fr, all BN254). + * @details Layout from MergeProver::construct_proof. + */ + template + static void walk_merge_proof(ScalarFn&& process_scalar, CommitmentFn&& process_commitment) + { + // shift_size + process_scalar(); + // 4 merged table commitments + for (size_t i = 0; i < 4; i++) { + process_commitment(); + } + // Reversed batched left tables commitment + process_commitment(); + // 4 left + 4 right + 4 merged table evaluations + 1 reversed eval = 13 scalars + for (size_t i = 0; i < 13; i++) { + process_scalar(); + } + // Shplonk Q + KZG W + process_commitment(); + process_commitment(); + } + + /** + * @brief Walk an ECCVM proof (all Grumpkin). + * @details Layout from ECCVMFlavor::PROOF_LENGTH formula and ECCVM prover code. + * Grumpkin RoundUnivariateHandler commits to each round univariate and sends + * 2 evaluations (at 0 and 1), interleaved per round. + */ + template + static void walk_eccvm_proof(ScalarFn&& process_scalar, CommitmentFn&& process_commitment) + { + constexpr size_t log_n = CONST_ECCVM_LOG_N; + constexpr size_t num_witness = ECCVMFlavor::NUM_WITNESS_ENTITIES + ECCVMFlavor::NUM_MASKING_POLYNOMIALS; + + // Witness commitments (wires + derived + masking poly) + for (size_t i = 0; i < num_witness; i++) { + process_commitment(); + } + // Libra concatenation commitment + process_commitment(); + // Libra sum + process_scalar(); + // Sumcheck round univariates: per round, Grumpkin commits then sends 2 evaluations + for (size_t i = 0; i < log_n; i++) { + process_commitment(); // univariate commitment for round i + process_scalar(); // eval at 0 for round i + process_scalar(); // eval at 1 for round i + } + // Sumcheck evaluations + for (size_t i = 0; i < ECCVMFlavor::NUM_ALL_ENTITIES; i++) { + process_scalar(); + } + // Libra claimed evaluation + process_scalar(); + // Libra grand sum + quotient commitments + process_commitment(); + process_commitment(); + // Gemini fold commitments + for (size_t i = 0; i < log_n - 1; i++) { + process_commitment(); + } + // Gemini fold evaluations + for (size_t i = 0; i < log_n; i++) { + process_scalar(); + } + // Small IPA evaluations (for sumcheck libra) + for (size_t i = 0; i < NUM_SMALL_IPA_EVALUATIONS; i++) { + process_scalar(); + } + // Shplonk Q + process_commitment(); + + // --- Translation section --- + // Translator concatenated masking commitment + process_commitment(); + // 5 translation evaluations (op, Px, Py, z1, z2) + for (size_t i = 0; i < NUM_TRANSLATION_EVALUATIONS; i++) { + process_scalar(); + } + // Translation masking term evaluation + process_scalar(); + // Translation grand sum + quotient commitments + process_commitment(); + process_commitment(); + // Translation SmallSubgroupIPA evaluations + for (size_t i = 0; i < NUM_SMALL_IPA_EVALUATIONS; i++) { + process_scalar(); + } + // Translation Shplonk Q + process_commitment(); + } + + /** + * @brief Walk an IPA proof (64 Fr, all Grumpkin). + * @details IPA_PROOF_LENGTH = 4 * CONST_ECCVM_LOG_N + 4 + */ + template + static void walk_ipa_proof(ScalarFn&& process_scalar, CommitmentFn&& process_commitment) + { + // L and R commitments per round + for (size_t i = 0; i < CONST_ECCVM_LOG_N; i++) { + process_commitment(); // L_i + process_commitment(); // R_i + } + // G_0 commitment + process_commitment(); + // a_0 scalar + process_scalar(); + } + + /** + * @brief Walk a Translator proof (all BN254). + * @details Layout from TranslatorFlavor::PROOF_LENGTH formula. + */ + template + static void walk_translator_proof(ScalarFn&& process_scalar, CommitmentFn&& process_commitment) + { + constexpr size_t log_n = TranslatorFlavor::CONST_TRANSLATOR_LOG_N; + + // Gemini masking poly commitment + process_commitment(); + // Wire commitments: concatenated + ordered range constraints + for (size_t i = 0; i < TranslatorFlavor::NUM_COMMITMENTS_IN_PROOF; i++) { + process_commitment(); + } + // Z_PERM commitment + process_commitment(); + // Libra concatenation commitment + process_commitment(); + // Libra sum + process_scalar(); + // Sumcheck round univariates + for (size_t i = 0; i < log_n * TranslatorFlavor::BATCHED_RELATION_PARTIAL_LENGTH; i++) { + process_scalar(); + } + // Sumcheck evaluations (computable precomputed and concat evals excluded) + for (size_t i = 0; i < TranslatorFlavor::NUM_SENT_EVALUATIONS; i++) { + process_scalar(); + } + // Libra claimed evaluation + process_scalar(); + // Libra grand sum + quotient commitments + process_commitment(); + process_commitment(); + // Gemini fold commitments + for (size_t i = 0; i < log_n - 1; i++) { + process_commitment(); + } + // Gemini fold evaluations + for (size_t i = 0; i < log_n; i++) { + process_scalar(); + } + // Small IPA evaluations + for (size_t i = 0; i < NUM_SMALL_IPA_EVALUATIONS; i++) { + process_scalar(); + } + // Shplonk Q + KZG W + process_commitment(); + process_commitment(); + } + + /** + * @brief Walk a full Chonk proof (5 sub-proofs across two curves). + */ + template + static void walk_chonk_proof(BN254ScalarFn&& bn254_scalar, + BN254CommFn&& bn254_comm, + GrumpkinScalarFn&& grumpkin_scalar, + GrumpkinCommFn&& grumpkin_comm, + size_t mega_num_public_inputs) + { + walk_mega_zk_proof(bn254_scalar, bn254_comm, mega_num_public_inputs); + walk_merge_proof(bn254_scalar, bn254_comm); + walk_eccvm_proof(grumpkin_scalar, grumpkin_comm); + walk_ipa_proof(grumpkin_scalar, grumpkin_comm); + walk_translator_proof(bn254_scalar, bn254_comm); + } + + // ========================================================================= + // Walk count validation — ensure the constants used in walks match PROOF_LENGTH. + // These mirror the walk logic using the same constants; if a PROOF_LENGTH formula + // changes, the static_assert fires, prompting an update to the corresponding walk. + // ========================================================================= + + // Fr-elements per element type for each curve + static constexpr size_t BN254_FRS_PER_SCALAR = 1; + static constexpr size_t BN254_FRS_PER_COMM = 4; // Fq x,y each as (lo,hi) Fr pair + static constexpr size_t GRUMPKIN_FRS_PER_SCALAR = 2; // Fq stored as (lo,hi) Fr pair + static constexpr size_t GRUMPKIN_FRS_PER_COMM = 2; // Fr x,y coordinates + + // clang-format off + // MegaZK (without public inputs) — mirrors walk_mega_zk_proof with num_public_inputs=0 + static constexpr size_t EXPECTED_MEGA_ZK_FRS = + MegaZKFlavor::NUM_WITNESS_ENTITIES * BN254_FRS_PER_COMM + // witness comms + 1 * BN254_FRS_PER_COMM + // libra concat + 1 * BN254_FRS_PER_SCALAR + // libra sum + MegaZKFlavor::VIRTUAL_LOG_N * MegaZKFlavor::BATCHED_RELATION_PARTIAL_LENGTH * BN254_FRS_PER_SCALAR +// sumcheck univariates + MegaZKFlavor::NUM_ALL_ENTITIES * BN254_FRS_PER_SCALAR + // sumcheck evals + 1 * BN254_FRS_PER_SCALAR + // libra claimed eval + 2 * BN254_FRS_PER_COMM + // libra grand sum + quotient + (MegaZKFlavor::VIRTUAL_LOG_N - 1) * BN254_FRS_PER_COMM + // gemini folds + MegaZKFlavor::VIRTUAL_LOG_N * BN254_FRS_PER_SCALAR + // gemini evals + NUM_SMALL_IPA_EVALUATIONS * BN254_FRS_PER_SCALAR + // small IPA evals + 2 * BN254_FRS_PER_COMM; // shplonk Q + KZG W + static_assert(EXPECTED_MEGA_ZK_FRS == ChonkProof::HIDING_KERNEL_PROOF_LENGTH_WITHOUT_PUBLIC_INPUTS); + + // Merge — mirrors walk_merge_proof + static constexpr size_t EXPECTED_MERGE_FRS = + 1 * BN254_FRS_PER_SCALAR + // shift_size + 5 * BN254_FRS_PER_COMM + // 4 merged tables + 1 reversed batched left + 13 * BN254_FRS_PER_SCALAR + // evaluations + 2 * BN254_FRS_PER_COMM; // shplonk Q + KZG W + static_assert(EXPECTED_MERGE_FRS == MERGE_PROOF_SIZE); + + // ECCVM — mirrors walk_eccvm_proof + static constexpr size_t EXPECTED_ECCVM_FRS = + (ECCVMFlavor::NUM_WITNESS_ENTITIES + ECCVMFlavor::NUM_MASKING_POLYNOMIALS) * GRUMPKIN_FRS_PER_COMM + // witnesses + 1 * GRUMPKIN_FRS_PER_COMM + // libra concat + 1 * GRUMPKIN_FRS_PER_SCALAR + // libra sum + CONST_ECCVM_LOG_N * GRUMPKIN_FRS_PER_COMM + // sumcheck univariate comms + 2 * CONST_ECCVM_LOG_N * GRUMPKIN_FRS_PER_SCALAR + // sumcheck univariate evals (2 per round) + ECCVMFlavor::NUM_ALL_ENTITIES * GRUMPKIN_FRS_PER_SCALAR + // sumcheck evals + 1 * GRUMPKIN_FRS_PER_SCALAR + // libra claimed eval + 2 * GRUMPKIN_FRS_PER_COMM + // libra grand sum + quotient + (CONST_ECCVM_LOG_N - 1) * GRUMPKIN_FRS_PER_COMM + // gemini folds + CONST_ECCVM_LOG_N * GRUMPKIN_FRS_PER_SCALAR + // gemini evals + NUM_SMALL_IPA_EVALUATIONS * GRUMPKIN_FRS_PER_SCALAR + // small IPA evals + 1 * GRUMPKIN_FRS_PER_COMM + // shplonk Q + 1 * GRUMPKIN_FRS_PER_COMM + // translator masking comm + NUM_TRANSLATION_EVALUATIONS * GRUMPKIN_FRS_PER_SCALAR + // translation evals + 1 * GRUMPKIN_FRS_PER_SCALAR + // masking term eval + 2 * GRUMPKIN_FRS_PER_COMM + // translation grand sum + quotient + NUM_SMALL_IPA_EVALUATIONS * GRUMPKIN_FRS_PER_SCALAR + // translation small IPA evals + 1 * GRUMPKIN_FRS_PER_COMM; // translation shplonk Q + static_assert(EXPECTED_ECCVM_FRS == ECCVMFlavor::PROOF_LENGTH); + + // IPA — mirrors walk_ipa_proof + static constexpr size_t EXPECTED_IPA_FRS = + 2 * CONST_ECCVM_LOG_N * GRUMPKIN_FRS_PER_COMM + // L and R per round + 1 * GRUMPKIN_FRS_PER_COMM + // G_0 + 1 * GRUMPKIN_FRS_PER_SCALAR; // a_0 + static_assert(EXPECTED_IPA_FRS == IPA_PROOF_LENGTH); + + // Translator — mirrors walk_translator_proof + static constexpr size_t EXPECTED_TRANSLATOR_FRS = + 1 * BN254_FRS_PER_COMM + // gemini masking poly + TranslatorFlavor::NUM_COMMITMENTS_IN_PROOF * BN254_FRS_PER_COMM + // wire comms (concat + ordered) + 1 * BN254_FRS_PER_COMM + // z_perm + 1 * BN254_FRS_PER_COMM + // libra concat + 1 * BN254_FRS_PER_SCALAR + // libra sum + TranslatorFlavor::CONST_TRANSLATOR_LOG_N * TranslatorFlavor::BATCHED_RELATION_PARTIAL_LENGTH * BN254_FRS_PER_SCALAR + // sumcheck univariates + TranslatorFlavor::NUM_SENT_EVALUATIONS * BN254_FRS_PER_SCALAR + // sumcheck evals + 1 * BN254_FRS_PER_SCALAR + // libra claimed eval + 2 * BN254_FRS_PER_COMM + // libra grand sum + quotient + (TranslatorFlavor::CONST_TRANSLATOR_LOG_N - 1) * BN254_FRS_PER_COMM + // gemini folds + TranslatorFlavor::CONST_TRANSLATOR_LOG_N * BN254_FRS_PER_SCALAR + // gemini evals + NUM_SMALL_IPA_EVALUATIONS * BN254_FRS_PER_SCALAR + // small IPA evals + 2 * BN254_FRS_PER_COMM; // shplonk Q + KZG W + static_assert(EXPECTED_TRANSLATOR_FRS == TranslatorFlavor::PROOF_LENGTH); + // clang-format on + + public: + /** + * @brief Count the total compressed elements for a Chonk proof. + * Each element (scalar or commitment, either curve) compresses to exactly 32 bytes. + */ + static size_t compressed_element_count(size_t mega_num_public_inputs = 0) + { + size_t count = 0; + auto counter = [&]() { count++; }; + walk_chonk_proof(counter, counter, counter, counter, mega_num_public_inputs); + return count; + } + + /** + * @brief Derive mega_num_public_inputs from compressed proof size. + * @param compressed_bytes Total size of the compressed proof in bytes. + */ + static size_t compressed_mega_num_public_inputs(size_t compressed_bytes) + { + BB_ASSERT(compressed_bytes % 32 == 0); + size_t total_elements = compressed_bytes / 32; + size_t fixed_elements = compressed_element_count(0); + BB_ASSERT(total_elements >= fixed_elements); + return total_elements - fixed_elements; + } + + // ========================================================================= + // Chonk proof compression + // ========================================================================= + + static std::vector compress_chonk_proof(const ChonkProof& proof) + { + auto flat = proof.to_field_elements(); + std::vector out; + out.reserve(flat.size() * 32); // upper bound: every element compresses to 32 bytes + size_t offset = 0; + + // BN254 callbacks + auto bn254_scalar = [&]() { write_u256(out, uint256_t(flat[offset++])); }; + + auto bn254_comm = [&]() { + bool is_infinity = flat[offset].is_zero() && flat[offset + 1].is_zero() && flat[offset + 2].is_zero() && + flat[offset + 3].is_zero(); + if (is_infinity) { + write_u256(out, uint256_t(0)); + offset += 4; + return; + } + + Fq x = reconstruct_fq(flat[offset], flat[offset + 1]); + Fq y = reconstruct_fq(flat[offset + 2], flat[offset + 3]); + offset += 4; + + uint256_t x_val = uint256_t(x); + if (y_is_negative(y)) { + x_val |= SIGN_BIT_MASK; + } + write_u256(out, x_val); + }; + + // Grumpkin callbacks + // Grumpkin commitments have coordinates in BN254::ScalarField (Fr), so x and y are each 1 Fr. + auto grumpkin_comm = [&]() { + Fr x = flat[offset]; + Fr y = flat[offset + 1]; + offset += 2; + + if (x.is_zero() && y.is_zero()) { + write_u256(out, uint256_t(0)); + return; + } + + uint256_t x_val = uint256_t(x); + if (y_is_negative(y)) { + x_val |= SIGN_BIT_MASK; + } + write_u256(out, x_val); + }; + + // Grumpkin scalars are Fq values stored as (lo, hi) Fr pairs + auto grumpkin_scalar = [&]() { + Fq fq_val = reconstruct_fq(flat[offset], flat[offset + 1]); + offset += 2; + write_u256(out, uint256_t(fq_val)); + }; + + size_t mega_num_pub_inputs = + proof.mega_proof.size() - ChonkProof::HIDING_KERNEL_PROOF_LENGTH_WITHOUT_PUBLIC_INPUTS; + walk_chonk_proof(bn254_scalar, bn254_comm, grumpkin_scalar, grumpkin_comm, mega_num_pub_inputs); + BB_ASSERT(offset == flat.size()); + return out; + } + + static ChonkProof decompress_chonk_proof(const std::vector& compressed, size_t mega_num_public_inputs) + { + HonkProof flat; + size_t pos = 0; + + // BN254 callbacks + auto bn254_scalar = [&]() { flat.emplace_back(read_u256(compressed, pos)); }; + + auto bn254_comm = [&]() { + uint256_t raw = read_u256(compressed, pos); + bool sign = (raw & SIGN_BIT_MASK) != 0; + uint256_t x_val = raw & ~SIGN_BIT_MASK; + + if (x_val == uint256_t(0) && !sign) { + for (int j = 0; j < 4; j++) { + flat.emplace_back(Fr::zero()); + } + return; + } + + Fq x(x_val); + Fq y_squared = x * x * x + Bn254G1Params::b; + auto [is_square, y] = y_squared.sqrt(); + BB_ASSERT(is_square); + + if (y_is_negative(y) != sign) { + y = -y; + } + + auto [x_lo, x_hi] = split_fq(x); + auto [y_lo, y_hi] = split_fq(y); + flat.emplace_back(x_lo); + flat.emplace_back(x_hi); + flat.emplace_back(y_lo); + flat.emplace_back(y_hi); + }; + + // Grumpkin callbacks + auto grumpkin_comm = [&]() { + uint256_t raw = read_u256(compressed, pos); + bool sign = (raw & SIGN_BIT_MASK) != 0; + uint256_t x_val = raw & ~SIGN_BIT_MASK; + + if (x_val == uint256_t(0) && !sign) { + flat.emplace_back(Fr::zero()); + flat.emplace_back(Fr::zero()); + return; + } + + Fr x(x_val); + // Grumpkin curve: y² = x³ + b, where b = -17 (in BN254::ScalarField) + Fr y_squared = x * x * x + grumpkin::G1Params::b; + auto [is_square, y] = y_squared.sqrt(); + BB_ASSERT(is_square); + + if (y_is_negative(y) != sign) { + y = -y; + } + + flat.emplace_back(x); + flat.emplace_back(y); + }; + + auto grumpkin_scalar = [&]() { + uint256_t raw = read_u256(compressed, pos); + Fq fq_val(raw); + auto [lo, hi] = split_fq(fq_val); + flat.emplace_back(lo); + flat.emplace_back(hi); + }; + + walk_chonk_proof(bn254_scalar, bn254_comm, grumpkin_scalar, grumpkin_comm, mega_num_public_inputs); + BB_ASSERT(pos == compressed.size()); + return ChonkProof::from_field_elements(flat); + } +}; + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/commitment_schemes/gemini/gemini.hpp b/barretenberg/cpp/src/barretenberg/commitment_schemes/gemini/gemini.hpp index 3f1a41927378..fa441c6904a6 100644 --- a/barretenberg/cpp/src/barretenberg/commitment_schemes/gemini/gemini.hpp +++ b/barretenberg/cpp/src/barretenberg/commitment_schemes/gemini/gemini.hpp @@ -125,6 +125,7 @@ template class GeminiProver_ { class PolynomialBatcher { size_t full_batched_size = 0; // size of the full batched polynomial (generally the circuit size) + size_t actual_data_size_ = 0; // max end_index across all polynomials (actual data extent) Polynomial batched_unshifted; // linear combination of unshifted polynomials Polynomial batched_to_be_shifted_by_one; // linear combination of to-be-shifted polynomials @@ -133,10 +134,11 @@ template class GeminiProver_ { RefVector unshifted; // set of unshifted polynomials RefVector to_be_shifted_by_one; // set of polynomials to be left shifted by 1 - PolynomialBatcher(const size_t full_batched_size) + PolynomialBatcher(const size_t full_batched_size, const size_t actual_data_size = 0) : full_batched_size(full_batched_size) - , batched_unshifted(full_batched_size) - , batched_to_be_shifted_by_one(Polynomial::shiftable(full_batched_size)) + , actual_data_size_(actual_data_size == 0 ? full_batched_size : actual_data_size) + , batched_unshifted(actual_data_size_, full_batched_size) + , batched_to_be_shifted_by_one(Polynomial::shiftable(actual_data_size_, full_batched_size)) {} bool has_unshifted() const { return unshifted.size() > 0; } @@ -191,8 +193,8 @@ template class GeminiProver_ { */ std::pair compute_partially_evaluated_batch_polynomials(const Fr& r_challenge) { - // Initialize A₀₊ and compute A₀₊ += F as necessary - Polynomial A_0_pos(full_batched_size); // A₀₊ + // Initialize A₀₊ with only the actual data extent; virtual zeroes cover the rest + Polynomial A_0_pos(actual_data_size_, full_batched_size); // A₀₊ if (has_unshifted()) { A_0_pos += batched_unshifted; // A₀₊ += F diff --git a/barretenberg/cpp/src/barretenberg/commitment_schemes/shplonk/shplonk.hpp b/barretenberg/cpp/src/barretenberg/commitment_schemes/shplonk/shplonk.hpp index 0840c841da17..5325a763a011 100644 --- a/barretenberg/cpp/src/barretenberg/commitment_schemes/shplonk/shplonk.hpp +++ b/barretenberg/cpp/src/barretenberg/commitment_schemes/shplonk/shplonk.hpp @@ -61,10 +61,6 @@ template class ShplonkProver_ { max_poly_size = std::max(max_poly_size, claim.polynomial.size()); } } - // The polynomials in Sumcheck Round claims and Libra opening claims are generally not dyadic, - // so we round up to the next power of 2. - max_poly_size = numeric::round_up_power_2(max_poly_size); - // Q(X) = ∑ⱼ νʲ ⋅ ( fⱼ(X) − vⱼ) / ( X − xⱼ ) Polynomial Q(max_poly_size); Polynomial tmp(max_poly_size); diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp index 369dbcb0d981..233934567ff3 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.hpp @@ -17,6 +17,11 @@ #include namespace bb::group_elements { + +// MSB of the top 64-bit limb in a uint256_t (bit 255). Used in point compression to encode the +// y-coordinate parity bit, and cleared when recovering the x-coordinate. +static constexpr uint64_t UINT256_TOP_LIMB_MSB = 0x8000000000000000ULL; + template concept SupportsHashToCurve = T::can_hash_to_curve; template class alignas(64) affine_element { @@ -80,10 +85,6 @@ template class alignas(64) affine constexpr affine_element operator*(const Fr& exponent) const noexcept; - template > 255) == uint256_t(0), void>> - [[nodiscard]] constexpr uint256_t compress() const noexcept; - static constexpr affine_element infinity(); constexpr affine_element set_infinity() const noexcept; constexpr void self_set_infinity() noexcept; diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp index 420ad908be68..3dd3b3f048cf 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element.test.cpp @@ -139,7 +139,10 @@ template class TestAffineElement : public testing::Test { { for (size_t i = 0; i < 10; i++) { affine_element P = affine_element(element::random_element()); - uint256_t compressed = P.compress(); + uint256_t compressed = uint256_t(P.x); + if (uint256_t(P.y).get_bit(0)) { + compressed.data[3] |= group_elements::UINT256_TOP_LIMB_MSB; + } affine_element Q = affine_element::from_compressed(compressed); EXPECT_EQ(P, Q); } @@ -168,8 +171,6 @@ template class TestAffineElement : public testing::Test { affine_element R(0, P.y); ASSERT_FALSE(P == R); } - // Regression test to ensure that the point at infinity is not equal to its coordinate-wise reduction, which may lie - // on the curve, depending on the y-coordinate. static void test_infinity_ordering_regression() { affine_element P(0, 1); diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp index ba51cb33093a..294f4889125e 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/affine_element_impl.hpp @@ -21,7 +21,7 @@ template constexpr affine_element affine_element::from_compressed(const uint256_t& compressed) noexcept { uint256_t x_coordinate = compressed; - x_coordinate.data[3] = x_coordinate.data[3] & (~0x8000000000000000ULL); + x_coordinate.data[3] = x_coordinate.data[3] & (~UINT256_TOP_LIMB_MSB); bool y_bit = compressed.get_bit(255); Fq x = Fq(x_coordinate); @@ -80,18 +80,6 @@ constexpr affine_element affine_element::operator*(const F return bb::group_elements::element(*this) * exponent; } -template -template - -constexpr uint256_t affine_element::compress() const noexcept -{ - uint256_t out(x); - if (uint256_t(y).get_bit(0)) { - out.data[3] = out.data[3] | 0x8000000000000000ULL; - } - return out; -} - template constexpr affine_element affine_element::infinity() { affine_element e{}; @@ -157,15 +145,9 @@ constexpr bool affine_element::operator==(const affine_element& other return !only_one_is_infinity && (both_infinity || ((x == other.x) && (y == other.y))); } -/** - * Comparison operators (for std::sort) - * - * @details CAUTION!! Don't use this operator. It has no meaning other than for use by std::sort. - **/ template constexpr bool affine_element::operator>(const affine_element& other) const noexcept { - // We are setting point at infinity to always be the lowest element if (is_point_at_infinity()) { return false; } diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp index fc2afad7028d..b6ae9683eca4 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element.hpp @@ -59,7 +59,6 @@ template class alignas(32) element { constexpr element dbl() const noexcept; constexpr void self_dbl() noexcept; - constexpr void self_mixed_add_or_sub(const affine_element& other, uint64_t predicate) noexcept; constexpr element operator+(const element& other) const noexcept; constexpr element operator+(const affine_element& other) const noexcept; @@ -128,27 +127,6 @@ template class alignas(32) element { template > static element random_coordinates_on_curve(numeric::RNG* engine = nullptr) noexcept; - // { - // bool found_one = false; - // Fq yy; - // Fq x; - // Fq y; - // Fq t0; - // while (!found_one) { - // x = Fq::random_element(engine); - // yy = x.sqr() * x + Params::b; - // if constexpr (Params::has_a) { - // yy += (x * Params::a); - // } - // y = yy.sqrt(); - // t0 = y.sqr(); - // found_one = (yy == t0); - // } - // return { x, y, Fq::one() }; - // } - static void conditional_negate_affine(const affine_element& in, - affine_element& out, - uint64_t predicate) noexcept; friend std::ostream& operator<<(std::ostream& os, const element& a) { @@ -162,10 +140,6 @@ template std::ostream& operator<<(std::ostrea return os << "x:" << e.x << " y:" << e.y << " z:" << e.z; } -// constexpr element::one = element{ Params::one_x, Params::one_y, Fq::one() }; -// constexpr element::point_at_infinity = one.set_infinity(); -// constexpr element::curve_b = Params::b; - } // namespace bb::group_elements #include "./element_impl.hpp" diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp index 3e45799aaade..54ce94e5feb7 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp @@ -155,99 +155,6 @@ template constexpr element element -constexpr void element::self_mixed_add_or_sub(const affine_element& other, - const uint64_t predicate) noexcept -{ - if constexpr (Fq::modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) { - if (is_point_at_infinity()) { - conditional_negate_affine(other, *(affine_element*)this, predicate); // NOLINT - z = Fq::one(); - return; - } - } else { - const bool edge_case_trigger = x.is_msb_set() || other.x.is_msb_set(); - if (edge_case_trigger) { - if (x.is_msb_set()) { - conditional_negate_affine(other, *(affine_element*)this, predicate); // NOLINT - z = Fq::one(); - } - return; - } - } - - // T0 = z1.z1 - Fq T0 = z.sqr(); - - // T1 = x2.t0 - x1 = x2.z1.z1 - x1 - Fq T1 = other.x * T0; - T1 -= x; - - // T2 = T0.z1 = z1.z1.z1 - // T2 = T2.y2 - y1 = y2.z1.z1.z1 - y1 - Fq T2 = z * T0; - T2 *= other.y; - T2.self_conditional_negate(predicate); - T2 -= y; - - if (__builtin_expect(T1.is_zero(), 0)) { - if (T2.is_zero()) { - // y2 equals y1, x2 equals x1, double x1 - self_dbl(); - return; - } - self_set_infinity(); - return; - } - - // T2 = 2T2 = 2(y2.z1.z1.z1 - y1) = R - // z3 = z1 + H - T2 += T2; - z += T1; - - // T3 = T1*T1 = HH - Fq T3 = T1.sqr(); - - // z3 = z3 - z1z1 - HH - T0 += T3; - - // z3 = (z1 + H)*(z1 + H) - z.self_sqr(); - z -= T0; - - // T3 = 4HH - T3 += T3; - T3 += T3; - - // T1 = T1*T3 = 4HHH - T1 *= T3; - - // T3 = T3 * x1 = 4HH*x1 - T3 *= x; - - // T0 = 2T3 - T0 = T3 + T3; - - // T0 = T0 + T1 = 2(4HH*x1) + 4HHH - T0 += T1; - x = T2.sqr(); - - // x3 = x3 - T0 = R*R - 8HH*x1 -4HHH - x -= T0; - - // T3 = T3 - x3 = 4HH*x1 - x3 - T3 -= x; - - T1 *= y; - T1 += T1; - - // T3 = T2 * T3 = R*(4HH*x1 - x3) - T3 *= T2; - - // y3 = T3 - T1 - y = T3 - T1; -} - template constexpr element element::operator+=(const affine_element& other) noexcept { @@ -1057,14 +964,6 @@ std::vector> element::batch_mul_with_endomo return work_elements; } -template -void element::conditional_negate_affine(const affine_element& in, - affine_element& out, - const uint64_t predicate) noexcept -{ - out = { in.x, predicate ? -in.y : in.y }; -} - template void element::batch_normalize(element* elements, const size_t num_elements) noexcept { diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp index e60de391c108..4c02e1c39983 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/group.hpp @@ -130,16 +130,6 @@ template class group { } return derive_generators(domain_bytes, num_generators, starting_index); } - - BB_INLINE static void conditional_negate_affine(const affine_element* src, - affine_element* dest, - uint64_t predicate); }; } // namespace bb - -#ifdef DISABLE_ASM -#include "group_impl_int128.tcc" -#else -#include "group_impl_asm.tcc" -#endif diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_asm.tcc b/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_asm.tcc deleted file mode 100644 index 2177ba1ad37a..000000000000 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_asm.tcc +++ /dev/null @@ -1,162 +0,0 @@ -#pragma once - -#ifndef DISABLE_ASM - -#include "barretenberg/ecc/groups/group.hpp" -#include - -namespace bb { -// copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const affine_element* src, affine_element* -// dest) -// { -// if constexpr (Params::small_elements) { -// #if defined __AVX__ && defined USE_AVX -// ASSERT((((uintptr_t)src & 0x1f) == 0)); -// ASSERT((((uintptr_t)dest & 0x1f) == 0)); -// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t" -// "vmovdqa 32(%0), %%ymm1 \n\t" -// "vmovdqa %%ymm0, 0(%1) \n\t" -// "vmovdqa %%ymm1, 32(%1) \n\t" -// : -// : "r"(src), "r"(dest) -// : "%ymm0", "%ymm1", "memory"); -// #else -// *dest = *src; -// #endif -// } else { -// *dest = *src; -// } -// } - -// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const element* src, element* dest) -// { -// if constexpr (Params::small_elements) { -// #if defined __AVX__ && defined USE_AVX -// ASSERT((((uintptr_t)src & 0x1f) == 0)); -// ASSERT((((uintptr_t)dest & 0x1f) == 0)); -// __asm__ __volatile__("vmovdqa 0(%0), %%ymm0 \n\t" -// "vmovdqa 32(%0), %%ymm1 \n\t" -// "vmovdqa 64(%0), %%ymm2 \n\t" -// "vmovdqa %%ymm0, 0(%1) \n\t" -// "vmovdqa %%ymm1, 32(%1) \n\t" -// "vmovdqa %%ymm2, 64(%1) \n\t" -// : -// : "r"(src), "r"(dest) -// : "%ymm0", "%ymm1", "%ymm2", "memory"); -// #else -// *dest = *src; -// #endif -// } else { -// *dest = src; -// } -// } - -// copies src into dest, inverting y-coordinate if 'predicate' is true -// n.b. requires src and dest to be aligned on 32 byte boundary -template -inline void group::conditional_negate_affine(const affine_element* src, - affine_element* dest, - uint64_t predicate) -{ - constexpr uint256_t twice_modulus = Fq::modulus + Fq::modulus; - - constexpr uint64_t twice_modulus_0 = twice_modulus.data[0]; - constexpr uint64_t twice_modulus_1 = twice_modulus.data[1]; - constexpr uint64_t twice_modulus_2 = twice_modulus.data[2]; - constexpr uint64_t twice_modulus_3 = twice_modulus.data[3]; - - if constexpr (Params::small_elements) { -#if defined __AVX__ && defined USE_AVX - BB_ASSERT_EQ(((uintptr_t)src & 0x1f, 0)); - BB_ASSERT_EQ(((uintptr_t)dest & 0x1f, 0)); - __asm__ __volatile__("xorq %%r8, %%r8 \n\t" - "movq 32(%0), %%r8 \n\t" - "movq 40(%0), %%r9 \n\t" - "movq 48(%0), %%r10 \n\t" - "movq 56(%0), %%r11 \n\t" - "movq %[modulus_0], %%r12 \n\t" - "movq %[modulus_1], %%r13 \n\t" - "movq %[modulus_2], %%r14 \n\t" - "movq %[modulus_3], %%r15 \n\t" - "subq %%r8, %%r12 \n\t" - "sbbq %%r9, %%r13 \n\t" - "sbbq %%r10, %%r14 \n\t" - "sbbq %%r11, %%r15 \n\t" - "testq %2, %2 \n\t" - "cmovnzq %%r12, %%r8 \n\t" - "cmovnzq %%r13, %%r9 \n\t" - "cmovnzq %%r14, %%r10 \n\t" - "cmovnzq %%r15, %%r11 \n\t" - "vmovdqa 0(%0), %%ymm0 \n\t" - "vmovdqa %%ymm0, 0(%1) \n\t" - "movq %%r8, 32(%1) \n\t" - "movq %%r9, 40(%1) \n\t" - "movq %%r10, 48(%1) \n\t" - "movq %%r11, 56(%1) \n\t" - : - : "r"(src), - "r"(dest), - "r"(predicate), - [modulus_0] "i"(twice_modulus_0), - [modulus_1] "i"(twice_modulus_1), - [modulus_2] "i"(twice_modulus_2), - [modulus_3] "i"(twice_modulus_3) - : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "%ymm0", "memory", "cc"); -#else - __asm__ __volatile__("xorq %%r8, %%r8 \n\t" - "movq 32(%0), %%r8 \n\t" - "movq 40(%0), %%r9 \n\t" - "movq 48(%0), %%r10 \n\t" - "movq 56(%0), %%r11 \n\t" - "movq %[modulus_0], %%r12 \n\t" - "movq %[modulus_1], %%r13 \n\t" - "movq %[modulus_2], %%r14 \n\t" - "movq %[modulus_3], %%r15 \n\t" - "subq %%r8, %%r12 \n\t" - "sbbq %%r9, %%r13 \n\t" - "sbbq %%r10, %%r14 \n\t" - "sbbq %%r11, %%r15 \n\t" - "testq %2, %2 \n\t" - "cmovnzq %%r12, %%r8 \n\t" - "cmovnzq %%r13, %%r9 \n\t" - "cmovnzq %%r14, %%r10 \n\t" - "cmovnzq %%r15, %%r11 \n\t" - "movq 0(%0), %%r12 \n\t" - "movq 8(%0), %%r13 \n\t" - "movq 16(%0), %%r14 \n\t" - "movq 24(%0), %%r15 \n\t" - "movq %%r8, 32(%1) \n\t" - "movq %%r9, 40(%1) \n\t" - "movq %%r10, 48(%1) \n\t" - "movq %%r11, 56(%1) \n\t" - "movq %%r12, 0(%1) \n\t" - "movq %%r13, 8(%1) \n\t" - "movq %%r14, 16(%1) \n\t" - "movq %%r15, 24(%1) \n\t" - : - : "r"(src), - "r"(dest), - "r"(predicate), - [modulus_0] "i"(twice_modulus_0), - [modulus_1] "i"(twice_modulus_1), - [modulus_2] "i"(twice_modulus_2), - [modulus_3] "i"(twice_modulus_3) - : "%r8", "%r9", "%r10", "%r11", "%r12", "%r13", "%r14", "%r15", "memory", "cc"); -#endif - } else { - if (predicate) { // NOLINT - Fq::__copy(src->x, dest->x); - dest->y = -src->y; - } else { - copy_affine(*src, *dest); - } - } -} - -} // namespace bb - -#endif diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_int128.tcc b/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_int128.tcc deleted file mode 100644 index 761cbe7d1334..000000000000 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/group_impl_int128.tcc +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#ifdef DISABLE_ASM - -#include "barretenberg/ecc/groups/group.hpp" -#include - -namespace bb { - -// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const affine_element* src, affine_element* -// dest) -// { -// *dest = *src; -// } - -// // copies src into dest. n.b. both src and dest must be aligned on 32 byte boundaries -// template -// inline void group::copy(const element* src, element* dest) -// { -// *dest = *src; -// } - -template -inline void group::conditional_negate_affine(const affine_element* src, - affine_element* dest, - uint64_t predicate) -{ - *dest = predicate ? -(*src) : (*src); -} -} // namespace bb - -#endif diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp index e6c32a388f8a..9ec8606b9788 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.hpp @@ -7,93 +7,60 @@ #pragma once #include "barretenberg/numeric/bitop/get_msb.hpp" #include -#include // NOLINTBEGIN(readability-implicit-bool-conversion) + +/** + * @brief Fixed-window non-adjacent form (WNAF) scalar decomposition for elliptic curve scalar multiplication. + * + * @details WNAF decomposes a scalar into a sequence of odd signed digits in the range [-(2^w - 1), 2^w - 1], + * where w = wnaf_bits. Each digit is packed into a uint64_t entry with the following bit layout: + * + * Bit 63 32 31 30 0 + * ┌────────────────────────┬────┬──────────────────────────┐ + * │ point_index │sign│ table_index │ + * └────────────────────────┴────┴──────────────────────────┘ + * + * - table_index (bits 0-30): abs(digit) >> 1. Since all digits are odd, the absolute value is always + * 2*k + 1 for some k, so table_index = k. This directly indexes a precomputed + * lookup table of odd multiples [1·P, 3·P, 5·P, ...]. + * In the Pippenger MSM path, this is the bucket index that determines which + * bucket the point is accumulated into. + * - sign (bit 31): 0 = positive digit, 1 = negative digit (negate the point's y-coordinate). + * - point_index (bits 32-63): identifies which input point this entry refers to. In single-scalar + * multiplication this is 0. In multi-scalar multiplication (Pippenger), + * this records which of the N input points the entry belongs to, since the + * schedule is later sorted by bucket and the original point ordering is lost. + * + * The template `wnaf_round` / `fixed_wnaf` variants shift point_index into bits 32+ internally. + * The runtime `fixed_wnaf` variant expects the caller to pass point_index pre-shifted. + */ namespace bb::wnaf { constexpr size_t SCALAR_BITS = 127; #define WNAF_SIZE(x) ((bb::wnaf::SCALAR_BITS + (x) - 1) / (x)) // NOLINT(cppcoreguidelines-macro-usage) -constexpr size_t get_optimal_bucket_width(const size_t num_points) -{ - if (num_points >= 14617149) { - return 21; - } - if (num_points >= 1139094) { - return 18; - } - // if (num_points >= 100000) - if (num_points >= 155975) { - return 15; - } - if (num_points >= 144834) - // if (num_points >= 100000) - { - return 14; - } - if (num_points >= 25067) { - return 12; - } - if (num_points >= 13926) { - return 11; - } - if (num_points >= 7659) { - return 10; - } - if (num_points >= 2436) { - return 9; - } - if (num_points >= 376) { - return 7; - } - if (num_points >= 231) { - return 6; - } - if (num_points >= 97) { - return 5; - } - if (num_points >= 35) { - return 4; - } - if (num_points >= 10) { - return 3; - } - if (num_points >= 2) { - return 2; - } - return 1; -} -constexpr size_t get_num_buckets(const size_t num_points) -{ - const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2); - return 1UL << bits_per_bucket; -} - -constexpr size_t get_num_rounds(const size_t num_points) -{ - const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2); - return WNAF_SIZE(bits_per_bucket + 1); -} +/** + * @brief Extract a window of `bits` consecutive bits starting at `bit_position` from a 128-bit scalar. + * + * @tparam bits The number of bits in the window (0 returns 0). + * @tparam bit_position The starting bit index within the 128-bit scalar. + * @param scalar Pointer to a 128-bit scalar stored as two consecutive uint64_t limbs (little-endian word order). + * @return The integer value of the extracted bit window. + * + * @details We determine which 64-bit limb(s) the window touches by computing + * lo_limb_idx = bit_position / 64 and hi_limb_idx = (bit_position + bits - 1) / 64. + * For the low limb, we right-shift by (bit_position % 64) to align the desired bits to position 0. + * If the window fits entirely within one limb (lo_limb_idx == hi_limb_idx), we simply mask off `bits` bits. + * Otherwise, the window straddles two limbs: we left-shift the high limb by (64 - bit_position % 64) to place + * its contributing bits adjacent to the low limb's bits, OR them together, and then mask to `bits` bits. + */ template inline uint64_t get_wnaf_bits_const(const uint64_t* scalar) noexcept { if constexpr (bits == 0) { return 0ULL; } else { - /** - * we want to take a 128 bit scalar and shift it down by (bit_position). - * We then wish to mask out `bits` number of bits. - * Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also - * (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual - * bit position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - - * (bit_position & 63) - * - * So, step 1: - * get low limb and shift right by (bit_position & 63) - * get high limb and shift left by (64 - (bit_position & 63)) - * - */ constexpr size_t lo_limb_idx = bit_position / 64; constexpr size_t hi_limb_idx = (bit_position + bits - 1) / 64; constexpr uint64_t lo_shift = bit_position & 63UL; @@ -110,21 +77,17 @@ template inline uint64_t get_wnaf_bits_const( } } +/** + * @brief A variant of the previous function that the bit position and number of bits are provided at runtime. + * + * @param scalar Pointer to a 128-bit scalar stored as two consecutive uint64_t limbs (little-endian word order). + * @param bits The number of bits in the window (0 returns 0). + * @param bit_position The starting bit index within the 128-bit scalar. + * @return The integer value of the extracted bit window. + */ inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const uint64_t bit_position) noexcept { - /** - * we want to take a 128 bit scalar and shift it down by (bit_position). - * We then wish to mask out `bits` number of bits. - * Low limb contains first 64 bits, so we wish to shift this limb by (bit_position mod 64), which is also - * (bit_position & 63) If we require bits from the high limb, these need to be shifted left, not right. Actual bit - * position of bit in high limb = `b`. Desired position = 64 - (amount we shifted low limb by) = 64 - (bit_position - * & 63) - * - * So, step 1: - * get low limb and shift right by (bit_position & 63) - * get high limb and shift left by (64 - (bit_position & 63)) - * - */ + const auto lo_limb_idx = static_cast(bit_position >> 6); const auto hi_limb_idx = static_cast((bit_position + bits - 1) >> 6); const uint64_t lo_shift = bit_position & 63UL; @@ -138,35 +101,11 @@ inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const return (lo & bit_mask) | (hi & hi_mask); } -inline void fixed_wnaf_packed( - const uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const uint64_t point_index, const size_t wnaf_bits) noexcept -{ - skew_map = ((scalar[0] & 1) == 0); - uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); - const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; - - for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { - uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[(wnaf_entries - round_i)] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - previous = slice + predicate; - } - size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1)); - uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - uint64_t predicate = ((slice & 1UL) == 0UL); - - wnaf[1] = ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); -} - /** * @brief Performs fixed-window non-adjacent form (WNAF) computation for scalar multiplication. * - * WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn optimizes - * the number of point doublings in scalar multiplication, in turn aiding efficiency. + * @details WNAF is a method for representing integers which optimizes the number of non-zero terms, which in turn + * optimizes the number of point doublings in scalar multiplication, in turn aiding efficiency. * * @param scalar Pointer to 128-bit scalar for which WNAF is to be computed. * @param wnaf Pointer to num_points+1 size array where the computed WNAF will be stored. @@ -182,16 +121,29 @@ inline void fixed_wnaf(const uint64_t* scalar, const uint64_t num_points, const size_t wnaf_bits) noexcept { + // If the scalar is even, we set the skew map to true. The skew is used to subtract a base point from the msm result + // in case scalar is even. skew_map = ((scalar[0] & 1) == 0); + // The first slice is the least significant slice of the scalar. uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; + // For the rest we start a rolling window of wnaf_bits bits, and compute the wnaf slice. for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); + // Check if the slice is even. This will be used to borrow from the previous slice. uint64_t predicate = ((slice & 1UL) == 0UL); + // If the current slice is odd (predicate=0), the WNAF digit is simply `previous`. + // If even (predicate=1), we borrow: subtract 2^wnaf_bits from `previous` to get a + // negative value, then negate via XOR with all-ones (two's complement identity: + // -x = ~x + 1, but we immediately shift right by 1, absorbing the +1 since the + // result is guaranteed odd). The >> 1 converts from the raw odd value to a bucket + // index (e.g., value 5 → bucket 2, value 7 → bucket 3). Bit 31 stores the sign + // (1 = negative), and the upper bits carry point_index for multi-scalar indexing. wnaf[(wnaf_entries - round_i) * num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index); + // Carry the borrow into the next window: if we borrowed, add 1 to the current slice. previous = slice + predicate; } size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1)); @@ -199,151 +151,19 @@ inline void fixed_wnaf(const uint64_t* scalar, uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); + ((((previous - (predicate << (wnaf_bits))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index); wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); } /** - * Current flow... - * - * If a wnaf entry is even, we add +1 to it, and subtract 32 from the previous entry. - * This works if the previous entry is odd. If we recursively apply this process, starting at the least significant - *window, this will always be the case. - * - * However, we want to skip over windows that are 0, which poses a problem. - * - * Scenario 1: even window followed by 0 window followed by any window 'x' - * - * We can't add 1 to the even window and subtract 32 from the 0 window, as we don't have a bucket that maps to -32 - * This means that we have to identify whether we are going to borrow 32 from 'x', requiring us to look at least 2 - *steps ahead - * - * Scenario 2: <0> <0> - * - * This problem proceeds indefinitely - if we have adjacent 0 windows, we do not know whether we need to track a - *borrow flag until we identify the next non-zero window - * - * Scenario 3: <0> - * - * This one works... - * - * Ok, so we should be a bit more limited with when we don't include window entries. - * The goal here is to identify short scalars, so we want to identify the most significant non-zero window - **/ -inline uint64_t get_num_scalar_bits(const uint64_t* scalar) -{ - const uint64_t msb_1 = numeric::get_msb(scalar[1]); - const uint64_t msb_0 = numeric::get_msb(scalar[0]); - - const uint64_t scalar_1_mask = (0ULL - (scalar[1] > 0)); - const uint64_t scalar_0_mask = (0ULL - (scalar[0] > 0)) & ~scalar_1_mask; - - const uint64_t msb = (scalar_1_mask & (msb_1 + 64)) | (scalar_0_mask & (msb_0)); - return msb; -} - -/** - * How to compute an x-bit wnaf slice? - * - * Iterate over number of slices in scalar. - * For each slice, if slice is even, ADD +1 to current slice and SUBTRACT 2^x from previous slice. - * (for 1st slice we instead add +1 and set the scalar's 'skew' value to 'true' (i.e. need to subtract 1 from it at the - * end of our scalar mul algo)) - * - * In *wnaf we store the following: - * 1. bits 0-30: ABSOLUTE value of wnaf (i.e. -3 goes to 3) - * 2. bit 31: 'predicate' bool (i.e. does the wnaf value need to be negated?) - * 3. bits 32-63: position in a point array that describes the elliptic curve point this wnaf slice is referencing - * - * N.B. IN OUR STDLIB ALGORITHMS THE SKEW VALUE REPRESENTS AN ADDITION NOT A SUBTRACTION (i.e. we add +1 at the end of - * the scalar mul algo we don't sub 1) (this is to eliminate situations which could produce the point at infinity as an - * output as our circuit logic cannot accommodate this edge case). - * - * Credits: Zac W. - * - * @param scalar Pointer to the 128-bit non-montgomery scalar that is supposed to be transformed into wnaf - * @param wnaf Pointer to output array that needs to accommodate enough 64-bit WNAF entries - * @param skew_map Reference to output skew value, which if true shows that the point should be added once at the end of - * computation - * @param wnaf_round_counts Pointer to output array specifying the number of points participating in each round - * @param point_index The index of the point that should be multiplied by this scalar in the point array - * @param num_points Total points in the MSM (2*num_initial_points) + * @brief Recursive WNAF round for a fixed 127-bit scalar (SCALAR_BITS). * + * @details Processes one window per recursive call, using compile-time unrolling via `round_i`. + * Uses the runtime `get_wnaf_bits` for bit extraction. The WNAF output array is interleaved: + * entry for round `r` is stored at index `(wnaf_entries - r) << log2(num_points)`, so that + * entries for the same round across different points are contiguous for cache locality. + * Each entry packs: bits [0..30] = lookup table index, bit 31 = sign, bits [32..63] = point_index. */ -inline void fixed_wnaf_with_counts(const uint64_t* scalar, - uint64_t* wnaf, - bool& skew_map, - uint64_t* wnaf_round_counts, - const uint64_t point_index, - const uint64_t num_points, - const size_t wnaf_bits) noexcept -{ - const size_t max_wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; - if ((scalar[0] | scalar[1]) == 0ULL) { - skew_map = false; - for (size_t round_i = 0; round_i < max_wnaf_entries; ++round_i) { - wnaf[(round_i)*num_points] = 0xffffffffffffffffULL; - } - return; - } - const auto current_scalar_bits = static_cast(get_num_scalar_bits(scalar) + 1); - skew_map = ((scalar[0] & 1) == 0); - uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast(skew_map); - const auto wnaf_entries = static_cast((current_scalar_bits + wnaf_bits - 1) / wnaf_bits); - - if (wnaf_entries == 1) { - wnaf[(max_wnaf_entries - 1) * num_points] = (previous >> 1UL) | (point_index); - ++wnaf_round_counts[max_wnaf_entries - 1]; - for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) { - wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL; - } - return; - } - - // If there are several windows - for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) { - - // Get a bit slice - uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); - - // Get the predicate (last bit is zero) - uint64_t predicate = ((slice & 1UL) == 0UL); - - // Update round count - ++wnaf_round_counts[max_wnaf_entries - round_i]; - - // Calculate entry value - // If the last bit of current slice is 1, we simply put the previous value with the point index - // If the last bit of the current slice is 0, we negate everything, so that we subtract from the WNAF form and - // make it 0 - wnaf[(max_wnaf_entries - round_i) * num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - - // Update the previous value to the next windows - previous = slice + predicate; - } - // The final iteration for top bits - auto final_bits = static_cast(current_scalar_bits - (wnaf_bits * (wnaf_entries - 1))); - uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - uint64_t predicate = ((slice & 1UL) == 0UL); - - ++wnaf_round_counts[(max_wnaf_entries - wnaf_entries + 1)]; - wnaf[((max_wnaf_entries - wnaf_entries + 1) * num_points)] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - - // Saving top bits - ++wnaf_round_counts[max_wnaf_entries - wnaf_entries]; - wnaf[(max_wnaf_entries - wnaf_entries) * num_points] = ((slice + predicate) >> 1UL) | (point_index); - - // Fill all unused slots with -1 - for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) { - wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL; - } -} - template inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept { @@ -354,21 +174,29 @@ inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_in uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf_round(scalar, wnaf, point_index, slice + predicate); } else { constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits; uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - // uint64_t slice = get_wnaf_bits_const(scalar); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); } } +/** + * @brief Recursive WNAF round for an arbitrary-width scalar. + * + * @details Same algorithm as the SCALAR_BITS overload above, but parametrized by `scalar_bits` so it can + * handle scalars of any bit width (e.g., after an endomorphism split produces shorter scalars). + * Uses the compile-time `get_wnaf_bits_const` for bit extraction since all parameters are template constants. + * Correctly handles the edge case where `scalar_bits` is an exact multiple of `wnaf_bits` (the final + * window is a full `wnaf_bits` wide rather than the remainder). + */ template inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept { @@ -379,7 +207,7 @@ inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_in uint64_t slice = get_wnaf_bits_const(scalar); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf_round(scalar, wnaf, point_index, slice + predicate); } else { @@ -389,41 +217,12 @@ inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_in uint64_t slice = get_wnaf_bits_const(scalar); uint64_t predicate = ((slice & 1UL) == 0UL); wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | + ((((previous - (predicate << wnaf_bits)) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | (point_index << 32UL); wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); } } -template -inline void wnaf_round_packed(const uint64_t* scalar, - uint64_t* wnaf, - const uint64_t point_index, - const uint64_t previous) noexcept -{ - constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits; - - if constexpr (round_i < wnaf_entries - 1) { - uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits); - // uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[(wnaf_entries - round_i)] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - wnaf_round_packed(scalar, wnaf, point_index, slice + predicate); - } else { - constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits; - uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits); - // uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[1] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index); - - wnaf[0] = ((slice + predicate) >> 1UL) | (point_index); - } -} - template inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept { @@ -440,80 +239,6 @@ inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const s wnaf_round(scalar, wnaf, point_index, previous); } -template -inline void wnaf_round_with_restricted_first_slice(uint64_t* scalar, - uint64_t* wnaf, - const uint64_t point_index, - const uint64_t previous) noexcept -{ - constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits; - constexpr auto log2_num_points = static_cast(numeric::get_msb(static_cast(num_points))); - constexpr size_t bits_in_first_slice = scalar_bits % wnaf_bits; - if constexpr (round_i == 1) { - uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - - wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (bits_in_first_slice /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | - (predicate << 31UL)) | - (point_index << 32UL); - if (round_i == 1) { - std::cerr << "writing value " << std::hex << wnaf[(wnaf_entries - round_i) << log2_num_points] << std::dec - << " at index " << ((wnaf_entries - round_i) << log2_num_points) << std::endl; - } - wnaf_round_with_restricted_first_slice( - scalar, wnaf, point_index, slice + predicate); - - } else if constexpr (round_i < wnaf_entries - 1) { - uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[(wnaf_entries - round_i) << log2_num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index << 32UL); - wnaf_round_with_restricted_first_slice( - scalar, wnaf, point_index, slice + predicate); - } else { - uint64_t slice = get_wnaf_bits_const(scalar); - uint64_t predicate = ((slice & 1UL) == 0UL); - wnaf[num_points] = - ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) | - (point_index << 32UL); - wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL); - } -} - -template -inline void fixed_wnaf_with_restricted_first_slice(uint64_t* scalar, - uint64_t* wnaf, - bool& skew_map, - const size_t point_index) noexcept -{ - constexpr size_t bits_in_first_slice = num_bits % wnaf_bits; - std::cerr << "bits in first slice = " << bits_in_first_slice << std::endl; - skew_map = ((scalar[0] & 1) == 0); - uint64_t previous = get_wnaf_bits_const(scalar) + static_cast(skew_map); - std::cerr << "previous = " << previous << std::endl; - wnaf_round_with_restricted_first_slice(scalar, wnaf, point_index, previous); -} - -// template -// inline void fixed_wnaf_packed(const uint64_t* scalar, -// uint64_t* wnaf, -// bool& skew_map, -// const uint64_t point_index) noexcept -// { -// skew_map = ((scalar[0] & 1) == 0); -// uint64_t previous = get_wnaf_bits_const(scalar) + (uint64_t)skew_map; -// wnaf_round_packed(scalar, wnaf, point_index, previous); -// } - -// template -// inline constexpr std::array fixed_wnaf(const uint64_t *scalar) const noexcept -// { -// bool skew_map = ((scalar[0] * 1) == 0); -// uint64_t previous = get_wnaf_bits_const(scalar) + (uint64_t)skew_map; -// std::array result; -// } } // namespace bb::wnaf // NOLINTEND(readability-implicit-bool-conversion) diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp index 7890577d64c1..b91504a99abb 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/wnaf.test.cpp @@ -14,26 +14,90 @@ namespace { void recover_fixed_wnaf(const uint64_t* wnaf, bool skew, uint64_t& hi, uint64_t& lo, size_t wnaf_bits) { - size_t wnaf_entries = (127 + wnaf_bits - 1) / wnaf_bits; - uint128_t scalar = 0; // (uint128_t)(skew); - for (int i = 0; i < static_cast(wnaf_entries); ++i) { - uint64_t entry_formatted = wnaf[static_cast(i)]; - bool negative = (entry_formatted >> 31) != 0U; - uint64_t entry = ((entry_formatted & 0x0fffffffU) << 1) + 1; + const size_t wnaf_entries = (127 + wnaf_bits - 1) / wnaf_bits; + const uint64_t max_table_index = (1UL << (wnaf_bits - 1)) - 1; + + for (size_t i = 0; i < wnaf_entries; ++i) { + uint64_t entry = wnaf[i]; + uint64_t table_index = entry & 0x7fffffffUL; + bool sign = ((entry >> 31) & 1) != 0U; + uint64_t point_index_bits = entry >> 32; + + EXPECT_LE(table_index, max_table_index) + << "entry " << i << ": table_index " << table_index << " exceeds max " << max_table_index; + + // The most significant digit is always positive by construction (no sign bit is OR'd in). + if (i == 0) { + EXPECT_FALSE(sign) << "entry 0 (most significant digit) must be positive"; + } + + // All current callers use point_index=0, so bits 32-63 should be clear. + EXPECT_EQ(point_index_bits, 0UL) << "entry " << i << ": unexpected non-zero point_index bits"; + } + + // Recover the scalar: sum signed odd digits at their positional weights, then subtract skew. + uint128_t scalar = 0; + for (size_t i = 0; i < wnaf_entries; ++i) { + uint64_t entry_formatted = wnaf[i]; + bool negative = ((entry_formatted >> 31) & 1) != 0U; + uint64_t digit = ((entry_formatted & 0x7fffffffUL) << 1) + 1; + auto shift = static_cast(wnaf_bits * (wnaf_entries - 1 - i)); if (negative) { - scalar -= (static_cast(entry)) - << static_cast(wnaf_bits * (wnaf_entries - 1 - static_cast(i))); + scalar -= static_cast(digit) << shift; } else { - scalar += (static_cast(entry)) - << static_cast(wnaf_bits * (wnaf_entries - 1 - static_cast(i))); + scalar += static_cast(digit) << shift; } } scalar -= static_cast(skew); hi = static_cast(scalar >> static_cast(64)); - lo = static_cast(static_cast(scalar & static_cast(0xffff'ffff'ffff'ffff))); + lo = static_cast(scalar & static_cast(0xffff'ffff'ffff'ffff)); } } // namespace +TEST(wnaf, GetWnafBitsConstLimbBoundary) +{ + // scalar[0] bits 59-63 = 1,0,1,0,1 and scalar[1] bits 0-4 = 1,0,1,0,1 + // Full bit pattern around the boundary (bit 63 | bit 64): + // bit: ...59 60 61 62 63 | 64 65 66 67 68 69... + // val: ... 1 0 1 0 1 | 1 0 1 0 1 0... + const uint64_t scalar[2] = { 0xA800000000000000ULL, 0x0000000000000015ULL }; + + // Window starts at bit 63 — straddles the limb boundary (2 bits from lo, 3 from hi) + // bits 63,64,65,66,67 = 1,1,0,1,0 → 1 + 2 + 0 + 8 + 0 = 11 + EXPECT_EQ((wnaf::get_wnaf_bits_const<5, 63>(scalar)), 11ULL); + + // Window starts at bit 64 — exactly at the hi limb start + // bits 64,65,66,67,68 = 1,0,1,0,1 → 1 + 0 + 4 + 0 + 16 = 21 + EXPECT_EQ((wnaf::get_wnaf_bits_const<5, 64>(scalar)), 21ULL); + + // Window starts at bit 65 — one past the boundary, entirely in hi limb + // bits 65,66,67,68,69 = 0,1,0,1,0 → 0 + 2 + 0 + 8 + 0 = 10 + EXPECT_EQ((wnaf::get_wnaf_bits_const<5, 65>(scalar)), 10ULL); +} + +TEST(wnaf, WnafPowerOfTwo) +{ + // Powers of 2 are all even (skew = true) and have a single 1-bit with all lower bits zero, + // so every window below the leading bit is even, forcing borrows to cascade through all rounds. + auto test_power_of_two_scalar = [](uint64_t lo, uint64_t hi) { + uint64_t buffer[2] = { lo, hi }; + uint64_t wnaf_out[WNAF_SIZE(5)] = { 0 }; + bool skew = false; + wnaf::fixed_wnaf<1, 5>(buffer, wnaf_out, skew, 0); + EXPECT_TRUE(skew); // all powers of 2 are even + uint64_t recovered_hi = 0; + uint64_t recovered_lo = 0; + recover_fixed_wnaf(wnaf_out, skew, recovered_hi, recovered_lo, 5); + EXPECT_EQ(lo, recovered_lo); + EXPECT_EQ(hi, recovered_hi); + }; + + test_power_of_two_scalar(2ULL, 0ULL); // 2^1: smallest even, borrows cascade through all 26 windows + test_power_of_two_scalar(1ULL << 32, 0ULL); // 2^32: mid-lo-limb + test_power_of_two_scalar(0ULL, 1ULL); // 2^64: exactly at the limb boundary + test_power_of_two_scalar(0ULL, 1ULL << 62); // 2^126: near the 127-bit maximum +} + TEST(wnaf, WnafZero) { uint64_t buffer[2]{ 0, 0 }; diff --git a/barretenberg/cpp/src/barretenberg/flavor/prover_polynomials.hpp b/barretenberg/cpp/src/barretenberg/flavor/prover_polynomials.hpp index 6af015c231b3..a7290f15d807 100644 --- a/barretenberg/cpp/src/barretenberg/flavor/prover_polynomials.hpp +++ b/barretenberg/cpp/src/barretenberg/flavor/prover_polynomials.hpp @@ -64,6 +64,16 @@ class ProverPolynomialsBase : public AllEntitiesBase { } } + // Returns the maximum end_index across all polynomials (i.e. the actual data extent) + [[nodiscard]] size_t max_end_index() const + { + size_t result = 0; + for (const auto& poly : this->get_all()) { + result = std::max(result, poly.end_index()); + } + return result; + } + void increase_polynomials_virtual_size(const size_t size_in) { for (auto& polynomial : this->get_all()) { diff --git a/barretenberg/cpp/src/barretenberg/polynomials/polynomial.cpp b/barretenberg/cpp/src/barretenberg/polynomials/polynomial.cpp index e5b9d5ea022e..f7f4009f8747 100644 --- a/barretenberg/cpp/src/barretenberg/polynomials/polynomial.cpp +++ b/barretenberg/cpp/src/barretenberg/polynomials/polynomial.cpp @@ -186,8 +186,13 @@ template Polynomial& Polynomial::operator+=(PolynomialSpan template Fr Polynomial::evaluate(const Fr& z) const { - BB_ASSERT(size() == virtual_size()); - return polynomial_arithmetic::evaluate(data(), z, size()); + // Evaluate only the backing data; virtual zeroes beyond backing contribute nothing. + // When start_index > 0, multiply by z^start_index to account for the offset. + Fr result = polynomial_arithmetic::evaluate(data(), z, size()); + if (start_index() > 0) { + result *= z.pow(start_index()); + } + return result; } template Fr Polynomial::evaluate_mle(std::span evaluation_points, bool shift) const diff --git a/barretenberg/cpp/src/barretenberg/ultra_honk/oink_prover.cpp b/barretenberg/cpp/src/barretenberg/ultra_honk/oink_prover.cpp index 65608330b60b..df06f1b624bc 100644 --- a/barretenberg/cpp/src/barretenberg/ultra_honk/oink_prover.cpp +++ b/barretenberg/cpp/src/barretenberg/ultra_honk/oink_prover.cpp @@ -22,7 +22,7 @@ namespace bb { template void OinkProver::prove() { BB_BENCH_NAME("OinkProver::prove"); - commitment_key = CommitmentKey(prover_instance->dyadic_size()); + commitment_key = CommitmentKey(prover_instance->polynomials.max_end_index()); send_vk_hash_and_public_inputs(); commit_to_masking_poly(); commit_to_wires(); diff --git a/barretenberg/cpp/src/barretenberg/ultra_honk/ultra_prover.cpp b/barretenberg/cpp/src/barretenberg/ultra_honk/ultra_prover.cpp index d938df981329..567862e08c60 100644 --- a/barretenberg/cpp/src/barretenberg/ultra_honk/ultra_prover.cpp +++ b/barretenberg/cpp/src/barretenberg/ultra_honk/ultra_prover.cpp @@ -60,8 +60,15 @@ template void UltraProver_::generate_gate_challenges() template typename UltraProver_::Proof UltraProver_::construct_proof() { - size_t key_size = prover_instance->dyadic_size(); + // The CRS only needs to accommodate the actual data extent (max_end_index) rather than the + // full dyadic_size. All committed polynomials fit within this bound: witness/selector polys + // have backing ≤ max_end_index, Gemini fold polys have size ≤ dyadic_size/2 < max_end_index, + // Shplonk quotient Q is sized at max(claim sizes), and KZG opening proof is sized at Q.size(). + // For ZK, the gemini_masking_poly (at dyadic_size) is already reflected in max_end_index. + size_t key_size = prover_instance->polynomials.max_end_index(); if constexpr (Flavor::HasZK) { + // SmallSubgroupIPA commits fixed-size polynomials (up to SUBGROUP_SIZE + 3). Ensure the + // CRS is large enough for tiny test circuits where max_end_index may be smaller. constexpr size_t log_subgroup_size = static_cast(numeric::get_msb(Curve::SUBGROUP_SIZE)); key_size = std::max(key_size, size_t{ 1 } << (log_subgroup_size + 1)); } @@ -120,7 +127,7 @@ template void UltraProver_::execute_pcs() auto& ck = commitment_key; - PolynomialBatcher polynomial_batcher(prover_instance->dyadic_size()); + PolynomialBatcher polynomial_batcher(prover_instance->dyadic_size(), prover_instance->polynomials.max_end_index()); polynomial_batcher.set_unshifted(prover_instance->polynomials.get_unshifted()); polynomial_batcher.set_to_be_shifted_by_one(prover_instance->polynomials.get_to_be_shifted()); diff --git a/barretenberg/ts/src/index.ts b/barretenberg/ts/src/index.ts index 247c5825ebf0..5fc183d8786b 100644 --- a/barretenberg/ts/src/index.ts +++ b/barretenberg/ts/src/index.ts @@ -21,12 +21,15 @@ export { BBApiException } from './bbapi_exception.js'; export type { Bn254G1Point, Bn254G2Point, + ChonkProof, GrumpkinPoint, Secp256k1Point, Secp256r1Point, Field2, } from './cbind/generated/api_types.js'; +export { toChonkProof } from './cbind/generated/api_types.js'; + // Export curve constants for use in foundation export { BN254_FQ_MODULUS, diff --git a/yarn-project/ivc-integration/src/chonk_integration.test.ts b/yarn-project/ivc-integration/src/chonk_integration.test.ts index 6e47c3ac5c86..b089de47d01f 100644 --- a/yarn-project/ivc-integration/src/chonk_integration.test.ts +++ b/yarn-project/ivc-integration/src/chonk_integration.test.ts @@ -1,7 +1,8 @@ -import { AztecClientBackend, BackendType, Barretenberg } from '@aztec/bb.js'; +import { AztecClientBackend, BackendType, Barretenberg, toChonkProof } from '@aztec/bb.js'; import { createLogger } from '@aztec/foundation/log'; import { jest } from '@jest/globals'; +import { Decoder } from 'msgpackr'; import { ungzip } from 'pako'; import { @@ -61,6 +62,29 @@ describe.each([BackendType.Wasm, BackendType.NativeUnixSocket])('Client IVC Inte expect(verified).toBe(true); }); + it('Should compress and decompress a client IVC proof via bbapi', async () => { + const [bytecodes, witnessStack, , vks] = await generateTestingIVCStack(1, 0); + const ivcBackend = new AztecClientBackend(bytecodes, barretenberg); + const [, proof, vk] = await ivcBackend.prove(witnessStack, vks); + + // Decode the msgpack-encoded proof back to a ChonkProof object + const chonkProof = toChonkProof(new Decoder({ useRecords: false }).decode(proof)); + + // Compress via bbapi + const compressResult = await barretenberg.chonkCompressProof({ proof: chonkProof }); + expect(compressResult.compressedProof.length).toBeGreaterThan(0); + logger.info(`Compressed proof: ${compressResult.compressedProof.length} bytes`); + + // Decompress via bbapi + const decompressResult = await barretenberg.chonkDecompressProof({ + compressedProof: compressResult.compressedProof, + }); + + // Verify the decompressed proof matches the original + const verified = await barretenberg.chonkVerify({ proof: decompressResult.proof, vk }); + expect(verified.valid).toBe(true); + }); + it('Should generate an array of gate numbers for the stack of programs being proved by ClientIVC', async () => { // Create ACIR bytecodes const bytecodes = [