diff --git a/Cargo.lock b/Cargo.lock index 5ffacb584b..0a3a548194 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6492,6 +6492,7 @@ dependencies = [ "hashbrown 0.14.5", "hex", "indicatif", + "itertools 0.13.0", "log", "num-bigint 0.4.6", "p3-baby-bear", diff --git a/crates/core/machine/src/io.rs b/crates/core/machine/src/io.rs index ee6c652f73..6f5475cf22 100644 --- a/crates/core/machine/src/io.rs +++ b/crates/core/machine/src/io.rs @@ -113,13 +113,20 @@ impl SP1PublicValues { self.buffer.write_slice(slice); } + /// Hash the public values. + pub fn hash(&self) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(self.buffer.data.as_slice()); + hasher.finalize().to_vec() + } + /// Hash the public values, mask the top 3 bits and return a BigUint. Matches the implementation /// of `hashPublicValues` in the Solidity verifier. /// /// ```solidity /// sha256(publicValues) & bytes32(uint256((1 << 253) - 1)); /// ``` - pub fn hash(&self) -> BigUint { + pub fn hash_bn254(&self) -> BigUint { // Hash the public values. let mut hasher = Sha256::new(); hasher.update(self.buffer.data.as_slice()); @@ -188,7 +195,7 @@ mod tests { let mut public_values = SP1PublicValues::new(); public_values.write_slice(&test_bytes); - let hash = public_values.hash(); + let hash = public_values.hash_bn254(); let expected_hash = "1ce987d0a7fcc2636fe87e69295ba12b1cc46c256b369ae7401c51b805ee91bd"; let expected_hash_biguint = BigUint::from_bytes_be(&hex::decode(expected_hash).unwrap()); diff --git a/crates/prover/src/verify.rs b/crates/prover/src/verify.rs index f9c0817ca3..99cf818d01 100644 --- a/crates/prover/src/verify.rs +++ b/crates/prover/src/verify.rs @@ -441,7 +441,7 @@ pub fn verify_plonk_bn254_public_inputs( return Err(PlonkVerificationError::InvalidVerificationKey.into()); } - let public_values_hash = public_values.hash(); + let public_values_hash = public_values.hash_bn254(); if public_values_hash != expected_public_values_hash { return Err(PlonkVerificationError::InvalidPublicValues.into()); } @@ -464,7 +464,7 @@ pub fn verify_groth16_bn254_public_inputs( return Err(Groth16VerificationError::InvalidVerificationKey.into()); } - let public_values_hash = public_values.hash(); + let public_values_hash = public_values.hash_bn254(); if public_values_hash != expected_public_values_hash { return Err(Groth16VerificationError::InvalidPublicValues.into()); } diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 1d296f94bf..ada8973cea 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -53,6 +53,7 @@ sysinfo = "0.30.13" sp1-core-executor = { workspace = true } sp1-stark = { workspace = true } getrandom = { version = "0.2.15", features = ["custom", "js"] } +itertools = "0.13.0" [features] default = ["network"] diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index da73835c01..e4721a0efa 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -39,7 +39,9 @@ use {std::future::Future, tokio::task::block_in_place}; pub use provers::{CpuProver, MockProver, Prover}; pub use sp1_core_executor::{ExecutionReport, HookEnv, SP1Context, SP1ContextBuilder}; -pub use sp1_core_machine::{io::SP1Stdin, riscv::cost::CostEstimator, SP1_CIRCUIT_VERSION}; +pub use sp1_core_machine::{ + io::SP1PublicValues, io::SP1Stdin, riscv::cost::CostEstimator, SP1_CIRCUIT_VERSION, +}; pub use sp1_prover::{ CoreSC, HashableKey, InnerSC, OuterSC, PlonkBn254Proof, SP1Prover, SP1ProvingKey, SP1VerifyingKey, @@ -291,6 +293,8 @@ pub fn block_on(fut: impl Future) -> T { #[cfg(test)] mod tests { + use sp1_prover::init::SP1PublicValues; + use crate::{utils, CostEstimator, ProverClient, SP1Stdin}; #[test] @@ -327,6 +331,48 @@ mod tests { client.execute(elf, stdin).max_cycles(1).run().unwrap(); } + #[test] + fn test_e2e_core() { + utils::setup_logger(); + let client = ProverClient::local(); + let elf = + include_bytes!("../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); + let (pk, vk) = client.setup(elf); + let mut stdin = SP1Stdin::new(); + stdin.write(&10usize); + + // Generate proof & verify. + let mut proof = client.prove(&pk, stdin).run().unwrap(); + client.verify(&proof, &vk).unwrap(); + + // Test invalid public values. + proof.public_values = SP1PublicValues::from(&[255, 4, 84]); + if client.verify(&proof, &vk).is_ok() { + panic!("verified proof with invalid public values") + } + } + + #[test] + fn test_e2e_compressed() { + utils::setup_logger(); + let client = ProverClient::local(); + let elf = + include_bytes!("../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); + let (pk, vk) = client.setup(elf); + let mut stdin = SP1Stdin::new(); + stdin.write(&10usize); + + // Generate proof & verify. + let mut proof = client.prove(&pk, stdin).compressed().run().unwrap(); + client.verify(&proof, &vk).unwrap(); + + // Test invalid public values. + proof.public_values = SP1PublicValues::from(&[255, 4, 84]); + if client.verify(&proof, &vk).is_ok() { + panic!("verified proof with invalid public values") + } + } + #[test] fn test_e2e_prove_plonk() { utils::setup_logger(); @@ -336,8 +382,16 @@ mod tests { let (pk, vk) = client.setup(elf); let mut stdin = SP1Stdin::new(); stdin.write(&10usize); - let proof = client.prove(&pk, stdin).plonk().run().unwrap(); + + // Generate proof & verify. + let mut proof = client.prove(&pk, stdin).plonk().run().unwrap(); client.verify(&proof, &vk).unwrap(); + + // Test invalid public values. + proof.public_values = SP1PublicValues::from(&[255, 4, 84]); + if client.verify(&proof, &vk).is_ok() { + panic!("verified proof with invalid public values") + } } #[test] diff --git a/crates/sdk/src/provers/mock.rs b/crates/sdk/src/provers/mock.rs index d4d2467045..b774cd005b 100644 --- a/crates/sdk/src/provers/mock.rs +++ b/crates/sdk/src/provers/mock.rs @@ -97,7 +97,7 @@ impl Prover for MockProver { proof: SP1Proof::Plonk(PlonkBn254Proof { public_inputs: [ pk.vk.hash_bn254().as_canonical_biguint().to_string(), - public_values.hash().to_string(), + public_values.hash_bn254().to_string(), ], encoded_proof: "".to_string(), raw_proof: "".to_string(), @@ -114,7 +114,7 @@ impl Prover for MockProver { proof: SP1Proof::Groth16(Groth16Bn254Proof { public_inputs: [ pk.vk.hash_bn254().as_canonical_biguint().to_string(), - public_values.hash().to_string(), + public_values.hash_bn254().to_string(), ], encoded_proof: "".to_string(), raw_proof: "".to_string(), diff --git a/crates/sdk/src/provers/mod.rs b/crates/sdk/src/provers/mod.rs index 34152d8460..ecc62ff4ce 100644 --- a/crates/sdk/src/provers/mod.rs +++ b/crates/sdk/src/provers/mod.rs @@ -8,6 +8,11 @@ pub use cpu::CpuProver; pub use cuda::CudaProver; pub use mock::MockProver; +use itertools::Itertools; +use p3_field::PrimeField32; +use std::borrow::Borrow; +use std::time::Duration; + use anyhow::Result; use sp1_core_executor::SP1Context; use sp1_core_machine::{io::SP1Stdin, SP1_CIRCUIT_VERSION}; @@ -15,8 +20,7 @@ use sp1_prover::{ components::SP1ProverComponents, CoreSC, InnerSC, SP1CoreProofData, SP1Prover, SP1ProvingKey, SP1ReduceProof, SP1VerifyingKey, }; -use sp1_stark::{MachineVerificationError, SP1ProverOpts}; -use std::time::Duration; +use sp1_stark::{air::PublicValues, MachineVerificationError, SP1ProverOpts, Word}; use strum_macros::EnumString; use thiserror::Error; @@ -44,6 +48,8 @@ pub struct ProofOpts { #[derive(Error, Debug)] pub enum SP1VerificationError { + #[error("Invalid public values")] + InvalidPublicValues, #[error("Version mismatch")] VersionMismatch(String), #[error("Core machine verification error: {0}")] @@ -90,14 +96,53 @@ pub trait Prover: Send + Sync { return Err(SP1VerificationError::VersionMismatch(bundle.sp1_version.clone())); } match &bundle.proof { - SP1Proof::Core(proof) => self - .sp1_prover() - .verify(&SP1CoreProofData(proof.clone()), vkey) - .map_err(SP1VerificationError::Core), - SP1Proof::Compressed(proof) => self - .sp1_prover() - .verify_compressed(&SP1ReduceProof { proof: proof.clone() }, vkey) - .map_err(SP1VerificationError::Recursion), + SP1Proof::Core(proof) => { + let public_values: &PublicValues, _> = + proof.last().unwrap().public_values.as_slice().borrow(); + + // Get the commited value digest bytes. + let commited_value_digest_bytes = public_values + .committed_value_digest + .iter() + .flat_map(|w| w.0.iter().map(|x| x.as_canonical_u32() as u8)) + .collect_vec(); + + // Make sure the commited value digest matches the public values hash. + for (a, b) in commited_value_digest_bytes.iter().zip_eq(bundle.public_values.hash()) + { + if *a != b { + return Err(SP1VerificationError::InvalidPublicValues); + } + } + + // Verify the core proof. + self.sp1_prover() + .verify(&SP1CoreProofData(proof.clone()), vkey) + .map_err(SP1VerificationError::Core) + } + SP1Proof::Compressed(proof) => { + let public_values: &PublicValues, _> = + proof.public_values.as_slice().borrow(); + + // Get the commited value digest bytes. + let commited_value_digest_bytes = public_values + .committed_value_digest + .iter() + .flat_map(|w| w.0.iter().map(|x| x.as_canonical_u32() as u8)) + .collect_vec(); + + // Make sure the commited value digest matches the public values hash. + for (a, b) in commited_value_digest_bytes.iter().zip_eq(bundle.public_values.hash()) + { + if *a != b { + return Err(SP1VerificationError::InvalidPublicValues); + } + } + + self.sp1_prover() + .verify_compressed(&SP1ReduceProof { proof: proof.clone() }, vkey) + .map_err(SP1VerificationError::Recursion) + } SP1Proof::Plonk(proof) => self .sp1_prover() .verify_plonk_bn254(