From b1eaab97a36f3b17129af8bf91cb7e33ceb02ae9 Mon Sep 17 00:00:00 2001 From: Adam Tucker Date: Fri, 1 May 2026 10:17:13 -0600 Subject: [PATCH] Validate PIR Merkle proofs --- Cargo.lock | 2 +- zcash_voting/Cargo.toml | 2 +- zcash_voting/src/storage/operations.rs | 42 +++++++++++-- zcash_voting/src/zkp1.rs | 86 ++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de2cc9db..b1261260 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3731,7 +3731,7 @@ dependencies = [ [[package]] name = "zcash_voting" -version = "0.2.2" +version = "0.2.3" dependencies = [ "blake2b_simd", "ff", diff --git a/zcash_voting/Cargo.toml b/zcash_voting/Cargo.toml index f2012b25..c731fe2e 100644 --- a/zcash_voting/Cargo.toml +++ b/zcash_voting/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zcash_voting" -version = "0.2.2" +version = "0.2.3" edition = "2021" description = "Client-side library for Zcash shielded voting: ZKP delegation and vote-commitment proofs (Halo 2), ElGamal encryption, governance PCZT construction, Merkle witness generation, and SQLite round-state persistence." license = "MIT OR Apache-2.0" diff --git a/zcash_voting/src/storage/operations.rs b/zcash_voting/src/storage/operations.rs index 65663fc5..2c9e981a 100644 --- a/zcash_voting/src/storage/operations.rs +++ b/zcash_voting/src/storage/operations.rs @@ -345,14 +345,48 @@ impl VotingDb { }) .collect::, _>>()?; - let imt_proofs: Vec<_> = pir_client + let expected_nf_imt_root = { + let root_bytes: [u8; 32] = + params.nullifier_imt_root.as_slice().try_into().map_err(|_| { + VotingError::Internal { + message: format!( + "nullifier_imt_root must be 32 bytes, got {}", + params.nullifier_imt_root.len() + ), + } + })?; + Option::from(pasta_curves::pallas::Base::from_repr(root_bytes)).ok_or_else(|| { + VotingError::Internal { + message: "nullifier_imt_root is not a valid field element".to_string(), + } + })? + }; + + let raw_imt_proofs = pir_client .fetch_proofs(&nullifiers) .map_err(|e| VotingError::Internal { message: format!("PIR parallel fetch failed: {e}"), - })? + })?; + if raw_imt_proofs.len() != nullifiers.len() { + return Err(VotingError::Internal { + message: format!( + "PIR returned {} proofs for {} nullifiers", + raw_imt_proofs.len(), + nullifiers.len() + ), + }); + } + let imt_proofs: Vec<_> = raw_imt_proofs .into_iter() - .map(crate::zkp1::convert_pir_proof) - .collect(); + .zip(nullifiers.iter().copied()) + .map(|(proof, nullifier)| { + crate::zkp1::validate_and_convert_pir_proof( + proof, + nullifier, + expected_nf_imt_root, + ) + }) + .collect::, _>>()?; let pir_elapsed = pir_start.elapsed(); eprintln!( "[ZKP1] PIR fetch total: {:.2}s for {} proofs", diff --git a/zcash_voting/src/zkp1.rs b/zcash_voting/src/zkp1.rs index 8c8cd85e..6305cc92 100644 --- a/zcash_voting/src/zkp1.rs +++ b/zcash_voting/src/zkp1.rs @@ -87,6 +87,41 @@ pub fn convert_pir_proof(pir: pir_client::ImtProofData) -> ImtProofData { } } +fn base_hex(value: pallas::Base) -> String { + hex::encode(value.to_repr()) +} + +fn validate_pir_proof_raw( + proof: &pir_client::ImtProofData, + nullifier: pallas::Base, + expected_root: pallas::Base, +) -> Result<(), String> { + if !proof.verify(nullifier) { + return Err( + "PIR proof verification failed: Merkle path/root does not authenticate queried nullifier" + .to_string(), + ); + } + if proof.root != expected_root { + return Err(format!( + "PIR proof root mismatch: expected {}, got {}", + base_hex(expected_root), + base_hex(proof.root) + )); + } + Ok(()) +} + +pub(crate) fn validate_and_convert_pir_proof( + proof: pir_client::ImtProofData, + nullifier: pallas::Base, + expected_root: pallas::Base, +) -> Result { + validate_pir_proof_raw(&proof, nullifier, expected_root) + .map_err(|message| VotingError::Internal { message })?; + Ok(convert_pir_proof(proof)) +} + /// IMT provider that wraps pre-fetched proofs for real notes and /// fetches proofs for padded notes on-the-fly via PIR. struct PirImtProvider<'a> { @@ -112,6 +147,7 @@ impl ImtProvider for PirImtProvider<'_> { let pir_proof = client .fetch_proof(nf) .map_err(|e| ImtError(format!("PIR fetch failed: {e}")))?; + validate_pir_proof_raw(&pir_proof, nf, self.root).map_err(ImtError)?; Ok(convert_pir_proof(pir_proof)) } } @@ -662,6 +698,56 @@ mod tests { } } + fn raw_pir_proof(proof: ImtProofData) -> pir_client::ImtProofData { + pir_client::ImtProofData { + root: proof.root, + nf_bounds: proof.nf_bounds, + leaf_pos: proof.leaf_pos, + path: proof.path, + } + } + + #[test] + fn validate_and_convert_pir_proof_accepts_valid_proof() { + let imt = TestImt::new(); + let nf = imt.leaves[0][0] + pallas::Base::one(); + let proof = raw_pir_proof(imt.proof(nf)); + + let converted = validate_and_convert_pir_proof(proof, nf, imt.root).unwrap(); + + assert_eq!(converted.root, imt.root); + } + + #[test] + fn validate_and_convert_pir_proof_rejects_unverified_path() { + let imt = TestImt::new(); + let nf = imt.leaves[0][0] + pallas::Base::one(); + let proof = raw_pir_proof(imt.proof(nf)); + let boundary_value = imt.leaves[0][0]; + + let err = validate_and_convert_pir_proof(proof, boundary_value, imt.root).unwrap_err(); + + assert!( + err.to_string().contains("PIR proof verification failed"), + "unexpected error: {err}" + ); + } + + #[test] + fn validate_and_convert_pir_proof_rejects_wrong_root() { + let imt = TestImt::new(); + let nf = imt.leaves[0][0] + pallas::Base::one(); + let proof = raw_pir_proof(imt.proof(nf)); + let wrong_root = imt.root + pallas::Base::one(); + + let err = validate_and_convert_pir_proof(proof, nf, wrong_root).unwrap_err(); + + assert!( + err.to_string().contains("PIR proof root mismatch"), + "unexpected error: {err}" + ); + } + #[test] fn test_build_and_prove_validation() { let reporter = TestReporter {