diff --git a/barretenberg/cpp/scripts/bench_cpu_scaling_remote.sh b/barretenberg/cpp/scripts/bench_cpu_scaling_remote.sh new file mode 100755 index 000000000000..d3dacce9ff8a --- /dev/null +++ b/barretenberg/cpp/scripts/bench_cpu_scaling_remote.sh @@ -0,0 +1,293 @@ +#!/bin/bash + +# CPU scaling benchmark wrapper that uses benchmark_remote.sh properly +# This script runs a command multiple times with different HARDWARE_CONCURRENCY values +# and tracks the scaling performance of specific BB_BENCH entries +# Uses --bench_out flag to get JSON output for accurate timing extraction + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +CYAN='\033[0;36m' +MAGENTA='\033[0;35m' +NC='\033[0m' # No Color + +# Parse arguments +if [ $# -lt 2 ]; then + echo -e "${RED}Usage: $0 \"benchmark_name\" \"command\" [cpu_counts]${NC}" + echo -e "Example: $0 \"ClientIvcProve\" \"./build/bin/bb prove --ivc_inputs_path input.msgpack --scheme client_ivc\"" + echo -e "Example: $0 \"construct_mock_function_circuit\" \"./build/bin/ultra_honk_bench --benchmark_filter=.*power_of_2.*/15\" \"1,2,4,8\"" + exit 1 +fi + +BENCH_NAME="$1" +COMMAND="$2" +CPU_LIST="${3:-1,2,4,8,16}" + +# Convert comma-separated list to array +IFS=',' read -ra CPU_COUNTS <<< "$CPU_LIST" + +# Check if required environment variables are set for remote execution +if [ -z "${BB_SSH_KEY:-}" ] || [ -z "${BB_SSH_INSTANCE:-}" ] || [ -z "${BB_SSH_CPP_PATH:-}" ]; then + echo -e "${RED}Error: Remote execution requires BB_SSH_KEY, BB_SSH_INSTANCE, and BB_SSH_CPP_PATH environment variables${NC}" + echo "Please set:" + echo " export BB_SSH_KEY='-i /path/to/key.pem'" + echo " export BB_SSH_INSTANCE='user@ec2-instance.amazonaws.com'" + echo " export BB_SSH_CPP_PATH='/path/to/barretenberg/cpp'" + exit 1 +fi + +# Create output directory with timestamp +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUTPUT_DIR="bench_scaling_remote_${TIMESTAMP}" +mkdir -p "$OUTPUT_DIR" + +# Results file +RESULTS_FILE="$OUTPUT_DIR/scaling_results.txt" +CSV_FILE="$OUTPUT_DIR/scaling_results.csv" + +echo -e "${GREEN}╔════════════════════════════════════════════════════════════════╗${NC}" +echo -e "${GREEN}║ CPU Scaling Benchmark (Remote Execution) ║${NC}" +echo -e "${GREEN}╚════════════════════════════════════════════════════════════════╝${NC}" +echo "" +echo -e "${CYAN}Benchmark Entry:${NC} ${YELLOW}$BENCH_NAME${NC}" +echo -e "${CYAN}Command:${NC} $COMMAND" +echo -e "${CYAN}CPU Counts:${NC} ${CPU_COUNTS[@]}" +echo -e "${CYAN}Remote Host:${NC} ${BB_SSH_INSTANCE}" +echo -e "${CYAN}Remote Path:${NC} ${BB_SSH_CPP_PATH}" +echo -e "${CYAN}Output Directory:${NC} $OUTPUT_DIR" +echo "" + +# Initialize results file +echo "CPU Scaling Benchmark: $BENCH_NAME" > "$RESULTS_FILE" +echo "Command: $COMMAND" >> "$RESULTS_FILE" +echo "Remote Host: $BB_SSH_INSTANCE" >> "$RESULTS_FILE" +echo "Date: $(date)" >> "$RESULTS_FILE" +echo "================================================" >> "$RESULTS_FILE" +echo "" >> "$RESULTS_FILE" + +# Initialize CSV file +echo "CPUs,Time_ms,Time_s,Speedup,Efficiency" > "$CSV_FILE" + +# Function to extract time for specific benchmark entry from JSON +extract_bench_time() { + local json_file=$1 + local bench_name=$2 + + # Extract time from JSON file using grep and sed + # JSON format is: {"benchmark_name": time_in_nanoseconds, ...} + local time_ns="" + + if [ -f "$json_file" ]; then + # Extract the value for the specific benchmark name from JSON + time_ns=$(grep -oP "\"${bench_name//\\/\\\\}\":\s*\K\d+" "$json_file" 2>/dev/null | head -1) + fi + + # If JSON extraction failed, try to extract from log file (fallback) + if [ -z "$time_ns" ] && [ -f "${json_file%/bench.json}/output.log" ]; then + local log_file="${json_file%/bench.json}/output.log" + # Try to extract from hierarchical BB_BENCH output + # Look for pattern like: " ├─ ClientIvcProve ... 28.13s" + local time_s=$(grep -E "├─.*${bench_name}" "$log_file" | grep -oP '\d+\.\d+s' | grep -oP '\d+\.\d+' | head -1) + if [ -n "$time_s" ]; then + # Convert seconds to nanoseconds + time_ns=$(awk -v s="$time_s" 'BEGIN{printf "%.0f", s * 1000000000}') + fi + fi + + echo "$time_ns" +} + +# Store baseline time for speedup calculation +BASELINE_TIME="" + +# Arrays to store results +declare -a ALL_CPUS=() +declare -a ALL_TIMES=() +declare -a ALL_SPEEDUPS=() + +echo -e "${BLUE}Starting benchmark runs on remote machine...${NC}" +echo "" + +# Run benchmark for each CPU count +for cpu_count in "${CPU_COUNTS[@]}"; do + echo -e "${YELLOW}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + echo -e "${YELLOW}Running with ${cpu_count} CPU(s)...${NC}" + + # Create output subdirectory + run_dir="$OUTPUT_DIR/run_${cpu_count}cpus" + mkdir -p "$run_dir" + log_file="$run_dir/output.log" + + # Run command on remote machine with specified CPU count + echo -e "${CYAN}Executing on remote via benchmark_remote.sh...${NC}" + start_time=$(date +%s.%N) + + # Use benchmark_remote.sh to execute on remote with --bench_out for JSON output + # The benchmark_remote.sh script handles locking and setup + # Use tee to show output in real-time AND save to log file + bench_json_file="$run_dir/bench.json" + ./scripts/benchmark_remote.sh bb "HARDWARE_CONCURRENCY=$cpu_count $COMMAND --bench_out /tmp/bench_${cpu_count}.json" 2>&1 | tee "$log_file" + + # Retrieve the JSON file from remote + ssh $BB_SSH_KEY $BB_SSH_INSTANCE "cat /tmp/bench_${cpu_count}.json" > "$bench_json_file" 2>/dev/null + + end_time=$(date +%s.%N) + wall_time=$(awk -v e="$end_time" -v s="$start_time" 'BEGIN{printf "%.2f", e-s}') + + # Extract the specific benchmark time from JSON file + bench_time_ns=$(extract_bench_time "$bench_json_file" "$BENCH_NAME") + + if [ -z "$bench_time_ns" ] || [ "$bench_time_ns" = "0" ]; then + echo -e "${RED}Warning: Could not extract timing for '$BENCH_NAME' from JSON${NC}" + echo -e "${YELLOW}Check the JSON file: $bench_json_file${NC}" + + # Show what's in the JSON file for debugging + if [ -f "$bench_json_file" ]; then + echo -e "${YELLOW}JSON content (first 500 chars):${NC}" + head -c 500 "$bench_json_file" + echo "" + fi + + echo "CPUs: $cpu_count - No timing data found" >> "$RESULTS_FILE" + continue + fi + + # Convert to milliseconds and seconds + bench_time_ms=$(awk -v ns="$bench_time_ns" 'BEGIN{printf "%.2f", ns / 1000000}') + bench_time_s=$(awk -v ns="$bench_time_ns" 'BEGIN{printf "%.3f", ns / 1000000000}') + + # Calculate speedup and efficiency + if [ -z "$BASELINE_TIME" ]; then + BASELINE_TIME="$bench_time_ns" + speedup="1.00" + efficiency="100.0" + else + speedup=$(awk -v base="$BASELINE_TIME" -v curr="$bench_time_ns" 'BEGIN{printf "%.2f", base / curr}') + efficiency=$(awk -v sp="$speedup" -v cpus="$cpu_count" 'BEGIN{printf "%.1f", (sp / cpus) * 100}') + fi + + # Store results + ALL_CPUS+=("$cpu_count") + ALL_TIMES+=("$bench_time_ms") + ALL_SPEEDUPS+=("$speedup") + + # Write to results file + echo "CPUs: $cpu_count" >> "$RESULTS_FILE" + echo " Time: ${bench_time_ms} ms (${bench_time_s} s)" >> "$RESULTS_FILE" + echo " Speedup: ${speedup}x" >> "$RESULTS_FILE" + echo " Efficiency: ${efficiency}%" >> "$RESULTS_FILE" + echo " Wall time: ${wall_time}s" >> "$RESULTS_FILE" + echo "" >> "$RESULTS_FILE" + + # Write to CSV + echo "$cpu_count,$bench_time_ms,$bench_time_s,$speedup,$efficiency" >> "$CSV_FILE" + + # Display results + echo -e "${GREEN}✓ Completed${NC}" + echo -e " ${CYAN}Time for '$BENCH_NAME':${NC} ${bench_time_ms} ms" + echo -e " ${CYAN}Speedup:${NC} ${speedup}x" + echo -e " ${CYAN}Efficiency:${NC} ${efficiency}%" + echo "" +done + +# Generate summary +echo -e "${GREEN}╔════════════════════════════════════════════════════════════════╗${NC}" +echo -e "${GREEN}║ SUMMARY ║${NC}" +echo -e "${GREEN}╚════════════════════════════════════════════════════════════════╝${NC}" +echo "" + +# Print table header +printf "${CYAN}%-8s %-15s %-12s %-12s${NC}\n" "CPUs" "Time (ms)" "Speedup" "Efficiency" +printf "${CYAN}%-8s %-15s %-12s %-12s${NC}\n" "────" "──────────" "───────" "──────────" + +# Print results table +for i in "${!ALL_CPUS[@]}"; do + cpu="${ALL_CPUS[$i]}" + time="${ALL_TIMES[$i]}" + speedup="${ALL_SPEEDUPS[$i]}" + + if [ "$i" -eq 0 ]; then + efficiency="100.0%" + else + efficiency=$(awk -v sp="$speedup" -v cpus="$cpu" 'BEGIN{printf "%.1f%%", (sp / cpus) * 100}') + fi + + # Color code based on efficiency + if [ "$i" -eq 0 ]; then + color="${GREEN}" + else + eff_val=$(echo "$efficiency" | sed 's/%//') + if (( $(echo "$eff_val > 75" | bc -l) )); then + color="${GREEN}" + elif (( $(echo "$eff_val > 50" | bc -l) )); then + color="${YELLOW}" + else + color="${RED}" + fi + fi + + printf "${color}%-8s %-15s %-12s %-12s${NC}\n" "$cpu" "$time" "${speedup}x" "$efficiency" +done + +echo "" +echo -e "${MAGENTA}═══════════════════════════════════════════════════════════════${NC}" +echo "" + +# Generate scaling plot (ASCII art) +echo -e "${CYAN}Scaling Visualization:${NC}" +echo "" + +if [ "${#ALL_TIMES[@]}" -gt 0 ]; then + # Find max time for scaling + max_time=$(printf '%s\n' "${ALL_TIMES[@]}" | sort -rn | head -1) + + # Create ASCII bar chart + for i in "${!ALL_CPUS[@]}"; do + cpu="${ALL_CPUS[$i]}" + time="${ALL_TIMES[$i]}" + + # Calculate bar length (max 50 chars) + bar_len=$(awk -v t="$time" -v m="$max_time" 'BEGIN{printf "%.0f", (t/m) * 50}') + + # Create bar + bar="" + for ((j=0; j ivc = steps.accumulate(); // Construct the hiding kernel as the final step of the IVC - const bool verified = ivc->prove_and_verify(); + auto proof = ivc->prove(); + const bool verified = ClientIVC::verify(proof, ivc->get_vk()); return verified; } diff --git a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_client_ivc.cpp b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_client_ivc.cpp index a511fc2c4b71..d7e9fc73458e 100644 --- a/barretenberg/cpp/src/barretenberg/bbapi/bbapi_client_ivc.cpp +++ b/barretenberg/cpp/src/barretenberg/bbapi/bbapi_client_ivc.cpp @@ -86,7 +86,8 @@ ClientIvcProve::Response ClientIvcProve::execute(BBApiRequest& request) && // We verify this proof. Another bb call to verify has some overhead of loading VK/proof/SRS, // and it is mysterious if this transaction fails later in the lifecycle. info("ClientIvcProve - verifying the generated proof as a sanity check"); - if (!request.ivc_in_progress->verify(proof)) { + ClientIVC::VerificationKey vk = request.ivc_in_progress->get_vk(); + if (!ClientIVC::verify(proof, vk)) { throw_or_abort("Failed to verify the generated proof!"); } diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.cpp b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.cpp index 97c739ffcb4a..324b8b705bb8 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.cpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.cpp @@ -172,8 +172,6 @@ ClientIVC::perform_recursive_verification_and_databus_consistency_checks( } case QUEUE_TYPE::PG: case QUEUE_TYPE::PG_TAIL: { - BB_ASSERT_NEQ(input_verifier_accumulator, nullptr); - output_verifier_accumulator = perform_pg_recursive_verification(circuit, input_verifier_accumulator, verifier_instance, @@ -184,18 +182,10 @@ ClientIVC::perform_recursive_verification_and_databus_consistency_checks( break; } case QUEUE_TYPE::PG_FINAL: { - BB_ASSERT_NEQ(input_verifier_accumulator, nullptr); BB_ASSERT_EQ(stdlib_verification_queue.size(), size_t(1)); hide_op_queue_accumulation_result(circuit); - // Propagate the public inputs of the tail kernel by converting them to public inputs of the hiding circuit. - auto num_public_inputs = static_cast(honk_vk->num_public_inputs); - num_public_inputs -= KernelIO::PUBLIC_INPUTS_SIZE; // exclude fixed kernel_io public inputs - for (size_t i = 0; i < num_public_inputs; i++) { - verifier_inputs.proof[i].set_public(); - } - auto final_verifier_accumulator = perform_pg_recursive_verification(circuit, input_verifier_accumulator, verifier_instance, @@ -245,7 +235,7 @@ ClientIVC::perform_recursive_verification_and_databus_consistency_checks( kernel_input.output_pg_accum_hash.assert_equal(*prev_accum_hash); if (!is_hiding_kernel) { - // The hiding kernel has no return data but uses the traditional public-inputs mechanism + // The hiding kernel has no return data; it uses the traditional public-inputs mechanism bus_depot.set_kernel_return_data_commitment(witness_commitments.return_data); } } else { @@ -384,7 +374,7 @@ HonkProof ClientIVC::construct_oink_proof(const std::shared_ptrgate_challenges = prover_accumulation_transcript->template get_powers_of_challenge("gate_challenge", CONST_PG_LOG_N); - fold_output.accumulator = proving_key; // initialize the prover accum with the completed key + prover_accumulator = proving_key; // initialize the prover accum with the completed key HonkProof oink_proof = oink_prover.export_proof(); vinfo("oink proof constructed"); @@ -406,13 +396,14 @@ HonkProof ClientIVC::construct_pg_proof(const std::shared_ptr info("Accumulator hash in PG prover: ", accum_hash); } auto verifier_instance = std::make_shared>(honk_vk); - FoldingProver folding_prover({ fold_output.accumulator, proving_key }, + FoldingProver folding_prover({ prover_accumulator, proving_key }, { native_verifier_accum, verifier_instance }, transcript, trace_usage_tracker); - fold_output = folding_prover.prove(); + auto output = folding_prover.prove(); + prover_accumulator = output.accumulator; // update the prover accumulator vinfo("pg proof constructed"); - return fold_output.proof; + return output.proof; } /** @@ -472,8 +463,6 @@ void ClientIVC::accumulate(ClientCircuit& circuit, const std::shared_ptrcommitment_key = bn254_commitment_key; trace_usage_tracker.update(circuit); - honk_vk = precomputed_vk; - // We're accumulating a kernel if the verification queue is empty (because the kernel circuit contains recursive // verifiers for all the entries previously present in the verification queue) and if it's not the first accumulate // call (which will always be for an app circuit). @@ -495,22 +484,22 @@ void ClientIVC::accumulate(ClientCircuit& circuit, const std::shared_ptr& verification_key) { // Note: a structured trace is not used for the hiding kernel auto hiding_decider_pk = std::make_shared(circuit, TraceSettings(), bn254_commitment_key); - honk_vk = std::make_shared(hiding_decider_pk->get_precomputed()); - auto& hiding_circuit_vk = honk_vk; + // Hiding circuit is proven by a MegaZKProver - MegaZKProver prover(hiding_decider_pk, hiding_circuit_vk, transcript); + MegaZKProver prover(hiding_decider_pk, verification_key, transcript); HonkProof proof = prover.construct_proof(); return proof; @@ -633,7 +622,7 @@ HonkProof ClientIVC::construct_mega_proof_for_hiding_kernel(ClientCircuit& circu ClientIVC::Proof ClientIVC::prove() { // deallocate the protogalaxy accumulator - fold_output.accumulator = nullptr; + prover_accumulator = nullptr; auto mega_proof = verification_queue.front().proof; // A transcript is shared between the Hiding circuit prover and the Goblin prover @@ -669,17 +658,6 @@ bool ClientIVC::verify(const Proof& proof, const VerificationKey& vk) return goblin_verified && mega_verified; } -/** - * @brief Verify a full proof of the IVC - * - * @param proof - * @return bool - */ -bool ClientIVC::verify(const Proof& proof) const -{ - return verify(proof, get_vk()); -} - /** * @brief Internal method for constructing a decider proof * @@ -688,36 +666,12 @@ bool ClientIVC::verify(const Proof& proof) const HonkProof ClientIVC::construct_decider_proof(const std::shared_ptr& transcript) { vinfo("prove decider..."); - fold_output.accumulator->commitment_key = bn254_commitment_key; - MegaDeciderProver decider_prover(fold_output.accumulator, transcript); + prover_accumulator->commitment_key = bn254_commitment_key; + MegaDeciderProver decider_prover(prover_accumulator, transcript); decider_prover.construct_proof(); return decider_prover.export_proof(); } -/** - * @brief Construct and verify a proof for the IVC - * @note Use of this method only makes sense when the prover and verifier are the same entity, e.g. in - * development/testing. - * - */ -bool ClientIVC::prove_and_verify() -{ - auto start = std::chrono::steady_clock::now(); - const auto proof = prove(); - auto end = std::chrono::steady_clock::now(); - auto diff = std::chrono::duration_cast(end - start); - vinfo("time to call ClientIVC::prove: ", diff.count(), " ms."); - - start = end; - const bool verified = verify(proof); - end = std::chrono::steady_clock::now(); - - diff = std::chrono::duration_cast(end - start); - vinfo("time to verify ClientIVC proof: ", diff.count(), " ms."); - - return verified; -} - // Proof methods size_t ClientIVC::Proof::size() const { @@ -842,7 +796,12 @@ ClientIVC::Proof ClientIVC::Proof::from_file_msgpack(const std::string& filename // VerificationKey construction ClientIVC::VerificationKey ClientIVC::get_vk() const { - return { honk_vk, std::make_shared(), std::make_shared() }; + BB_ASSERT_EQ(verification_queue.size(), 1UL); + BB_ASSERT_EQ(verification_queue.front().type == QUEUE_TYPE::MEGA, true); + auto verification_key = verification_queue.front().honk_vk; + return { verification_key, + std::make_shared(), + std::make_shared() }; } void ClientIVC::update_native_verifier_accumulator(const VerifierInputs& queue_entry, diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.hpp b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.hpp index 5b61c7f278a1..de5d3c202ea3 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.hpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.hpp @@ -257,8 +257,6 @@ class ClientIVC { ExecutionTraceUsageTracker trace_usage_tracker; private: - using ProverFoldOutput = FoldingResult; - // Transcript for CIVC prover (shared between Hiding circuit, Merge, ECCVM, and Translator) std::shared_ptr transcript = std::make_shared(); @@ -269,14 +267,13 @@ class ClientIVC { public: size_t num_circuits_accumulated = 0; // number of circuits accumulated so far - ProverFoldOutput fold_output; // prover accumulator and fold proof - HonkProof decider_proof; // decider proof to be verified in the hiding circuit + std::shared_ptr prover_accumulator; // current PG prover accumulator instance + HonkProof decider_proof; // decider proof to be verified in the hiding circuit std::shared_ptr recursive_verifier_native_accum; // native verifier accumulator used in recursive folding std::shared_ptr - native_verifier_accum; // native verifier accumulator used in prover folding - std::shared_ptr honk_vk; // honk vk to be completed and folded into the accumulator + native_verifier_accum; // native verifier accumulator used in prover folding // Set of tuples {proof, verification_key, type (Oink/PG)} to be recursively verified VerificationQueue verification_queue; @@ -327,14 +324,9 @@ class ClientIVC { static void hide_op_queue_accumulation_result(ClientCircuit& circuit); static void hide_op_queue_content_in_tail(ClientCircuit& circuit); static void hide_op_queue_content_in_hiding(ClientCircuit& circuit); - HonkProof construct_mega_proof_for_hiding_kernel(ClientCircuit& circuit); static bool verify(const Proof& proof, const VerificationKey& vk); - bool verify(const Proof& proof) const; - - bool prove_and_verify(); - HonkProof construct_decider_proof(const std::shared_ptr& transcript); VerificationKey get_vk() const; @@ -358,6 +350,9 @@ class ClientIVC { const std::shared_ptr& transcript, bool is_kernel); + HonkProof construct_honk_proof_for_hiding_kernel(ClientCircuit& circuit, + const std::shared_ptr& verification_key); + QUEUE_TYPE get_queue_type() const; static std::shared_ptr perform_oink_recursive_verification( diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp index eadf345d8aef..642b0f7d2412 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/client_ivc.test.cpp @@ -67,10 +67,6 @@ class ClientIVCTests : public ::testing::Test { ClientIVC ivc{ num_circuits, trace_settings }; for (size_t j = 0; j < num_circuits; ++j) { - // Use default test settings for the mock hiding kernel since it's size must always be consistent - if (j == num_circuits - 1) { - settings = TestSettings{}; - } circuit_producer.construct_and_accumulate_next_circuit(ivc, settings); } return { ivc.prove(), ivc.get_vk() }; @@ -112,7 +108,8 @@ TEST_F(ClientIVCTests, BadProofFailure) for (size_t idx = 0; idx < NUM_CIRCUITS; ++idx) { circuit_producer.construct_and_accumulate_next_circuit(ivc, settings); } - EXPECT_TRUE(ivc.prove_and_verify()); + auto proof = ivc.prove(); + EXPECT_TRUE(ClientIVC::verify(proof, ivc.get_vk())); } // The IVC throws an exception if the FIRST fold proof is tampered with @@ -139,7 +136,8 @@ TEST_F(ClientIVCTests, BadProofFailure) num_public_inputs); // tamper with first proof } } - EXPECT_FALSE(ivc.prove_and_verify()); + auto proof = ivc.prove(); + EXPECT_FALSE(ClientIVC::verify(proof, ivc.get_vk())); } // The IVC fails if the SECOND fold proof is tampered with @@ -160,7 +158,8 @@ TEST_F(ClientIVCTests, BadProofFailure) circuit.num_public_inputs()); // tamper with second proof } } - EXPECT_FALSE(ivc.prove_and_verify()); + auto proof = ivc.prove(); + EXPECT_FALSE(ClientIVC::verify(proof, ivc.get_vk())); } EXPECT_TRUE(true); @@ -313,7 +312,8 @@ TEST_F(ClientIVCTests, StructuredTraceOverflow) log2_num_gates += 1; } - EXPECT_TRUE(ivc.prove_and_verify()); + auto proof = ivc.prove(); + EXPECT_TRUE(ClientIVC::verify(proof, ivc.get_vk())); }; /** @@ -348,8 +348,9 @@ TEST_F(ClientIVCTests, DynamicTraceOverflow) ivc, { .log2_num_gates = test.log2_num_arith_gates[idx] }); } - EXPECT_EQ(check_accumulator_target_sum_manual(ivc.fold_output.accumulator), true); - EXPECT_TRUE(ivc.prove_and_verify()); + EXPECT_EQ(check_accumulator_target_sum_manual(ivc.prover_accumulator), true); + auto proof = ivc.prove(); + EXPECT_TRUE(ClientIVC::verify(proof, ivc.get_vk())); } } @@ -421,5 +422,6 @@ TEST_F(ClientIVCTests, DatabusFailure) ivc.accumulate(circuit, vk); } - EXPECT_FALSE(ivc.prove_and_verify()); + auto proof = ivc.prove(); + EXPECT_FALSE(ClientIVC::verify(proof, ivc.get_vk())); }; diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/mock_circuit_producer.hpp b/barretenberg/cpp/src/barretenberg/client_ivc/mock_circuit_producer.hpp index 30f9d981d423..44478f28b85d 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/mock_circuit_producer.hpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/mock_circuit_producer.hpp @@ -202,6 +202,11 @@ class PrivateFunctionExecutionMockCircuitProducer { std::pair> create_next_circuit_and_vk(ClientIVC& ivc, TestSettings settings = {}) { + // If this is a mock hiding kernel, remove the settings and use a default (non-structured) trace + if (ivc.num_circuits_accumulated == ivc.get_num_circuits() - 1) { + settings = TestSettings{}; + ivc.trace_settings = TraceSettings{}; + } auto circuit = create_next_circuit(ivc, settings.log2_num_gates, settings.num_public_inputs); return { circuit, get_verification_key(circuit, ivc.trace_settings) }; } diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/mock_kernel_pinning.test.cpp b/barretenberg/cpp/src/barretenberg/client_ivc/mock_kernel_pinning.test.cpp index b4bacd16136e..cb6c7ecc3ae7 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/mock_kernel_pinning.test.cpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/mock_kernel_pinning.test.cpp @@ -32,8 +32,12 @@ TEST_F(MockKernelTest, PinFoldingKernelSizes) auto [circuit, vk] = circuit_producer.create_next_circuit_and_vk(ivc); ivc.accumulate(circuit, vk); - EXPECT_TRUE(circuit.blocks.has_overflow); // trace overflow mechanism should be triggered + // Expect trace overflow for all but the hiding kernel (final circuit) + if (idx < NUM_CIRCUITS - 1) { + EXPECT_TRUE(circuit.blocks.has_overflow); + EXPECT_EQ(ivc.prover_accumulator->log_dyadic_size(), 19); + } else { + EXPECT_FALSE(circuit.blocks.has_overflow); + } } - - EXPECT_EQ(ivc.fold_output.accumulator->log_dyadic_size(), 19); } diff --git a/barretenberg/cpp/src/barretenberg/client_ivc/test_bench_shared.hpp b/barretenberg/cpp/src/barretenberg/client_ivc/test_bench_shared.hpp index a034bcf98cd0..6584fcc3f0d0 100644 --- a/barretenberg/cpp/src/barretenberg/client_ivc/test_bench_shared.hpp +++ b/barretenberg/cpp/src/barretenberg/client_ivc/test_bench_shared.hpp @@ -13,23 +13,6 @@ namespace bb { -/** - * @brief Verify an IVC proof - * - */ -bool verify_ivc(ClientIVC::Proof& proof, ClientIVC& ivc) -{ - bool verified = ivc.verify(proof); - - // This is a benchmark, not a test, so just print success or failure to the log - if (verified) { - info("IVC successfully verified!"); - } else { - info("IVC failed to verify."); - } - return verified; -} - /** * @brief Perform a specified number of circuit accumulation rounds * diff --git a/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp b/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp index 79ef78974b9c..15f02989e9af 100644 --- a/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp +++ b/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp @@ -37,7 +37,7 @@ std::string format_time(double time_ms) std::ostringstream oss; if (time_ms >= 1000.0) { oss << std::fixed << std::setprecision(2) << (time_ms / 1000.0) << " s"; - } else if (time_ms >= 1.0) { + } else if (time_ms >= 1.0 && time_ms < 1000.0) { oss << std::fixed << std::setprecision(2) << time_ms << " ms"; } else { oss << std::fixed << std::setprecision(1) << (time_ms * 1000.0) << " μs"; @@ -50,42 +50,75 @@ std::string format_time_aligned(double time_ms) { std::ostringstream oss; if (time_ms >= 1000.0) { - std::string time_str = - (std::ostringstream{} << std::fixed << std::setprecision(2) << (time_ms / 1000.0) << "s").str(); - oss << std::left << std::setw(10) << time_str; + std::ostringstream time_oss; + time_oss << std::fixed << std::setprecision(2) << (time_ms / 1000.0) << "s"; + oss << std::left << std::setw(10) << time_oss.str(); } else { - std::string time_str = (std::ostringstream{} << std::fixed << std::setprecision(1) << time_ms << "ms").str(); - oss << std::left << std::setw(10) << time_str; + std::ostringstream time_oss; + time_oss << std::fixed << std::setprecision(1) << time_ms << "ms"; + oss << std::left << std::setw(10) << time_oss.str(); } return oss.str(); } +// Helper to format percentage value +std::string format_percentage_value(double percentage, const char* color) +{ + std::ostringstream oss; + oss << color << " " << std::left << std::fixed << std::setprecision(1) << std::setw(5) << percentage << "%" + << Colors::RESET; + return oss.str(); +} + // Helper to format percentage with color based on percentage value std::string format_percentage(double value, double total, double min_threshold = 0.0) { - if (total <= 0) { - return " "; - } - double percentage = (value / total) * 100.0; - if (percentage < min_threshold) { + double percentage = (total <= 0) ? 0.0 : (value / total) * 100.0; + if (total <= 0 || percentage < min_threshold) { return " "; } // Choose color based on percentage value (like time colors) const char* color = Colors::CYAN; // Default color + return format_percentage_value(percentage, color); +} + +// Helper to format percentage section +std::string format_percentage_section(double time_ms, double parent_time, size_t indent_level) +{ + if (parent_time > 0 && indent_level > 0) { + return format_percentage(time_ms * 1000000.0, parent_time); + } + return " "; +} + +// Helper to format time section +std::string format_time_section(double time_ms) +{ std::ostringstream oss; - oss << color << " " << std::left << std::fixed << std::setprecision(1) << std::setw(5) << percentage << "%" - << Colors::RESET; + oss << " "; + if (time_ms >= 100.0 && time_ms < 1000.0) { + oss << Colors::DIM << format_time_aligned(time_ms) << Colors::RESET; + } else { + oss << format_time_aligned(time_ms); + } + return oss.str(); +} + +// Helper to format call stats +std::string format_call_stats(double time_ms, uint64_t count) +{ + if (!(time_ms >= 100.0 && count > 1)) { + return ""; + } + double avg_ms = time_ms / static_cast(count); + std::ostringstream oss; + oss << Colors::DIM << " (" << format_time(avg_ms) << " x " << count << ")" << Colors::RESET; return oss.str(); } -std::string format_aligned_section(double time_ms, - double parent_time, - uint64_t count, - size_t indent_level, - size_t num_threads = 1, - double mean_ms = 0.0) +std::string format_aligned_section(double time_ms, double parent_time, uint64_t count, size_t indent_level) { std::ostringstream oss; @@ -93,29 +126,13 @@ std::string format_aligned_section(double time_ms, oss << Colors::MAGENTA << "[" << indent_level << "] " << Colors::RESET; // Format percentage FIRST - if (parent_time > 0 && indent_level > 0) { - oss << format_percentage(time_ms * 1000000.0, parent_time); - } else { - oss << " "; // Keep alignment for root entries - } + oss << format_percentage_section(time_ms, parent_time, indent_level); // Format time AFTER percentage with appropriate color (with more spacing) - if (time_ms >= 100.0 && time_ms < 1000.0) { - oss << " " << Colors::DIM << format_time_aligned(time_ms) << Colors::RESET; - } else { - oss << " " << format_time_aligned(time_ms); - } + oss << format_time_section(time_ms); // Format calls/threads info - only show for >= 100ms, make it DIM - if (time_ms >= 100.0) { - if (num_threads > 1) { - oss << Colors::DIM << " (" << std::fixed << std::setprecision(2) << mean_ms << " ms x " << num_threads - << ")" << Colors::RESET; - } else if (count > 1) { - double avg_ms = time_ms / static_cast(count); - oss << Colors::DIM << " (" << format_time(avg_ms) << " x " << count << ")" << Colors::RESET; - } - } + oss << format_call_stats(time_ms, count); return oss.str(); } @@ -129,12 +146,12 @@ struct TimeColor { TimeColor get_time_colors(double time_ms) { if (time_ms >= 1000.0) { - return { Colors::BOLD, Colors::WHITE }; // Bold white for >= 1 second + return { Colors::BOLD, Colors::WHITE }; } if (time_ms >= 100.0) { - return { Colors::YELLOW, Colors::YELLOW }; // Yellow for >= 100ms + return { Colors::YELLOW, Colors::YELLOW }; } - return { Colors::DIM, Colors::DIM }; // Dim for < 100ms + return { Colors::DIM, Colors::DIM }; } // Print separator line @@ -164,12 +181,12 @@ void AggregateEntry::add_thread_time_sample(const TimeAndCount& stats) // Account for aggregate time and count time += stats.time; count += stats.count; + time_max = std::max(static_cast(stats.time), time_max); // Use Welford's method to be able to track the variance - double time_ms = static_cast(stats.time / stats.count) / 1000000.0; num_threads++; - double delta = time_ms - time_mean; + double delta = static_cast(stats.time) - time_mean; time_mean += delta / static_cast(num_threads); - double delta2 = time_ms - time_mean; + double delta2 = static_cast(stats.time) - time_mean; time_m2 += delta * delta2; } @@ -261,7 +278,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts(std::ostream& os, size_t // Loop for a flattened view uint64_t time = 0; for (auto& [parent_key, entry] : entry_map) { - time += entry.time; + time += entry.time_max; } if (!first) { @@ -307,10 +324,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream // Helper function to print a stat line with tree drawing auto print_entry = [&](const AggregateEntry& entry, size_t indent_level, bool is_last, uint64_t parent_time) { std::string indent(indent_level * 2, ' '); - std::string prefix; - if (indent_level > 0) { - prefix = is_last ? "└─ " : "├─ "; - } + std::string prefix = (indent_level == 0) ? "" : (is_last ? "└─ " : "├─ "); // Use exactly 80 characters for function name without indent const size_t name_width = 80; @@ -319,7 +333,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream display_name = display_name.substr(0, name_width - 3) + "..."; } - double time_ms = static_cast(entry.time) / 1000000.0; + double time_ms = static_cast(entry.time_max) / 1000000.0; auto colors = get_time_colors(time_ms); // Print indent + prefix + name (exactly 80 chars) + time/percentage/calls @@ -330,29 +344,24 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream os << std::left << std::setw(static_cast(name_width)) << display_name << Colors::RESET; // Print time if available with aligned section including indent level - if (entry.time > 0) { + if (entry.time_max > 0) { if (time_ms < 100.0) { // Minimal format for <100ms: only [level] and percentage, no time display std::ostringstream minimal_oss; minimal_oss << Colors::MAGENTA << "[" << indent_level << "] " << Colors::RESET; - - // Format percentage FIRST - if (parent_time > 0 && indent_level > 0) { - minimal_oss << format_percentage(time_ms * 1000000.0, static_cast(parent_time)); - } else { - minimal_oss << " "; // Keep alignment for root entries - } - - // Add spacing to replace where time would be (with extra spacing to match) - minimal_oss << " " << std::setw(10) << ""; - + minimal_oss << format_percentage_section(time_ms, static_cast(parent_time), indent_level); + minimal_oss << " " << std::setw(10) << ""; // Add spacing to replace where time would be os << " " << colors.time_color << std::setw(40) << std::left << minimal_oss.str() << Colors::RESET; } else { - // Full format for >=100ms - double mean_ms = entry.num_threads > 1 ? entry.time_mean / 1000000.0 : 0.0; - std::string aligned_section = format_aligned_section( - time_ms, static_cast(parent_time), entry.count, indent_level, entry.num_threads, mean_ms); + std::string aligned_section = + format_aligned_section(time_ms, static_cast(parent_time), entry.count, indent_level); os << " " << colors.time_color << std::setw(40) << std::left << aligned_section << Colors::RESET; + if (entry.num_threads > 1) { + double mean_ms = entry.time_mean / 1000000.0; + double stddev_percentage = floor(entry.get_std_dev() * 100 / entry.time_mean); + os << " " << entry.num_threads << " threads " << mean_ms << "ms average " << stddev_percentage + << "% stddev"; + } } } @@ -392,7 +401,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream if (!printed_in_detail.contains(key)) { for (const auto& [child_key, parent_map] : aggregated) { for (const auto& [parent_key, entry] : parent_map) { - if (parent_key == key && entry.time >= 500000) { // 0.5ms in nanoseconds + if (parent_key == key && entry.time_max >= 500000) { // 0.5ms in nanoseconds children.push_back(child_key); break; } @@ -405,27 +414,22 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream std::ranges::sort(children, [&](OperationKey a, OperationKey b) { uint64_t time_a = 0; uint64_t time_b = 0; - - // Get time for child 'a' when called from THIS parent if (auto it = aggregated.find(a); it != aggregated.end()) { for (const auto& [parent_key, entry] : it->second) { if (parent_key == key) { - time_a = entry.time; + time_a = entry.time_max; break; } } } - - // Get time for child 'b' when called from THIS parent if (auto it = aggregated.find(b); it != aggregated.end()) { for (const auto& [parent_key, entry] : it->second) { if (parent_key == key) { - time_b = entry.time; + time_b = entry.time_max; break; } } } - return time_a > time_b; }); @@ -433,25 +437,21 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream uint64_t children_total_time = 0; for (const auto& child_key : children) { if (auto it = aggregated.find(child_key); it != aggregated.end()) { - // Sum time for this child across all parent contexts where parent matches current key for (const auto& [parent_key, entry] : it->second) { - if (parent_key == key && entry.time >= 500000) { // 0.5ms in nanoseconds - children_total_time += entry.time; + if (parent_key == key && entry.time_max >= 500000) { // 0.5ms in nanoseconds + children_total_time += entry.time_max; } } } } - - // Check if there's significant unaccounted time (>5%) and we have children - uint64_t parent_total_time = entry_to_print->time; + uint64_t parent_total_time = entry_to_print->time_max; bool should_add_other = false; - uint64_t other_time = 0; if (!children.empty() && parent_total_time > 0 && children_total_time < parent_total_time) { - other_time = parent_total_time - children_total_time; - double other_percentage = - (static_cast(other_time) / static_cast(parent_total_time)) * 100.0; - should_add_other = other_percentage > 5.0 && other_time > 0; + uint64_t unaccounted = parent_total_time - children_total_time; + double percentage = (static_cast(unaccounted) / static_cast(parent_total_time)) * 100.0; + should_add_other = percentage > 5.0 && unaccounted > 0; } + uint64_t other_time = should_add_other ? (parent_total_time - children_total_time) : 0; if (!children.empty() && keys_to_parents[key].size() > 1) { os << std::string(indent_level * 2, ' ') << " ├─ NOTE: Shared children. Can add up to > 100%.\n"; @@ -465,13 +465,12 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream // Print "(other)" category if significant unaccounted time exists if (should_add_other && keys_to_parents[key].size() <= 1) { - // Create fake AggregateEntry for (other) AggregateEntry other_entry; other_entry.key = "(other)"; other_entry.time = other_time; + other_entry.time_max = other_time; other_entry.count = 1; other_entry.num_threads = 1; - print_entry(other_entry, indent_level + 1, true, parent_total_time); // always last } }; @@ -479,26 +478,24 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream // Find root entries (those that ONLY have empty parent key and significant time) std::vector roots; for (const auto& [key, parent_map] : aggregated) { - // Check if this operation has an empty parent key entry with significant time - if (auto empty_parent_it = parent_map.find(""); empty_parent_it != parent_map.end()) { - if (empty_parent_it->second.time > 0) { - roots.push_back(key); - } + auto empty_parent_it = parent_map.find(""); + if (empty_parent_it != parent_map.end() && empty_parent_it->second.time > 0) { + roots.push_back(key); } } - // Sort roots by time + // Sort roots by time (descending) std::ranges::sort(roots, [&](OperationKey a, OperationKey b) { uint64_t time_a = 0; uint64_t time_b = 0; - if (auto it = aggregated.find(a); it != aggregated.end()) { - if (auto parent_it = it->second.find(""); parent_it != it->second.end()) { - time_a = parent_it->second.time; + if (auto it_a = aggregated.find(a); it_a != aggregated.end()) { + if (auto parent_it = it_a->second.find(""); parent_it != it_a->second.end()) { + time_a = parent_it->second.time_max; } } - if (auto it = aggregated.find(b); it != aggregated.end()) { - if (auto parent_it = it->second.find(""); parent_it != it->second.end()) { - time_b = parent_it->second.time; + if (auto it_b = aggregated.find(b); it_b != aggregated.end()) { + if (auto parent_it = it_b->second.find(""); parent_it != it_b->second.end()) { + time_b = parent_it->second.time_max; } } return time_a > time_b; @@ -513,41 +510,47 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream print_separator(os, false); // Calculate totals from root entries - uint64_t total_time = 0; - uint64_t total_calls = 0; - std::set unique_functions; - uint64_t shared_count = 0; - - for (auto& [key, parent_map] : aggregated) { - unique_functions.insert(key); + std::set unique_funcs; + for (const auto& [key, _] : aggregated) { + unique_funcs.insert(key); + } + size_t unique_functions_count = unique_funcs.size(); - // Count as shared if has multiple parents - if (keys_to_parents[key].size() > 1) { + uint64_t shared_count = 0; + for (const auto& [key, parents] : keys_to_parents) { + if (parents.size() > 1) { shared_count++; } - auto& root_entry = parent_map[""]; - total_time += root_entry.time; - // Sum ALL calls - for (auto& entry : parent_map) { - total_calls += entry.second.count; + } + + uint64_t total_time = 0; + for (const auto& [_, parent_map] : aggregated) { + if (auto it = parent_map.find(""); it != parent_map.end()) { + total_time = std::max(static_cast(total_time), it->second.time_max); + } + } + + uint64_t total_calls = 0; + for (const auto& [_, parent_map] : aggregated) { + for (const auto& [__, entry] : parent_map) { + total_calls += entry.count; } } double total_time_ms = static_cast(total_time) / 1000000.0; - os << " " << Colors::BOLD << "Total: " << Colors::RESET << Colors::MAGENTA << unique_functions.size() + os << " " << Colors::BOLD << "Total: " << Colors::RESET << Colors::MAGENTA << unique_functions_count << " functions" << Colors::RESET; if (shared_count > 0) { os << " (" << Colors::RED << shared_count << " shared" << Colors::RESET << ")"; } - os << ", " << Colors::GREEN << total_calls << " measurements" << Colors::RESET << ", "; - + os << ", " << Colors::GREEN << total_calls << " measurements" << Colors::RESET << ", " << Colors::YELLOW; if (total_time_ms >= 1000.0) { - os << Colors::YELLOW << std::fixed << std::setprecision(2) << (total_time_ms / 1000.0) << " seconds" - << Colors::RESET; + os << std::fixed << std::setprecision(2) << (total_time_ms / 1000.0) << " seconds"; } else { - os << Colors::YELLOW << std::fixed << std::setprecision(2) << total_time_ms << " ms" << Colors::RESET; + os << std::fixed << std::setprecision(2) << total_time_ms << " ms"; } + os << Colors::RESET; os << "\n"; print_separator(os, true); diff --git a/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp b/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp index c4b71f383166..a202c28a4880 100644 --- a/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp +++ b/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp @@ -52,6 +52,7 @@ struct AggregateEntry { std::size_t count = 0; size_t num_threads = 0; double time_mean = 0; + std::size_t time_max = 0; double time_stddev = 0; // Welford's algorithm state diff --git a/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp b/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp index c923e3e7a62f..b60b2f732bce 100644 --- a/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp +++ b/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp @@ -41,6 +41,7 @@ class ThreadPool { do_iterations(); { + BB_BENCH_NAME("spinning main thread"); std::unique_lock lock(tasks_mutex); complete_condition_.wait(lock, [this] { return complete_ == num_iterations_; }); } @@ -71,6 +72,7 @@ class ThreadPool { } iteration = iteration_++; } + BB_BENCH_NAME("do_iterations()"); task_(iteration); { std::unique_lock lock(tasks_mutex); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp index e117af832254..e630574ef248 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp @@ -511,7 +511,7 @@ TEST_F(AcirIntegrationTest, DISABLED_ClientIVCMsgpackInputs) std::shared_ptr ivc = steps.accumulate(); ClientIVC::Proof proof = ivc->prove(); - EXPECT_TRUE(ivc->verify(proof)); + EXPECT_TRUE(ivc->verify(proof, ivc->get_vk())); } /** diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.cpp index a9fce6c6e8e9..418414f4b8c8 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.cpp @@ -162,8 +162,6 @@ void mock_ivc_accumulation(const std::shared_ptr& ivc, ClientIVC::QUE ivc->goblin.merge_verification_queue.emplace_back(acir_format::create_mock_merge_proof()); // If the type is PG_FINAL, we also need to populate the ivc instance with a mock decider proof if (type == ClientIVC::QUEUE_TYPE::PG_FINAL) { - // we have to create a mock honk vk - ivc->honk_vk = entry.honk_vk; ivc->decider_proof = acir_format::create_mock_decider_proof(); } ivc->num_circuits_accumulated++; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.test.cpp index db0b6ddf5c25..51a4229b1f65 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/pg_recursion_constraint.test.cpp @@ -313,7 +313,8 @@ TEST_F(IvcRecursionConstraintTest, AccumulateSingleApp) // add the trailing kernels construct_and_accumulate_trailing_kernels(ivc, trace_settings); - EXPECT_TRUE(ivc->prove_and_verify()); + auto proof = ivc->prove(); + EXPECT_TRUE(ClientIVC::verify(proof, ivc->get_vk())); } /** @@ -343,7 +344,8 @@ TEST_F(IvcRecursionConstraintTest, AccumulateTwoApps) // Accumulate the trailing kernels construct_and_accumulate_trailing_kernels(ivc, trace_settings); - EXPECT_TRUE(ivc->prove_and_verify()); + auto proof = ivc->prove(); + EXPECT_TRUE(ClientIVC::verify(proof, ivc->get_vk())); } // Test generation of "init" kernel VK via dummy IVC data @@ -584,7 +586,8 @@ TEST_F(IvcRecursionConstraintTest, RecursiveVerifierAppCircuitTest) construct_and_accumulate_trailing_kernels(ivc, trace_settings); - EXPECT_TRUE(ivc->prove_and_verify()); + auto proof = ivc->prove(); + EXPECT_TRUE(ClientIVC::verify(proof, ivc->get_vk())); } /** @@ -607,5 +610,6 @@ TEST_F(IvcRecursionConstraintTest, BadRecursiveVerifierAppCircuitTest) construct_and_accumulate_trailing_kernels(ivc, trace_settings); // We expect the CIVC proof to fail due to the app with a failed UH recursive verification - EXPECT_FALSE(ivc->prove_and_verify()); + auto proof = ivc->prove(); + EXPECT_FALSE(ClientIVC::verify(proof, ivc->get_vk())); } diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp index bbb99d27e840..36a1e15ea90c 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp @@ -795,7 +795,7 @@ std::vector> element::batch_mul_with_endomo const std::span>& points, const Fr& scalar) noexcept { BB_BENCH(); - typedef affine_element affine_element; + using affine_element = affine_element; const size_t num_points = points.size(); // Space for temporary values @@ -833,20 +833,6 @@ std::vector> element::batch_mul_with_endomo } }; - /** - * @brief Perform batch affine addition in parallel - * - */ - const auto batch_affine_add_internal = - [num_points, &scratch_space, &batch_affine_add_chunked](const affine_element* lhs, affine_element* rhs) { - parallel_for_heuristic( - num_points, - [&](size_t start, size_t end, BB_UNUSED size_t chunk_index) { - batch_affine_add_chunked(lhs + start, rhs + start, end - start, &scratch_space[0] + start); - }, - thread_heuristics::FF_ADDITION_COST * 6 + thread_heuristics::FF_MULTIPLICATION_COST * 6); - }; - /** * @brief Perform point doubling lhs[i]=lhs[i]+lhs[i] with batch inversion * @@ -878,18 +864,6 @@ std::vector> element::batch_mul_with_endomo lhs[i].y = personal_scratch_space[i] * (temp - lhs[i].x) - lhs[i].y; } }; - /** - * @brief Perform point doubling in parallel - * - */ - const auto batch_affine_double = [num_points, &scratch_space, &batch_affine_double_chunked](affine_element* lhs) { - parallel_for_heuristic( - num_points, - [&](size_t start, size_t end, BB_UNUSED size_t chunk_index) { - batch_affine_double_chunked(lhs + start, end - start, &scratch_space[0] + start); - }, - thread_heuristics::FF_ADDITION_COST * 7 + thread_heuristics::FF_MULTIPLICATION_COST * 6); - }; // We compute the resulting point through WNAF by evaluating (the (\sum_i (16ⁱ⋅ // (a_i ∈ {-15,-13,-11,-9,-7,-5,-3,-1,1,3,5,7,9,11,13,15}))) - skew), where skew is 0 or 1. The result of the sum is @@ -916,50 +890,58 @@ std::vector> element::batch_mul_with_endomo constexpr size_t LOOKUP_SIZE = 8; constexpr size_t NUM_ROUNDS = 32; + + detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar); + detail::EndomorphismWnaf wnaf{ endo_scalars }; + + std::vector work_elements(num_points); std::array, LOOKUP_SIZE> lookup_table; for (auto& table : lookup_table) { table.resize(num_points); } - // Initialize first etnries in lookup table std::vector temp_point_vector(num_points); - parallel_for_heuristic( - num_points, - [&](size_t i) { - // If the point is at infinity we fix-up the result later - // To avoid 'trying to invert zero in the field' we set the point to 'one' here - temp_point_vector[i] = points[i].is_point_at_infinity() ? affine_element::one() : points[i]; - lookup_table[0][i] = points[i].is_point_at_infinity() ? affine_element::one() : points[i]; - }, - thread_heuristics::FF_COPY_COST * 2); - - // Construct lookup table - batch_affine_double(&temp_point_vector[0]); - for (size_t j = 1; j < LOOKUP_SIZE; ++j) { - parallel_for_heuristic( - num_points, - [&](size_t i) { lookup_table[j][i] = lookup_table[j - 1][i]; }, - thread_heuristics::FF_COPY_COST); - batch_affine_add_internal(&temp_point_vector[0], &lookup_table[j][0]); - } - detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar); - detail::EndomorphismWnaf wnaf{ endo_scalars }; + auto execute_range = [&](size_t start, size_t end) { + // Perform batch affine addition in parallel + const auto add_chunked = [&](const affine_element* lhs, affine_element* rhs) { + batch_affine_add_chunked(&lhs[start], &rhs[start], end - start, &scratch_space[start]); + }; - std::vector work_elements(num_points); + // Perform point doubling in parallel + const auto double_chunked = [&](affine_element* lhs) { + batch_affine_double_chunked(&lhs[start], end - start, &scratch_space[start]); + }; - constexpr Fq beta = Fq::cube_root_of_unity(); - uint64_t wnaf_entry = 0; - uint64_t index = 0; - bool sign = 0; - // Prepare elements for the first batch addition - for (size_t j = 0; j < 2; ++j) { - wnaf_entry = wnaf.table[j]; - index = wnaf_entry & 0x0fffffffU; - sign = static_cast((wnaf_entry >> 31) & 1); - const bool is_odd = ((j & 1) == 1); - parallel_for_heuristic( - num_points, - [&](size_t i) { + // Initialize first entries in lookup table + for (size_t i = start; i < end; ++i) { + if (points[i].is_point_at_infinity()) { + temp_point_vector[i] = affine_element::one(); + lookup_table[0][i] = affine_element::one(); + } else { + temp_point_vector[i] = points[i]; + lookup_table[0][i] = points[i]; + } + } + // Costruct lookup table + double_chunked(&temp_point_vector[0]); + for (size_t j = 1; j < LOOKUP_SIZE; ++j) { + for (size_t i = start; i < end; ++i) { + lookup_table[j][i] = lookup_table[j - 1][i]; + } + add_chunked(&temp_point_vector[0], &lookup_table[j][0]); + } + + constexpr Fq beta = Fq::cube_root_of_unity(); + uint64_t wnaf_entry = 0; + uint64_t index = 0; + bool sign = 0; + // Prepare elements for the first batch addition + for (size_t j = 0; j < 2; ++j) { + wnaf_entry = wnaf.table[j]; + index = wnaf_entry & 0x0fffffffU; + sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((j & 1) == 1); + for (size_t i = start; i < end; ++i) { auto to_add = lookup_table[static_cast(index)][i]; to_add.y.self_conditional_negate(sign ^ is_odd); if (is_odd) { @@ -970,64 +952,51 @@ std::vector> element::batch_mul_with_endomo } else { temp_point_vector[i] = to_add; } - }, - (is_odd ? thread_heuristics::FF_MULTIPLICATION_COST : 0) + thread_heuristics::FF_COPY_COST + - thread_heuristics::FF_ADDITION_COST); - } - // First cycle of addition - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - // Run through SM logic in wnaf form (excluding the skew) - for (size_t j = 2; j < NUM_ROUNDS * 2; ++j) { - wnaf_entry = wnaf.table[j]; - index = wnaf_entry & 0x0fffffffU; - sign = static_cast((wnaf_entry >> 31) & 1); - const bool is_odd = ((j & 1) == 1); - if (!is_odd) { - for (size_t k = 0; k < 4; ++k) { - batch_affine_double(&work_elements[0]); } } - parallel_for_heuristic( - num_points, - [&](size_t i) { + add_chunked(&temp_point_vector[0], &work_elements[0]); + // Run through SM logic in wnaf form (excluding the skew) + for (size_t j = 2; j < NUM_ROUNDS * 2; ++j) { + wnaf_entry = wnaf.table[j]; + index = wnaf_entry & 0x0fffffffU; + sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((j & 1) == 1); + if (!is_odd) { + for (size_t k = 0; k < 4; ++k) { + double_chunked(&work_elements[0]); + } + } + for (size_t i = start; i < end; ++i) { auto to_add = lookup_table[static_cast(index)][i]; to_add.y.self_conditional_negate(sign ^ is_odd); if (is_odd) { to_add.x *= beta; } temp_point_vector[i] = to_add; - }, - (is_odd ? thread_heuristics::FF_MULTIPLICATION_COST : 0) + thread_heuristics::FF_COPY_COST + - thread_heuristics::FF_ADDITION_COST); - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - } - - // Apply skew for the first endo scalar - if (wnaf.skew) { - parallel_for_heuristic( - num_points, - [&](size_t i) { temp_point_vector[i] = -lookup_table[0][i]; }, - thread_heuristics::FF_ADDITION_COST + thread_heuristics::FF_COPY_COST); - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - } - // Apply skew for the second endo scalar - if (wnaf.endo_skew) { - parallel_for_heuristic( - num_points, - [&](size_t i) { + } + add_chunked(&temp_point_vector[0], &work_elements[0]); + } + // Apply skew for the first endo scalar + if (wnaf.skew) { + for (size_t i = start; i < end; ++i) { + temp_point_vector[i] = -lookup_table[0][i]; + } + add_chunked(&temp_point_vector[0], &work_elements[0]); + } + // Apply skew for the second endo scalar + if (wnaf.endo_skew) { + for (size_t i = start; i < end; ++i) { temp_point_vector[i] = lookup_table[0][i]; temp_point_vector[i].x *= beta; - }, - thread_heuristics::FF_MULTIPLICATION_COST + thread_heuristics::FF_COPY_COST); - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - } - // handle points at infinity explicitly - parallel_for_heuristic( - num_points, - [&](size_t i) { + } + add_chunked(&temp_point_vector[0], &work_elements[0]); + } + // handle points at infinity explicitly + for (size_t i = start; i < end; ++i) { work_elements[i] = points[i].is_point_at_infinity() ? work_elements[i].set_infinity() : work_elements[i]; - }, - thread_heuristics::FF_COPY_COST); + } + }; + parallel_for_range(num_points, execute_range); return work_elements; } diff --git a/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations.hpp b/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations.hpp index e3abe1497d28..060839382f9e 100644 --- a/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations.hpp @@ -27,8 +27,9 @@ template class TranslatorOpcodeConstraintRelationImpl { * @brief Returns true if the contribution from all subrelations for the provided inputs is identically zero * */ - template inline static bool skip(const AllEntities& in) + template static bool skip(const AllEntities& in) { + // All contributions are zero outside the minicircuit or at odd indices not masked return (in.lagrange_even_in_minicircuit + in.lagrange_mini_masking).is_zero(); } /** @@ -79,9 +80,13 @@ template class TranslatorAccumulatorTransferRelationImpl { * slower. * */ - template inline static bool skip(const AllEntities& in) + template static bool skip(const AllEntities& in) { - return (in.lagrange_odd_in_minicircuit + in.lagrange_last_in_minicircuit + in.lagrange_result_row).is_zero(); + // All contributions are zero outside the minicircuit or at even indices within the minicircuite excluding + // masked areas (except from the last and result row in minicircuit) + return (in.lagrange_odd_in_minicircuit + in.lagrange_last_in_minicircuit + in.lagrange_result_row + + in.lagrange_mini_masking) + .is_zero(); } /** * @brief Relation enforcing non-arithmetic transitions of accumulator (value that is tracking the batched @@ -110,7 +115,7 @@ template class TranslatorZeroConstraintsRelationImpl { // 1 + polynomial degree of this relation static constexpr size_t RELATION_LENGTH = 4; // degree((some lagrange)(A)) = 2 - static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ 4, // p_x_low_limbs_range_constraint_0 is zero outside of the minicircuit 4, // p_x_low_limbs_range_constraint_1 is zero outside of the minicircuit 4, // p_x_low_limbs_range_constraint_2 is zero outside of the minicircuit @@ -175,18 +180,25 @@ template class TranslatorZeroConstraintsRelationImpl { 4, // accumulator_high_limbs_range_constraint_tail is zero outside of the minicircuit 4, // quotient_low_limbs_range_constraint_tail is zero outside of the minicircuit 4, // quotient_high_limbs_range_constraint_tail is zero outside of the minicircuit + 4, // op is zero outside of the minicircuit + 4, // x_lo_y_hi is zero outside of the minicircuit + 4, // x_hi_z_1 is zero outside of the minicircuit + 4, // y_lo_z_2 is zero outside of the minicircuit }; /** - * @brief Might return true if the contribution from all subrelations for the provided inputs is identically zero + * @brief Returns true if the contribution from all subrelations for the provided inputs is identically zero * * */ - template inline static bool skip(const AllEntities& in) + template static bool skip(const AllEntities& in) { + // All contributions are identically zero if outside the minicircuit and masked area or when we have a + // no-op (i.e. op is zero at an even index) static constexpr auto minus_one = -FF(1); - return (in.lagrange_even_in_minicircuit + in.lagrange_last_in_minicircuit + minus_one).is_zero(); + return (in.lagrange_even_in_minicircuit + in.op + minus_one).is_zero() || + (in.lagrange_odd_in_minicircuit + in.lagrange_even_in_minicircuit + in.lagrange_mini_masking).is_zero(); } /** * @brief Relation enforcing all the range-constraint polynomials to be zero after the minicircuit diff --git a/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations_impl.hpp b/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations_impl.hpp index e838fcedfd8b..a46e369342a7 100644 --- a/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_extra_relations_impl.hpp @@ -207,9 +207,8 @@ void TranslatorAccumulatorTransferRelationImpl::accumulate(ContainerOverSubr }; /** - * @brief Relation enforcing all the range-constraint polynomials to be zero after the minicircuit - * @details This relation ensures that while we are out of the minicircuit the range constraint polynomials are zero - * + * @brief Relation enforcing all the range-constraint and op queue polynomials to be zero after the minicircuit + * @param evals transformed to `evals + C(in(X)...)*scaling_factor` * @param in an std::array containing the fully extended Univariate edges. * @param parameters contains beta, gamma, and public_input_delta, .... @@ -295,6 +294,10 @@ void TranslatorZeroConstraintsRelationImpl::accumulate(ContainerOverSubrelat auto accumulator_high_limbs_range_constraint_tail = View(in.accumulator_high_limbs_range_constraint_tail); auto quotient_low_limbs_range_constraint_tail = View(in.quotient_low_limbs_range_constraint_tail); auto quotient_high_limbs_range_constraint_tail = View(in.quotient_high_limbs_range_constraint_tail); + auto op = View(in.op); + auto x_lo_y_hi = View(in.x_lo_y_hi); + auto x_hi_z_1 = View(in.x_hi_z_1); + auto y_lo_z_2 = View(in.y_lo_z_2); auto lagrange_mini_masking = View(in.lagrange_mini_masking); // 0 in the minicircuit, -1 outside @@ -492,5 +495,17 @@ void TranslatorZeroConstraintsRelationImpl::accumulate(ContainerOverSubrelat // Contribution 63, ensure quotient_high_limbs_range_constraint_tail is 0 outside of minicircuit std::get<63>(accumulators) += quotient_high_limbs_range_constraint_tail * not_in_mininicircuit_or_masked; + + // Contribution 64, ensure op is 0 outside of minicircuit + std::get<64>(accumulators) += op * not_in_mininicircuit_or_masked; + + // Contribution 65, ensure x_lo_y_hi is 0 outside of minicircuit + std::get<65>(accumulators) += x_lo_y_hi * not_in_mininicircuit_or_masked; + + // Contribution 66, ensure x_hi_z_1 is 0 outside of minicircuit + std::get<66>(accumulators) += x_hi_z_1 * not_in_mininicircuit_or_masked; + + // Contribution 67, ensure y_lo_z_2 is 0 outside of minicircuit + std::get<67>(accumulators) += y_lo_z_2 * not_in_mininicircuit_or_masked; }; } // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_relation_consistency.test.cpp b/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_relation_consistency.test.cpp index c60a2925ee38..27cb69a35ead 100644 --- a/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_relation_consistency.test.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/translator_vm/translator_relation_consistency.test.cpp @@ -912,7 +912,10 @@ TEST_F(TranslatorRelationConsistency, ZeroConstraintsRelation) const auto& relation_wide_limbs_range_constraint_1 = input_elements.relation_wide_limbs_range_constraint_1; const auto& relation_wide_limbs_range_constraint_2 = input_elements.relation_wide_limbs_range_constraint_2; const auto& relation_wide_limbs_range_constraint_3 = input_elements.relation_wide_limbs_range_constraint_3; - + const auto& op = input_elements.op; + const auto& x_lo_y_hi = input_elements.x_lo_y_hi; + const auto& x_hi_z_1 = input_elements.x_hi_z_1; + const auto& y_lo_z_2 = input_elements.y_lo_z_2; const auto& lagrange_odd_in_minicircuit = input_elements.lagrange_odd_in_minicircuit; const auto& lagrange_even_in_minicircuit = input_elements.lagrange_even_in_minicircuit; const auto& lagrange_mini_masking = input_elements.lagrange_mini_masking; @@ -1049,6 +1052,14 @@ TEST_F(TranslatorRelationConsistency, ZeroConstraintsRelation) (lagrange_mini_masking - FF(1)) * quotient_low_limbs_range_constraint_tail; expected_values[63] = (lagrange_even_in_minicircuit + lagrange_odd_in_minicircuit - 1) * (lagrange_mini_masking - FF(1)) * quotient_high_limbs_range_constraint_tail; + expected_values[64] = + (lagrange_even_in_minicircuit + lagrange_odd_in_minicircuit - 1) * (lagrange_mini_masking - FF(1)) * op; + expected_values[65] = (lagrange_even_in_minicircuit + lagrange_odd_in_minicircuit - 1) * + (lagrange_mini_masking - FF(1)) * x_lo_y_hi; + expected_values[66] = (lagrange_even_in_minicircuit + lagrange_odd_in_minicircuit - 1) * + (lagrange_mini_masking - FF(1)) * x_hi_z_1; + expected_values[67] = (lagrange_even_in_minicircuit + lagrange_odd_in_minicircuit - 1) * + (lagrange_mini_masking - FF(1)) * y_lo_z_2; validate_relation_execution(expected_values, input_elements, parameters); }; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp index 4f2080befce9..2ccbe85433a8 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp @@ -100,6 +100,8 @@ template class stdlib_bigfield : public testing::Test { fr elt_native_lo = fr(uint256_t(elt_native).slice(0, fq_ct::NUM_LIMB_BITS * 2)); fr elt_native_hi = fr(uint256_t(elt_native).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)); fq_ct elt_ct(witness_ct(builder, elt_native_lo), witness_ct(builder, elt_native_hi)); + // UNset free witness tag so we don't have to unset it in every test + elt_ct.unset_free_witness_tag(); return std::make_pair(elt_native, elt_ct); } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp index a43f5a966b95..12c546b973aa 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp @@ -223,6 +223,8 @@ template bigfield::bigfield(const byt const uint256_t hi_nibble_shift = uint256_t(1) << 4; const field_t sum = lo_nibble + (hi_nibble * hi_nibble_shift); sum.assert_equal(split_byte); + lo_nibble.set_origin_tag(split_byte.tag); + hi_nibble.set_origin_tag(split_byte.tag); return std::make_pair(lo_nibble, hi_nibble); }; @@ -1892,7 +1894,6 @@ template void bigfield::assert_less_t template void bigfield::assert_equal(const bigfield& other) const { Builder* ctx = this->context ? this->context : other.context; - (void)OriginTag(get_origin_tag(), other.get_origin_tag()); if (is_constant() && other.is_constant()) { std::cerr << "bigfield: calling assert equal on 2 CONSTANT bigfield elements...is this intended?" << std::endl; BB_ASSERT_EQ(get_value(), other.get_value(), "We expect constants to be less than the target modulus"); @@ -1922,6 +1923,12 @@ template void bigfield::assert_equal( return; } else { + // Remove tags, we don't want to cause violations on assert_equal + const auto original_tag = get_origin_tag(); + const auto other_original_tag = other.get_origin_tag(); + set_origin_tag(OriginTag()); + other.set_origin_tag(OriginTag()); + bigfield diff = *this - other; const uint512_t diff_val = diff.get_value(); const uint512_t modulus(target_basis.modulus); @@ -1938,6 +1945,10 @@ template void bigfield::assert_equal( false, num_quotient_bits); unsafe_evaluate_multiply_add(diff, { one() }, {}, quotient, { zero() }); + + // Restore tags + set_origin_tag(original_tag); + other.set_origin_tag(other_original_tag); } } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp index 1406280082a0..4e731bd45205 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp @@ -29,7 +29,9 @@ field_t::field_t(const witness_t& value) , additive_constant(bb::fr::zero()) , multiplicative_constant(bb::fr::one()) , witness_index(value.witness_index) -{} +{ + set_free_witness_tag(); +} template field_t::field_t(Builder* parent_context, const bb::fr& value) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp index e3431b164bb4..e551b5d7f657 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp @@ -1326,8 +1326,8 @@ template class stdlib_field : public testing::Test { uint256_t(bb::fr::random_element()) & ((uint256_t(1) << grumpkin::MAX_NO_WRAP_INTEGER_BIT_LENGTH) - 1); auto a = field_ct(witness_ct(&builder, a_val)); auto b = field_ct(witness_ct(&builder, bb::fr::random_element())); - EXPECT_TRUE(a.get_origin_tag().is_empty()); - EXPECT_TRUE(b.get_origin_tag().is_empty()); + EXPECT_TRUE(a.get_origin_tag().is_free_witness()); + EXPECT_TRUE(b.get_origin_tag().is_free_witness()); const size_t parent_id = 0; const auto submitted_value_origin_tag = OriginTag(parent_id, /*round_id=*/0, /*is_submitted=*/true); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp index f5d5941bf000..ce8645ad03f4 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.cpp @@ -11,6 +11,7 @@ #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" #include "./cycle_group.hpp" +#include "barretenberg/numeric/general/general.hpp" #include "barretenberg/stdlib/primitives/plookup/plookup.hpp" #include "barretenberg/stdlib_circuit_builders/plookup_tables/fixed_base/fixed_base.hpp" #include "barretenberg/stdlib_circuit_builders/plookup_tables/types.hpp" @@ -765,7 +766,9 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ const std::span offset_generators, const bool unconditional_add) { - BB_ASSERT_EQ(scalars.size(), base_points.size()); + BB_ASSERT_EQ(!scalars.empty(), true, "Empty scalars provided to variable base batch mul!"); + BB_ASSERT_EQ(scalars.size(), base_points.size(), "Points/scalars size mismatch in variable base batch mul!"); + const size_t num_points = scalars.size(); Builder* context = nullptr; for (auto& scalar : scalars) { @@ -781,21 +784,16 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ } } - size_t num_bits = 0; + size_t num_bits = scalars[0].num_bits(); for (auto& s : scalars) { - num_bits = std::max(num_bits, s.num_bits()); + BB_ASSERT_EQ(s.num_bits(), num_bits, "Scalars of different bit-lengths not supported!"); } - size_t num_rounds = (num_bits + TABLE_BITS - 1) / TABLE_BITS; - - const size_t num_points = scalars.size(); + size_t num_rounds = numeric::ceil_div(num_bits, TABLE_BITS); std::vector scalar_slices; scalar_slices.reserve(num_points); - for (size_t i = 0; i < num_points; ++i) { - scalar_slices.emplace_back(straus_scalar_slice(context, scalars[i], TABLE_BITS)); - // AUDITTODO: temporary safety check. See test MixedLengthScalarsIsNotSupported - BB_ASSERT_EQ( - scalar_slices[i].slices_native.size() == num_rounds, true, "Scalars of different sizes not supported!"); + for (const auto& scalar : scalars) { + scalar_slices.emplace_back(context, scalar, TABLE_BITS); } /** @@ -805,31 +803,24 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ * generation times */ std::vector operation_transcript; - std::vector> native_straus_tables; Element offset_generator_accumulator = offset_generators[0]; { + // Construct native straus lookup table for each point + std::vector> native_straus_tables; for (size_t i = 0; i < num_points; ++i) { - std::vector native_straus_table; - native_straus_table.emplace_back(offset_generators[i + 1]); - size_t table_size = 1ULL << TABLE_BITS; - for (size_t j = 1; j < table_size; ++j) { - native_straus_table.emplace_back(native_straus_table[j - 1] + base_points[i].get_value()); - } - native_straus_tables.emplace_back(native_straus_table); - } - for (size_t i = 0; i < num_points; ++i) { - auto table_transcript = straus_lookup_table::compute_straus_lookup_table_hints( + std::vector table_transcript = straus_lookup_table::compute_straus_lookup_table_hints( base_points[i].get_value(), offset_generators[i + 1], TABLE_BITS); std::copy(table_transcript.begin() + 1, table_transcript.end(), std::back_inserter(operation_transcript)); + native_straus_tables.emplace_back(std::move(table_transcript)); } - Element accumulator = offset_generators[0]; + // Perform the Straus algorithm natively to generate the witness values (hints) for all intermediate points + Element accumulator = offset_generators[0]; for (size_t i = 0; i < num_rounds; ++i) { if (i != 0) { for (size_t j = 0; j < TABLE_BITS; ++j) { - // offset_generator_accuulator is a regular Element, so dbl() won't add constraints accumulator = accumulator.dbl(); - operation_transcript.emplace_back(accumulator); + operation_transcript.push_back(accumulator); offset_generator_accumulator = offset_generator_accumulator.dbl(); } } @@ -839,29 +830,30 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ accumulator += point; - operation_transcript.emplace_back(accumulator); + operation_transcript.push_back(accumulator); offset_generator_accumulator = offset_generator_accumulator + Element(offset_generators[j + 1]); } } } - // Normalize the computed witness points and convert into AffineElement type + // Normalize the computed witness points and convert them into AffineElements Element::batch_normalize(&operation_transcript[0], operation_transcript.size()); - std::vector operation_hints; operation_hints.reserve(operation_transcript.size()); - for (auto& element : operation_transcript) { - operation_hints.emplace_back(AffineElement(element.x, element.y)); + for (const Element& element : operation_transcript) { + operation_hints.emplace_back(element.x, element.y); } + // Construct an in-circuit straus lookup table for each point std::vector point_tables; const size_t hints_per_table = (1ULL << TABLE_BITS) - 1; OriginTag tag{}; for (size_t i = 0; i < num_points; ++i) { - std::span table_hints(&operation_hints[i * hints_per_table], hints_per_table); // Merge tags tag = OriginTag(tag, scalars[i].get_origin_tag(), base_points[i].get_origin_tag()); - point_tables.emplace_back(straus_lookup_table(context, base_points[i], offset_generators[i + 1], TABLE_BITS)); + + std::span table_hints(&operation_hints[i * hints_per_table], hints_per_table); + point_tables.emplace_back(context, base_points[i], offset_generators[i + 1], TABLE_BITS, table_hints); } AffineElement* hint_ptr = &operation_hints[num_points * hints_per_table]; @@ -874,39 +866,34 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ std::vector points_to_add; for (size_t i = 0; i < num_rounds; ++i) { for (size_t j = 0; j < num_points; ++j) { - const std::optional scalar_slice = scalar_slices[j].read(num_rounds - i - 1); - // if we are doing a batch mul over scalars of different bit-lengths, we may not have any scalar bits for a - // given round and a given scalar - if (scalar_slice.has_value()) { - const cycle_group point = point_tables[j].read(scalar_slice.value()); - points_to_add.emplace_back(point); - } + const field_t scalar_slice = scalar_slices[j].read(num_rounds - i - 1); + const cycle_group point = point_tables[j].read(scalar_slice); + points_to_add.push_back(point); } } + // Perform the Straus algorithm in-circuit, using the previously computed native hints std::vector> x_coordinate_checks; size_t point_counter = 0; for (size_t i = 0; i < num_rounds; ++i) { + // perform once-per-round doublings (except for first round) if (i != 0) { for (size_t j = 0; j < TABLE_BITS; ++j) { accumulator = accumulator.dbl(*hint_ptr); hint_ptr++; } } - + // perform each round's additions for (size_t j = 0; j < num_points; ++j) { - const std::optional scalar_slice = scalar_slices[j].read(num_rounds - i - 1); - // if we are doing a batch mul over scalars of different bit-lengths, we may not have a bit slice - // for a given round and a given scalar - BB_ASSERT_EQ(scalar_slice.value().get_value(), scalar_slices[j].slices_native[num_rounds - i - 1]); - if (scalar_slice.has_value()) { - const auto& point = points_to_add[point_counter++]; - if (!unconditional_add) { - x_coordinate_checks.push_back({ accumulator.x, point.x }); - } - accumulator = accumulator.unconditional_add(point, *hint_ptr); - hint_ptr++; + field_t scalar_slice = scalar_slices[j].read(num_rounds - i - 1); + + BB_ASSERT_EQ(scalar_slice.get_value(), scalar_slices[j].slices_native[num_rounds - i - 1]); + const auto& point = points_to_add[point_counter++]; + if (!unconditional_add) { + x_coordinate_checks.emplace_back(accumulator.x, point.x); } + accumulator = accumulator.unconditional_add(point, *hint_ptr); + hint_ptr++; } } @@ -915,7 +902,7 @@ typename cycle_group::batch_mul_internal_output cycle_group::_ // because `assert_is_not_zero` witness generation needs a modular inversion (expensive) field_t coordinate_check_product = 1; for (auto& [x1, x2] : x_coordinate_checks) { - auto x_diff = x2 - x1; + const field_t x_diff = x2 - x1; coordinate_check_product *= x_diff; } coordinate_check_product.assert_is_not_zero("_variable_base_batch_mul_internal x-coordinate collision"); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp index c4bcbf68a84f..59a84992d524 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/cycle_group.test.cpp @@ -733,227 +733,272 @@ TYPED_TEST(CycleGroupTest, TestSubtract) EXPECT_EQ(proof_result, true); } -TYPED_TEST(CycleGroupTest, TestBatchMul) +/** + * @brief Assign different tags to all points and scalars and return the union of that tag + * @details We assign the tags with the same round index to a (point,scalar) pair, but the point is treated as + * submitted value, while scalar as a challenge. Merging these tags should not run into any edgecases + * + */ +template auto assign_and_merge_tags(T1& points, T2& scalars) +{ + OriginTag merged_tag; + for (size_t i = 0; i < points.size(); i++) { + const auto point_tag = OriginTag(/*parent_index=*/0, /*round_index=*/i, /*is_submitted=*/true); + const auto scalar_tag = OriginTag(/*parent_index=*/0, /*round_index=*/i, /*is_submitted=*/false); + + merged_tag = OriginTag(merged_tag, OriginTag(point_tag, scalar_tag)); + points[i].set_origin_tag(point_tag); + scalars[i].set_origin_tag(scalar_tag); + } + return merged_tag; +} + +TYPED_TEST(CycleGroupTest, TestBatchMulGeneralMSM) { STDLIB_TYPE_ALIASES; auto builder = Builder(); const size_t num_muls = 1; - /** - * @brief Assign different tags to all points and scalars and return the union of that tag - * - *@details We assign the tags with the same round index to a (point,scalar) pair, but the point is treated as - *submitted value, while scalar as a challenge. Merging these tags should not run into any edgecases - */ - auto assign_and_merge_tags = [](auto& points, auto& scalars) { - OriginTag merged_tag; - for (size_t i = 0; i < points.size(); i++) { - const auto point_tag = OriginTag(/*parent_index=*/0, /*round_index=*/i, /*is_submitted=*/true); - const auto scalar_tag = OriginTag(/*parent_index=*/0, /*round_index=*/i, /*is_submitted=*/false); - - merged_tag = OriginTag(merged_tag, OriginTag(point_tag, scalar_tag)); - points[i].set_origin_tag(point_tag); - scalars[i].set_origin_tag(scalar_tag); - } - return merged_tag; - }; // case 1, general MSM with inputs that are combinations of constant and witnesses - { - std::vector points; - std::vector scalars; - Element expected = Group::point_at_infinity; + std::vector points; + std::vector scalars; + Element expected = Group::point_at_infinity; - for (size_t i = 0; i < num_muls; ++i) { - auto element = TestFixture::generators[i]; - typename Group::Fr scalar = Group::Fr::random_element(&engine); + for (size_t i = 0; i < num_muls; ++i) { + auto element = TestFixture::generators[i]; + typename Group::Fr scalar = Group::Fr::random_element(&engine); - // 1: add entry where point, scalar are witnesses - expected += (element * scalar); - points.emplace_back(cycle_group_ct::from_witness(&builder, element)); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + // 1: add entry where point, scalar are witnesses + expected += (element * scalar); + points.emplace_back(cycle_group_ct::from_witness(&builder, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - // 2: add entry where point is constant, scalar is witness - expected += (element * scalar); - points.emplace_back(cycle_group_ct(element)); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + // 2: add entry where point is constant, scalar is witness + expected += (element * scalar); + points.emplace_back(cycle_group_ct(element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - // 3: add entry where point is witness, scalar is constant - expected += (element * scalar); - points.emplace_back(cycle_group_ct::from_witness(&builder, element)); - scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + // 3: add entry where point is witness, scalar is constant + expected += (element * scalar); + points.emplace_back(cycle_group_ct::from_witness(&builder, element)); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); - // 4: add entry where point is constant, scalar is constant - expected += (element * scalar); - points.emplace_back(cycle_group_ct(element)); - scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); - } + // 4: add entry where point is constant, scalar is constant + expected += (element * scalar); + points.emplace_back(cycle_group_ct(element)); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + } - // Here and in the following cases assign different tags to points and scalars and get the union of them back - const auto expected_tag = assign_and_merge_tags(points, scalars); + // Here and in the following cases assign different tags to points and scalars and get the union of them back + const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_EQ(result.get_value(), AffineElement(expected)); - // The tag should the union of all tags - EXPECT_EQ(result.get_origin_tag(), expected_tag); - } + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_EQ(result.get_value(), AffineElement(expected)); + // The tag should the union of all tags + EXPECT_EQ(result.get_origin_tag(), expected_tag); + + bool check_result = CircuitChecker::check(builder); + EXPECT_EQ(check_result, true); +} + +TYPED_TEST(CycleGroupTest, TestBatchMulProducesInfinity) +{ + STDLIB_TYPE_ALIASES; + auto builder = Builder(); // case 2, MSM that produces point at infinity - { - std::vector points; - std::vector scalars; + std::vector points; + std::vector scalars; - auto element = TestFixture::generators[0]; - typename Group::Fr scalar = Group::Fr::random_element(&engine); - points.emplace_back(cycle_group_ct::from_witness(&builder, element)); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + auto element = TestFixture::generators[0]; + typename Group::Fr scalar = Group::Fr::random_element(&engine); + points.emplace_back(cycle_group_ct::from_witness(&builder, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - points.emplace_back(cycle_group_ct::from_witness(&builder, element)); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, -scalar)); + points.emplace_back(cycle_group_ct::from_witness(&builder, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, -scalar)); - const auto expected_tag = assign_and_merge_tags(points, scalars); + const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_TRUE(result.is_point_at_infinity().get_value()); + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_TRUE(result.is_point_at_infinity().get_value()); - EXPECT_EQ(result.get_origin_tag(), expected_tag); - } + EXPECT_EQ(result.get_origin_tag(), expected_tag); + + bool check_result = CircuitChecker::check(builder); + EXPECT_EQ(check_result, true); +} + +TYPED_TEST(CycleGroupTest, TestBatchMulMultiplyByZero) +{ + STDLIB_TYPE_ALIASES; + auto builder = Builder(); // case 3. Multiply by zero - { - std::vector points; - std::vector scalars; + std::vector points; + std::vector scalars; - auto element = TestFixture::generators[0]; - typename Group::Fr scalar = 0; - points.emplace_back(cycle_group_ct::from_witness(&builder, element)); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + auto element = TestFixture::generators[0]; + typename Group::Fr scalar = 0; + points.emplace_back(cycle_group_ct::from_witness(&builder, element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_TRUE(result.is_point_at_infinity().get_value()); - EXPECT_EQ(result.get_origin_tag(), expected_tag); - } + const auto expected_tag = assign_and_merge_tags(points, scalars); + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_TRUE(result.is_point_at_infinity().get_value()); + EXPECT_EQ(result.get_origin_tag(), expected_tag); + + bool check_result = CircuitChecker::check(builder); + EXPECT_EQ(check_result, true); +} + +TYPED_TEST(CycleGroupTest, TestBatchMulInputsAreInfinity) +{ + STDLIB_TYPE_ALIASES; + auto builder = Builder(); // case 4. Inputs are points at infinity + std::vector points; + std::vector scalars; + + auto element = TestFixture::generators[0]; + typename Group::Fr scalar = Group::Fr::random_element(&engine); + + // is_infinity = witness + { + cycle_group_ct point = cycle_group_ct::from_witness(&builder, element); + point.set_point_at_infinity(witness_ct(&builder, true)); + points.emplace_back(point); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + } + // is_infinity = constant { - std::vector points; - std::vector scalars; + cycle_group_ct point = cycle_group_ct::from_witness(&builder, element); + point.set_point_at_infinity(true); + points.emplace_back(point); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + } - auto element = TestFixture::generators[0]; - typename Group::Fr scalar = Group::Fr::random_element(&engine); + const auto expected_tag = assign_and_merge_tags(points, scalars); + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_TRUE(result.is_point_at_infinity().get_value()); + EXPECT_EQ(result.get_origin_tag(), expected_tag); - // is_infinity = witness - { - cycle_group_ct point = cycle_group_ct::from_witness(&builder, element); - point.set_point_at_infinity(witness_ct(&builder, true)); - points.emplace_back(point); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - } - // is_infinity = constant - { - cycle_group_ct point = cycle_group_ct::from_witness(&builder, element); - point.set_point_at_infinity(true); - points.emplace_back(point); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - } + bool check_result = CircuitChecker::check(builder); + EXPECT_EQ(check_result, true); +} - const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_TRUE(result.is_point_at_infinity().get_value()); - EXPECT_EQ(result.get_origin_tag(), expected_tag); - } +TYPED_TEST(CycleGroupTest, TestBatchMulFixedBaseInLookupTable) +{ + STDLIB_TYPE_ALIASES; + auto builder = Builder(); + const size_t num_muls = 1; // case 5, fixed-base MSM with inputs that are combinations of constant and witnesses (group elements are in // lookup table) - { - std::vector points; - std::vector scalars; - std::vector scalars_native; - Element expected = Group::point_at_infinity; - for (size_t i = 0; i < num_muls; ++i) { - auto element = plookup::fixed_base::table::lhs_generator_point(); - typename Group::Fr scalar = Group::Fr::random_element(&engine); - - // 1: add entry where point is constant, scalar is witness - expected += (element * scalar); - points.emplace_back(element); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - scalars_native.emplace_back(uint256_t(scalar)); - - // 2: add entry where point is constant, scalar is constant - element = plookup::fixed_base::table::rhs_generator_point(); - expected += (element * scalar); - points.emplace_back(element); - scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); - scalars_native.emplace_back(uint256_t(scalar)); - } - const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_EQ(result.get_value(), AffineElement(expected)); - EXPECT_EQ(result.get_value(), crypto::pedersen_commitment::commit_native(scalars_native)); - EXPECT_EQ(result.get_origin_tag(), expected_tag); + std::vector points; + std::vector scalars; + std::vector scalars_native; + Element expected = Group::point_at_infinity; + for (size_t i = 0; i < num_muls; ++i) { + auto element = plookup::fixed_base::table::lhs_generator_point(); + typename Group::Fr scalar = Group::Fr::random_element(&engine); + + // 1: add entry where point is constant, scalar is witness + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + scalars_native.emplace_back(uint256_t(scalar)); + + // 2: add entry where point is constant, scalar is constant + element = plookup::fixed_base::table::rhs_generator_point(); + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + scalars_native.emplace_back(uint256_t(scalar)); } + const auto expected_tag = assign_and_merge_tags(points, scalars); + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_EQ(result.get_value(), AffineElement(expected)); + EXPECT_EQ(result.get_value(), crypto::pedersen_commitment::commit_native(scalars_native)); + EXPECT_EQ(result.get_origin_tag(), expected_tag); + bool check_result = CircuitChecker::check(builder); + EXPECT_EQ(check_result, true); +} + +TYPED_TEST(CycleGroupTest, TestBatchMulFixedBaseSomeInLookupTable) +{ + STDLIB_TYPE_ALIASES; + auto builder = Builder(); + + const size_t num_muls = 1; // case 6, fixed-base MSM with inputs that are combinations of constant and witnesses (some group elements are // in lookup table) - { - std::vector points; - std::vector scalars; - std::vector scalars_native; - Element expected = Group::point_at_infinity; - for (size_t i = 0; i < num_muls; ++i) { - auto element = plookup::fixed_base::table::lhs_generator_point(); - typename Group::Fr scalar = Group::Fr::random_element(&engine); - - // 1: add entry where point is constant, scalar is witness - expected += (element * scalar); - points.emplace_back(element); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - scalars_native.emplace_back(scalar); - - // 2: add entry where point is constant, scalar is constant - element = plookup::fixed_base::table::rhs_generator_point(); - expected += (element * scalar); - points.emplace_back(element); - scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); - scalars_native.emplace_back(scalar); - - // // 3: add entry where point is constant, scalar is witness - scalar = Group::Fr::random_element(&engine); - element = Group::one * Group::Fr::random_element(&engine); - expected += (element * scalar); - points.emplace_back(element); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - scalars_native.emplace_back(scalar); - } - const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_EQ(result.get_value(), AffineElement(expected)); - EXPECT_EQ(result.get_origin_tag(), expected_tag); + std::vector points; + std::vector scalars; + std::vector scalars_native; + Element expected = Group::point_at_infinity; + for (size_t i = 0; i < num_muls; ++i) { + auto element = plookup::fixed_base::table::lhs_generator_point(); + typename Group::Fr scalar = Group::Fr::random_element(&engine); + + // 1: add entry where point is constant, scalar is witness + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + scalars_native.emplace_back(scalar); + + // 2: add entry where point is constant, scalar is constant + element = plookup::fixed_base::table::rhs_generator_point(); + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); + scalars_native.emplace_back(scalar); + + // 3: add entry where point is constant, scalar is witness + scalar = Group::Fr::random_element(&engine); + element = Group::one * Group::Fr::random_element(&engine); + expected += (element * scalar); + points.emplace_back(element); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + scalars_native.emplace_back(scalar); } + const auto expected_tag = assign_and_merge_tags(points, scalars); + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_EQ(result.get_value(), AffineElement(expected)); + EXPECT_EQ(result.get_origin_tag(), expected_tag); + bool check_result = CircuitChecker::check(builder); + EXPECT_EQ(check_result, true); +} + +TYPED_TEST(CycleGroupTest, TestBatchMulFixedBaseZeroScalars) +{ + STDLIB_TYPE_ALIASES; + auto builder = Builder(); + + const size_t num_muls = 1; // case 7, Fixed-base MSM where input scalars are 0 - { - std::vector points; - std::vector scalars; + std::vector points; + std::vector scalars; - for (size_t i = 0; i < num_muls; ++i) { - auto element = plookup::fixed_base::table::lhs_generator_point(); - typename Group::Fr scalar = 0; + for (size_t i = 0; i < num_muls; ++i) { + auto element = plookup::fixed_base::table::lhs_generator_point(); + typename Group::Fr scalar = 0; - // 1: add entry where point is constant, scalar is witness - points.emplace_back((element)); - scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); + // 1: add entry where point is constant, scalar is witness + points.emplace_back((element)); + scalars.emplace_back(cycle_group_ct::cycle_scalar::from_witness(&builder, scalar)); - // // 2: add entry where point is constant, scalar is constant - points.emplace_back((element)); - scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); - } - const auto expected_tag = assign_and_merge_tags(points, scalars); - auto result = cycle_group_ct::batch_mul(points, scalars); - EXPECT_EQ(result.is_point_at_infinity().get_value(), true); - EXPECT_EQ(result.get_origin_tag(), expected_tag); + // 2: add entry where point is constant, scalar is constant + points.emplace_back((element)); + scalars.emplace_back(typename cycle_group_ct::cycle_scalar(scalar)); } + const auto expected_tag = assign_and_merge_tags(points, scalars); + auto result = cycle_group_ct::batch_mul(points, scalars); + EXPECT_EQ(result.is_point_at_infinity().get_value(), true); + EXPECT_EQ(result.get_origin_tag(), expected_tag); bool check_result = CircuitChecker::check(builder); EXPECT_EQ(check_result, true); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_lookup_table.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_lookup_table.cpp index e25c233f3cfa..2b00073f282b 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_lookup_table.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_lookup_table.cpp @@ -30,11 +30,10 @@ std::vector::Element> straus_lookup_table< size_t table_bits) { const size_t table_size = 1UL << table_bits; - Element base = base_point.is_point_at_infinity() ? Group::one : base_point; std::vector hints; hints.emplace_back(offset_generator); for (size_t i = 1; i < table_size; ++i) { - hints.emplace_back(hints[i - 1] + base); + hints.emplace_back(hints[i - 1] + base_point); } return hints; } @@ -81,6 +80,10 @@ straus_lookup_table::straus_lookup_table(Builder* context, field_t modded_x = field_t::conditional_assign(base_point.is_point_at_infinity(), fallback_point.x, base_point.x); field_t modded_y = field_t::conditional_assign(base_point.is_point_at_infinity(), fallback_point.y, base_point.y); cycle_group modded_base_point(modded_x, modded_y, false); + // We assume that the native hints (if present) do not account for the point at infinity edge case in the same way + // as above (i.e. replacing with "one") so we avoid using any provided hints in this case. (N.B. No efficiency is + // lost here since native addition with the point at infinity is nearly free). + const bool hint_available = hints.has_value() && !base_point.is_point_at_infinity().get_value(); // if the input point is constant, it is cheaper to fix the point as a witness and then derive the table, than it is // to derive the table and fix its witnesses to be constant! (due to group additions = 1 gate, and fixing x/y coords @@ -90,7 +93,7 @@ straus_lookup_table::straus_lookup_table(Builder* context, point_table[0] = cycle_group::from_constant_witness(_context, offset_generator.get_value()); for (size_t i = 1; i < table_size; ++i) { std::optional hint = - hints.has_value() ? std::optional(hints.value()[i - 1]) : std::nullopt; + hint_available ? std::optional(hints.value()[i - 1]) : std::nullopt; point_table[i] = point_table[i - 1].unconditional_add(modded_base_point, hint); } } else { @@ -98,7 +101,7 @@ straus_lookup_table::straus_lookup_table(Builder* context, // ensure all of the ecc add gates are lined up so that we can pay 1 gate per add and not 2 for (size_t i = 1; i < table_size; ++i) { std::optional hint = - hints.has_value() ? std::optional(hints.value()[i - 1]) : std::nullopt; + hint_available ? std::optional(hints.value()[i - 1]) : std::nullopt; x_coordinate_checks.emplace_back(point_table[i - 1].x, modded_base_point.x); point_table[i] = point_table[i - 1].unconditional_add(modded_base_point, hint); } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.cpp index 0db07b8d4d09..925c84d51a3a 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.cpp @@ -127,11 +127,9 @@ straus_scalar_slice::straus_scalar_slice(Builder* context, * @param index * @return field_t */ -template std::optional> straus_scalar_slice::read(size_t index) +template field_t straus_scalar_slice::read(size_t index) { - if (index >= slices.size()) { - return std::nullopt; - } + BB_ASSERT_LT(index, slices.size(), "Straus scalar slice index out of bounds!"); return slices[index]; } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.hpp index 58d1b5a5c246..62895b5cba19 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/group/straus_scalar_slice.hpp @@ -26,10 +26,10 @@ template class straus_scalar_slice { using field_t = stdlib::field_t; straus_scalar_slice(Builder* context, const cycle_scalar& scalars, size_t table_bits); - std::optional read(size_t index); + field_t read(size_t index); size_t _table_bits; std::vector slices; std::vector slices_native; }; -} // namespace bb::stdlib \ No newline at end of file +} // namespace bb::stdlib diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.cpp index f0bcd7968164..91577fb0b04a 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.cpp @@ -127,6 +127,10 @@ safe_uint_t safe_uint_t::divide( safe_uint_t remainder( remainder_field, remainder_bit_size, format("divide method remainder: ", description)); + const auto merged_tag = OriginTag(get_origin_tag(), other.get_origin_tag()); + quotient.set_origin_tag(merged_tag); + remainder.set_origin_tag(merged_tag); + // This line implicitly checks we are not overflowing safe_uint_t int_val = quotient * other + remainder; @@ -138,7 +142,6 @@ safe_uint_t safe_uint_t::divide( this->assert_equal(int_val, "divide method quotient and/or remainder incorrect"); - quotient.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag())); return quotient; } @@ -161,6 +164,10 @@ template safe_uint_t safe_uint_t::operator/ safe_uint_t remainder( remainder_field, (size_t)(other.current_max.get_msb() + 1), format("/ operator remainder")); + const auto merged_tag = OriginTag(get_origin_tag(), other.get_origin_tag()); + quotient.set_origin_tag(merged_tag); + remainder.set_origin_tag(merged_tag); + // This line implicitly checks we are not overflowing safe_uint_t int_val = quotient * other + remainder; @@ -172,7 +179,6 @@ template safe_uint_t safe_uint_t::operator/ this->assert_equal(int_val, "/ operator quotient and/or remainder incorrect"); - quotient.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag())); return quotient; } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.test.cpp index 195a1365913e..a6d3839fc37d 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/safe_uint/safe_uint.test.cpp @@ -437,6 +437,8 @@ TYPED_TEST(SafeUintTest, TestDivideMethod) field_ct a1(witness_ct(&builder, 2)); field_ct b1(witness_ct(&builder, 9)); + a1.unset_free_witness_tag(); + b1.unset_free_witness_tag(); suint_ct c1(a1, 2); c1.set_origin_tag(submitted_value_origin_tag); suint_ct d1(b1, 4); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/transcript/transcript.hpp b/barretenberg/cpp/src/barretenberg/stdlib/transcript/transcript.hpp index 2326ceab03e4..01e837830bcc 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/transcript/transcript.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/transcript/transcript.hpp @@ -22,7 +22,6 @@ template struct StdlibTranscriptParams { { ASSERT(!data.empty()); - return stdlib::poseidon2::hash(data); } /** diff --git a/barretenberg/cpp/src/barretenberg/transcript/transcript.hpp b/barretenberg/cpp/src/barretenberg/transcript/transcript.hpp index 5c95cedc9084..368c9df2643f 100644 --- a/barretenberg/cpp/src/barretenberg/transcript/transcript.hpp +++ b/barretenberg/cpp/src/barretenberg/transcript/transcript.hpp @@ -385,6 +385,15 @@ template class BaseTranscript { manifest.add_challenge(round_number, labels...); } + // In case the transcript is used for recursive verification, we need to sanitize current round data so we don't + // get an origin tag violation inside the hasher. We are doing this to ensure that the free witness tagged + // elements that are sent to the transcript and are assigned tags externally, don't trigger the origin tag + // security mechanism while we are hashing them + if constexpr (in_circuit) { + for (auto& element : current_round_data) { + element.unset_free_witness_tag(); + } + } // Compute the new challenge buffer from which we derive the challenges. // Create challenges from Frs. @@ -500,6 +509,13 @@ template class BaseTranscript { */ DataType hash_independent_buffer() { + // In case the transcript is used for recursive verification, we need to sanitize current round data so we don't + // get an origin tag violation inside the hasher + if constexpr (in_circuit) { + for (auto& element : independent_hash_buffer) { + element.unset_free_witness_tag(); + } + } DataType buffer_hash = TranscriptParams::hash(independent_hash_buffer); independent_hash_buffer.clear(); return buffer_hash; @@ -606,19 +622,6 @@ template class BaseTranscript { BB_ASSERT_LTE(num_frs_read + element_size, proof_data.size()); auto element_frs = std::span{ proof_data }.subspan(num_frs_read, element_size); - num_frs_read += element_size; - - BaseTranscript::add_element_frs_to_hash_buffer(label, element_frs); - - auto element = TranscriptParams::template deserialize(element_frs); - DEBUG_LOG(label, element); - -#ifdef LOG_INTERACTIONS - if constexpr (Loggable) { - info("received: ", label, ": ", element); - } -#endif - // In case the transcript is used for recursive verification, we can track proper Fiat-Shamir usage if constexpr (in_circuit) { // The verifier is receiving data from the prover. If before this we were in the challenge generation phase, @@ -627,16 +630,38 @@ template class BaseTranscript { reception_phase = true; round_index++; } - // If the element is iterable, then we need to assign origin tags to all the elements + // Assign an origin tag to the elements going into the hash buffer + const auto element_origin_tag = OriginTag(transcript_index, round_index, /*is_submitted=*/true); + for (auto& subelement : element_frs) { + subelement.set_origin_tag(element_origin_tag); + } + } + num_frs_read += element_size; + + BaseTranscript::add_element_frs_to_hash_buffer(label, element_frs); + + auto element = TranscriptParams::template deserialize(element_frs); + DEBUG_LOG(label, element); + + // Ensure that the element got assigned an origin tag + if constexpr (in_circuit) { + const auto element_origin_tag = OriginTag(transcript_index, round_index, /*is_submitted=*/true); + // If the element is iterable, then we need to check origin tags to all the elements if constexpr (is_iterable_v) { for (auto& subelement : element) { - subelement.set_origin_tag(OriginTag(transcript_index, round_index, /*is_submitted=*/true)); + ASSERT(subelement.get_origin_tag() == element_origin_tag); } } else { - // If the element is not iterable, then we need to assign an origin tag to the element - element.set_origin_tag(OriginTag(transcript_index, round_index, /*is_submitted=*/true)); + // If the element is not iterable, then we need to check an origin tag of the element + ASSERT(element.get_origin_tag() == element_origin_tag); } } +#ifdef LOG_INTERACTIONS + if constexpr (Loggable) { + info("received: ", label, ": ", element); + } +#endif + return element; } diff --git a/noir-projects/noir-protocol-circuits/bootstrap.sh b/noir-projects/noir-protocol-circuits/bootstrap.sh index 119cb9081082..9b7e67c529c0 100755 --- a/noir-projects/noir-protocol-circuits/bootstrap.sh +++ b/noir-projects/noir-protocol-circuits/bootstrap.sh @@ -28,11 +28,11 @@ export circuits_hash=$(hash_str "$NOIR_HASH" $(cache_content_hash "^noir-project # Circuits matching these patterns we have client-ivc keys computed, rather than ultra-honk. readarray -t ivc_patterns < <(jq -r '.[]' "../client_ivc_circuits.json") -readarray -t ivc_tail_patterns < <(jq -r '.[]' "../client_ivc_tail_circuits.json") +ivc_hiding_pattern=("hiding") readarray -t rollup_honk_patterns < <(jq -r '.[]' "../rollup_honk_circuits.json") # Convert to regex string here and export for use in exported functions. export ivc_regex=$(IFS="|"; echo "${ivc_patterns[*]}") -export ivc_tail_regex=$(IFS="|"; echo "${ivc_tail_patterns[*]}") +export hiding_kernel_regex=$(IFS="|"; echo "${ivc_hiding_pattern[*]}") export rollup_honk_regex=$(IFS="|"; echo "${rollup_honk_patterns[*]}") function on_exit { @@ -90,7 +90,7 @@ function compile { local outdir=$(mktemp -d) trap "rm -rf $outdir" EXIT function write_vk { - if echo "$name" | grep -qE "${ivc_tail_regex}"; then + if echo "$name" | grep -qE "${hiding_kernel_regex}"; then # We still need the standalone IVC vk. We also create the final IVC vk from the tail (specifically, the number of public inputs is used from it). denoise "$BB write_vk --scheme client_ivc --verifier_type standalone_hiding -b - -o $outdir" elif echo "$name" | grep -qE "${ivc_regex}"; then @@ -130,7 +130,7 @@ function compile { echo_stderr "Root rollup verifier at: $verifier_path (${SECONDS}s)" # Include the verifier path if we create it. cache_upload vk-$hash.tar.gz $key_path $verifier_path &> /dev/null - elif echo "$name" | grep -qE "${ivc_tail_regex}"; then + elif echo "$name" | grep -qE "${hiding_kernel_regex}"; then # If we are a tail kernel circuit, we also need to generate the ivc vk. SECONDS=0 local ivc_vk_path="$key_dir/${name}.ivc.vk"