Skip to content

Commit

Permalink
fix: core compress pv in sdk (#1501)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Sep 12, 2024
1 parent 26fd15e commit ef09ebb
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions crates/core/machine/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,20 @@ impl SP1PublicValues {
self.buffer.write_slice(slice);
}

/// Hash the public values.
pub fn hash(&self) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(self.buffer.data.as_slice());
hasher.finalize().to_vec()
}

/// Hash the public values, mask the top 3 bits and return a BigUint. Matches the implementation
/// of `hashPublicValues` in the Solidity verifier.
///
/// ```solidity
/// sha256(publicValues) & bytes32(uint256((1 << 253) - 1));
/// ```
pub fn hash(&self) -> BigUint {
pub fn hash_bn254(&self) -> BigUint {
// Hash the public values.
let mut hasher = Sha256::new();
hasher.update(self.buffer.data.as_slice());
Expand Down Expand Up @@ -188,7 +195,7 @@ mod tests {

let mut public_values = SP1PublicValues::new();
public_values.write_slice(&test_bytes);
let hash = public_values.hash();
let hash = public_values.hash_bn254();

let expected_hash = "1ce987d0a7fcc2636fe87e69295ba12b1cc46c256b369ae7401c51b805ee91bd";
let expected_hash_biguint = BigUint::from_bytes_be(&hex::decode(expected_hash).unwrap());
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ pub fn verify_plonk_bn254_public_inputs(
return Err(PlonkVerificationError::InvalidVerificationKey.into());
}

let public_values_hash = public_values.hash();
let public_values_hash = public_values.hash_bn254();
if public_values_hash != expected_public_values_hash {
return Err(PlonkVerificationError::InvalidPublicValues.into());
}
Expand All @@ -464,7 +464,7 @@ pub fn verify_groth16_bn254_public_inputs(
return Err(Groth16VerificationError::InvalidVerificationKey.into());
}

let public_values_hash = public_values.hash();
let public_values_hash = public_values.hash_bn254();
if public_values_hash != expected_public_values_hash {
return Err(Groth16VerificationError::InvalidPublicValues.into());
}
Expand Down
1 change: 1 addition & 0 deletions crates/sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ sysinfo = "0.30.13"
sp1-core-executor = { workspace = true }
sp1-stark = { workspace = true }
getrandom = { version = "0.2.15", features = ["custom", "js"] }
itertools = "0.13.0"

[features]
default = ["network"]
Expand Down
58 changes: 56 additions & 2 deletions crates/sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ use {std::future::Future, tokio::task::block_in_place};
pub use provers::{CpuProver, MockProver, Prover};

pub use sp1_core_executor::{ExecutionReport, HookEnv, SP1Context, SP1ContextBuilder};
pub use sp1_core_machine::{io::SP1Stdin, riscv::cost::CostEstimator, SP1_CIRCUIT_VERSION};
pub use sp1_core_machine::{
io::SP1PublicValues, io::SP1Stdin, riscv::cost::CostEstimator, SP1_CIRCUIT_VERSION,
};
pub use sp1_prover::{
CoreSC, HashableKey, InnerSC, OuterSC, PlonkBn254Proof, SP1Prover, SP1ProvingKey,
SP1VerifyingKey,
Expand Down Expand Up @@ -291,6 +293,8 @@ pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
#[cfg(test)]
mod tests {

use sp1_prover::init::SP1PublicValues;

use crate::{utils, CostEstimator, ProverClient, SP1Stdin};

#[test]
Expand Down Expand Up @@ -327,6 +331,48 @@ mod tests {
client.execute(elf, stdin).max_cycles(1).run().unwrap();
}

#[test]
fn test_e2e_core() {
utils::setup_logger();
let client = ProverClient::local();
let elf =
include_bytes!("../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf");
let (pk, vk) = client.setup(elf);
let mut stdin = SP1Stdin::new();
stdin.write(&10usize);

// Generate proof & verify.
let mut proof = client.prove(&pk, stdin).run().unwrap();
client.verify(&proof, &vk).unwrap();

// Test invalid public values.
proof.public_values = SP1PublicValues::from(&[255, 4, 84]);
if client.verify(&proof, &vk).is_ok() {
panic!("verified proof with invalid public values")
}
}

#[test]
fn test_e2e_compressed() {
utils::setup_logger();
let client = ProverClient::local();
let elf =
include_bytes!("../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf");
let (pk, vk) = client.setup(elf);
let mut stdin = SP1Stdin::new();
stdin.write(&10usize);

// Generate proof & verify.
let mut proof = client.prove(&pk, stdin).compressed().run().unwrap();
client.verify(&proof, &vk).unwrap();

// Test invalid public values.
proof.public_values = SP1PublicValues::from(&[255, 4, 84]);
if client.verify(&proof, &vk).is_ok() {
panic!("verified proof with invalid public values")
}
}

#[test]
fn test_e2e_prove_plonk() {
utils::setup_logger();
Expand All @@ -336,8 +382,16 @@ mod tests {
let (pk, vk) = client.setup(elf);
let mut stdin = SP1Stdin::new();
stdin.write(&10usize);
let proof = client.prove(&pk, stdin).plonk().run().unwrap();

// Generate proof & verify.
let mut proof = client.prove(&pk, stdin).plonk().run().unwrap();
client.verify(&proof, &vk).unwrap();

// Test invalid public values.
proof.public_values = SP1PublicValues::from(&[255, 4, 84]);
if client.verify(&proof, &vk).is_ok() {
panic!("verified proof with invalid public values")
}
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions crates/sdk/src/provers/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl Prover<DefaultProverComponents> for MockProver {
proof: SP1Proof::Plonk(PlonkBn254Proof {
public_inputs: [
pk.vk.hash_bn254().as_canonical_biguint().to_string(),
public_values.hash().to_string(),
public_values.hash_bn254().to_string(),
],
encoded_proof: "".to_string(),
raw_proof: "".to_string(),
Expand All @@ -114,7 +114,7 @@ impl Prover<DefaultProverComponents> for MockProver {
proof: SP1Proof::Groth16(Groth16Bn254Proof {
public_inputs: [
pk.vk.hash_bn254().as_canonical_biguint().to_string(),
public_values.hash().to_string(),
public_values.hash_bn254().to_string(),
],
encoded_proof: "".to_string(),
raw_proof: "".to_string(),
Expand Down
65 changes: 55 additions & 10 deletions crates/sdk/src/provers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ pub use cpu::CpuProver;
pub use cuda::CudaProver;
pub use mock::MockProver;

use itertools::Itertools;
use p3_field::PrimeField32;
use std::borrow::Borrow;
use std::time::Duration;

use anyhow::Result;
use sp1_core_executor::SP1Context;
use sp1_core_machine::{io::SP1Stdin, SP1_CIRCUIT_VERSION};
use sp1_prover::{
components::SP1ProverComponents, CoreSC, InnerSC, SP1CoreProofData, SP1Prover, SP1ProvingKey,
SP1ReduceProof, SP1VerifyingKey,
};
use sp1_stark::{MachineVerificationError, SP1ProverOpts};
use std::time::Duration;
use sp1_stark::{air::PublicValues, MachineVerificationError, SP1ProverOpts, Word};
use strum_macros::EnumString;
use thiserror::Error;

Expand Down Expand Up @@ -44,6 +48,8 @@ pub struct ProofOpts {

#[derive(Error, Debug)]
pub enum SP1VerificationError {
#[error("Invalid public values")]
InvalidPublicValues,
#[error("Version mismatch")]
VersionMismatch(String),
#[error("Core machine verification error: {0}")]
Expand Down Expand Up @@ -90,14 +96,53 @@ pub trait Prover<C: SP1ProverComponents>: Send + Sync {
return Err(SP1VerificationError::VersionMismatch(bundle.sp1_version.clone()));
}
match &bundle.proof {
SP1Proof::Core(proof) => self
.sp1_prover()
.verify(&SP1CoreProofData(proof.clone()), vkey)
.map_err(SP1VerificationError::Core),
SP1Proof::Compressed(proof) => self
.sp1_prover()
.verify_compressed(&SP1ReduceProof { proof: proof.clone() }, vkey)
.map_err(SP1VerificationError::Recursion),
SP1Proof::Core(proof) => {
let public_values: &PublicValues<Word<_>, _> =
proof.last().unwrap().public_values.as_slice().borrow();

// Get the commited value digest bytes.
let commited_value_digest_bytes = public_values
.committed_value_digest
.iter()
.flat_map(|w| w.0.iter().map(|x| x.as_canonical_u32() as u8))
.collect_vec();

// Make sure the commited value digest matches the public values hash.
for (a, b) in commited_value_digest_bytes.iter().zip_eq(bundle.public_values.hash())
{
if *a != b {
return Err(SP1VerificationError::InvalidPublicValues);
}
}

// Verify the core proof.
self.sp1_prover()
.verify(&SP1CoreProofData(proof.clone()), vkey)
.map_err(SP1VerificationError::Core)
}
SP1Proof::Compressed(proof) => {
let public_values: &PublicValues<Word<_>, _> =
proof.public_values.as_slice().borrow();

// Get the commited value digest bytes.
let commited_value_digest_bytes = public_values
.committed_value_digest
.iter()
.flat_map(|w| w.0.iter().map(|x| x.as_canonical_u32() as u8))
.collect_vec();

// Make sure the commited value digest matches the public values hash.
for (a, b) in commited_value_digest_bytes.iter().zip_eq(bundle.public_values.hash())
{
if *a != b {
return Err(SP1VerificationError::InvalidPublicValues);
}
}

self.sp1_prover()
.verify_compressed(&SP1ReduceProof { proof: proof.clone() }, vkey)
.map_err(SP1VerificationError::Recursion)
}
SP1Proof::Plonk(proof) => self
.sp1_prover()
.verify_plonk_bn254(
Expand Down

0 comments on commit ef09ebb

Please sign in to comment.