diff --git a/Cargo.lock b/Cargo.lock index a7bb2c1a41..10b8938615 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,7 +138,7 @@ dependencies = [ [[package]] name = "algebraic" version = "0.0.1" -source = "git+https://github.com/0xEigenLabs/eigen-zkvm.git?branch=main#4b9d8d840f70995c0fe75ee5aed63ce3e171ae10" +source = "git+https://github.com/0xEigenLabs/eigen-zkvm.git?rev=4ed1da7#4ed1da7766f7aa8740f355d8c1f3d213411ccf4b" dependencies = [ "byteorder", "ff_ce 0.11.0", @@ -3462,7 +3462,6 @@ dependencies = [ name = "pil_analyzer" version = "0.1.0" dependencies = [ - "analysis 0.1.0", "ast 0.1.0", "env_logger", "itertools 0.10.5", @@ -3522,7 +3521,7 @@ dependencies = [ [[package]] name = "plonky" version = "0.0.2" -source = "git+https://github.com/0xEigenLabs/eigen-zkvm.git?branch=main#4b9d8d840f70995c0fe75ee5aed63ce3e171ae10" +source = "git+https://github.com/0xEigenLabs/eigen-zkvm.git?rev=4ed1da7#4ed1da7766f7aa8740f355d8c1f3d213411ccf4b" dependencies = [ "algebraic", "bellman_vk_codegen", @@ -3635,6 +3634,7 @@ dependencies = [ "parser 0.1.0", "pilopt", "riscv", + "riscv_executor", "strum", "tempfile", ] @@ -4271,10 +4271,24 @@ dependencies = [ "number 0.1.0", "parser_util 0.1.0", "regex-syntax 0.6.29", + "riscv_executor", "serde_json", "test-log", ] +[[package]] +name = "riscv_executor" +version = "0.1.0" +dependencies = [ + "analysis 0.1.0", + "ast 0.1.0", + "importer", + "itertools 0.11.0", + "log", + "number 0.1.0", + "parser 0.1.0", +] + [[package]] name = "rkyv" version = "0.7.42" @@ -4699,7 +4713,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "starky" version = "0.0.1" -source = "git+https://github.com/0xEigenLabs/eigen-zkvm.git?branch=main#4b9d8d840f70995c0fe75ee5aed63ce3e171ae10" +source = "git+https://github.com/0xEigenLabs/eigen-zkvm.git?rev=4ed1da7#4ed1da7766f7aa8740f355d8c1f3d213411ccf4b" dependencies = [ "algebraic", "array_tool", diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 68cf46dbc2..e21b02d198 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -144,6 +144,7 @@ impl ASMPILConverter { ), PilStatement::PolynomialIdentity( 0, + None, lhs - (Expression::from(T::one()) - next_reference("first_step")) * direct_reference(pc_update_name), @@ -154,10 +155,14 @@ impl ASMPILConverter { ReadOnly => { let not_reset: Expression = Expression::from(T::one()) - direct_reference("instr__reset"); - vec![PilStatement::PolynomialIdentity(0, not_reset * (lhs - rhs))] + vec![PilStatement::PolynomialIdentity( + 0, + None, + not_reset * (lhs - rhs), + )] } _ => { - vec![PilStatement::PolynomialIdentity(0, lhs - rhs)] + vec![PilStatement::PolynomialIdentity(0, None, lhs - rhs)] } } }) @@ -389,7 +394,7 @@ impl ASMPILConverter { }); for mut statement in body { - if let PilStatement::PolynomialIdentity(_start, expr) = statement { + if let PilStatement::PolynomialIdentity(_start, _attr, expr) = statement { match extract_update(expr) { (Some(var), expr) => { let reference = direct_reference(&instruction_flag); @@ -406,12 +411,13 @@ impl ASMPILConverter { } (None, expr) => self.pil.push(PilStatement::PolynomialIdentity( 0, + None, direct_reference(&instruction_flag) * expr.clone(), )), } } else { match &mut statement { - PilStatement::PermutationIdentity(_, left, _) + PilStatement::PermutationIdentity(_, _, left, _) | PilStatement::PlookupIdentity(_, left, _) => { assert!( left.selector.is_none(), @@ -715,6 +721,7 @@ impl ASMPILConverter { .sum(); self.pil.push(PilStatement::PolynomialIdentity( 0, + None, direct_reference(register) - assign_constraint, )); } diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 1218143de2..c5194b4992 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -149,7 +149,13 @@ impl Display for Identity> { } } IdentityKind::Plookup => write!(f, "{} in {};", self.left, self.right), - IdentityKind::Permutation => write!(f, "{} is {};", self.left, self.right), + IdentityKind::Permutation => write!( + f, + "#[{}] {} is {};", + self.attribute.clone().unwrap_or_default(), + self.left, + self.right + ), IdentityKind::Connect => write!(f, "{} connect {};", self.left, self.right), } } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 26ec1bf9d8..4fda5bbfb1 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -226,6 +226,7 @@ impl Analyzed { self.identities.push(Identity { id, kind: IdentityKind::Polynomial, + attribute: None, // TODO(md): None for the meantime as we do not have tagged identities, will be updated in following pr source, left: SelectedExpressions { selector: Some(identity), @@ -504,6 +505,7 @@ pub struct Identity { /// The ID is specific to the identity kind. pub id: u64, pub kind: IdentityKind, + pub attribute: Option, pub source: SourceRef, /// For a simple polynomial identity, the selector contains /// the actual expression (see expression_for_poly_id). diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index ac23386dc4..27c435f543 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -383,7 +383,7 @@ impl Display for PilStatement { value.as_ref().map(|v| format!("{v}")).unwrap_or_default() ) } - PilStatement::PolynomialIdentity(_, expression) => { + PilStatement::PolynomialIdentity(_, _attr, expression) => { if let Expression::BinaryOperation(left, BinaryOperator::Sub, right) = expression { write!(f, "{left} = {right};") } else { @@ -391,7 +391,12 @@ impl Display for PilStatement { } } PilStatement::PlookupIdentity(_, left, right) => write!(f, "{left} in {right};"), - PilStatement::PermutationIdentity(_, left, right) => write!(f, "{left} is {right};"), + PilStatement::PermutationIdentity( + _, // + _, // + left, + right, + ) => write!(f, "{left} is {right};"), // PilStatement::ConnectIdentity(_, left, right) => write!( f, "{{ {} }} connect {{ {} }};", diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 26657af854..3a588c7c76 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -34,7 +34,7 @@ pub enum PilStatement { PolynomialConstantDeclaration(usize, Vec>), PolynomialConstantDefinition(usize, String, FunctionDefinition), PolynomialCommitDeclaration(usize, Vec>, Option>), - PolynomialIdentity(usize, Expression), + PolynomialIdentity(usize, Option, Expression), PlookupIdentity( usize, SelectedExpressions>, @@ -42,6 +42,7 @@ pub enum PilStatement { ), PermutationIdentity( usize, + Option, SelectedExpressions>, SelectedExpressions>, ), @@ -50,6 +51,11 @@ pub enum PilStatement { Expression(usize, Expression), } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct Attribute { + pub name: Option, +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub struct SelectedExpressions { pub selector: Option, diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index 8fa9e7bfa2..831249137b 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -193,7 +193,12 @@ impl ExpressionVisitable> for Pi match self { PilStatement::Expression(_, e) => e.visit_expressions_mut(f, o), PilStatement::PlookupIdentity(_, left, right) - | PilStatement::PermutationIdentity(_, left, right) => [left, right] + | PilStatement::PermutationIdentity( + _, // + _, // + left, + right, + ) => [left, right] // .into_iter() .try_for_each(|e| e.visit_expressions_mut(f, o)), PilStatement::ConnectIdentity(_start, left, right) => left @@ -203,7 +208,7 @@ impl ExpressionVisitable> for Pi PilStatement::Namespace(_, _, e) | PilStatement::PolynomialDefinition(_, _, e) - | PilStatement::PolynomialIdentity(_, e) + | PilStatement::PolynomialIdentity(_, _, e) | PilStatement::PublicDeclaration(_, _, _, None, e) | PilStatement::ConstantDefinition(_, _, e) | PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions_mut(f, o), @@ -230,7 +235,12 @@ impl ExpressionVisitable> for Pi match self { PilStatement::Expression(_, e) => e.visit_expressions(f, o), PilStatement::PlookupIdentity(_, left, right) - | PilStatement::PermutationIdentity(_, left, right) => [left, right] + | PilStatement::PermutationIdentity( + _, // + _, // + left, + right, + ) => [left, right] // .into_iter() .try_for_each(|e| e.visit_expressions(f, o)), PilStatement::ConnectIdentity(_start, left, right) => left @@ -240,7 +250,7 @@ impl ExpressionVisitable> for Pi PilStatement::Namespace(_, _, e) | PilStatement::PolynomialDefinition(_, _, e) - | PilStatement::PolynomialIdentity(_, e) + | PilStatement::PolynomialIdentity(_, _, e) | PilStatement::PublicDeclaration(_, _, _, None, e) | PilStatement::ConstantDefinition(_, _, e) | PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions(f, o), diff --git a/bberg/src/circuit_builder.rs b/bberg/src/circuit_builder.rs index d868806898..2057851a5c 100644 --- a/bberg/src/circuit_builder.rs +++ b/bberg/src/circuit_builder.rs @@ -9,22 +9,26 @@ pub trait CircuitBuilder { &mut self, name: &str, relations: &[String], + permutations: &[String], fixed: &[String], shifted: &[String], all_cols_with_shifts: &[String], ); } -fn circuit_hpp_includes(name: &str, relations: &[String]) -> String { - let relation_imports = get_relations_imports(name, relations); +fn circuit_hpp_includes(name: &str, relations: &[String], permutations: &[String]) -> String { + let relation_imports = get_relations_imports(name, relations, permutations); format!( " // AUTOGENERATED FILE #pragma once + #include \"barretenberg/common/constexpr_utils.hpp\" #include \"barretenberg/common/throw_or_abort.hpp\" #include \"barretenberg/ecc/curves/bn254/fr.hpp\" #include \"barretenberg/proof_system/circuit_builder/circuit_builder_base.hpp\" + #include \"barretenberg/relations/generic_permutation/generic_permutation_relation.hpp\" + #include \"barretenberg/honk/proof_system/logderivative_library.hpp\" #include \"barretenberg/flavor/generated/{name}_flavor.hpp\" {relation_imports} @@ -32,6 +36,23 @@ fn circuit_hpp_includes(name: &str, relations: &[String]) -> String { ) } +fn get_params() -> &'static str { + r#" + const FF gamma = FF::random_element(); + const FF beta = FF::random_element(); + proof_system::RelationParameters params{ + .eta = 0, + .beta = beta, + .gamma = gamma, + .public_input_delta = 0, + .lookup_grand_product_delta = 0, + .beta_sqr = 0, + .beta_cube = 0, + .eccvm_set_permutation_delta = 0, + }; + "# +} + impl CircuitBuilder for BBFiles { // Create circuit builder // Generate some code that can read a commits.bin and constants.bin into data structures that bberg understands @@ -39,11 +60,12 @@ impl CircuitBuilder for BBFiles { &mut self, name: &str, relations: &[String], + permutations: &[String], all_cols: &[String], to_be_shifted: &[String], all_cols_with_shifts: &[String], ) { - let includes = circuit_hpp_includes(name, relations); + let includes = circuit_hpp_includes(name, relations, permutations); let row_with_all_included = create_row_type(&format!("{name}Full"), all_cols_with_shifts); @@ -57,19 +79,40 @@ impl CircuitBuilder for BBFiles { |name: &String| format!("polys.{name}_shift = Polynomial(polys.{name}.shifted());"); let check_circuit_transformation = |relation_name: &String| { format!( - "if (!evaluate_relation.template operator()<{name}_vm::{relation_name}>(\"{relation_name}\")) {{ + "if (!evaluate_relation.template operator()<{name}_vm::{relation_name}>(\"{relation_name}\", {name}_vm::get_relation_label_{relation_name})) {{ return false; }}", name = name, relation_name = relation_name ) }; + let check_permutation_transformation = |permutation_name: &String| { + format!( + "if (!evaluate_permutation.template operator()>(\"{permutation_name}\")) {{ + return false; + }}", + permutation_name = permutation_name + ) + }; // Apply transformations let compute_polys_assignemnt = map_with_newline(all_cols, compute_polys_transformation); let all_poly_shifts = map_with_newline(to_be_shifted, all_polys_transformation); let check_circuit_for_each_relation = map_with_newline(relations, check_circuit_transformation); + let check_circuit_for_each_permutation = + map_with_newline(permutations, check_permutation_transformation); + + let (params, permutation_check_closure) = if !permutations.is_empty() { + (get_params(), get_permutation_check_closure()) + } else { + ("", "".to_owned()) + }; + let relation_check_closure = if !relations.is_empty() { + get_relation_check_closure() + } else { + "".to_owned() + }; let circuit_hpp = format!(" {includes} @@ -116,36 +159,19 @@ class {name}CircuitBuilder {{ [[maybe_unused]] bool check_circuit() {{ + {params} + auto polys = compute_polynomials(); const size_t num_rows = polys.get_polynomial_size(); - const auto evaluate_relation = [&](const std::string& relation_name) {{ - typename Relation::SumcheckArrayOfValuesOverSubrelations result; - for (auto& r : result) {{ - r = 0; - }} - constexpr size_t NUM_SUBRELATIONS = result.size(); - - for (size_t i = 0; i < num_rows; ++i) {{ - Relation::accumulate(result, polys.get_row(i), {{}}, 1); - - bool x = true; - for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) {{ - if (result[j] != 0) {{ - throw_or_abort( - format(\"Relation \", relation_name, \", subrelation index \", j, \" failed at row \", i)); - x = false; - }} - }} - if (!x) {{ - return false; - }} - }} - return true; - }}; + {relation_check_closure} + + {permutation_check_closure} {check_circuit_for_each_relation} + {check_circuit_for_each_permutation} + return true; }} @@ -172,3 +198,64 @@ class {name}CircuitBuilder {{ ); } } + +fn get_permutation_check_closure() -> String { + " + const auto evaluate_permutation = [&](const std::string& permutation_name) { + + // Check the tuple permutation relation + proof_system::honk::logderivative_library::compute_logderivative_inverse< + Flavor, + PermutationSettings>( + polys, params, num_rows); + + typename PermutationSettings::SumcheckArrayOfValuesOverSubrelations + permutation_result; + + for (auto& r : permutation_result) { + r = 0; + } + for (size_t i = 0; i < num_rows; ++i) { + PermutationSettings::accumulate(permutation_result, polys.get_row(i), params, 1); + } + for (auto r : permutation_result) { + if (r != 0) { + info(\"Tuple \", permutation_name, \" failed.\"); + return false; + } + } + return true; + }; + ".to_string() +} + +fn get_relation_check_closure() -> String { + " + const auto evaluate_relation = [&](const std::string& relation_name, + std::string (*debug_label)(int)) { + typename Relation::SumcheckArrayOfValuesOverSubrelations result; + for (auto& r : result) { + r = 0; + } + constexpr size_t NUM_SUBRELATIONS = result.size(); + + for (size_t i = 0; i < num_rows; ++i) { + Relation::accumulate(result, polys.get_row(i), {}, 1); + + bool x = true; + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + if (result[j] != 0) { + std::string row_name = debug_label(static_cast(j)); + throw_or_abort( + format(\"Relation \", relation_name, \", subrelation index \", row_name, \" failed at row \", i)); + x = false; + } + } + if (!x) { + return false; + } + } + return true; + }; + ".to_string() +} diff --git a/bberg/src/flavor_builder.rs b/bberg/src/flavor_builder.rs index 4c42d8dbaf..0c10ea4f64 100644 --- a/bberg/src/flavor_builder.rs +++ b/bberg/src/flavor_builder.rs @@ -1,5 +1,6 @@ use crate::{ file_writer::BBFiles, + permutation_builder::{get_inverses_from_permutations, Permutation}, utils::{get_relations_imports, map_with_newline}, }; @@ -9,6 +10,7 @@ pub trait FlavorBuilder { &mut self, name: &str, relation_file_names: &[String], + permutations: &[Permutation], fixed: &[String], witness: &[String], all_cols: &[String], @@ -24,6 +26,7 @@ impl FlavorBuilder for BBFiles { &mut self, name: &str, relation_file_names: &[String], + permutations: &[Permutation], fixed: &[String], witness: &[String], all_cols: &[String], @@ -31,15 +34,19 @@ impl FlavorBuilder for BBFiles { shifted: &[String], all_cols_and_shifts: &[String], ) { + // TODO: move elsewhere and rename + let inverses = get_inverses_from_permutations(permutations); + let first_poly = &witness[0]; - let includes = flavor_includes(name, relation_file_names); + let includes = flavor_includes(name, relation_file_names, &inverses); let num_precomputed = fixed.len(); let num_witness = witness.len(); let num_all = all_cols_and_shifts.len(); // Top of file boilerplate let class_aliases = create_class_aliases(); - let relation_definitions = create_relation_definitions(name, relation_file_names); + let relation_definitions = + create_relation_definitions(name, relation_file_names, permutations); let container_size_definitions = container_size_definitions(num_precomputed, num_witness, num_all); @@ -58,6 +65,8 @@ impl FlavorBuilder for BBFiles { let transcript = generate_transcript(witness); + let declare_permutation_sumcheck = create_permuation_sumcheck_declaration(&inverses, name); + let flavor_hpp = format!( " {includes} @@ -96,6 +105,10 @@ class {name}Flavor {{ }}; }} // namespace proof_system::honk::flavor + +namespace sumcheck {{ + {declare_permutation_sumcheck} +}} }} // namespace proof_system::honk @@ -107,8 +120,8 @@ class {name}Flavor {{ } /// Imports located at the top of the flavor files -fn flavor_includes(name: &str, relation_file_names: &[String]) -> String { - let relation_imports = get_relations_imports(name, relation_file_names); +fn flavor_includes(name: &str, relation_file_names: &[String], permutations: &[String]) -> String { + let relation_imports = get_relations_imports(name, relation_file_names, permutations); format!( " @@ -119,6 +132,8 @@ fn flavor_includes(name: &str, relation_file_names: &[String]) -> String { #include \"barretenberg/polynomials/barycentric.hpp\" #include \"barretenberg/polynomials/univariate.hpp\" +#include \"barretenberg/relations/generic_permutation/generic_permutation_relation.hpp\" + #include \"barretenberg/flavor/flavor_macros.hpp\" #include \"barretenberg/transcript/transcript.hpp\" #include \"barretenberg/polynomials/evaluation_domain.hpp\" @@ -138,6 +153,16 @@ fn create_relations_tuple(master_name: &str, relation_file_names: &[String]) -> .join(", ") } +/// Creates comma separated relations tuple file +/// TODO(md): maybe need the filename in here too if we scope these +fn create_permutations_tuple(permutations: &[Permutation]) -> String { + permutations + .iter() + .map(|perm| format!("sumcheck::{}_relation", perm.attribute.clone().unwrap())) + .collect::>() + .join(", ") +} + /// Create Class Aliases /// /// Contains boilerplate defining key characteristics of the flavor class @@ -165,12 +190,21 @@ fn create_class_aliases() -> &'static str { /// definitions. /// /// We then also define some constants, making use of the preprocessor. -fn create_relation_definitions(name: &str, relation_file_names: &[String]) -> String { +fn create_relation_definitions( + name: &str, + relation_file_names: &[String], + permutations: &[Permutation], +) -> String { // Relations tuple = ns::relation_name_0, ns::relation_name_1, ... ns::relation_name_n (comma speratated) let comma_sep_relations = create_relations_tuple(name, relation_file_names); + let comma_sep_perms: String = create_permutations_tuple(permutations); + let mut all_relations = comma_sep_relations.to_string(); + if !permutations.is_empty() { + all_relations = all_relations + &format!(", {comma_sep_perms}"); + } format!(" - using Relations = std::tuple<{comma_sep_relations}>; + using Relations = std::tuple<{all_relations}>; static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); @@ -445,12 +479,11 @@ fn generate_transcript(witness: &[String]) -> String { let declaration_transform = |c: &_| format!("Commitment {c};"); let deserialize_transform = |name: &_| { format!( - "{name} = deserialize_from_buffer(Transcript::proof_data, num_bytes_read);", - ) - }; - let serialize_transform = |name: &_| { - format!("serialize_to_buffer({name}, Transcript::proof_data);") + "{name} = deserialize_from_buffer(Transcript::proof_data, num_bytes_read);", + ) }; + let serialize_transform = + |name: &_| format!("serialize_to_buffer({name}, Transcript::proof_data);"); // Perform Transformations let declarations = map_with_newline(witness, declaration_transform); @@ -524,3 +557,15 @@ fn generate_transcript(witness: &[String]) -> String { }}; ") } + +fn create_permuation_sumcheck_declaration(permutations: &[String], name: &str) -> String { + let sumcheck_transformation = |perm: &String| { + format!( + " + DECLARE_SUMCHECK_RELATION_CLASS({perm}, flavor::{name}Flavor); + ", + ) + }; + + map_with_newline(permutations, sumcheck_transformation) +} diff --git a/bberg/src/lib.rs b/bberg/src/lib.rs index ba9ac632dd..0682b25107 100644 --- a/bberg/src/lib.rs +++ b/bberg/src/lib.rs @@ -3,6 +3,7 @@ mod circuit_builder; mod composer_builder; mod file_writer; mod flavor_builder; +pub mod permutation_builder; mod prover_builder; mod relation_builder; mod utils; diff --git a/bberg/src/permutation_builder.rs b/bberg/src/permutation_builder.rs new file mode 100644 index 0000000000..9d501461dd --- /dev/null +++ b/bberg/src/permutation_builder.rs @@ -0,0 +1,280 @@ +use crate::file_writer::BBFiles; +use ast::{ + analyzed::{AlgebraicExpression, Analyzed, Identity, IdentityKind}, + parsed::SelectedExpressions, +}; +use itertools::Itertools; +use number::FieldElement; + +use crate::utils::sanitize_name; + +#[derive(Debug)] +/// Permutation +/// +/// Contains the information required to produce a permutation relation +pub struct Permutation { + /// -> Attribute - the name given to the inverse helper column + pub attribute: Option, + /// -> PermSide - the left side of the permutation + pub left: PermutationSide, + /// -> PermSide - the right side of the permutation + pub right: PermutationSide, +} + +#[derive(Debug)] +/// PermSide +/// +/// One side of a two sided permutation relationship +pub struct PermutationSide { + /// -> Option - the selector for the permutation ( on / off toggle ) + selector: Option, + /// The columns involved in this side of the permutation + cols: Vec, +} + +pub trait PermutationBuilder { + /// Takes in an AST and works out what permutation relations are needed + /// Note: returns the name of the inverse columns, such that they can be added to he prover in subsequent steps + fn create_permutation_files( + &self, + name: &str, + analyzed: &Analyzed, + ) -> Vec; +} + +impl PermutationBuilder for BBFiles { + fn create_permutation_files( + &self, + project_name: &str, + analyzed: &Analyzed, + ) -> Vec { + let perms: Vec<&Identity>> = analyzed + .identities + .iter() + .filter(|identity| matches!(identity.kind, IdentityKind::Permutation)) + .collect(); + let new_perms = perms + .iter() + .map(|perm| Permutation { + attribute: perm.attribute.clone(), + left: get_perm_side(&perm.left), + right: get_perm_side(&perm.right), + }) + .collect_vec(); + + create_permutations(self, project_name, &new_perms); + new_perms + } +} + +/// The attributes of a permutation contain the name of the inverse, we collect all of these to create the inverse column +pub fn get_inverses_from_permutations(permutations: &[Permutation]) -> Vec { + permutations + .iter() + .map(|perm| perm.attribute.clone().unwrap()) + .collect() +} + +/// Write the permutation settings files to disk +fn create_permutations(bb_files: &BBFiles, project_name: &str, permutations: &Vec) { + for permutation in permutations { + let perm_settings = create_permutation_settings_file(permutation); + + let folder = format!("{}/{}", bb_files.rel, project_name); + let file_name = format!( + "{}{}", + permutation.attribute.clone().unwrap_or("NONAME".to_owned()), + ".hpp".to_owned() + ); + bb_files.write_file(&folder, &file_name, &perm_settings); + } +} + +/// All relation types eventually get wrapped in the relation type +/// This function creates the export for the relation type so that it can be added to the flavor +fn create_relation_exporter(permutation_name: &str) -> String { + let settings_name = format!("{}_permutation_settings", permutation_name); + let permutation_export = format!("template using {permutation_name}_relation = GenericPermutationRelation<{settings_name}, FF_>;"); + let relation_export = format!("template using {permutation_name} = GenericPermutation<{settings_name}, FF_>;"); + + format!( + " + {permutation_export} + {relation_export} + " + ) +} + +fn permutation_settings_includes() -> &'static str { + r#" + #pragma once + + #include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp" + + #include + #include + "# +} + +fn create_permutation_settings_file(permutation: &Permutation) -> String { + println!("Permutation: {:?}", permutation); + let columns_per_set = permutation.left.cols.len(); + // TODO(md): Throw an error if no attribute is provided for the permutation + // TODO(md): In the future we will need to condense off the back of this - combining those with the same inverse column + let permutation_name = permutation + .attribute + .clone() + .expect("Inverse column name must be provided"); // TODO(md): catch this earlier than here + + // NOTE: syntax is not flexible enough to enable the single row case right now :(:(:(:(:)))) + // This also will need to work for both sides of this ! + let selector = permutation.left.selector.clone().unwrap(); // TODO: deal with unwrap + let lhs_cols = permutation.left.cols.clone(); + let rhs_cols = permutation.right.cols.clone(); + + // 0. The polynomial containing the inverse products -> taken from the attributes + // 1. The polynomial enabling the relation (the selector) + // 2. lhs selector + // 3. rhs selector + // 4.. + columns per set. lhs cols + // 4 + columns per set.. . rhs cols + let mut perm_entities: Vec = [ + permutation_name.clone(), + selector.clone(), + selector.clone(), + selector.clone(), // TODO: update this away from the simple example + ] + .to_vec(); + + perm_entities.extend(lhs_cols); + perm_entities.extend(rhs_cols); + + let permutation_settings_includes = permutation_settings_includes(); + let inverse_computed_at = create_inverse_computed_at(selector); + let const_entities = create_get_const_entities(&perm_entities); + let nonconst_entities = create_get_nonconst_entities(&perm_entities); + let relation_exporter = create_relation_exporter(&permutation_name); + + format!( + // TODO: replace with the inverse label name! + " + {permutation_settings_includes} + + namespace proof_system::honk::sumcheck {{ + + class {permutation_name}_permutation_settings {{ + public: + // This constant defines how many columns are bundled together to form each set. + constexpr static size_t COLUMNS_PER_SET = {columns_per_set}; + + /** + * @brief If this method returns true on a row of values, then the inverse polynomial at this index. Otherwise the + * value needs to be set to zero. + * + * @details If this is true then permutation takes place in this row + */ + {inverse_computed_at} + + /** + * @brief Get all the entities for the permutation when we don't need to update them + * + * @details The entities are returned as a tuple of references in the following order: + * - The entity/polynomial used to store the product of the inverse values + * - The entity/polynomial that switches on the subrelation of the permutation relation that ensures correctness of + * the inverse polynomial + * - The entity/polynomial that enables adding a tuple-generated value from the first set to the logderivative sum + * subrelation + * - The entity/polynomial that enables adding a tuple-generated value from the second set to the logderivative sum + * subrelation + * - A sequence of COLUMNS_PER_SET entities/polynomials that represent the first set (N.B. ORDER IS IMPORTANT!) + * - A sequence of COLUMNS_PER_SET entities/polynomials that represent the second set (N.B. ORDER IS IMPORTANT!) + * + * @return All the entities needed for the permutation + */ + {const_entities} + + /** + * @brief Get all the entities for the permutation when need to update them + * + * @details The entities are returned as a tuple of references in the following order: + * - The entity/polynomial used to store the product of the inverse values + * - The entity/polynomial that switches on the subrelation of the permutation relation that ensures correctness of + * the inverse polynomial + * - The entity/polynomial that enables adding a tuple-generated value from the first set to the logderivative sum + * subrelation + * - The entity/polynomial that enables adding a tuple-generated value from the second set to the logderivative sum + * subrelation + * - A sequence of COLUMNS_PER_SET entities/polynomials that represent the first set (N.B. ORDER IS IMPORTANT!) + * - A sequence of COLUMNS_PER_SET entities/polynomials that represent the second set (N.B. ORDER IS IMPORTANT!) + * + * @return All the entities needed for the permutation + */ + {nonconst_entities} + }}; + + {relation_exporter} + }} + " + ) +} + +// TODO: make this dynamic such that there can be more than one +fn create_inverse_computed_at(inverse_selector: String) -> String { + let inverse_computed_selector = format!("in.{inverse_selector}"); + format!(" + template static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in) {{ + return ({inverse_computed_selector} == 1); + }}") +} + +fn create_get_const_entities(settings: &[String]) -> String { + let forward = create_forward_as_tuple(settings); + format!( + " + template static inline auto get_const_entities(const AllEntities& in) {{ + {forward} + }} + " + ) +} + +fn create_get_nonconst_entities(settings: &[String]) -> String { + let forward = create_forward_as_tuple(settings); + format!( + " + template static inline auto get_nonconst_entities(AllEntities& in) {{ + {forward} + }} + " + ) +} + +fn create_forward_as_tuple(settings: &[String]) -> String { + let adjusted = settings.iter().map(|col| format!("in.{col}")).join(",\n"); + format!( + " + return std::forward_as_tuple( + {} + ); + ", + adjusted + ) +} + +fn get_perm_side( + def: &SelectedExpressions>, +) -> PermutationSide { + let get_name = |expr: &AlgebraicExpression| match expr { + AlgebraicExpression::Reference(a_ref) => sanitize_name(&a_ref.name), + _ => panic!("Expected reference"), + }; + + PermutationSide { + selector: def.selector.as_ref().map(|expr| get_name(expr)), + cols: def + .expressions + .iter() + .map(|expr| get_name(expr)) + .collect_vec(), + } +} diff --git a/bberg/src/prover_builder.rs b/bberg/src/prover_builder.rs index 9f35ecc8d9..b7eff21811 100644 --- a/bberg/src/prover_builder.rs +++ b/bberg/src/prover_builder.rs @@ -237,7 +237,7 @@ fn includes_cpp(name: &str) -> String { #include \"{name}_prover.hpp\" #include \"barretenberg/commitment_schemes/claim.hpp\" #include \"barretenberg/commitment_schemes/commitment_key.hpp\" - #include \"barretenberg/honk/proof_system/lookup_library.hpp\" + #include \"barretenberg/honk/proof_system/logderivative_library.hpp\" #include \"barretenberg/honk/proof_system/permutation_library.hpp\" #include \"barretenberg/honk/proof_system/power_polynomial.hpp\" #include \"barretenberg/polynomials/polynomial.hpp\" diff --git a/bberg/src/relation_builder.rs b/bberg/src/relation_builder.rs index 811c8d66f7..bdfcde4973 100644 --- a/bberg/src/relation_builder.rs +++ b/bberg/src/relation_builder.rs @@ -1,40 +1,132 @@ +use ast::analyzed::AlgebraicExpression; use ast::analyzed::Identity; use ast::analyzed::{ AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicUnaryOperator, IdentityKind, }; use ast::parsed::SelectedExpressions; +use itertools::Itertools; +use std::collections::HashMap; use std::collections::HashSet; use number::{DegreeType, FieldElement}; use crate::file_writer::BBFiles; +use crate::utils::capitalize; use crate::utils::map_with_newline; +/// Returned back to the vm builder from the create_relations call +pub struct RelationOutput { + /// A list of the names of the created relations + pub relations: Vec, + /// A list of the names of all of the 'used' shifted polys + pub shifted_polys: Vec, +} + +/// Each created bb Identity is passed around with its degree so as needs to be manually +/// provided for sumcheck +type BBIdentity = (DegreeType, String); + pub trait RelationBuilder { - fn create_relations( + /// Create Relations + /// + /// Takes in the ast ( for relations ), groups each of them by file, and then + /// calls 'create relation' for each + /// + /// Relation output is passed back to the caller as the prover requires both: + /// - The shifted polys + /// - The names of the relations files created + fn create_relations( + &self, + root_name: &str, + identities: &[Identity>], + ) -> RelationOutput; + + /// Create Relation + /// + /// Name and root name are required to determine the file path, e.g. it will be in the bberg/relations/generated + /// followed by /root_name/name + /// - root name should be the name provided with the --name flag + /// - name will be a pil namespace + /// + /// - Identities are the identities that will be used to create the relations, they are generated within create_relations + /// - row_type contains all of the columns that the relations namespace touches. + fn create_relation( &self, root_name: &str, name: &str, sub_relations: &[String], identities: &[BBIdentity], row_type: &str, + labels_lookup: String, ); + /// Declare views + /// + /// Declare views is a macro that generates a reference for each of the columns + /// This reference will be a span into a sumcheck related object, it must be declared for EACH sub-relation + /// as the sumcheck object is sensitive to the degree of the relation. fn create_declare_views(&self, name: &str, all_cols_and_shifts: &[String]); } -// TODO: MOve -> to gen code we need to know the degree of each poly -type BBIdentity = (DegreeType, String); - impl RelationBuilder for BBFiles { - fn create_relations( + fn create_relations( + &self, + file_name: &str, + analyzed_identities: &[Identity>], + ) -> RelationOutput { + // Group relations per file + let grouped_relations: HashMap>>> = + group_relations_per_file(analyzed_identities); + let relations = grouped_relations.keys().cloned().collect_vec(); + + // Contains all of the rows in each relation, will be useful for creating composite builder types + let mut all_rows: HashMap = HashMap::new(); + let mut shifted_polys: Vec = Vec::new(); + + // ----------------------- Create the relation files ----------------------- + for (relation_name, analyzed_idents) in grouped_relations.iter() { + let IdentitiesOutput { + subrelations, + identities, + collected_cols, + collected_shifts, + expression_labels, + } = create_identities(file_name, analyzed_idents); + + // TODO: This can probably be moved into the create_identities function + let row_type = create_row_type(&capitalize(relation_name), &collected_cols); + + // Aggregate all shifted polys + shifted_polys.extend(collected_shifts); + // Aggregate all rows + all_rows.insert(relation_name.to_owned(), row_type.clone()); + + let labels_lookup = create_relation_labels(relation_name, expression_labels); + self.create_relation( + file_name, + relation_name, + &subrelations, + &identities, + &row_type, + labels_lookup, + ); + } + + RelationOutput { + relations, + shifted_polys, + } + } + + fn create_relation( &self, root_name: &str, name: &str, sub_relations: &[String], identities: &[BBIdentity], row_type: &str, + labels_lookup: String, ) { let includes = relation_includes(); let class_boilerplate = relation_class_boilerplate(name, sub_relations, identities); @@ -46,6 +138,8 @@ namespace proof_system::{root_name}_vm {{ {row_type}; +{labels_lookup} + {class_boilerplate} {export} @@ -67,7 +161,7 @@ namespace proof_system::{root_name}_vm {{ let declare_views = format!( " - #define DECLARE_VIEWS(index) \\ + #define {name}_DECLARE_VIEWS(index) \\ using Accumulator = typename std::tuple_element::type; \\ using View = typename Accumulator::View; \\ {make_view_per_row} @@ -84,6 +178,35 @@ namespace proof_system::{root_name}_vm {{ } } +/// Group relations per file +/// +/// The compiler returns all relations in one large vector, however we want to distinguish +/// which files .pil files the relations belong to for later code gen +/// +/// Say we have two files foo.pil and bar.pil +/// foo.pil contains the following relations: +/// - foo1 +/// - foo2 +/// bar.pil contains the following relations: +/// - bar1 +/// - bar2 +/// +/// This function will return a hashmap with the following structure: +/// { +/// "foo": [foo1, foo2], +/// "bar": [bar1, bar2] +/// } +/// +/// This allows us to generate a relation.hpp file containing ONLY the relations for that .pil file +fn group_relations_per_file( + identities: &[Identity>], +) -> HashMap>>> { + identities + .iter() + .cloned() + .into_group_map_by(|identity| identity.source.file.clone().replace(".pil", "")) +} + fn relation_class_boilerplate( name: &str, sub_relations: &[String], @@ -250,33 +373,45 @@ fn craft_expression( } } +pub struct IdentitiesOutput { + subrelations: Vec, + identities: Vec, + collected_cols: Vec, + collected_shifts: Vec, + expression_labels: HashMap, +} + pub(crate) fn create_identities( + file_name: &str, identities: &[Identity>], -) -> (Vec, Vec, Vec, Vec) { +) -> IdentitiesOutput { // We only want the expressions for now // When we have a poly type, we only need the left side of it - let expressions = identities + let ids = identities .iter() - .filter_map(|identity| { - if identity.kind == IdentityKind::Polynomial { - Some(identity.left.clone()) - } else { - None - } - }) + .filter(|identity| identity.kind == IdentityKind::Polynomial) .collect::>(); let mut identities = Vec::new(); let mut subrelations = Vec::new(); + let mut expression_labels: HashMap = HashMap::new(); // Each relation can be given a label, this label can be assigned here let mut collected_cols: HashSet = HashSet::new(); let mut collected_public_identities: HashSet = HashSet::new(); + // Collect labels for each identity + // TODO: shite + for (i, id) in ids.iter().enumerate() { + if let Some(label) = &id.attribute { + expression_labels.insert(i, label.clone()); + } + } + + let expressions = ids.iter().map(|id| id.left.clone()).collect::>(); for (i, expression) in expressions.iter().enumerate() { let relation_boilerplate = format!( - "DECLARE_VIEWS({i}); + "{file_name}_DECLARE_VIEWS({i}); ", ); - // TODO: deal with unwrap // TODO: collected pattern is shit let mut identity = create_identity( @@ -311,6 +446,45 @@ pub(crate) fn create_identities( }) .collect(); - // Returning both for now - (subrelations, identities, collected_cols, collected_shifts) + IdentitiesOutput { + subrelations, + identities, + collected_cols, + collected_shifts, + expression_labels, + } +} + +/// Relation labels +/// +/// To view relation labels we create a sparse switch that contains all of the collected labels +/// Whenever there is a failure, we can lookup into this mapping +/// +/// Note: this mapping will never be that big, so we are quite naive in implementation +/// It should be able to be called from else where with relation_name::get_relation_label +fn create_relation_labels(relation_name: &str, labels: HashMap) -> String { + let label_transformation = |(index, label)| { + format!( + "case {index}: + return \"{label}\"; + " + ) + }; + + let switch_statement: String = labels + .into_iter() + .map(label_transformation) + .collect::>() + .join("\n"); + + format!( + " + inline std::string get_relation_label_{relation_name}(int index) {{ + switch (index) {{ + {switch_statement} + }} + return std::to_string(index); + }} + " + ) } diff --git a/bberg/src/utils.rs b/bberg/src/utils.rs index 6bbb4e43db..676c3fac1d 100644 --- a/bberg/src/utils.rs +++ b/bberg/src/utils.rs @@ -4,12 +4,13 @@ use number::FieldElement; /// /// We may have multiple relation files in the generated foler /// This method will return all of the imports for the relation header files -pub fn get_relations_imports(name: &str, relations: &[String]) -> String { +pub fn get_relations_imports(name: &str, relations: &[String], permutations: &[String]) -> String { + let all_relations = flatten(&[relations.to_vec(), permutations.to_vec()]); let transformation = |relation_name: &_| { format!("#include \"barretenberg/relations/generated/{name}/{relation_name}.hpp\"") }; - map_with_newline(relations, transformation) + map_with_newline(&all_relations, transformation) } /// Sanitize Names diff --git a/bberg/src/verifier_builder.rs b/bberg/src/verifier_builder.rs index 8ac07514da..918231122a 100644 --- a/bberg/src/verifier_builder.rs +++ b/bberg/src/verifier_builder.rs @@ -10,9 +10,11 @@ impl VerifierBuilder for BBFiles { fn create_verifier_cpp(&mut self, name: &str, witness: &[String]) { let include_str = includes_cpp(name); - let wire_transformation = |n: &String| format!( + let wire_transformation = |n: &String| { + format!( "commitments.{n} = transcript->template receive_from_prover(commitment_labels.{n});" - ); + ) + }; let wire_commitments = map_with_newline(witness, wire_transformation); let ver_cpp = format!(" diff --git a/bberg/src/vm_builder.rs b/bberg/src/vm_builder.rs index bbd940369a..7bea6a6a37 100644 --- a/bberg/src/vm_builder.rs +++ b/bberg/src/vm_builder.rs @@ -1,35 +1,44 @@ -use std::collections::HashMap; - -use ast::analyzed::AlgebraicExpression as Expression; use ast::analyzed::Analyzed; -use ast::analyzed::Identity; -use itertools::Itertools; use number::FieldElement; use crate::circuit_builder::CircuitBuilder; use crate::composer_builder::ComposerBuilder; use crate::file_writer::BBFiles; use crate::flavor_builder::FlavorBuilder; +use crate::permutation_builder::get_inverses_from_permutations; +use crate::permutation_builder::Permutation; +use crate::permutation_builder::PermutationBuilder; use crate::prover_builder::ProverBuilder; -use crate::relation_builder::{create_identities, create_row_type, RelationBuilder}; -use crate::utils::capitalize; +use crate::relation_builder::RelationBuilder; +use crate::relation_builder::RelationOutput; use crate::utils::collect_col; use crate::utils::flatten; use crate::utils::sanitize_name; use crate::utils::transform_map; use crate::verifier_builder::VerifierBuilder; +/// All of the combinations of columns that are used in a bberg flavor file struct ColumnGroups { + /// fixed or constant columns in pil -> will be found in vk fixed: Vec, + /// witness or commit columns in pil -> will be found in proof witness: Vec, + /// fixed + witness columns all_cols: Vec, + /// Columns that will not be shifted unshifted: Vec, + /// Columns that will be shifted to_be_shifted: Vec, + /// The shifts of the columns that will be shifted shifted: Vec, + /// fixed + witness + shifted all_cols_with_shifts: Vec, } +/// Analyzed to cpp +/// +/// Converts an analyzed pil AST into a set of cpp files that can be used to generate a proof pub(crate) fn analyzed_to_cpp( analyzed: &Analyzed, fixed: &[(String, Vec)], @@ -42,37 +51,16 @@ pub(crate) fn analyzed_to_cpp( // Inlining step to remove the intermediate poly definitions let analyzed_identities = analyzed.identities_with_inlined_intermediate_polynomials(); - // Group relations per file - let grouped_relations = group_relations_per_file(&analyzed_identities); - let relations = grouped_relations.keys().cloned().collect_vec(); - - // Contains all of the rows in each relation, will be useful for creating composite builder types - // TODO: this will change up - let mut all_rows: HashMap = HashMap::new(); - let mut shifted_polys: Vec = Vec::new(); - - // ----------------------- Create the relation files ----------------------- - for (relation_name, analyzed_idents) in grouped_relations.iter() { - // TODO: make this more granular instead of doing everything at once - let (subrelations, identities, collected_polys, collected_shifts) = - create_identities(analyzed_idents); + // ----------------------- Handle Standard Relation Identities ----------------------- + // We collect all references to shifts as we traverse all identities and create relation files + let RelationOutput { + relations, + shifted_polys, + } = bb_files.create_relations(file_name, &analyzed_identities); - shifted_polys.extend(collected_shifts); - - // let all_cols_with_shifts = combine_cols(collected_polys, collected_shifts); - // TODO: This can probably be moved into the create_identities function - let row_type = create_row_type(&capitalize(relation_name), &collected_polys); - - all_rows.insert(relation_name.clone(), row_type.clone()); - - bb_files.create_relations( - file_name, - relation_name, - &subrelations, - &identities, - &row_type, - ); - } + // ----------------------- Handle Lookup / Permutation Relation Identities ----------------------- + let permutations = bb_files.create_permutation_files(file_name, analyzed); + let inverses = get_inverses_from_permutations(&permutations); // TODO: hack - this can be removed with some restructuring let shifted_polys: Vec = shifted_polys @@ -90,7 +78,7 @@ pub(crate) fn analyzed_to_cpp( to_be_shifted, shifted, all_cols_with_shifts, - } = get_all_col_names(fixed, witness, &shifted_polys); + } = get_all_col_names(fixed, witness, &shifted_polys, &permutations); bb_files.create_declare_views(file_name, &all_cols_with_shifts); @@ -98,6 +86,7 @@ pub(crate) fn analyzed_to_cpp( bb_files.create_circuit_builder_hpp( file_name, &relations, + &inverses, &all_cols, &to_be_shifted, &all_cols_with_shifts, @@ -107,6 +96,7 @@ pub(crate) fn analyzed_to_cpp( bb_files.create_flavor_hpp( file_name, &relations, + &permutations, &fixed, &witness, &all_cols, @@ -128,35 +118,6 @@ pub(crate) fn analyzed_to_cpp( bb_files.create_prover_hpp(file_name); } -/// Group relations per file -/// -/// The compiler returns all relations in one large vector, however we want to distinguish -/// which files .pil files the relations belong to for later code gen -/// -/// Say we have two files foo.pil and bar.pil -/// foo.pil contains the following relations: -/// - foo1 -/// - foo2 -/// bar.pil contains the following relations: -/// - bar1 -/// - bar2 -/// -/// This function will return a hashmap with the following structure: -/// { -/// "foo": [foo1, foo2], -/// "bar": [bar1, bar2] -/// } -/// -/// This allows us to generate a relation.hpp file containing ONLY the relations for that .pil file -fn group_relations_per_file( - identities: &[Identity>], -) -> HashMap>>> { - identities - .iter() - .cloned() - .into_group_map_by(|identity| identity.source.file.clone().replace(".pil", "")) -} - /// Get all col names /// /// In the flavor file, there are a number of different groups of columns that we need to keep track of @@ -171,14 +132,18 @@ fn get_all_col_names( fixed: &[(String, Vec)], witness: &[(String, Vec)], to_be_shifted: &[String], + permutations: &[Permutation], ) -> ColumnGroups { // Transformations let sanitize = |(name, _): &(String, Vec)| sanitize_name(name).to_owned(); let append_shift = |name: &String| format!("{}_shift", *name); + let perm_inverses = get_inverses_from_permutations(permutations); + // Gather sanitized column names let fixed_names = collect_col(fixed, sanitize); let witness_names = collect_col(witness, sanitize); + let witness_names = flatten(&[witness_names, perm_inverses]); // Group columns by properties let shifted = transform_map(to_be_shifted, append_shift); @@ -191,7 +156,6 @@ fn get_all_col_names( let all_cols_with_shifts: Vec = flatten(&[fixed_names.clone(), witness_names.clone(), shifted.clone()]); - // TODO: remove dup ColumnGroups { fixed: fixed_names, witness: witness_names, diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index bb8d50c351..3396d33ee4 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -270,7 +270,7 @@ pub fn convert_analyzed_to_pil_with_callback( force_overwrite: bool, prove_with: Option, external_witness_values: Vec<(&str, Vec)>, - bname: Option + bname: Option, ) -> Result<(PathBuf, Option>), Vec> { let mut monitor = DiffMonitor::default(); let analyzed = compile_asm_string_to_analyzed_ast(file_name, contents, Some(&mut monitor))?; @@ -307,7 +307,7 @@ pub fn compile_asm_string( force_overwrite, prove_with, external_witness_values, - bname + bname, ) } @@ -321,7 +321,7 @@ pub fn compile_asm_string_with_callback>( force_overwrite: bool, prove_with: Option, external_witness_values: Vec<(&str, Vec)>, - bname: Option + bname: Option, ) -> Result<(PathBuf, Option>), Vec> { let mut monitor = DiffMonitor::default(); let analyzed = compile_asm_string_to_analyzed_ast(file_name, contents, Some(&mut monitor))?; @@ -337,7 +337,7 @@ pub fn compile_asm_string_with_callback>( force_overwrite, prove_with, external_witness_values, - bname + bname, ) } diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 311ab1f2c6..0909e43a4c 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -110,6 +110,32 @@ mod test { ); } + #[test] + fn parse_permutation_attribute() { + let parsed = powdr::PILFileParser::new() + .parse::( + " + #[attribute] + { f } is { g };", + ) + .unwrap(); + assert_eq!( + parsed, + PILFile(vec![PilStatement::PermutationIdentity( + 13, + Some("attribute".to_string()), + SelectedExpressions { + selector: None, + expressions: vec![direct_reference("f")] + }, + SelectedExpressions { + selector: None, + expressions: vec![direct_reference("g")] + } + )]) + ); + } + fn parse_file(name: &str) -> PILFile { let file = std::path::PathBuf::from(format!( "{}/../test_data/{name}", diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 4682c19661..fc1075aeb4 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -127,7 +127,7 @@ PolynomialCommitDeclaration: PilStatement = { } PolynomialIdentity: PilStatement = { - "=" => PilStatement::PolynomialIdentity(start, Expression::BinaryOperation(l, BinaryOperator::Sub, r)) + "=" => PilStatement::PolynomialIdentity(start, attr, Expression::BinaryOperation(l, BinaryOperator::Sub, r)) } PolynomialNameList: Vec> = { @@ -147,8 +147,12 @@ SelectedExpressions: SelectedExpressions> = { Expression => SelectedExpressions{selector: None, expressions: vec![<>]}, } +Attribute: String = { + "#[" "]" => <>.to_string() +} + PermutationIdentity: PilStatement = { - <@L> "is" => PilStatement::PermutationIdentity(<>) + <@L> "is" => PilStatement::PermutationIdentity(<>) } ConnectIdentity: PilStatement = { diff --git a/pil_analyzer/src/condenser.rs b/pil_analyzer/src/condenser.rs index 0cd8c8edeb..6a8f56f677 100644 --- a/pil_analyzer/src/condenser.rs +++ b/pil_analyzer/src/condenser.rs @@ -119,6 +119,7 @@ impl Condenser { .map(|constraint| Identity { id: identity.id, kind: identity.kind, + attribute: identity.attribute.clone(), source: identity.source.clone(), left: SelectedExpressions { selector: Some(constraint), @@ -131,6 +132,7 @@ impl Condenser { vec![Identity { id: identity.id, kind: identity.kind, + attribute: identity.attribute.clone(), source: identity.source.clone(), left: self.condense_selected_expressions(&identity.left), right: self.condense_selected_expressions(&identity.right), diff --git a/pil_analyzer/src/statement_processor.rs b/pil_analyzer/src/statement_processor.rs index 42511d7b6f..22c081f8e4 100644 --- a/pil_analyzer/src/statement_processor.rs +++ b/pil_analyzer/src/statement_processor.rs @@ -198,11 +198,21 @@ where } fn handle_identity_statement(&mut self, statement: PilStatement) -> Vec> { - let (start, kind, left, right) = match statement { - PilStatement::PolynomialIdentity(start, expression) - | PilStatement::Expression(start, expression) => ( + let (start, kind, attribute, left, right) = match statement { + PilStatement::PolynomialIdentity(start, attr, expression) => ( start, IdentityKind::Polynomial, + attr, + SelectedExpressions { + selector: Some(self.process_expression(expression)), + expressions: vec![], + }, + SelectedExpressions::default(), + ), + PilStatement::Expression(start, expression) => ( + start, + IdentityKind::Polynomial, + None, SelectedExpressions { selector: Some(self.process_expression(expression)), expressions: vec![], @@ -212,18 +222,21 @@ where PilStatement::PlookupIdentity(start, key, haystack) => ( start, IdentityKind::Plookup, + None, self.process_selected_expressions(key), self.process_selected_expressions(haystack), ), - PilStatement::PermutationIdentity(start, left, right) => ( + PilStatement::PermutationIdentity(start, attribute, left, right) => ( start, IdentityKind::Permutation, + attribute.clone(), self.process_selected_expressions(left), self.process_selected_expressions(right), ), PilStatement::ConnectIdentity(start, left, right) => ( start, IdentityKind::Connect, + None, SelectedExpressions { selector: None, expressions: self.expression_processor().process_expressions(left), @@ -242,6 +255,7 @@ where vec![PILItem::Identity(Identity { id: self.counters.dispense_identity_id(kind), kind, + attribute, source: self.driver.source_position_to_source_ref(start), left, right, diff --git a/riscv/tests/common/mod.rs b/riscv/tests/common/mod.rs index d34d062132..505b411f1f 100644 --- a/riscv/tests/common/mod.rs +++ b/riscv/tests/common/mod.rs @@ -29,6 +29,7 @@ pub fn verify_riscv_asm_string(file_name: &str, contents: &str, inputs: Vec TypeChecker { let errors: Vec<_> = statements .iter() .filter_map(|s| match s { - ast::parsed::PilStatement::PolynomialIdentity(_, _) => None, - ast::parsed::PilStatement::PermutationIdentity(_, l, _) + ast::parsed::PilStatement::PolynomialIdentity(_, _, _) => None, + ast::parsed::PilStatement::PermutationIdentity( + _, // + _, // + l, + _, + ) | ast::parsed::PilStatement::PlookupIdentity(_, l, _) => l .selector .is_some()