diff --git a/.gitignore b/.gitignore index ddba9612..121ab618 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ -/target +nfa.json +target /compiler/pkg -/compiler/.yarn \ No newline at end of file +/compiler/.yarn + +Prover.toml +prover.toml +inputs.txt \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 77e96016..644fa386 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bumpalo" version = "3.17.0" @@ -75,9 +81,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.36" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2df961d8c8a0d08aa9945718ccf584145eee3f3aa06cddbeac12933781102e04" +checksum = "eccb054f56cbd38340b380d4a8e69ef1f02f1af43db2f0cc817a4774d80ae071" dependencies = [ "clap_builder", "clap_derive", @@ -85,9 +91,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.36" +version = "4.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "132dbda40fb6753878316a489d5a1242a8ef2f0d9e47ba01c951ea8aa7d013a5" +checksum = "efd9466fac8543255d3b1fcad4762c5e116ffe808c8a3043d4263cd4fd4862a2" dependencies = [ "anstream", "anstyle", @@ -124,6 +130,7 @@ name = "compiler" version = "2.0.0-alpha.1" dependencies = [ "clap", + "comptime", "heck", "regex-automata", "serde", @@ -133,12 +140,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "comptime" +version = "0.1.0" +source = "git+https://github.com/jp4g/sparse_array?branch=feat%2Fcomptime-codegen#36d5251606d1044759de79b941b0d5d0600c695d" +dependencies = [ + "hex", + "lazy_static", + "num-bigint", + "num-traits", +] + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -161,6 +185,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "log" version = "0.4.27" @@ -173,6 +203,34 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.21.3" diff --git a/Cargo.toml b/Cargo.toml index 4908f205..05f7e40a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,8 @@ heck = "0.5.0" regex-automata = "0.4.9" serde = "1.0.219" serde_json = "1.0.140" -serde-wasm-bindgen = "0.6.5" thiserror = "2.0.12" +serde-wasm-bindgen = "0.6.5" wasm-bindgen = "0.2.100" resolver = "2" \ No newline at end of file diff --git a/compiler/Cargo.toml b/compiler/Cargo.toml index 729218d9..a83fd6c7 100644 --- a/compiler/Cargo.toml +++ b/compiler/Cargo.toml @@ -16,6 +16,7 @@ heck = "0.5.0" regex-automata = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } -serde-wasm-bindgen = { workspace = true } +comptime = { git = "https://github.com/jp4g/sparse_array", branch = "feat/comptime-codegen" } thiserror = { workspace = true } -wasm-bindgen = { workspace = true } \ No newline at end of file +serde-wasm-bindgen = { workspace = true } +wasm-bindgen = { workspace = true } diff --git a/compiler/output/Rando.circom b/compiler/output/Rando.circom deleted file mode 100644 index a84ff77d..00000000 --- a/compiler/output/Rando.circom +++ /dev/null @@ -1,103 +0,0 @@ -pragma circom 2.1.5; - -include "@zk-email/zk-regex-circom/circuits/regex_helpers.circom"; - -// regex: a*b -template RandoRegex(maxBytes) { -signal input currStates[maxBytes]; -signal input haystack[maxBytes]; -signal input nextStates[maxBytes]; -signal input traversalPathLength; - -var numStartStates = 4; -var numAcceptStates = 2; -var numTransitions = 11; -var startStates[numStartStates] = [0, 2, 3, 1]; -var acceptStates[numAcceptStates] = [4, 5]; - -signal isCurrentState[numTransitions][maxBytes]; -signal isNextState[numTransitions][maxBytes]; -signal isValidTransition[numTransitions][maxBytes]; -signal reachedLastTransition[maxBytes]; -signal reachedAcceptState[maxBytes]; -signal isValidRegex[maxBytes]; -signal isValidRegexTemp[maxBytes]; -signal isWithinPathLength[maxBytes]; -signal isTransitionLinked[maxBytes]; - -component isValidTraversal[maxBytes]; - - // Check if the first state in the haystack is a valid start state - component isValidStartState; - isValidStartState = MultiOR(numStartStates); - for (var i = 0; i < numStartStates; i++) { - isValidStartState.in[i] <== IsEqual()([startStates[i], currStates[0]]); - } - isValidStartState.out === 1; - - // Check if the traversal path has valid transitions - for (var i = 0; i < maxBytes; i++) { - isWithinPathLength[i] <== LessThan(log2Ceil(maxBytes))([i, traversalPathLength]); - - // Check if the traversal is a valid path - if (i != maxBytes - 1) { - isTransitionLinked[i] <== IsEqual()([nextStates[i], currStates[i+1]]); - isTransitionLinked[i] === isWithinPathLength[i]; - } - - // Transition 0: 0 -[0-255]-> 0 -isValidTransition[0][i] <== CheckByteRangeTransition()(0, 0, 0, 255, currStates[i], nextStates[i], haystack[i]); - - // Transition 1: 0 -[97]-> 4 -isValidTransition[1][i] <== CheckByteTransition()(0, 4, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 2: 1 -[0-255]-> 0 -isValidTransition[2][i] <== CheckByteRangeTransition()(1, 0, 0, 255, currStates[i], nextStates[i], haystack[i]); - - // Transition 3: 2 -[97]-> 4 -isValidTransition[3][i] <== CheckByteTransition()(2, 4, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 4: 3 -[97]-> 4 -isValidTransition[4][i] <== CheckByteTransition()(3, 4, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 5: 4 -[97]-> 5 -isValidTransition[5][i] <== CheckByteTransition()(4, 5, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 6: 4 -[97]-> 5 -isValidTransition[6][i] <== CheckByteTransition()(4, 5, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 7: 4 -[97]-> 5 -isValidTransition[7][i] <== CheckByteTransition()(4, 5, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 8: 5 -[97]-> 5 -isValidTransition[8][i] <== CheckByteTransition()(5, 5, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 9: 5 -[97]-> 5 -isValidTransition[9][i] <== CheckByteTransition()(5, 5, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 10: 6 -[97]-> 5 -isValidTransition[10][i] <== CheckByteTransition()(6, 5, 97, currStates[i], nextStates[i], haystack[i]); - - // Combine all valid transitions for this byte - isValidTraversal[i] = MultiOR(numTransitions); - for (var j = 0; j < numTransitions; j++) { - isValidTraversal[i].in[j] <== isValidTransition[j][i]; - } - isValidTraversal[i].out === isWithinPathLength[i]; - - // Check if any accept state has been reached at the last transition - reachedLastTransition[i] <== IsEqual()([i, traversalPathLength]); - component isAcceptState = MultiOR(numAcceptStates); - for (var j = 0; j < numAcceptStates; j++) { - isAcceptState.in[j] <== IsEqual()([acceptStates[j], nextStates[i]]); - } - reachedAcceptState[i] <== isAcceptState.out; - isValidRegexTemp[i] <== AND()(reachedLastTransition[i], reachedAcceptState[i]); if (i == 0) { - isValidRegex[i] <== isValidRegexTemp[i]; - } else { - isValidRegex[i] <== isValidRegexTemp[i] + isValidRegex[i-1]; - } - } - - isValidRegex[maxBytes-1] === 1; -} diff --git a/compiler/output/Regex.circom b/compiler/output/Regex.circom deleted file mode 100644 index 710c1d60..00000000 --- a/compiler/output/Regex.circom +++ /dev/null @@ -1,166 +0,0 @@ -pragma circom 2.1.5; - -template CheckByteTransition() { - signal input currState; - signal input nextState; - signal input byte; - signal input captureGroupId; - signal input captureGroupStart; - - signal input inCurrState; - signal input inNextState; - signal input inByte; - signal input inCaptureGroupId; - signal input inCaptureGroupStart; - - signal output out; - - signal isCurrentState <== IsEqual()([currState, inCurrState]); - signal isNextState <== IsEqual()([nextState, inNextState]); - signal isByteEqual <== IsEqual()([byte, inByte]); - signal isCaptureGroupEqual <== IsEqual()([captureGroupId, inCaptureGroupId]); - signal isCaptureGroupStartEqual <== IsEqual()([captureGroupStart, inCaptureGroupStart]); - - out <== MultiAND(5)([isCurrentState, isNextState, isByteEqual, isCaptureGroupEqual, isCaptureGroupStartEqual]); -} - -template CheckByteRangeTransition() { - signal input currState; - signal input nextState; - signal input byteStart; - signal input byteEnd; - signal input captureGroupId; - signal input captureGroupStart; - - signal input inCurrState; - signal input inNextState; - signal input inByte; - signal input inCaptureGroupId; - signal input inCaptureGroupStart; - - signal output out; - - signal isCurrentState <== IsEqual()([currState, inCurrState]); - signal isNextState <== IsEqual()([nextState, inNextState]); - signal isCaptureGroupEqual <== IsEqual()([captureGroupId, inCaptureGroupId]); - signal isCaptureGroupStartEqual <== IsEqual()([captureGroupStart, inCaptureGroupStart]); - - signal isByteValid[2]; - isByteValid[0] <== GreaterEqThan(8)([inByte, byteStart]); - isByteValid[1] <== LessEqThan(8)([inByte, byteEnd]); - - out <== MultiAND(5)([isCurrentState, isNextState, isByteValid[0], isByteValid[1], isCaptureGroupEqual, isCaptureGroupStartEqual]); -} - -template Regex(maxBytes) { - signal input currStates[maxBytes]; - signal input haystack[maxBytes]; - signal input nextStates[maxBytes]; - signal input captureGroupIds[maxBytes]; - signal input captureGroupStarts[maxBytes]; - signal input traversalPathLength; - - var numStartStates = 6; - var numAcceptStates = 1; - var numTransitions = 10; - var startStates[numStartStates] = [4, 3, 2, 0, 5, 1]; - var acceptStates[numAcceptStates] = [6]; - - signal isCurrentState[numTransitions][maxBytes]; - signal isNextState[numTransitions][maxBytes]; - signal isValidTransition[numTransitions][maxBytes]; - signal reachedLastTransition[maxBytes]; - signal reachedAcceptState[maxBytes]; - signal isValidRegex[maxBytes]; - signal isValidRegexTemp[maxBytes]; - signal isWithinPathLength[maxBytes]; - signal isTransitionLinked[maxBytes]; - - component isValidStartState; - component isValidTraversal[maxBytes]; - - // Check if the first state in the haystack is a valid start state - isValidStartState = MultiOR(numStartStates); - for (var i = 0; i < numStartStates; i++) { - isValidStartState.in[i] <== IsEqual()([startStates[i], currStates[0]]); - } - isValidStartState.out === 1; - - // Check if the traversal path has valid transitions - for (var i = 0; i < maxBytes; i++) { - isWithinPathLength[i] <== LessThan(log2Ceil(maxBytes))([i, traversalPathLength]); - - // Check if the traversal is a valid path - if (i !== maxBytes - 1) { - isTransitionLinked[i] <== IsEqual()([nextStates[i], currStates[i+1]]); - isTransitionLinked[i] === isWithinPathLength[i]; - } - - // Transition 0: 0 -[0-255]-> 0 - isValidTransition[0][i] <== CheckByteRangeTransition()(0, 0, 0, 255, currStates[i], nextStates[i], haystack[i]); - - // Transition 1: 0 -[97-97]-> 3 - isValidTransition[1][i] <== CheckByteTransition()(0, 3, 97, 1, 1, currStates[i], nextStates[i], haystack[i]); - - // Transition 2: 0 -[98-98]-> 6 - isValidTransition[2][i] <== CheckByteTransition()(0, 6, 98, currStates[i], nextStates[i], haystack[i]) - - // Transition 3: 1 -[0-255]-> 0 - isValidTransition[3][i] <== CheckByteRangeTransition()(1, 0, 0, 255, currStates[i], nextStates[i], haystack[i]); - - // Transition 4: 2 -[98-98]-> 6 - isValidTransition[4][i] <== CheckByteTransition()(2, 6, 98, currStates[i], nextStates[i], haystack[i]); - - // Transition 5: 2 -[97-97]-> 3 - isValidTransition[5][i] <== CheckByteTransition()(2, 3, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 6: 3 -[98-98]-> 6 - isValidTransition[6][i] <== CheckByteTransition()(3, 6, 98, currStates[i], nextStates[i], haystack[i]); - - // Transition 7: 3 -[97-97]-> 3 - isValidTransition[7][i] <== CheckByteTransition()(3, 3, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 8: 4 -[97-97]-> 3 - isValidTransition[8][i] <== CheckByteTransition()(4, 3, 97, currStates[i], nextStates[i], haystack[i]); - - // Transition 9: 5 -[98-98]-> 6 - isValidTransition[9][i] <== CheckByteTransition()(5, 6, 98, currStates[i], nextStates[i], haystack[i]); - - // Combine all valid transitions for this byte - isValidTraversal[i] = MultiOR(numTransitions); - for (var j = 0; j < numTransitions; j++) { - isValidTraversal[i].in[j] <== isValidTransition[j][i]; - } - isValidTraversal[i].out === isWithinPathLength[i]; - - // Check if any accept state has been reached at the last transition - reachedLastTransition[i] <== IsEqual()([i, traversalPathLength]); - reachedAcceptState[i] <== IsEqual()([nextStates[i], acceptStates[0]]); - isValidRegexTemp[i] <== AND()(reachedLastTransition[i], reachedAcceptState[i]); - if (i == 0) { - isValidRegex[i] <== isValidRegexTemp[i]; - } else { - isValidRegex[i] <== isValidRegexTemp[i] + isValidRegex[i-1]; - } - } - isValidRegex[maxBytes-1] === 1; - - // Capture group 1 - signal isCaptureGroup1[maxBytes]; - signal isCaptureGroupStart1[maxBytes]; - signal isCaptureGroupEnd1[maxBytes]; - signal isValidCaptureStart1[maxBytes]; - signal isValidCaptureEnd1[maxBytes]; - for (var i = 0; i < maxBytes; i++) { - isCaptureGroup1[i] <== IsEqual()([captureGroupIds[i], 1]); - isCaptureGroupStart1[i] <== IsEqual()([captureGroupStarts[i], 1]); - isCaptureGroupEnd1[i] <== IsEqual()([captureGroupStarts[i], 0]); - if (i == 0) { - isValidCaptureStart1[i] <== AND()(isCaptureGroup1[i], isCaptureGroupStart1[i]); - isValidCaptureEnd1[i] <== AND()(isCaptureGroup1[i], isCaptureGroupEnd1[i]); - } else { - isValidCaptureStart1[i] <== AND()(isCaptureGroup1[i], isCaptureGroupStart1[i]) + isValidCaptureStart1[i-1]; - isValidCaptureEnd1[i] <== AND()(isCaptureGroup1[i], isCaptureGroupEnd1[i]) + isValidCaptureEnd1[i-1]; - } - } -} \ No newline at end of file diff --git a/compiler/src/bin/zk-regex.rs b/compiler/src/bin/zk-regex.rs index 66a148a0..1cc1db49 100644 --- a/compiler/src/bin/zk-regex.rs +++ b/compiler/src/bin/zk-regex.rs @@ -20,11 +20,15 @@ enum Commands { /// Directory path for output files #[arg(short, long)] - circom_file_path: PathBuf, + output_file_path: PathBuf, /// Template name in PascalCase (e.g., TimestampRegex) #[arg(short, long, value_parser = validate_cli_template_name)] template_name: String, + + /// Noir boolean + #[arg(long)] + noir: bool, }, /// Process a raw regex string @@ -35,15 +39,19 @@ enum Commands { /// Directory path for output files #[arg(short, long)] - circom_file_path: PathBuf, + output_file_path: PathBuf, /// Template name in PascalCase (e.g., TimestampRegex) #[arg(short, long, value_parser = validate_cli_template_name)] template_name: String, + + /// Noir boolean + #[arg(long)] + noir: bool, }, /// Generate circuit inputs from a cached graph - GenerateCircomInput { + GenerateCircuitInput { /// Path to the graph JSON file #[arg(short, long)] graph_path: PathBuf, @@ -53,7 +61,7 @@ enum Commands { input: String, /// Maximum haystack length - #[arg(short = 'h', long)] + #[arg(short = 'l', long)] max_haystack_len: usize, /// Maximum match length @@ -63,6 +71,10 @@ enum Commands { /// Output JSON file for circuit inputs #[arg(short, long)] output: PathBuf, + + /// Generate inputs for Noir + #[arg(short, long)] + noir: Option }, } @@ -83,6 +95,7 @@ fn save_outputs( circom_code: String, output_dir: &PathBuf, template_name: &str, + file_extension: &str, ) -> Result<(), Box> { validate_cli_template_name(template_name)?; @@ -91,9 +104,9 @@ fn save_outputs( let snake_case_name = template_name.to_snake_case(); - // Save Circom file - let circom_path = output_dir.join(format!("{}_regex.circom", snake_case_name)); - std::fs::write(&circom_path, circom_code)?; + // Save circuit file + let circuit_path = output_dir.join(format!("{}_regex.{}", snake_case_name, file_extension)); + std::fs::write(&circuit_path, circom_code)?; // Save graph JSON let graph_json = nfa.to_json()?; @@ -101,7 +114,7 @@ fn save_outputs( std::fs::write(&graph_path, graph_json)?; println!("Generated files:"); - println!(" Circuit: {}", circom_path.display()); + println!(" Circuit: {}", circuit_path.display()); println!(" Graph: {}", graph_path.display()); Ok(()) @@ -113,8 +126,9 @@ fn main() -> Result<(), Box> { match cli.command { Commands::Decomposed { decomposed_regex_path, - circom_file_path, + output_file_path, template_name, + noir, } => { let config: DecomposedRegexConfig = serde_json::from_reader(File::open(decomposed_regex_path)?)?; @@ -122,44 +136,72 @@ fn main() -> Result<(), Box> { let (combined_pattern, max_bytes) = decomposed_to_composed_regex(&config); let nfa = compile(&combined_pattern)?; - - let circom_code = if !max_bytes.is_empty() { - nfa.generate_circom_code(&template_name, &combined_pattern, Some(&max_bytes))? - } else { - nfa.generate_circom_code(&template_name, &combined_pattern, None)? + let max_bytes = match max_bytes.is_empty() { + true => None, + false => Some(&max_bytes[..]), + }; + let code = match noir { + true => nfa.generate_noir_code(&combined_pattern, max_bytes)?, + false => nfa.generate_circom_code(&template_name, &combined_pattern, max_bytes)?, }; - save_outputs(&nfa, circom_code, &circom_file_path, &template_name)?; + let file_extension = if noir { "nr" } else { "circom" }; + save_outputs( + &nfa, + code, + &output_file_path, + &template_name, + &file_extension, + )?; } Commands::Raw { raw_regex, - circom_file_path, + output_file_path, template_name, + noir, } => { let nfa = compile(&raw_regex)?; - let circom_code = nfa.generate_circom_code(&template_name, &raw_regex, None)?; + let code = if noir { + nfa.generate_noir_code(&raw_regex, None)? + } else { + nfa.generate_circom_code(&template_name, &raw_regex, None)? + }; - save_outputs(&nfa, circom_code, &circom_file_path, &template_name)?; + // Create output file path by combining directory and template name + let file_extension = if noir { ".nr" } else { ".circom" }; + save_outputs( + &nfa, + code, + &output_file_path, + &template_name, + &file_extension, + )?; } - Commands::GenerateCircomInput { + Commands::GenerateCircuitInput { graph_path, input, max_haystack_len, max_match_len, output, + noir } => { // Load the cached graph let graph_json = std::fs::read_to_string(graph_path)?; let nfa = NFAGraph::from_json(&graph_json)?; // Generate circuit inputs - let inputs = nfa.generate_circom_inputs(&input, max_haystack_len, max_match_len)?; + let inputs = nfa.generate_circuit_inputs(&input, max_haystack_len, max_match_len)?; // Save inputs - let input_json = serde_json::to_string_pretty(&inputs)?; - std::fs::write(&output, input_json)?; + if noir.is_none_or(|x| !x ) { + let input_json = serde_json::to_string_pretty(&inputs)?; + std::fs::write(&output, input_json)?; + } else { + let input_toml = NFAGraph::to_prover_toml(&inputs); + std::fs::write(&output, input_toml)?; + } println!("Generated circuit inputs: {}", output.display()); } diff --git a/compiler/src/nfa/codegen/circom.rs b/compiler/src/nfa/codegen/circom.rs index bfffeeb7..415cf70e 100644 --- a/compiler/src/nfa/codegen/circom.rs +++ b/compiler/src/nfa/codegen/circom.rs @@ -14,12 +14,11 @@ //! - Start/accept state validation use std::collections::{HashMap, HashSet}; - -use regex_automata::meta::Regex; use serde::Serialize; use crate::nfa::NFAGraph; use crate::nfa::error::{NFAError, NFAResult}; +use crate::nfa::codegen::CircuitInputs; #[derive(Serialize)] pub struct CircomInputs { @@ -42,97 +41,27 @@ pub struct CircomInputs { #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "captureGroupStartIndices")] capture_group_start_indices: Option>, + #[serde(rename = "traversalPathLength")] + traversal_path_length: usize, } -impl NFAGraph { - /// Generates the core data needed for Circom circuit generation. - /// - /// Returns: - /// - Vector of start states - /// - Vector of accept states - /// - Vector of transitions: (from_state, min_byte, max_byte, to_state, capture_info) - /// - /// The transitions are compressed into byte ranges for efficiency. - pub fn generate_circom_data( - &self, - ) -> NFAResult<( - Vec, - Vec, - Vec<(usize, u8, u8, usize, Option<(usize, bool)>)>, - )> { - if self.start_states.is_empty() { - return Err(NFAError::Verification("NFA has no start states".into())); - } - if self.accept_states.is_empty() { - return Err(NFAError::Verification("NFA has no accept states".into())); - } - - let start_states = self.start_states.iter().copied().collect(); - let accept_states = self.accept_states.iter().copied().collect(); - - let transitions = self.get_transitions_with_capture_info(); - if transitions.is_empty() { - return Err(NFAError::Verification("NFA has no transitions".into())); +impl From for CircomInputs { + fn from(inputs: CircuitInputs) -> Self { + CircomInputs { + in_haystack: inputs.in_haystack, + match_start: inputs.match_start, + match_length: inputs.match_length, + curr_states: inputs.curr_states, + next_states: inputs.next_states, + capture_group_ids: inputs.capture_group_ids, + capture_group_starts: inputs.capture_group_starts, + capture_group_start_indices: inputs.capture_group_start_indices, + traversal_path_length: inputs.traversal_path_length, } - - // Group and convert to ranges - let mut range_transitions = Vec::new(); - let mut grouped: HashMap<(usize, usize, Option<(usize, bool)>), Vec> = HashMap::new(); - - for (src, byte, dst, capture) in transitions { - if src >= self.nodes.len() || dst >= self.nodes.len() { - return Err(NFAError::InvalidStateId(format!( - "State {}->{} out of bounds", - src, dst - ))); - } - grouped.entry((src, dst, capture)).or_default().push(byte); - } - - // Convert to ranges - for ((src, dst, capture), mut bytes) in grouped { - if bytes.is_empty() { - continue; - } - - bytes.sort_unstable(); - let mut start = bytes[0]; - let mut prev = start; - - for &byte in &bytes[1..] { - if byte != prev + 1 { - range_transitions.push((src, start, prev, dst, capture)); - start = byte; - } - prev = byte; - } - range_transitions.push((src, start, prev, dst, capture)); - } - - Ok((start_states, accept_states, range_transitions)) - } - - /// Escapes special characters in regex patterns for display in Circom comments. - /// Handles newlines, quotes, control characters etc. - fn escape_regex_for_display(pattern: &str) -> String { - pattern - .chars() - .map(|c| match c { - '\n' => "\\n".to_string(), - '\r' => "\\r".to_string(), - '\t' => "\\t".to_string(), - '\\' => "\\\\".to_string(), - '\0' => "\\0".to_string(), - '\'' => "\\'".to_string(), - '\"' => "\\\"".to_string(), - '\x08' => "\\b".to_string(), - '\x0c' => "\\f".to_string(), - c if c.is_ascii_control() => format!("\\x{:02x}", c as u8), - c => c.to_string(), - }) - .collect() } +} +impl NFAGraph { /// Generates complete Circom circuit code for the NFA. /// /// # Arguments @@ -157,7 +86,7 @@ impl NFAGraph { return Err(NFAError::InvalidInput("Empty regex name".into())); } - let (start_states, accept_states, transitions) = self.generate_circom_data()?; + let (start_states, accept_states, transitions) = self.generate_circuit_data()?; // Validate capture groups let capture_group_set: HashSet<_> = transitions @@ -488,88 +417,4 @@ impl NFAGraph { Ok(code) } - - pub fn generate_circom_inputs( - &self, - haystack: &str, - max_haystack_len: usize, - max_match_len: usize, - ) -> NFAResult { - let haystack_bytes = haystack.as_bytes(); - - if haystack_bytes.len() > max_haystack_len { - return Err(NFAError::InvalidInput(format!( - "Haystack length {} exceeds maximum length {}", - haystack_bytes.len(), - max_haystack_len - ))); - } - - // Generate path traversal - let result = self.get_path_to_accept(haystack_bytes)?; - let path = result.path; - let (match_start, match_length) = result.span; - let path_len = path.len(); - - if path_len > max_match_len { - return Err(NFAError::InvalidInput(format!( - "Path length {} exceeds maximum length {}", - path_len, max_match_len - ))); - } - - // Extract and pad arrays to max_haystack_len - let mut curr_states = path.iter().map(|(curr, _, _, _)| *curr).collect::>(); - let mut next_states = path.iter().map(|(_, next, _, _)| *next).collect::>(); - let mut in_haystack = haystack_bytes.to_vec(); - - // Pad with zeros - curr_states.resize(max_match_len, 136279841); - next_states.resize(max_match_len, 136279842); - in_haystack.resize(max_haystack_len, 0); - - // Handle capture groups if they exist - let (capture_group_ids, capture_group_starts, capture_group_start_indices) = - if path.iter().any(|(_, _, _, c)| c.is_some()) { - let mut ids = path - .iter() - .map(|(_, _, _, c)| c.map(|(id, _)| id).unwrap_or(0)) - .collect::>(); - let mut starts = path - .iter() - .map(|(_, _, _, c)| c.map(|(_, start)| start as u8).unwrap_or(0)) - .collect::>(); - - // Use regex_automata to get capture start indices - let re = Regex::new(&self.regex).map_err(|e| { - NFAError::RegexCompilation(format!("Failed to compile regex: {}", e)) - })?; - let mut captures = re.create_captures(); - re.captures(&haystack, &mut captures); - - let start_indices = (1..=captures.group_len()) - .filter_map(|i| captures.get_group(i)) - .map(|m| m.start) - .collect(); - - // Pad arrays - ids.resize(max_match_len, 0); - starts.resize(max_match_len, 0); - - (Some(ids), Some(starts), Some(start_indices)) - } else { - (None, None, None) - }; - - Ok(CircomInputs { - in_haystack, - match_start, - match_length, - curr_states, - next_states, - capture_group_ids, - capture_group_starts, - capture_group_start_indices, - }) - } } diff --git a/compiler/src/nfa/codegen/mod.rs b/compiler/src/nfa/codegen/mod.rs index b1edfc86..7f0970fa 100644 --- a/compiler/src/nfa/codegen/mod.rs +++ b/compiler/src/nfa/codegen/mod.rs @@ -1,5 +1,194 @@ //! Code generation module for converting NFAs to various output formats. -mod circom; +pub mod circom; +pub mod noir; -pub use circom::*; +use regex_automata::meta::Regex; +use serde::Serialize; +use std::collections::HashMap; + +use crate::nfa::{ + NFAGraph, + error::{NFAError, NFAResult}, +}; + +#[derive(Serialize)] +pub struct CircuitInputs { + in_haystack: Vec, + match_start: usize, + match_length: usize, + curr_states: Vec, + next_states: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + capture_group_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + capture_group_starts: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + capture_group_start_indices: Option>, + traversal_path_length: usize, +} + +impl NFAGraph { + pub fn generate_circuit_data( + &self, + ) -> NFAResult<( + Vec, + Vec, + Vec<(usize, u8, u8, usize, Option<(usize, bool)>)>, + )> { + if self.start_states.is_empty() { + return Err(NFAError::Verification("NFA has no start states".into())); + } + if self.accept_states.is_empty() { + return Err(NFAError::Verification("NFA has no accept states".into())); + } + + let start_states = self.start_states.iter().copied().collect(); + let accept_states = self.accept_states.iter().copied().collect(); + + let transitions = self.get_transitions_with_capture_info(); + if transitions.is_empty() { + return Err(NFAError::Verification("NFA has no transitions".into())); + } + + // Group and convert to ranges + let mut range_transitions = Vec::new(); + let mut grouped: HashMap<(usize, usize, Option<(usize, bool)>), Vec> = HashMap::new(); + + for (src, byte, dst, capture) in transitions { + if src >= self.nodes.len() || dst >= self.nodes.len() { + return Err(NFAError::InvalidStateId(format!( + "State {}->{} out of bounds", + src, dst + ))); + } + grouped.entry((src, dst, capture)).or_default().push(byte); + } + + // Convert to ranges + for ((src, dst, capture), mut bytes) in grouped { + if bytes.is_empty() { + continue; + } + + bytes.sort_unstable(); + let mut start = bytes[0]; + let mut prev = start; + + for &byte in &bytes[1..] { + if byte != prev + 1 { + range_transitions.push((src, start, prev, dst, capture)); + start = byte; + } + prev = byte; + } + range_transitions.push((src, start, prev, dst, capture)); + } + + Ok((start_states, accept_states, range_transitions)) + } + + pub fn generate_circuit_inputs( + &self, + haystack: &str, + max_haystack_len: usize, + max_match_len: usize, + ) -> NFAResult { + let haystack_bytes = haystack.as_bytes(); + + if haystack_bytes.len() > max_haystack_len { + return Err(NFAError::InvalidInput(format!( + "Haystack length {} exceeds maximum length {}", + haystack_bytes.len(), + max_haystack_len + ))); + } + + // Generate path traversal + let result = self.get_path_to_accept(haystack_bytes)?; + let path = result.path; + let (match_start, match_length) = result.span; + let path_len = path.len(); + + if path_len > max_match_len { + return Err(NFAError::InvalidInput(format!( + "Path length {} exceeds maximum length {}", + path_len, max_match_len + ))); + } + + // Extract and pad arrays to max_haystack_len + let mut curr_states = path.iter().map(|(curr, _, _, _)| *curr).collect::>(); + let mut next_states = path.iter().map(|(_, next, _, _)| *next).collect::>(); + let mut in_haystack = haystack_bytes.to_vec(); + + // Pad with zeros + curr_states.resize(max_match_len, 0); + next_states.resize(max_match_len, 0); + in_haystack.resize(max_haystack_len, 0); + + // Handle capture groups if they exist + let (capture_group_ids, capture_group_starts, capture_group_start_indices) = + if path.iter().any(|(_, _, _, c)| c.is_some()) { + let mut ids = path + .iter() + .map(|(_, _, _, c)| c.map(|(id, _)| id).unwrap_or(0)) + .collect::>(); + let mut starts = path + .iter() + .map(|(_, _, _, c)| c.map(|(_, start)| start as u8).unwrap_or(0)) + .collect::>(); + + // Use regex_automata to get capture start indices + let re = Regex::new(&self.regex).map_err(|e| { + NFAError::RegexCompilation(format!("Failed to compile regex: {}", e)) + })?; + let mut captures = re.create_captures(); + re.captures(&haystack, &mut captures); + + let start_indices = (1..=captures.group_len()) + .filter_map(|i| captures.get_group(i)) + .map(|m| m.start) + .collect(); + + // Pad arrays + ids.resize(max_match_len, 0); + starts.resize(max_match_len, 0); + + (Some(ids), Some(starts), Some(start_indices)) + } else { + (None, None, None) + }; + + Ok(CircuitInputs { + in_haystack, + match_start, + match_length, + curr_states, + next_states, + capture_group_ids, + capture_group_starts, + capture_group_start_indices, + traversal_path_length: path_len, + }) + } + + pub fn escape_regex_for_display(pattern: &str) -> String { + pattern + .chars() + .map(|c| match c { + '\n' => "\\n".to_string(), + '\r' => "\\r".to_string(), + '\t' => "\\t".to_string(), + '\\' => "\\\\".to_string(), + '\0' => "\\0".to_string(), + '\'' => "\\'".to_string(), + '\"' => "\\\"".to_string(), + '\x08' => "\\b".to_string(), + '\x0c' => "\\f".to_string(), + c if c.is_ascii_control() => format!("\\x{:02x}", c as u8), + c => c.to_string(), + }) + .collect() + } +} diff --git a/compiler/src/nfa/codegen/noir.rs b/compiler/src/nfa/codegen/noir.rs new file mode 100644 index 00000000..66f539f0 --- /dev/null +++ b/compiler/src/nfa/codegen/noir.rs @@ -0,0 +1,466 @@ +use std::collections::{HashMap, HashSet}; + +use crate::nfa::{ + NFAGraph, + codegen::CircuitInputs, + error::{NFAError, NFAResult}, +}; +use comptime::{FieldElement, SparseArray}; +use serde::Serialize; + +#[derive(Serialize)] +pub struct NoirInputs { + in_haystack: Vec, + match_start: usize, + match_length: usize, + curr_states: Vec, + next_states: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + capture_group_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + capture_group_starts: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + capture_group_start_indices: Option>, +} + +impl From for NoirInputs { + fn from(inputs: CircuitInputs) -> Self { + NoirInputs { + in_haystack: inputs.in_haystack, + match_start: inputs.match_start, + match_length: inputs.match_length, + curr_states: inputs.curr_states, + next_states: inputs.next_states, + capture_group_ids: inputs.capture_group_ids, + capture_group_starts: inputs.capture_group_starts, + capture_group_start_indices: inputs.capture_group_start_indices, + } + } +} + +impl NFAGraph { + /// Generate Noir code for the NFA + pub fn generate_noir_code( + &self, + regex_pattern: &str, + max_substring_bytes: Option<&[usize]>, + ) -> NFAResult { + // get nfa graph data + let (start_states, accept_states, transitions) = self.generate_circuit_data()?; + + let capture_group_set: HashSet<_> = transitions + .iter() + .filter_map(|(_, _, _, _, cap)| cap.map(|(id, _)| id)) + .collect(); + if !capture_group_set.is_empty() { + if let Some(max_bytes) = max_substring_bytes { + if max_bytes.len() < capture_group_set.len() { + return Err(NFAError::InvalidCapture(format!( + "Insufficient max_substring_bytes: need {} but got {}", + capture_group_set.len(), + max_bytes.len() + ))); + } + for &bytes in max_bytes { + if bytes == 0 { + return Err(NFAError::InvalidCapture( + "max_substring_bytes contains zero length".into(), + )); + } + } + } else { + return Err(NFAError::InvalidCapture( + "max_substring_bytes required for capture groups".into(), + )); + } + } + + let has_capture_groups = !capture_group_set.is_empty(); + + let mut code = String::new(); + + // imports + // todo: ability to change import path + if has_capture_groups { + code.push_str("use zkregex::utils::{\n"); + code.push_str(" select_subarray,\n"); + code.push_str(" captures::capture_substring,\n"); + code.push_str(" sparse_array::SparseArray,\n"); + code.push_str(" transitions::check_transition_with_captures\n"); + code.push_str("};\n\n"); + } else { + code.push_str("use zkregex::utils::{\n"); + code.push_str(" select_subarray,\n"); + code.push_str(" sparse_array::SparseArray,\n"); + code.push_str(" transitions::check_transition\n"); + code.push_str("};\n\n"); + } + + // codegen the transition lookup table + let transition_array = match max_substring_bytes.is_some() { + true => packed_transition_sparse_array(&transitions), + false => transition_sparse_array(&transitions), + }; + code.push_str(&format!( + "global TRANSITION_TABLE: {}\n\n", + transition_array.to_noir_string(None) + )); + + // hardcode max substring capture group lengths + if has_capture_groups { + for (index, length) in max_substring_bytes.unwrap().iter().enumerate() { + code.push_str(&format!( + "pub global CAPTURE_{}_MAX_LENGTH: u32 = {};\n", + index + 1, + length + )); + } + code.push_str(&format!( + "pub global NUM_CAPTURE_GROUPS: u32 = {};\n", + capture_group_set.len() + )); + } + + // add check for valid start states + code.push_str(start_state_fn(&start_states).as_str()); + code.push_str(accept_state_fn(&accept_states).as_str()); + + // regex match function doc + code.push_str(&format!("/**\n")); + code.push_str(&format!(" * Regex matching function\n")); + code.push_str(&format!(" * @param in_haystack - The input haystack to search from\n")); + code.push_str(&format!(" * @param match_start - The start index in the haystack for the subarray to match from\n")); + code.push_str(&format!(" * @param match_length - The length of the subarray to extract from haystack\n")); + code.push_str(&format!(" * @param current_states - The current states of the NFA at each index in the match subarray\n")); + code.push_str(&format!(" * @param next_states - The next states of the NFA at each index in the match subarray\n")); + if capture_group_set.len() > 0 { + code.push_str(&format!(" * @param capture_group_ids - The ids of the capture groups in the match subarray\n")); + code.push_str(&format!(" * @param capture_group_starts - The start positions of the capture groups in the match subarray\n")); + code.push_str(&format!(" * @param capture_group_start_indices - The start indices of the capture groups in the match subarray\n")); + code.push_str(&format!(" * @return - tuple of substring captures as dictated by the regular expression\n")); + } + code.push_str(&format!(" */\n")); + + // regex match function signature + code.push_str(&format!( + "pub fn regex_match(\n" + )); + code.push_str(&format!(" in_haystack: [u8; MAX_HAYSTACK_LEN],\n")); + code.push_str(&format!(" match_start: u32,\n")); + code.push_str(&format!(" match_length: u32,\n")); + code.push_str(&format!(" current_states: [Field; MAX_MATCH_LEN],\n")); + code.push_str(&format!(" next_states: [Field; MAX_MATCH_LEN],\n")); + if (max_substring_bytes.is_some()) { + code.push_str(&format!(" capture_group_ids: [Field; MAX_MATCH_LEN],\n")); + code.push_str(&format!(" capture_group_starts: [Field; MAX_MATCH_LEN],\n")); + code.push_str(&format!( + " capture_group_start_indices: [Field; NUM_CAPTURE_GROUPS],\n" + )); + } + + // define the return type according to existence of / qualities of capture groups + let return_type = if has_capture_groups { + let mut substrings = Vec::new(); + for i in 0..max_substring_bytes.unwrap().len() { + substrings.push(format!("BoundedVec", i + 1)); + } + format!("-> ({}) ", substrings.join(", ")) + } else { + String::default() + }; + code.push_str(&format!(") {}{{\n", return_type)); + + // print the actual regex match being performed + code.push_str(&format!(" // regex:{:?}\n", regex_pattern)); + + // resize haystack to MAX_MATCH_LEN + code.push_str(&format!(" // resize haystack \n")); + code.push_str(&format!(" let haystack: [u8; MAX_MATCH_LEN] = select_subarray(in_haystack, match_start, match_length);\n\n")); + code.push_str(&format!(" let mut reached_end_state = 1;\n")); + + // check start & range + code.push_str(&format!(" check_start_state(current_states[0]);\n")); + code.push_str(&format!(" for i in 0..MAX_MATCH_LEN-1 {{\n")); + code.push_str(&format!(" // match length - 1 since current states should be 1 less than next states\n")); + code.push_str(&format!( + " let in_range = (i < match_length - 1) as Field;\n" + )); + code.push_str(&format!( + " let matching_states = current_states[i + 1] - next_states[i];\n" + )); + code.push_str(&format!( + " assert(in_range * matching_states == 0, \"Invalid Transition Input\");\n" + )); + code.push_str(&format!(" }}\n")); + + // iterate through the haystack and check transitions + code.push_str(&format!(" for i in 0..MAX_MATCH_LEN {{\n")); + if max_substring_bytes.is_some() { + // if capture groups exist, perform check that unpacks transition values + code.push_str(&format!(" check_transition_with_captures(\n")); + code.push_str(&format!(" TRANSITION_TABLE,\n")); + code.push_str(&format!(" haystack[i] as Field,\n")); + code.push_str(&format!(" current_states[i],\n")); + code.push_str(&format!(" next_states[i],\n")); + code.push_str(&format!(" capture_group_ids[i],\n")); + code.push_str(&format!(" capture_group_starts[i],\n")); + code.push_str(&format!(" reached_end_state\n")); + code.push_str(&format!(" );\n")); + } else { + // if no capture groups exist, simple lookup + code.push_str(&format!(" check_transition(\n")); + code.push_str(&format!(" TRANSITION_TABLE,\n")); + code.push_str(&format!(" haystack[i] as Field,\n")); + code.push_str(&format!(" current_states[i],\n")); + code.push_str(&format!(" next_states[i],\n")); + code.push_str(&format!(" reached_end_state\n")); + code.push_str(&format!(" );\n")); + } + // toggle off constraints/ set match assertion if end state found + code.push_str(&format!( + " reached_end_state = reached_end_state * check_accept_state(\n" + )); + code.push_str(&format!(" next_states[i],\n")); + code.push_str(&format!(" i as Field,\n")); + code.push_str(&format!(" match_length as Field,\n")); + code.push_str(&format!(" );\n")); + code.push_str(&format!(" }}\n")); + code.push_str(&format!( + " assert(reached_end_state == 0, \"Did not reach a valid end state\");\n" + )); + // add substring capture logic if capture groups exist + if has_capture_groups { + let mut ids = Vec::new(); + for capture_group_id in capture_group_set { + let max_substring_bytes = if let Some(max_substring_bytes) = max_substring_bytes { + max_substring_bytes[capture_group_id - 1] + } else { + return Err(NFAError::InvalidCapture(format!( + "Max substring bytes not provided for capture group {}", + capture_group_id + ))); + }; + + code.push_str(&format!(" // Capture Group {}\n", capture_group_id)); + code.push_str(&format!(" let capture_{} = capture_substring::(\n", capture_group_id, capture_group_id, capture_group_id)); + code.push_str(&format!(" haystack,\n")); + code.push_str(&format!(" capture_group_ids,\n")); + code.push_str(&format!(" capture_group_starts,\n")); + code.push_str(&format!( + " capture_group_start_indices[{}],\n", + capture_group_id - 1 + )); + code.push_str(&format!(" );\n")); + ids.push(format!("capture_{}", capture_group_id)); + } + + // define the return tuple + let return_vec = ids + .iter() + .map(|id| format!("{}", id)) + .collect::>() + .join(", "); + code.push_str(&format!(" ({})\n", return_vec)); + } + code.push_str(&format!("}}\n\n")); + Ok(code) + } + + /// Generate Prover.toml from circuit inputs + pub fn to_prover_toml(inputs: &CircuitInputs) -> String { + let mut toml = String::new(); + + // regex match inputs + let haystack = inputs + .in_haystack + .iter() + .map(|num| format!("\"{num}\"")) + .collect::>() + .join(", "); + toml.push_str(&format!("in_haystack = [{}]\n", haystack)); + toml.push_str(&format!("match_start = \"{}\"\n", inputs.match_start)); + toml.push_str(&format!("match_length = \"{}\"\n", inputs.match_length)); + let curr_states = inputs + .curr_states + .iter() + .map(|num| format!("\"{num}\"")) + .collect::>() + .join(", "); + toml.push_str(&format!("curr_states = [{}]\n", curr_states)); + let next_states = inputs + .next_states + .iter() + .map(|num| format!("\"{num}\"")) + .collect::>() + .join(", "); + toml.push_str(&format!("next_states = [{}]\n", next_states)); + // substring capture inputs + if inputs.capture_group_ids.is_some() { + let capture_group_ids = inputs + .capture_group_ids + .as_ref() + .unwrap() + .iter() + .map(|num| format!("\"{num}\"")) + .collect::>() + .join(", "); + toml.push_str(&format!("capture_group_ids = [{}]\n", capture_group_ids)); + let capture_group_starts = inputs + .capture_group_starts + .as_ref() + .unwrap() + .iter() + .map(|num| format!("\"{num}\"")) + .collect::>() + .join(", "); + toml.push_str(&format!( + "capture_group_starts = [{}]\n", + capture_group_starts + )); + let capture_group_start_indices = inputs + .capture_group_start_indices + .as_ref() + .unwrap() + .iter() + .map(|num| format!("\"{num}\"")) + .collect::>() + .join(", "); + toml.push_str(&format!( + "capture_group_start_indices = [{}]\n", + capture_group_start_indices + )); + }; + toml + } +} + +/** + * Forms an expression to determine if any of the start states are matched + * @param start_states - The start states of the NFA + * @returns The expression determining if any of the start states are matched + */ +fn start_state_fn(start_states: &Vec) -> String { + let expression = start_states + .iter() + .map(|state| format!("(start_state - {state})")) + .collect::>() + .join(" * "); + format!( + r#" +/** + * Constrains a start state to be valid + * @dev start states are hardcoded in this function - "(start_state - {{state}})" for each start + * example: `(start_state - 0) * (start_state - 1) * (start_state - 2)` means 0, 1, or 2 + * are valid first states + * + * @param start_state - The start state of the NFA + */ +fn check_start_state(start_state: Field) {{ + let valid_start_state = {expression}; + assert(valid_start_state == 0, "Invalid start state"); +}} + "# + ) +} + +/** + * Forms an expression to determine if any of the accept states are matched + * @param start_states - The accept states of the NFA + * @returns The expression determining if any of the accept states are matched + */ +fn accept_state_fn(accept_states: &Vec) -> String { + let expression = accept_states + .iter() + .map(|state| format!("(next_state - {state})")) + .collect::>() + .join(" * "); + format!( + r#" +/** + * Constrains the recognition of accept_state being reached. If an aceppt state is reached, + * ensures asserted traversal path is valid + * @dev accept states are hardcoded in this function - "(next_state - {{state}})" for each accept + * example: `(next_state - 19) * (next_state - 20) * (next_state - 21)` means 19, 20, or 21 + * are valid accept states + * + * @param next_state - The asserted next state of the NFA + * @param haystack_index - The index being operated on in the haystack + * @param asserted_match_length - The asserted traversal path length + * @return - 0 if accept_state is reached, nonzero otherwise + */ +fn check_accept_state( + next_state: Field, + haystack_index: Field, + asserted_match_length: Field +) -> Field {{ + // check if the next state is an accept state + let accept_state_reached = {expression}; + let accept_state_reached_bool = (accept_state_reached == 0) as Field; + + // check if the haystack index is the asserted match length + // should equal 1 since haystack_index should be 1 less than asserted_match)length + let asserted_path_traversed = (asserted_match_length - haystack_index == 1) as Field; + + // if accept state reached, check asserted path traversed. Else return 1 + let valid_condition = + (1 - accept_state_reached_bool) + (accept_state_reached_bool * asserted_path_traversed); + assert(valid_condition == 1, "Accept state reached but not at asserted path end"); + + // return accept_state reached value + accept_state_reached +}} + +"# + ) +} + +/** + * Creates a sparse array for transitions + * @param transitions - The transitions to create the sparse array for + * @returns The sparse array for the transitions + */ +fn transition_sparse_array( + transitions: &Vec<(usize, u8, u8, usize, Option<(usize, bool)>)>, +) -> SparseArray { + // let r = 256 * transitions.len(); + let r = 257; + let mut entries = Vec::new(); + for (state_idx, start, end, dest, _) in transitions { + let bytes = (*start..=*end).collect::>(); + for byte in bytes { + let key = state_idx + (byte as usize * r) + (r * r * dest); + entries.push(FieldElement::from(key)); + } + } + let values = vec![FieldElement::from(1u32); entries.len()]; + // assume max byte = 256 and max transitions = 200 + let max_size = FieldElement::from(transitions.len() + 256 * r + 200 * r * r); + SparseArray::create(&entries, &values, max_size) +} + +/** + * Creates a packed sparse array for transitions + * byte 0: 1 if transition is valid, 0 if not + * byte 1: if the transition is the start of the capture group 1, 0 otherwise + * byte 2: if the transition is part of a capture group, the id of the capture group + */ +fn packed_transition_sparse_array( + transitions: &Vec<(usize, u8, u8, usize, Option<(usize, bool)>)>, +) -> SparseArray { + let r = 257; + let mut keys = Vec::new(); + let mut values = Vec::new(); + for (state_idx, start, end, dest, capture) in transitions { + let bytes = (*start..=*end).collect::>(); + let (capture_id, capture_bool) = capture.unwrap_or((0, false)); + for byte in bytes { + let key = state_idx + (byte as usize * r) + (r * r * dest); + let value = 1u32 | (capture_bool as u32) << 1 | (capture_id as u32) << 2; + keys.push(FieldElement::from(key)); + values.push(FieldElement::from(value)); + } + } + // assume max byte = 256 and max transitions = 200 + let max_size = FieldElement::from(transitions.len() + 256 * r + 200 * r * r); + SparseArray::create(&keys, &values, max_size) +} diff --git a/compiler/src/nfa/mod.rs b/compiler/src/nfa/mod.rs index 6983c10d..3b35c059 100644 --- a/compiler/src/nfa/mod.rs +++ b/compiler/src/nfa/mod.rs @@ -53,4 +53,4 @@ impl NFAGraph { accept_states: BTreeSet::new(), } } -} +} \ No newline at end of file diff --git a/compiler/src/nfa/wasm.rs b/compiler/src/nfa/wasm.rs index 79959642..b0b69e24 100644 --- a/compiler/src/nfa/wasm.rs +++ b/compiler/src/nfa/wasm.rs @@ -1,9 +1,13 @@ -use crate::{DecomposedRegexConfig, compile, nfa::NFAGraph, utils::decomposed_to_composed_regex}; +use crate::{ + DecomposedRegexConfig, + compile, + nfa::{NFAGraph, codegen::{CircuitInputs, circom::CircomInputs, noir::NoirInputs}}, + utils::decomposed_to_composed_regex +}; use serde::{Deserialize, Serialize}; use thiserror::Error; use wasm_bindgen::prelude::*; -use super::codegen::CircomInputs; /// Supported proving systems #[wasm_bindgen] @@ -11,9 +15,9 @@ use super::codegen::CircomInputs; #[serde(rename_all = "camelCase")] pub enum ProvingSystem { Circom, + Noir, // Future systems: // Halo2, - // Noir, } /// Input types for different proving systems @@ -22,7 +26,8 @@ pub enum ProvingSystem { pub enum ProvingSystemInputs { #[serde(rename = "circom")] Circom(CircomInputs), - // #[serde(rename = "noir")] Noir(NoirInputs), + #[serde(rename = "noir")] + Noir(NoirInputs), } /// Output from regex compilation @@ -157,6 +162,9 @@ fn generate_from_raw_internal( ProvingSystem::Circom => nfa .generate_circom_code(&template_name.0, &raw_regex.0, Some(max_substring_bytes)) .map_err(|e| WasmError::CodeGenError("circom".to_string(), e.to_string()))?, + ProvingSystem::Noir => nfa + .generate_noir_code(&raw_regex.0, Some(max_substring_bytes)) + .map_err(|e| WasmError::CodeGenError("noir".to_string(), e.to_string()))?, }; Ok(RegexOutput { graph, code }) @@ -198,10 +206,16 @@ fn generate_circuit_inputs_internal( let inputs = match proving_system { ProvingSystem::Circom => { - let circom_inputs = graph - .generate_circom_inputs(&haystack.0, max_haystack_length, max_match_length) + let inputs = graph + .generate_circuit_inputs(&haystack.0, max_haystack_length, max_match_length) .map_err(|e| WasmError::InputGenError(e.to_string()))?; - ProvingSystemInputs::Circom(circom_inputs) + ProvingSystemInputs::Circom(CircomInputs::from(inputs)) + } + ProvingSystem::Noir => { + let inputs = NoirInputs::from(graph + .generate_circuit_inputs(&haystack.0, max_haystack_length, max_match_length) + .map_err(|e| WasmError::InputGenError(e.to_string()))?); + ProvingSystemInputs::Noir(NoirInputs::from(inputs)) } }; diff --git a/noir/Nargo.toml b/noir/Nargo.toml new file mode 100644 index 00000000..c08e3ca1 --- /dev/null +++ b/noir/Nargo.toml @@ -0,0 +1,8 @@ +[package] +name = "zkregex.nr" +type = "lib" +authors = [""] +compiler_version = ">=1.0.0" + +[dependencies] +sort = { tag = "v0.2.3", git = "https://github.com/noir-lang/noir_sort" } diff --git a/noir/scripts/gen_inputs.sh b/noir/scripts/gen_inputs.sh new file mode 100755 index 00000000..d1034409 --- /dev/null +++ b/noir/scripts/gen_inputs.sh @@ -0,0 +1,122 @@ +#!/bin/bash + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR/../../" + +# Input args for the script +set_inputs_to_use() { + local template="$1" + + case "$template" in + "timestamp") + HAYSTACK=$(echo -e "dkim-signature:v=1; a=rsa-sha256; c=relaxed/relaxed; d=gmail.com; s=20230601; t=1694989812; x=1695594612; dara=google.com; h=to:subject:message-id:date:from:mime-version:from:to:cc:subject :date:message-id:reply-to; bh=BWETwQ9JDReS4GyR2v2TTR8Bpzj9ayumsWQJ3q7vehs=;\x00") + MAX_HAYSTACK_LEN=300 + MAX_MATCH_LEN=100 + ;; + "simple") + HAYSTACK=aaaaaaab + MAX_HAYSTACK_LEN=10 + MAX_MATCH_LEN=10 + ;; + "subject_all") + HAYSTACK=$(echo -e "\r\nsubject: this si a buject\r\n ") + MAX_HAYSTACK_LEN=50 + MAX_MATCH_LEN=50 + ;; + *) + echo "Error: Invalid template '$template' supplied" >&2 + folders=$(find ./noir/templates -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | awk '{print "\""$0"\""}' | paste -sd ", " -) + echo "Valid temlates are: $folders" >&2 + return 0 + ;; + esac +} + +gen_prover_toml() { + local template="$1" + + cargo run \ + --bin zk-regex generate-circuit-input \ + --graph-path ./noir/templates/$template/${template}_graph.json \ + --input "$HAYSTACK" \ + --max-haystack-len "$MAX_HAYSTACK_LEN" \ + --max-match-len "$MAX_MATCH_LEN" \ + --output ./noir/templates/$template/Prover.toml \ + --noir true +} + +transform_inputs() { + local template_path="$1" + local prover_toml_file="$1/Prover.toml" + local transformed_inputs_file="$1/inputs.txt" + + # If the transformed_inputs_file exists, delete it + if [ -f "$transformed_inputs_file" ]; then + rm "$transformed_inputs_file" + fi + + while IFS= read -r line || [ -n "$line" ]; do + + # Skip empty lines + if [ -z "$line" ] || [[ "$line" =~ ^[[:space:]]*$ ]]; then + continue + fi + + # Transform the line + # 1. Add "let " prefix + # 2. Remove double quotes + # 3. Add semicolon at the end + transformed_line="let ${line//\"}" + + # Skip if the line doesn't contain an equals sign + if [[ ! "$transformed_line" =~ "=" ]]; then + continue + fi + + echo "$transformed_line;" >> "$transformed_inputs_file" + done < "$prover_toml_file" +} + +## Ensure an argument is provided +if [ $# -ne 1 ]; then + echo "ERROR: Supply the template name you'd like to generate sample inputs for" >&2 + folders=$(find ./noir/templates -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | awk '{print "\""$0"\""}' | paste -sd ", " -) + echo "Available templates: $folders" >&2 + echo "" >&2 + echo "Usage: $0