diff --git a/air/src/air/tests.rs b/air/src/air/tests.rs index 4d16face3..6e4ff458b 100644 --- a/air/src/air/tests.rs +++ b/air/src/air/tests.rs @@ -9,6 +9,7 @@ use super::{ }; use crate::{AuxTraceRandElements, FieldExtension}; use crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin}; +use fri::fri_schedule::FoldingSchedule; use math::{fields::f64::BaseElement, get_power_series, polynom, FieldElement, StarkField}; use utils::collections::{BTreeMap, Vec}; @@ -219,20 +220,22 @@ impl MockAir { column_values: Vec>, trace_length: usize, ) -> Self { + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); let mut result = Self::new( TraceInfo::with_meta(4, trace_length, vec![1]), (), - ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31), + ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule), ); result.periodic_columns = column_values; result } pub fn with_assertions(assertions: Vec>, trace_length: usize) -> Self { + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); let mut result = Self::new( TraceInfo::with_meta(4, trace_length, vec![assertions.len() as u8]), (), - ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31), + ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule), ); result.assertions = assertions; result @@ -282,7 +285,8 @@ pub fn build_context( trace_width: usize, num_assertions: usize, ) -> AirContext { - let options = ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31); + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + let options = ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule); let t_degrees = vec![TransitionConstraintDegree::new(2)]; let trace_info = TraceInfo::new(trace_width, trace_length); AirContext::new(trace_info, t_degrees, num_assertions, options) diff --git a/air/src/options.rs b/air/src/options.rs index d8114ec86..51ac2e045 100644 --- a/air/src/options.rs +++ b/air/src/options.rs @@ -3,7 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use fri::FriOptions; +use fri::{fri_schedule::FoldingSchedule, FriOptions}; use math::{StarkField, ToElements}; use utils::{ collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, @@ -69,6 +69,10 @@ pub enum FieldExtension { /// 4. Grinding factor - higher values increase proof soundness, but also may increase proof /// generation time. More precisely, conjectured proof soundness is bounded by /// `num_queries * log2(blowup_factor) + grinding_factor`. +/// 5. FRI Schedule - The strategy for the FRI reduction process, which can be either a constant +/// folding factor or a dynamic schedule of factors. This setting influences the FRI proof +/// generation process, potentially affecting both the proof generation time and the resultant +/// proof size. /// /// Another important parameter in defining STARK security level, which is not a part of [ProofOptions] /// is the hash function used in the protocol. The soundness of a STARK proof is limited by the @@ -80,8 +84,7 @@ pub struct ProofOptions { blowup_factor: u8, grinding_factor: u8, field_extension: FieldExtension, - fri_folding_factor: u8, - fri_remainder_max_degree: u8, + fri_schedule: FoldingSchedule, } // PROOF OPTIONS IMPLEMENTATION @@ -106,16 +109,16 @@ impl ProofOptions { /// - `num_queries` is zero or greater than 255. /// - `blowup_factor` is smaller than 2, greater than 128, or is not a power of two. /// - `grinding_factor` is greater than 32. - /// - `fri_folding_factor` is not 2, 4, 8, or 16. - /// - `fri_remainder_max_degree` is greater than 255 or is not a power of two minus 1. + /// - `fri_folding_schedule` is not 2, 4, 8, or 16 in case of a constant FRI schedule, or is not + /// a power of two in case of a dynamic FRI schedule. In the case of constant FRI schedule, the + /// max remainder degree must be less than 255 and must be one less than a power of two. #[rustfmt::skip] pub fn new( num_queries: usize, blowup_factor: usize, grinding_factor: u32, field_extension: FieldExtension, - fri_folding_factor: usize, - fri_remainder_max_degree: usize, + fri_folding_schedule: &FoldingSchedule, ) -> ProofOptions { // TODO: return errors instead of panicking assert!(num_queries > 0, "number of queries must be greater than 0"); @@ -127,26 +130,36 @@ impl ProofOptions { assert!(grinding_factor <= MAX_GRINDING_FACTOR, "grinding factor cannot be greater than {MAX_GRINDING_FACTOR}"); - assert!(fri_folding_factor.is_power_of_two(), "FRI folding factor must be a power of 2"); - assert!(fri_folding_factor >= FRI_MIN_FOLDING_FACTOR, "FRI folding factor cannot be smaller than {FRI_MIN_FOLDING_FACTOR}"); - assert!(fri_folding_factor <= FRI_MAX_FOLDING_FACTOR, "FRI folding factor cannot be greater than {FRI_MAX_FOLDING_FACTOR}"); - - assert!( - (fri_remainder_max_degree + 1).is_power_of_two(), - "FRI polynomial remainder degree must be one less than a power of two" - ); - assert!( - fri_remainder_max_degree <= FRI_MAX_REMAINDER_DEGREE, - "FRI polynomial remainder degree cannot be greater than {FRI_MAX_REMAINDER_DEGREE}" - ); + // let fri_folding_factor = fri_folding_schedule.get_factor().unwrap_or(3); + // let fri_remainder_max_degree = fri_folding_schedule.get_max_remainder_degree().unwrap_or(127); + match fri_folding_schedule { + FoldingSchedule::Constant { fri_folding_factor, fri_remainder_max_degree } => { + assert!(fri_folding_factor.is_power_of_two(), "FRI folding factor must be a power of 2"); + assert!(*fri_folding_factor as usize >= FRI_MIN_FOLDING_FACTOR, "FRI folding factor cannot be smaller than {FRI_MIN_FOLDING_FACTOR}"); + assert!(*fri_folding_factor as usize <= FRI_MAX_FOLDING_FACTOR, "FRI folding factor cannot be greater than {FRI_MAX_FOLDING_FACTOR}"); + + assert!( + (fri_remainder_max_degree + 1).is_power_of_two(), + "FRI polynomial remainder degree must be one less than a power of two" + ); + assert!( + *fri_remainder_max_degree as usize <= FRI_MAX_REMAINDER_DEGREE, + "FRI polynomial remainder degree cannot be greater than {FRI_MAX_REMAINDER_DEGREE}" + ); + + }, + FoldingSchedule::Dynamic { schedule} => { + assert!(schedule.iter().all(|factor| factor.is_power_of_two()), "FRI folding factors must be powers of 2"); + assert!(!schedule.is_empty(), "FRI folding schedule cannot be empty"); + }, + } ProofOptions { num_queries: num_queries as u8, blowup_factor: blowup_factor as u8, grinding_factor: grinding_factor as u8, field_extension, - fri_folding_factor: fri_folding_factor as u8, - fri_remainder_max_degree: fri_remainder_max_degree as u8, + fri_schedule: fri_folding_schedule.clone(), } } @@ -202,9 +215,7 @@ impl ProofOptions { /// Returns options for FRI protocol instantiated with parameters from this proof options. pub fn to_fri_options(&self) -> FriOptions { - let folding_factor = self.fri_folding_factor as usize; - let remainder_max_degree = self.fri_remainder_max_degree as usize; - FriOptions::new(self.blowup_factor(), folding_factor, remainder_max_degree) + FriOptions::new(self.blowup_factor(), self.fri_schedule.clone()) } } @@ -212,8 +223,21 @@ impl ToElements for ProofOptions { fn to_elements(&self) -> Vec { // encode field extension and FRI parameters into a single field element let mut buf = self.field_extension as u32; - buf = (buf << 8) | self.fri_folding_factor as u32; - buf = (buf << 8) | self.fri_remainder_max_degree as u32; + match &self.fri_schedule { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree, + } => { + buf = (buf << 8) | *fri_folding_factor as u32; + buf = (buf << 8) | *fri_remainder_max_degree as u32; + } + FoldingSchedule::Dynamic { schedule } => { + buf = (buf << 8) | schedule.len() as u32; + for factor in schedule { + buf = (buf << 8) | *factor as u32; + } + } + } vec![ E::from(buf), @@ -231,8 +255,7 @@ impl Serializable for ProofOptions { target.write_u8(self.blowup_factor); target.write_u8(self.grinding_factor); target.write(self.field_extension); - target.write_u8(self.fri_folding_factor); - target.write_u8(self.fri_remainder_max_degree); + target.write(self.fri_schedule.clone()); } } @@ -247,8 +270,7 @@ impl Deserializable for ProofOptions { source.read_u8()? as usize, source.read_u8()? as u32, FieldExtension::read_from(source)?, - source.read_u8()? as usize, - source.read_u8()? as usize, + &FoldingSchedule::read_from(source)?, )) } } @@ -299,6 +321,7 @@ impl Deserializable for FieldExtension { #[cfg(test)] mod tests { use super::{FieldExtension, ProofOptions, ToElements}; + use fri::fri_schedule::FoldingSchedule; use math::fields::f64::BaseElement; #[test] @@ -306,6 +329,10 @@ mod tests { let field_extension = FieldExtension::None; let fri_folding_factor = 8; let fri_remainder_max_degree = 127; + let fri_folding_schedule = FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree, + }; let grinding_factor = 20; let blowup_factor = 8; let num_queries = 30; @@ -318,7 +345,7 @@ mod tests { ]); let expected = vec![ BaseElement::from(ext_fri), - BaseElement::from(grinding_factor as u32), + BaseElement::from(grinding_factor), BaseElement::from(blowup_factor as u32), BaseElement::from(num_queries as u32), ]; @@ -328,8 +355,7 @@ mod tests { blowup_factor, grinding_factor, field_extension, - fri_folding_factor as usize, - fri_remainder_max_degree as usize, + &fri_folding_schedule, ); assert_eq!(expected, options.to_elements()); } diff --git a/air/src/proof/context.rs b/air/src/proof/context.rs index 573a79516..2405896eb 100644 --- a/air/src/proof/context.rs +++ b/air/src/proof/context.rs @@ -223,6 +223,7 @@ fn bytes_to_element(bytes: &[u8]) -> B { mod tests { use super::{Context, ProofOptions, ToElements, TraceInfo}; use crate::{FieldExtension, TraceLayout}; + use fri::fri_schedule::FoldingSchedule; use math::fields::f64::BaseElement; #[test] @@ -247,6 +248,9 @@ mod tests { 0, ]); + let fri_constant_schedule = + FoldingSchedule::new_constant(fri_folding_factor, fri_remainder_max_degree); + let layout_info = u32::from_le_bytes([aux_rands, aux_width, num_aux_segments, main_width]); let expected = vec![ @@ -254,7 +258,7 @@ mod tests { BaseElement::from(1_u32), // lower bits of field modulus BaseElement::from(u32::MAX), // upper bits of field modulus BaseElement::from(ext_fri), - BaseElement::from(grinding_factor as u32), + BaseElement::from(grinding_factor), BaseElement::from(blowup_factor as u32), BaseElement::from(num_queries as u32), BaseElement::from(trace_length as u32), @@ -265,8 +269,7 @@ mod tests { blowup_factor, grinding_factor, field_extension, - fri_folding_factor as usize, - fri_remainder_max_degree as usize, + &fri_constant_schedule, ); let layout = TraceLayout::new( main_width as usize, diff --git a/crypto/src/hash/griffin/griffin64_256_jive/tests.rs b/crypto/src/hash/griffin/griffin64_256_jive/tests.rs index fd621cc86..c15328da3 100644 --- a/crypto/src/hash/griffin/griffin64_256_jive/tests.rs +++ b/crypto/src/hash/griffin/griffin64_256_jive/tests.rs @@ -204,7 +204,7 @@ proptest! { for i in 0..STATE_WIDTH { v1[i] = BaseElement::new(a[i]); } - v2 = v1.clone(); + v2 = v1; apply_mds_naive(&mut v1); GriffinJive64_256::apply_linear(&mut v2); diff --git a/crypto/src/hash/rescue/rp64_256/tests.rs b/crypto/src/hash/rescue/rp64_256/tests.rs index 49bb11273..2ba0aa416 100644 --- a/crypto/src/hash/rescue/rp64_256/tests.rs +++ b/crypto/src/hash/rescue/rp64_256/tests.rs @@ -199,7 +199,7 @@ proptest! { for i in 0..STATE_WIDTH { v1[i] = BaseElement::new(a[i]); } - v2 = v1.clone(); + v2 = v1; apply_mds_naive(&mut v1); Rp64_256::apply_mds(&mut v2); diff --git a/crypto/src/hash/rescue/rp64_256_jive/tests.rs b/crypto/src/hash/rescue/rp64_256_jive/tests.rs index 8148173f6..ac36546e1 100644 --- a/crypto/src/hash/rescue/rp64_256_jive/tests.rs +++ b/crypto/src/hash/rescue/rp64_256_jive/tests.rs @@ -36,7 +36,7 @@ fn mds_inv_test() { #[test] fn test_alphas() { let e: BaseElement = rand_value(); - let e_exp = e.exp(ALPHA.into()); + let e_exp = e.exp(ALPHA); assert_eq!(e, e_exp.exp(INV_ALPHA)); } @@ -197,7 +197,7 @@ proptest! { for i in 0..STATE_WIDTH { v1[i] = BaseElement::new(a[i]); } - v2 = v1.clone(); + v2 = v1; apply_mds_naive(&mut v1); RpJive64_256::apply_mds(&mut v2); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 7d7d5e3d1..65dc71b14 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -27,6 +27,7 @@ std = ["hex/std", "winterfell/std", "core-utils/std", "rand-utils"] [dependencies] winterfell = { version="0.6", path = "../winterfell", default-features = false } +winter-fri = { version="0.6", path = "../fri", default-features = false } core-utils = { version = "0.6", path = "../utils/core", package = "winter-utils", default-features = false } rand-utils = { version = "0.6", path = "../utils/rand", package = "winter-rand-utils", optional = true } hex = { version = "0.4", optional = true } diff --git a/examples/benches/fibonacci.rs b/examples/benches/fibonacci.rs index 9e5c2617d..7619651bf 100644 --- a/examples/benches/fibonacci.rs +++ b/examples/benches/fibonacci.rs @@ -6,6 +6,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use examples::{fibonacci, Example}; use std::time::Duration; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{ crypto::hashers::Blake3_256, math::fields::f128::BaseElement, FieldExtension, ProofOptions, }; @@ -17,7 +18,9 @@ fn fibonacci(c: &mut Criterion) { group.sample_size(10); group.measurement_time(Duration::from_secs(20)); - let options = ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 255); + let fri_constant_schedule = FoldingSchedule::new_constant(4, 255); + + let options = ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule); for &size in SIZES.iter() { let fib = diff --git a/examples/benches/rescue.rs b/examples/benches/rescue.rs index b150f8628..81e9d3836 100644 --- a/examples/benches/rescue.rs +++ b/examples/benches/rescue.rs @@ -5,6 +5,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use examples::{rescue, Example}; +use winter_fri::fri_schedule::FoldingSchedule; use std::time::Duration; use winterfell::{ @@ -18,7 +19,9 @@ fn rescue(c: &mut Criterion) { group.sample_size(10); group.measurement_time(Duration::from_secs(25)); - let options = ProofOptions::new(32, 32, 0, FieldExtension::None, 4, 255); + let fri_constant_schedule = FoldingSchedule::new_constant(4, 255); + + let options = ProofOptions::new(32, 32, 0, FieldExtension::None, &fri_constant_schedule); for &size in SIZES.iter() { let resc = rescue::RescueExample::>::new(size, options.clone()); diff --git a/examples/src/fibonacci/utils.rs b/examples/src/fibonacci/utils.rs index e2f29f7c2..b3fd9dc96 100644 --- a/examples/src/fibonacci/utils.rs +++ b/examples/src/fibonacci/utils.rs @@ -31,6 +31,7 @@ pub fn compute_mulfib_term(n: usize) -> BaseElement { #[cfg(test)] pub fn build_proof_options(use_extension_field: bool) -> winterfell::ProofOptions { + use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{FieldExtension, ProofOptions}; let extension = if use_extension_field { @@ -38,5 +39,7 @@ pub fn build_proof_options(use_extension_field: bool) -> winterfell::ProofOption } else { FieldExtension::None }; - ProofOptions::new(28, 8, 0, extension, 4, 7) + + let fri_constant_schedule = FoldingSchedule::new_constant(4, 7); + ProofOptions::new(28, 8, 0, extension, &fri_constant_schedule) } diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 14f27a584..6ce63b19c 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use structopt::StructOpt; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{ crypto::hashers::{GriffinJive64_256, Rp64_256, RpJive64_256}, math::fields::f128::BaseElement, @@ -92,14 +93,17 @@ impl ExampleOptions { val => panic!("'{val}' is not a valid hash function option"), }; + let fri_constant_schedule = FoldingSchedule::Constant { + fri_folding_factor: self.folding_factor as u8, + fri_remainder_max_degree: 31, + }; ( ProofOptions::new( num_queries, blowup_factor, self.grinding_factor, field_extension, - self.folding_factor, - 31, + &fri_constant_schedule, ), hash_fn, ) diff --git a/examples/src/merkle/tests.rs b/examples/src/merkle/tests.rs index 2f2d345af..b847b9e03 100644 --- a/examples/src/merkle/tests.rs +++ b/examples/src/merkle/tests.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use super::Blake3_256; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{FieldExtension, ProofOptions}; #[test] @@ -39,5 +40,6 @@ fn build_options(use_extension_field: bool) -> ProofOptions { } else { FieldExtension::None }; - ProofOptions::new(28, 8, 0, extension, 4, 31) + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + ProofOptions::new(28, 8, 0, extension, &fri_constant_schedule) } diff --git a/examples/src/rescue/tests.rs b/examples/src/rescue/tests.rs index 5a79caa7c..76251a672 100644 --- a/examples/src/rescue/tests.rs +++ b/examples/src/rescue/tests.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use super::Blake3_256; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{FieldExtension, ProofOptions}; #[test] @@ -39,5 +40,7 @@ fn build_options(use_extension_field: bool) -> ProofOptions { } else { FieldExtension::None }; - ProofOptions::new(28, 8, 0, extension, 4, 31) + + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + ProofOptions::new(28, 8, 0, extension, &fri_constant_schedule) } diff --git a/examples/src/rescue_raps/tests.rs b/examples/src/rescue_raps/tests.rs index a1e9b1c0c..c406599df 100644 --- a/examples/src/rescue_raps/tests.rs +++ b/examples/src/rescue_raps/tests.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use super::Blake3_256; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{FieldExtension, ProofOptions}; #[test] @@ -39,5 +40,7 @@ fn build_options(use_extension_field: bool) -> ProofOptions { } else { FieldExtension::None }; - ProofOptions::new(28, 8, 0, extension, 4, 31) + + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + ProofOptions::new(28, 8, 0, extension, &fri_constant_schedule) } diff --git a/examples/src/vdf/exempt/tests.rs b/examples/src/vdf/exempt/tests.rs index d39f1906a..7ae7397db 100644 --- a/examples/src/vdf/exempt/tests.rs +++ b/examples/src/vdf/exempt/tests.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use super::Blake3_256; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{FieldExtension, ProofOptions}; #[test] @@ -39,5 +40,7 @@ fn build_options(use_extension_field: bool) -> ProofOptions { } else { FieldExtension::None }; - ProofOptions::new(85, 2, 0, extension, 4, 31) + + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + ProofOptions::new(85, 2, 0, extension, &fri_constant_schedule) } diff --git a/examples/src/vdf/regular/tests.rs b/examples/src/vdf/regular/tests.rs index 7caf91475..bfa8bfce8 100644 --- a/examples/src/vdf/regular/tests.rs +++ b/examples/src/vdf/regular/tests.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use super::Blake3_256; +use winter_fri::fri_schedule::FoldingSchedule; use winterfell::{FieldExtension, ProofOptions}; #[test] @@ -39,5 +40,7 @@ fn build_options(use_extension_field: bool) -> ProofOptions { } else { FieldExtension::None }; - ProofOptions::new(85, 2, 0, extension, 4, 31) + + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + ProofOptions::new(85, 2, 0, extension, &fri_constant_schedule) } diff --git a/fri/benches/prover.rs b/fri/benches/prover.rs index bab72db1b..894ea2a42 100644 --- a/fri/benches/prover.rs +++ b/fri/benches/prover.rs @@ -8,7 +8,7 @@ use crypto::{hashers::Blake3_256, DefaultRandomCoin}; use math::{fft, fields::f128::BaseElement, FieldElement}; use rand_utils::rand_vector; use std::time::Duration; -use winter_fri::{DefaultProverChannel, FriOptions, FriProver}; +use winter_fri::{fri_schedule::FoldingSchedule, DefaultProverChannel, FriOptions, FriProver}; static BATCH_SIZES: [usize; 3] = [65536, 131072, 262144]; static BLOWUP_FACTOR: usize = 8; @@ -18,7 +18,8 @@ pub fn build_layers(c: &mut Criterion) { fri_group.sample_size(10); fri_group.measurement_time(Duration::from_secs(10)); - let options = FriOptions::new(BLOWUP_FACTOR, 4, 255); + let fri_constant_schedule = FoldingSchedule::new_constant(4, 255); + let options = FriOptions::new(BLOWUP_FACTOR, fri_constant_schedule); for &domain_size in &BATCH_SIZES { let evaluations = build_evaluations(domain_size); diff --git a/fri/src/fri_schedule.rs b/fri/src/fri_schedule.rs new file mode 100644 index 000000000..3930cd93a --- /dev/null +++ b/fri/src/fri_schedule.rs @@ -0,0 +1,149 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use utils::{ + collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, +}; + +/// Enumerates the possible schedules for the FRI folding process. +/// +/// The FRI folding process can operate under a constant factor or +/// can follow a dynamic sequence of factors. This enum provides a +/// way to specify which approach to use. +/// +/// # Variants +/// +/// - `Constant`: Represents a constant folding factor. This means that +/// the prover will use the same folding factor iteratively throughout +/// the FRI folding process. The prover will also specify the maximum +/// degree of the remainder polynomial at the last FRI layer. +/// +/// - `Dynamic`: Represents a dynamic schedule of folding factors. This means +/// that the prover can use different folding factors across different rounds. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum FoldingSchedule { + Constant { + fri_folding_factor: u8, + fri_remainder_max_degree: u8, + }, + Dynamic { + schedule: Vec, + }, +} + +impl FoldingSchedule { + // Constructors + // ------------------------------------------------------------------------------------------- + + pub fn new_constant(fri_folding_factor: u8, fri_remainder_max_degree: u8) -> Self { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree, + } + } + + pub fn new_dynamic(schedule: Vec) -> Self { + FoldingSchedule::Dynamic { schedule } + } + + // Accessors + // ------------------------------------------------------------------------------------------- + + pub fn get_factor(&self) -> Option { + match self { + FoldingSchedule::Constant { + fri_folding_factor, .. + } => Some(*fri_folding_factor), + FoldingSchedule::Dynamic { schedule: _ } => None, + } + } + + pub fn get_schedule(&self) -> Option<&[u8]> { + match self { + FoldingSchedule::Constant { .. } => None, + FoldingSchedule::Dynamic { schedule } => Some(schedule), + } + } + + pub fn get_max_remainder_degree(&self) -> Option { + match self { + FoldingSchedule::Constant { + fri_remainder_max_degree, + .. + } => Some(*fri_remainder_max_degree), + FoldingSchedule::Dynamic { schedule: _ } => None, + } + } + + // Utility methods + // ------------------------------------------------------------------------------------------- + + /// Returns true if the schedule is constant, false otherwise. + pub fn is_constant(&self) -> bool { + matches!(self, FoldingSchedule::Constant { .. }) + } + + /// Returns true if the schedule is dynamic, false otherwise. + pub fn is_dynamic(&self) -> bool { + matches!(self, FoldingSchedule::Dynamic { .. }) + } + + /// Returns the number of layers in the schedule if the schedule is dynamic, None otherwise. + pub fn len_schedule(&self) -> Option { + match self { + FoldingSchedule::Dynamic { schedule, .. } => Some(schedule.len()), + _ => None, + } + } +} + +// FRI SCHEDULE IMPLEMENTATION +// ================================================================================================ + +impl Serializable for FoldingSchedule { + // Serializes `FoldingSchedule` and writes the resulting bytes into the `target`. + fn write_into(&self, target: &mut W) { + match self { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree, + } => { + target.write_u8(1); + target.write_u8(*fri_folding_factor); + target.write_u8(*fri_remainder_max_degree); + } + FoldingSchedule::Dynamic { schedule } => { + target.write_u8(2); + target.write_u8(schedule.len() as u8); + for factor in schedule { + target.write_u8(*factor); + } + } + } + } +} + +impl Deserializable for FoldingSchedule { + // Reads a `FoldingSchedule` from the specified `source`. + fn read_from(source: &mut W) -> Result { + match source.read_u8()? { + 1 => Ok(FoldingSchedule::Constant { + fri_folding_factor: source.read_u8()?, + fri_remainder_max_degree: source.read_u8()?, + }), + 2 => { + let len = source.read_u8()?; + let mut schedule = Vec::with_capacity(len as usize); + for _ in 0..len { + schedule.push(source.read_u8()?); + } + Ok(FoldingSchedule::Dynamic { schedule }) + } + value => Err(DeserializationError::InvalidValue(format!( + "value {value} cannot be deserialized as FoldingSchedule enum" + ))), + } + } +} diff --git a/fri/src/lib.rs b/fri/src/lib.rs index 45e027c5f..22c0f66a4 100644 --- a/fri/src/lib.rs +++ b/fri/src/lib.rs @@ -69,6 +69,7 @@ extern crate alloc; pub mod folding; +pub mod fri_schedule; mod prover; pub use prover::{DefaultProverChannel, FriProver, ProverChannel}; diff --git a/fri/src/options.rs b/fri/src/options.rs index b762774f1..47f71b880 100644 --- a/fri/src/options.rs +++ b/fri/src/options.rs @@ -3,6 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use crate::fri_schedule::FoldingSchedule; use math::StarkField; // FRI OPTIONS @@ -11,8 +12,7 @@ use math::StarkField; /// FRI protocol config options for proof generation and verification. #[derive(Clone, PartialEq, Eq)] pub struct FriOptions { - folding_factor: usize, - remainder_max_degree: usize, + folding_schedule: FoldingSchedule, blowup_factor: usize, } @@ -23,22 +23,37 @@ impl FriOptions { /// Panics if: /// - `blowup_factor` is not a power of two. /// - `folding_factor` is not 2, 4, 8, or 16. - pub fn new(blowup_factor: usize, folding_factor: usize, remainder_max_degree: usize) -> Self { + pub fn new(blowup_factor: usize, folding_schedule: FoldingSchedule) -> Self { // TODO: change panics to errors assert!( blowup_factor.is_power_of_two(), "blowup factor must be a power of two, but was {blowup_factor}" ); - assert!( - folding_factor == 2 - || folding_factor == 4 - || folding_factor == 8 - || folding_factor == 16, - "folding factor {folding_factor} is not supported" - ); + + match &folding_schedule { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree: _, + } => { + assert!( + *fri_folding_factor == 2 + || *fri_folding_factor == 4 + || *fri_folding_factor == 8 + || *fri_folding_factor == 16, + "folding factor {fri_folding_factor} is not supported" + ); + } + FoldingSchedule::Dynamic { schedule } => { + assert!( + schedule.iter().all(|factor| factor.is_power_of_two()), + "FRI folding factors must be powers of 2" + ); + assert!(!schedule.is_empty(), "FRI folding schedule cannot be empty"); + } + } + FriOptions { - folding_factor, - remainder_max_degree, + folding_schedule, blowup_factor, } } @@ -57,16 +72,20 @@ impl FriOptions { /// /// In combination with `remainder_max_degree_plus_1` this property defines how many FRI layers are /// needed for an evaluation domain of a given size. - pub fn folding_factor(&self) -> usize { - self.folding_factor + pub fn folding_factor(&self) -> Option { + self.folding_schedule + .get_factor() + .map(|factor| factor as usize) } /// Returns maximum allowed remainder polynomial degree. /// /// In combination with `folding_factor` this property defines how many FRI layers are needed /// for an evaluation domain of a given size. - pub fn remainder_max_degree(&self) -> usize { - self.remainder_max_degree + pub fn remainder_max_degree(&self) -> Option { + self.folding_schedule + .get_max_remainder_degree() + .map(|degree| degree as usize) } /// Returns a blowup factor of the evaluation domain. @@ -78,17 +97,36 @@ impl FriOptions { self.blowup_factor } - /// Computes and return the number of FRI layers required for a domain of the specified size. + pub fn get_schedule(&self) -> &FoldingSchedule { + &self.folding_schedule + } + + /// Computes and returns the number of FRI layers required for a domain of the specified size. + /// + /// The number of layers for a given domain size is determined based on the folding schedule: + /// - For a `Constant` schedule, the number of layers is defined by the `fri_folding_factor`, + /// `fri_remainder_max_degree`, and `blowup_factor` settings. + /// - For a `Dynamic` schedule, it's simply the length of the custom folding schedule. /// - /// The number of layers for a given domain size is defined by the `folding_factor` and - /// `remainder_max_degree` and `blowup_factor` settings. + /// Note that for a `Constant` schedule, the domain size is progressively reduced by the folding + /// factor until it is less than or equal to the threshold defined by + /// `(fri_remainder_max_degree + 1) * blowup_factor`. pub fn num_fri_layers(&self, mut domain_size: usize) -> usize { - let mut result = 0; - let max_remainder_size = (self.remainder_max_degree + 1) * self.blowup_factor; - while domain_size > max_remainder_size { - domain_size /= self.folding_factor; - result += 1; + match self.get_schedule() { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree, + } => { + let mut result = 0; + let max_remainder_size = + (*fri_remainder_max_degree as usize + 1) * self.blowup_factor(); + while domain_size > max_remainder_size { + domain_size /= *fri_folding_factor as usize; + result += 1; + } + result + } + FoldingSchedule::Dynamic { schedule } => schedule.len(), } - result } } diff --git a/fri/src/proof.rs b/fri/src/proof.rs index 4a32282f9..d99760008 100644 --- a/fri/src/proof.rs +++ b/fri/src/proof.rs @@ -10,6 +10,8 @@ use utils::{ DeserializationError, Serializable, SliceReader, }; +use crate::fri_schedule::FoldingSchedule; + // FRI PROOF // ================================================================================================ @@ -121,7 +123,7 @@ impl FriProof { pub fn parse_layers( self, mut domain_size: usize, - folding_factor: usize, + folding_schedule: &FoldingSchedule, ) -> Result<(Vec>, Vec>), DeserializationError> where E: FieldElement, @@ -131,23 +133,61 @@ impl FriProof { domain_size.is_power_of_two(), "domain size must be a power of two" ); - assert!( - folding_factor.is_power_of_two(), - "folding factor must be a power of two" - ); - assert!(folding_factor > 1, "folding factor must be greater than 1"); let mut layer_proofs = Vec::new(); let mut layer_queries = Vec::new(); - // parse all layers - for (i, layer) in self.layers.into_iter().enumerate() { - domain_size /= folding_factor; - let (qv, mp) = layer.parse(domain_size, folding_factor).map_err(|err| { - DeserializationError::InvalidValue(format!("failed to parse FRI layer {i}: {err}")) - })?; - layer_proofs.push(mp); - layer_queries.push(qv); + match folding_schedule { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree: _, + } => { + let folding_factor = *fri_folding_factor as usize; + assert!( + folding_factor.is_power_of_two(), + "folding factor must be a power of two" + ); + assert!(folding_factor > 1, "folding factor must be greater than 1"); + + // parse all layers. + for (i, layer) in self.layers.into_iter().enumerate() { + domain_size /= folding_factor; + let (qv, mp) = layer.parse(domain_size, folding_factor).map_err(|err| { + DeserializationError::InvalidValue(format!( + "failed to parse FRI layer {}: {}", + i, err + )) + })?; + layer_proofs.push(mp); + layer_queries.push(qv); + } + } + FoldingSchedule::Dynamic { schedule } => { + for (i, layer) in self.layers.into_iter().enumerate() { + let fri_folding_factor = schedule[i]; + + assert!( + fri_folding_factor.is_power_of_two(), + "folding factor must be a power of two" + ); + assert!( + fri_folding_factor > 1, + "folding factor must be greater than 1" + ); + + domain_size /= fri_folding_factor as usize; + let (qv, mp) = layer + .parse(domain_size, fri_folding_factor.into()) + .map_err(|err| { + DeserializationError::InvalidValue(format!( + "failed to parse FRI layer {}: {}", + i, err + )) + })?; + layer_proofs.push(mp); + layer_queries.push(qv); + } + } } Ok((layer_queries, layer_proofs)) diff --git a/fri/src/prover/mod.rs b/fri/src/prover/mod.rs index 0c3b3c6e7..688d6856e 100644 --- a/fri/src/prover/mod.rs +++ b/fri/src/prover/mod.rs @@ -5,6 +5,7 @@ use crate::{ folding::{apply_drp, fold_positions}, + fri_schedule::FoldingSchedule, proof::{FriProof, FriProofLayer}, utils::hash_values, FriOptions, @@ -132,11 +133,6 @@ where // ACCESSORS // -------------------------------------------------------------------------------------------- - /// Returns folding factor for this prover. - pub fn folding_factor(&self) -> usize { - self.options.folding_factor() - } - /// Returns offset of the domain over which FRI protocol is executed by this prover. pub fn domain_offset(&self) -> B { self.options.domain_offset() @@ -175,15 +171,41 @@ where "a prior proof generation request has not been completed yet" ); - // reduce the degree by folding_factor at each iteration until the remaining polynomial - // has small enough degree - for _ in 0..self.options.num_fri_layers(evaluations.len()) { - match self.folding_factor() { - 2 => self.build_layer::<2>(channel, &mut evaluations), - 4 => self.build_layer::<4>(channel, &mut evaluations), - 8 => self.build_layer::<8>(channel, &mut evaluations), - 16 => self.build_layer::<16>(channel, &mut evaluations), - _ => unimplemented!("folding factor {} is not supported", self.folding_factor()), + let schedule = self.options.get_schedule().clone(); + + match schedule { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree: _, + } => { + // reduce the degree by folding_factor at each iteration until the remaining polynomial + // has small enough degree + for _ in 0..self.options.num_fri_layers(evaluations.len()) { + match fri_folding_factor { + 2 => self.build_layer::<2>(channel, &mut evaluations), + 4 => self.build_layer::<4>(channel, &mut evaluations), + 8 => self.build_layer::<8>(channel, &mut evaluations), + 16 => self.build_layer::<16>(channel, &mut evaluations), + _ => { + unimplemented!("folding factor {} is not supported", fri_folding_factor) + } + } + } + } + + FoldingSchedule::Dynamic { schedule } => { + for &fri_folding_factor in schedule.iter() { + match fri_folding_factor { + 2 => self.build_layer::<2>(channel, &mut evaluations), + 4 => self.build_layer::<4>(channel, &mut evaluations), + 8 => self.build_layer::<8>(channel, &mut evaluations), + 16 => self.build_layer::<16>(channel, &mut evaluations), + 32 => self.build_layer::<32>(channel, &mut evaluations), + _ => { + unimplemented!("folding factor {} is not supported", fri_folding_factor) + } + } + } } } @@ -247,24 +269,66 @@ where if !self.layers.is_empty() { let mut positions = positions.to_vec(); let mut domain_size = self.layers[0].evaluations.len(); - let folding_factor = self.options.folding_factor(); - - // for all FRI layers, except the last one, record tree root, determine a set of query - // positions, and query the layer at these positions. - for i in 0..self.layers.len() { - positions = fold_positions(&positions, domain_size, folding_factor); - - // sort of a static dispatch for folding_factor parameter - let proof_layer = match folding_factor { - 2 => query_layer::(&self.layers[i], &positions), - 4 => query_layer::(&self.layers[i], &positions), - 8 => query_layer::(&self.layers[i], &positions), - 16 => query_layer::(&self.layers[i], &positions), - _ => unimplemented!("folding factor {} is not supported", folding_factor), - }; - - layers.push(proof_layer); - domain_size /= folding_factor; + + match self.options.get_schedule() { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree: _, + } => { + // for all FRI layers, except the last one, record tree root, determine a set of query + // positions, and query the layer at these positions. + for i in 0..self.layers.len() { + positions = + fold_positions(&positions, domain_size, *fri_folding_factor as usize); + + // sort of a static dispatch for folding_factor parameter + let proof_layer = match fri_folding_factor { + 2 => query_layer::(&self.layers[i], &positions), + 4 => query_layer::(&self.layers[i], &positions), + 8 => query_layer::(&self.layers[i], &positions), + 16 => query_layer::(&self.layers[i], &positions), + _ => { + unimplemented!( + "folding factor {} is not supported", + fri_folding_factor + ) + } + }; + + layers.push(proof_layer); + domain_size /= *fri_folding_factor as usize; + } + } + + FoldingSchedule::Dynamic { schedule } => { + // for all FRI layers, except the last one, record tree root, determine a set of query + // positions, and query the layer at these positions. + #[allow(clippy::needless_range_loop)] + for i in 0..self.layers.len() { + let fri_folding_factor = schedule[i]; + + positions = + fold_positions(&positions, domain_size, fri_folding_factor as usize); + + // sort of a static dispatch for folding_factor parameter + let proof_layer = match fri_folding_factor { + 2 => query_layer::(&self.layers[i], &positions), + 4 => query_layer::(&self.layers[i], &positions), + 8 => query_layer::(&self.layers[i], &positions), + 16 => query_layer::(&self.layers[i], &positions), + 32 => query_layer::(&self.layers[i], &positions), + _ => { + unimplemented!( + "folding factor {} is not supported", + fri_folding_factor + ) + } + }; + + layers.push(proof_layer); + domain_size /= fri_folding_factor as usize; + } + } } } diff --git a/fri/src/prover/tests.rs b/fri/src/prover/tests.rs index 2e2c3b01f..a2c95fbb9 100644 --- a/fri/src/prover/tests.rs +++ b/fri/src/prover/tests.rs @@ -5,6 +5,7 @@ use super::{DefaultProverChannel, FriProver}; use crate::{ + fri_schedule::FoldingSchedule, verifier::{DefaultVerifierChannel, FriVerifier}, FriOptions, FriProof, VerifierError, }; @@ -21,28 +22,20 @@ type Blake3 = Blake3_256; fn fri_folding_2() { let trace_length_e = 12; let lde_blowup_e = 3; - let folding_factor_e = 1; + let folding_factor_e = 2; let max_remainder_degree = 7; - fri_prove_verify( - trace_length_e, - lde_blowup_e, - folding_factor_e, - max_remainder_degree, - ) + let folding_schedule = FoldingSchedule::new_constant(folding_factor_e, max_remainder_degree); + fri_prove_verify(trace_length_e, lde_blowup_e, folding_schedule) } #[test] fn fri_folding_4() { let trace_length_e = 12; let lde_blowup_e = 3; - let folding_factor_e = 2; + let folding_factor_e = 4; let max_remainder_degree = 255; - fri_prove_verify( - trace_length_e, - lde_blowup_e, - folding_factor_e, - max_remainder_degree, - ) + let folding_schedule = FoldingSchedule::new_constant(folding_factor_e, max_remainder_degree); + fri_prove_verify(trace_length_e, lde_blowup_e, folding_schedule) } // TEST UTILS @@ -89,7 +82,7 @@ pub fn verify_proof( proof, commitments, domain_size, - options.folding_factor(), + options.get_schedule(), ) .unwrap(); let mut coin = DefaultRandomCoin::::new(&[]); @@ -101,17 +94,12 @@ pub fn verify_proof( verifier.verify(&mut channel, &queried_evaluations, positions) } -fn fri_prove_verify( - trace_length_e: usize, - lde_blowup_e: usize, - folding_factor_e: usize, - max_remainder_degree: usize, -) { +fn fri_prove_verify(trace_length_e: usize, lde_blowup_e: usize, folding_schedule: FoldingSchedule) { let trace_length = 1 << trace_length_e; let lde_blowup = 1 << lde_blowup_e; - let folding_factor = 1 << folding_factor_e; + // let folding_factor = 1 << folding_factor_e; - let options = FriOptions::new(lde_blowup, folding_factor, max_remainder_degree); + let options = FriOptions::new(lde_blowup, folding_schedule); let mut channel = build_prover_channel(trace_length, &options); let evaluations = build_evaluations(trace_length, lde_blowup); diff --git a/fri/src/verifier/channel.rs b/fri/src/verifier/channel.rs index ebcf0acdd..92fbad46d 100644 --- a/fri/src/verifier/channel.rs +++ b/fri/src/verifier/channel.rs @@ -3,7 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use crate::{FriProof, VerifierError}; +use crate::{fri_schedule::FoldingSchedule, FriProof, VerifierError}; use crypto::{BatchMerkleProof, ElementHasher, Hasher, MerkleTree}; use math::FieldElement; use utils::{collections::Vec, group_vector_elements, DeserializationError}; @@ -128,13 +128,13 @@ where proof: FriProof, layer_commitments: Vec, domain_size: usize, - folding_factor: usize, + folding_schedule: &FoldingSchedule, ) -> Result { let num_partitions = proof.num_partitions(); let remainder = proof.parse_remainder()?; let (layer_queries, layer_proofs) = - proof.parse_layers::(domain_size, folding_factor)?; + proof.parse_layers::(domain_size, folding_schedule)?; Ok(DefaultVerifierChannel { layer_commitments, diff --git a/fri/src/verifier/mod.rs b/fri/src/verifier/mod.rs index 7f3544ef8..a62f65eae 100644 --- a/fri/src/verifier/mod.rs +++ b/fri/src/verifier/mod.rs @@ -5,8 +5,11 @@ //! Contains an implementation of FRI verifier and associated components. -use crate::{folding::fold_positions, utils::map_positions_to_indexes, FriOptions, VerifierError}; -use core::{convert::TryInto, marker::PhantomData, mem}; +use crate::{ + folding::fold_positions, fri_schedule::FoldingSchedule, utils::map_positions_to_indexes, + FriOptions, VerifierError, +}; +use core::{convert::TryInto, marker::PhantomData}; use crypto::{ElementHasher, RandomCoin}; use math::{polynom, FieldElement, StarkField}; use utils::collections::Vec; @@ -115,23 +118,45 @@ where let layer_commitments = channel.read_fri_layer_commitments(); let mut layer_alphas = Vec::with_capacity(layer_commitments.len()); let mut max_degree_plus_1 = max_poly_degree + 1; + let num_layers = layer_commitments.len(); + for (depth, commitment) in layer_commitments.iter().enumerate() { public_coin.reseed(*commitment); let alpha = public_coin.draw().map_err(VerifierError::RandomCoinError)?; layer_alphas.push(alpha); - // make sure the degree can be reduced by the folding factor at all layers - // but the remainder layer - if depth != layer_commitments.len() - 1 - && max_degree_plus_1 % options.folding_factor() != 0 - { - return Err(VerifierError::DegreeTruncation( - max_degree_plus_1 - 1, - options.folding_factor(), - depth, - )); + match options.get_schedule() { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree: _, + } => { + // make sure the degree can be reduced by the folding factor at all layers + // but the remainder layer + if depth != num_layers - 1 + && max_degree_plus_1 % *fri_folding_factor as usize != 0 + { + return Err(VerifierError::DegreeTruncation( + max_degree_plus_1 - 1, + *fri_folding_factor as usize, + depth, + )); + } + max_degree_plus_1 /= *fri_folding_factor as usize; + } + FoldingSchedule::Dynamic { schedule } => { + // make sure the degree can be reduced by the folding factor at all layers + // but the remainder layer + if depth != num_layers - 1 && max_degree_plus_1 % schedule[depth] as usize != 0 + { + return Err(VerifierError::DegreeTruncation( + max_degree_plus_1 - 1, + schedule[depth] as usize, + depth, + )); + } + max_degree_plus_1 /= schedule[depth] as usize; + } } - max_degree_plus_1 /= options.folding_factor(); } Ok(FriVerifier { @@ -214,26 +239,166 @@ where )); } - // static dispatch for folding factor parameter - let folding_factor = self.options.folding_factor(); - match folding_factor { - 2 => self.verify_generic::<2>(channel, evaluations, positions), - 4 => self.verify_generic::<4>(channel, evaluations, positions), - 8 => self.verify_generic::<8>(channel, evaluations, positions), - 16 => self.verify_generic::<16>(channel, evaluations, positions), - _ => Err(VerifierError::UnsupportedFoldingFactor(folding_factor)), + let mut domain_generator = self.domain_generator; + let mut domain_size = self.domain_size; + let mut max_degree_plus_1 = self.max_poly_degree + 1; + let mut positions = positions.to_vec(); + let mut evaluations = evaluations.to_vec(); + + match self.options.get_schedule() { + FoldingSchedule::Constant { + fri_folding_factor, + fri_remainder_max_degree: _, + } => { + for depth in 0..self.options.num_fri_layers(self.domain_size) { + let (next_evaluations, next_positions) = match fri_folding_factor { + 2 => self.verify_layer::<2>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + 4 => self.verify_layer::<4>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + 8 => self.verify_layer::<8>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + 16 => self.verify_layer::<16>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + _ => { + return Err(VerifierError::UnsupportedFoldingFactor( + (*fri_folding_factor).into(), + )) + } + }; + + evaluations = next_evaluations; + positions = next_positions; + + // Update the variables + domain_generator = + domain_generator.exp_vartime((*fri_folding_factor as u32).into()); + max_degree_plus_1 /= *fri_folding_factor as usize; + domain_size /= *fri_folding_factor as usize; + } + } + FoldingSchedule::Dynamic { schedule } => { + for (depth, &factor) in schedule.iter().enumerate() { + let (next_evaluations, next_positions) = match factor { + 2 => self.verify_layer::<2>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + 4 => self.verify_layer::<4>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + 8 => self.verify_layer::<8>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + 16 => self.verify_layer::<16>( + channel, + &evaluations, + &positions, + depth, + domain_generator, + domain_size, + max_degree_plus_1, + )?, + _ => return Err(VerifierError::UnsupportedFoldingFactor(factor.into())), + }; + + evaluations = next_evaluations; + positions = next_positions; + + // Update the variables + domain_generator = domain_generator.exp_vartime((factor as u32).into()); + max_degree_plus_1 /= factor as usize; + domain_size /= factor as usize; + } + } } + + // verify remainder + self.verify_remainder( + channel, + &evaluations, + &positions, + max_degree_plus_1, + domain_generator, + ) } - /// This is the actual implementation of the verification procedure described above, but it - /// also takes folding factor as a generic parameter N. - fn verify_generic( + /// Executes the query phase of the FRI protocol. + #[allow(clippy::too_many_arguments)] + fn verify_layer( &self, channel: &mut C, evaluations: &[E], positions: &[usize], - ) -> Result<(), VerifierError> { - // pre-compute roots of unity used in computing x coordinates in the folded domain + depth: usize, + domain_generator: E::BaseField, + domain_size: usize, + max_degree_plus_1: usize, + ) -> Result<(Vec, Vec), VerifierError> { + // 1. Determining which evaluations were queried in the folded layer. + let folded_positions = fold_positions(positions, domain_size, N); + + // 2. Finding these evaluations in the commitment Merkle tree. + let position_indexes = + map_positions_to_indexes(&folded_positions, domain_size, N, self.num_partitions); + + // 3. Reading the query values from the specified indexes in the Merkle tree. + let layer_commitment = self.layer_commitments[depth]; + let layer_values: Vec<[E; N]> = + channel.read_layer_queries(&position_indexes, &layer_commitment)?; + let query_values = + get_query_values(&layer_values, positions, &folded_positions, domain_size); + + if evaluations != query_values { + return Err(VerifierError::InvalidLayerFolding(depth)); + } + + // 4. Building x coordinates for each row polynomial. let folding_roots = (0..N) .map(|i| { self.domain_generator @@ -241,83 +406,57 @@ where }) .collect::>(); - // 1 ----- verify the recursive components of the FRI proof ----------------------------------- - let mut domain_generator = self.domain_generator; - let mut domain_size = self.domain_size; - let mut max_degree_plus_1 = self.max_poly_degree + 1; - let mut positions = positions.to_vec(); - let mut evaluations = evaluations.to_vec(); - - for depth in 0..self.options.num_fri_layers(self.domain_size) { - // determine which evaluations were queried in the folded layer - let mut folded_positions = - fold_positions(&positions, domain_size, self.options.folding_factor()); - // determine where these evaluations are in the commitment Merkle tree - let position_indexes = map_positions_to_indexes( - &folded_positions, - domain_size, - self.options.folding_factor(), - self.num_partitions, - ); - // read query values from the specified indexes in the Merkle tree - let layer_commitment = self.layer_commitments[depth]; - // TODO: add layer depth to the potential error message - let layer_values = channel.read_layer_queries(&position_indexes, &layer_commitment)?; - let query_values = - get_query_values::(&layer_values, &positions, &folded_positions, domain_size); - if evaluations != query_values { - return Err(VerifierError::InvalidLayerFolding(depth)); - } - - // build a set of x coordinates for each row polynomial - #[rustfmt::skip] - let xs = folded_positions.iter().map(|&i| { - let xe = domain_generator.exp_vartime((i as u64).into()) * self.options.domain_offset(); - folding_roots.iter() + let xs = folded_positions + .iter() + .map(|&i| { + let xe = + domain_generator.exp_vartime((i as u64).into()) * self.options.domain_offset(); + folding_roots + .iter() .map(|&r| E::from(xe * r)) - .collect::>().try_into().unwrap() + .collect::>() + .try_into() + .unwrap() }) .collect::>(); - // interpolate x and y values into row polynomials - let row_polys = polynom::interpolate_batch(&xs, &layer_values); + // 5. Interpolating x and y values into row polynomials. + let row_polys = polynom::interpolate_batch(&xs, &layer_values); + let alpha = self.layer_alphas[depth]; + let next_evaluations = row_polys.iter().map(|p| polynom::eval(p, alpha)).collect(); - // calculate the pseudo-random value used for linear combination in layer folding - let alpha = self.layer_alphas[depth]; - - // check that when the polynomials are evaluated at alpha, the result is equal to - // the corresponding column value - evaluations = row_polys.iter().map(|p| polynom::eval(p, alpha)).collect(); - - // make sure next degree reduction does not result in degree truncation - if max_degree_plus_1 % N != 0 { - return Err(VerifierError::DegreeTruncation( - max_degree_plus_1 - 1, - N, - depth, - )); - } - - // update variables for the next iteration of the loop - domain_generator = domain_generator.exp_vartime((N as u32).into()); - max_degree_plus_1 /= N; - domain_size /= N; - mem::swap(&mut positions, &mut folded_positions); + if max_degree_plus_1 % N != 0 { + return Err(VerifierError::DegreeTruncation( + max_degree_plus_1 - 1, + N, + depth, + )); } - // 2 ----- verify the remainder polynomial of the FRI proof ------------------------------- + Ok((next_evaluations, folded_positions)) + } + fn verify_remainder( + &self, + channel: &mut C, + evaluations: &[E], + positions: &[usize], + max_degree_plus_1: usize, + domain_generator: E::BaseField, + ) -> Result<(), VerifierError> { // read the remainder polynomial from the channel and make sure it agrees with the evaluations // from the previous layer. let remainder_poly = channel.read_remainder()?; + if remainder_poly.len() > max_degree_plus_1 { return Err(VerifierError::RemainderDegreeMismatch( max_degree_plus_1 - 1, )); } - let offset: E::BaseField = self.options().domain_offset(); - for (&position, evaluation) in positions.iter().zip(evaluations) { + let offset: E::BaseField = self.options.domain_offset(); + + for (&position, &evaluation) in positions.iter().zip(evaluations) { let comp_eval = eval_horner::( &remainder_poly, offset * domain_generator.exp_vartime((position as u64).into()), diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 8e3ae886e..a6becd88b 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -8,6 +8,7 @@ use air::{ Air, AirContext, Assertion, EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; +use fri::fri_schedule::FoldingSchedule; use math::{fields::f128::BaseElement, FieldElement, StarkField}; use utils::collections::Vec; @@ -39,10 +40,11 @@ pub struct MockAir { impl MockAir { pub fn with_trace_length(trace_length: usize) -> Self { + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); Self::new( TraceInfo::new(4, trace_length), (), - ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31), + ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule), ) } @@ -50,20 +52,22 @@ impl MockAir { column_values: Vec>, trace_length: usize, ) -> Self { + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); let mut result = Self::new( TraceInfo::new(4, trace_length), (), - ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31), + ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule), ); result.periodic_columns = column_values; result } pub fn with_assertions(assertions: Vec>, trace_length: usize) -> Self { + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); let mut result = Self::new( TraceInfo::new(4, trace_length), (), - ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31), + ProofOptions::new(32, 8, 0, FieldExtension::None, &fri_constant_schedule), ); result.assertions = assertions; result @@ -112,7 +116,14 @@ fn build_context( blowup_factor: usize, num_assertions: usize, ) -> AirContext { - let options = ProofOptions::new(32, blowup_factor, 0, FieldExtension::None, 4, 31); + let fri_constant_schedule = FoldingSchedule::new_constant(4, 31); + let options = ProofOptions::new( + 32, + blowup_factor, + 0, + FieldExtension::None, + &fri_constant_schedule, + ); let t_degrees = vec![TransitionConstraintDegree::new(2)]; AirContext::new(trace_info, t_degrees, num_assertions, options) } diff --git a/verifier/src/channel.rs b/verifier/src/channel.rs index 391bf58db..f606a4d5b 100644 --- a/verifier/src/channel.rs +++ b/verifier/src/channel.rs @@ -89,7 +89,7 @@ impl> VerifierChanne .parse_remainder() .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?; let (fri_layer_queries, fri_layer_proofs) = fri_proof - .parse_layers::(lde_domain_size, fri_options.folding_factor()) + .parse_layers::(lde_domain_size, fri_options.get_schedule()) .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?; // --- parse out-of-domain evaluation frame ----------------------------------------------- diff --git a/winterfell/Cargo.toml b/winterfell/Cargo.toml index 1d728b5e1..3648fee62 100644 --- a/winterfell/Cargo.toml +++ b/winterfell/Cargo.toml @@ -23,6 +23,7 @@ std = ["prover/std", "verifier/std"] [dependencies] prover = { version = "0.6", path = "../prover", package = "winter-prover", default-features = false } verifier = { version = "0.6", path = "../verifier", package = "winter-verifier", default-features = false } +winter-fri = { version = "0.6", path = "../fri", package = "winter-fri", default-features = false } # Allow math in docs [package.metadata.docs.rs] diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index d580e5510..14d8b39bd 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -372,6 +372,7 @@ //! # TransitionConstraintDegree, TraceTable, FieldExtension, Prover, ProofOptions, //! # StarkProof, Trace, crypto::{hashers::Blake3_256, DefaultRandomCoin}, //! # }; +//! # use winter_fri::fri_schedule::FoldingSchedule; //! # //! # pub fn build_do_work_trace(start: BaseElement, n: usize) -> TraceTable { //! # let trace_width = 1; @@ -483,14 +484,20 @@ //! let trace = build_do_work_trace(start, n); //! let result = trace.get(0, n - 1); //! +//! // Define a constant FRI folding schedule. This means that the prover will use the same +//! // folding factor for all FRI layers. +//! // The first parameter is the folding factor, and the second parameter is the maximum +//! // degree of the remainder polynomial at the last FRI layer. +//! let fri_constant_schedule = FoldingSchedule::new_constant(8, 31); +//! //! // Define proof options; these will be enough for ~96-bit security level. //! let options = ProofOptions::new( //! 32, // number of queries //! 8, // blowup factor //! 0, // grinding factor //! FieldExtension::None, -//! 8, // FRI folding factor -//! 31, // FRI max remainder polynomial degree +//! &fri_constant_schedule, // constant FRI folding schedule with folding factor 8 and +//! // remainder polynomial degree 31 //! ); //! //! // Instantiate the prover and generate the proof.