diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 634632ce2b..642b50e756 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -282,16 +282,18 @@ impl> ExecutionSegment { Some(SysPhantom::CtStart) => { #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .start(dsl_instr.cloned().unwrap_or("Default".to_string())) + metrics.cycle_tracker.start( + dsl_instr.cloned().unwrap_or("Default".to_string()), + metrics.cycle_count, + ) } Some(SysPhantom::CtEnd) => { #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .end(dsl_instr.cloned().unwrap_or("Default".to_string())) + metrics.cycle_tracker.end( + dsl_instr.cloned().unwrap_or("Default".to_string()), + metrics.cycle_count, + ) } _ => {} } diff --git a/crates/vm/src/metrics/cycle_tracker/mod.rs b/crates/vm/src/metrics/cycle_tracker/mod.rs index b1ef065451..c9fe306ca3 100644 --- a/crates/vm/src/metrics/cycle_tracker/mod.rs +++ b/crates/vm/src/metrics/cycle_tracker/mod.rs @@ -1,7 +1,18 @@ +/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`]. +#[derive(Clone, Debug, Default)] +pub struct SpanInfo { + /// The name of the span. + pub tag: String, + /// The cycle count at which the span starts. + pub start: usize, +} + #[derive(Clone, Debug, Default)] pub struct CycleTracker { /// Stack of span names, with most recent at the end - stack: Vec, + stack: Vec, + /// Depth of the stack. + depth: usize, } impl CycleTracker { @@ -10,29 +21,41 @@ impl CycleTracker { } pub fn top(&self) -> Option<&String> { - self.stack.last() + match self.stack.last() { + Some(span) => Some(&span.tag), + _ => None + } } /// Starts a new cycle tracker span for the given name. - /// If a span already exists for the given name, it ends the existing span and pushes a new one - /// to the vec. - pub fn start(&mut self, mut name: String) { + /// If a span already exists for the given name, it ends the existing span and pushes a new one to the vec. + pub fn start(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - self.stack.push(name); + self.stack.push(SpanInfo { + tag: name.clone(), + start: cycles_count, + }); + let padding = "│ ".repeat(self.depth); + tracing::info!("{}┌╴{}", padding, name); + self.depth += 1; } /// Ends the cycle tracker span for the given name. /// If no span exists for the given name, it panics. - pub fn end(&mut self, mut name: String) { + pub fn end(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - let stack_top = self.stack.pop(); - assert_eq!(stack_top.unwrap(), name, "Stack top does not match name"); + let SpanInfo { tag, start } = self.stack.pop().unwrap(); + assert_eq!(tag, name, "Stack top does not match name"); + self.depth -= 1; + let padding = "│ ".repeat(self.depth); + let span_cycles = cycles_count - start; + tracing::info!("{}└╴{} cycles", padding, span_cycles); } /// Ends the current cycle tracker span. @@ -42,7 +65,11 @@ impl CycleTracker { /// Get full name of span with all parent names separated by ";" in flamegraph format pub fn get_full_name(&self) -> String { - self.stack.join(";") + self.stack + .iter() + .map(|span_info| span_info.tag.clone()) + .collect::>() + .join(";") } } diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 916e8251ac..fe73dce25d 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -117,7 +117,7 @@ impl VmMetrics { .map(|(_, func)| (*func).clone()) .unwrap(); if pc == self.current_fn.start { - self.cycle_tracker.start(self.current_fn.name.clone()); + self.cycle_tracker.start(self.current_fn.name.clone(), 0); } else { while let Some(name) = self.cycle_tracker.top() { if name == &self.current_fn.name { diff --git a/crates/vm/src/system/memory/controller/mod.rs b/crates/vm/src/system/memory/controller/mod.rs index 680a03ab8e..32a81ffd2b 100644 --- a/crates/vm/src/system/memory/controller/mod.rs +++ b/crates/vm/src/system/memory/controller/mod.rs @@ -59,7 +59,7 @@ pub const MERKLE_AIR_OFFSET: usize = 1; pub const BOUNDARY_AIR_OFFSET: usize = 0; #[repr(C)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct RecordId(pub usize); pub type MemoryImage = AddressMap; diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index 5d5913b4be..e48af8f70a 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -19,7 +19,6 @@ openvm-instructions = { workspace = true } openvm-rv32im-circuit = { workspace = true } openvm-native-compiler = { workspace = true } - strum.workspace = true itertools.workspace = true tracing.workspace = true diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 385c9392ac..1ee1af6885 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -1,4 +1,4 @@ -use air::VerifyBatchBus; +use poseidon2::air::VerifyBatchBus; use alu_native_adapter::AluNativeAdapterChip; use branch_native_adapter::BranchNativeAdapterChip; use derive_more::derive::From; @@ -30,7 +30,7 @@ use strum::IntoEnumIterator; use crate::{ adapters::{convert_adapter::ConvertAdapterChip, *}, - chip::NativePoseidon2Chip, + poseidon2::chip::NativePoseidon2Chip, phantom::*, *, }; @@ -203,6 +203,7 @@ impl VmExtension for Native { VerifyBatchOpcode::VERIFY_BATCH.global_opcode(), Poseidon2Opcode::PERM_POS2.global_opcode(), Poseidon2Opcode::COMP_POS2.global_opcode(), + Poseidon2Opcode::MULTI_OBSERVE.global_opcode(), ], )?; diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 5ed28abd60..e1fca5c827 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -2,12 +2,12 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, }; use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; use openvm_native_compiler::{ - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2SubAir, BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS}; @@ -20,16 +20,12 @@ use openvm_stark_backend::{ rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use crate::{ - chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, - poseidon2::{ +use crate::poseidon2::{ + chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols }, - CHUNK, - }, -}; + }; #[derive(Clone, Debug)] pub struct NativePoseidon2Air { @@ -72,6 +68,7 @@ impl Air incorporate_sibling, inside_row, simple, + multi_observe_row, end_inside_row, end_top_level, start_top_level, @@ -99,7 +96,8 @@ impl Air builder.assert_bool(incorporate_sibling); builder.assert_bool(inside_row); builder.assert_bool(simple); - let enabled = incorporate_row + incorporate_sibling + inside_row + simple; + builder.assert_bool(multi_observe_row); + let enabled = incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; builder.assert_bool(enabled.clone()); builder.assert_bool(end_inside_row); builder.when(end_inside_row).assert_one(inside_row); @@ -680,6 +678,285 @@ impl Air &write_data_2, ) .eval(builder, simple * is_permute); + + //// multi_observe contraints + let multi_observe_specific: &MultiObserveCols = + specific[..MultiObserveCols::::width()].borrow(); + let next_multi_observe_specific: &MultiObserveCols = + next.specific[..MultiObserveCols::::width()].borrow(); + let &MultiObserveCols { + pc, + final_timestamp_increment, + state_ptr, + input_ptr, + init_pos, + len, + is_first, + is_last, + curr_len, + start_idx, + end_idx, + aux_after_start, + aux_before_end, + read_data, + write_data, + data, + should_permute, + read_sponge_state, + write_sponge_state, + write_final_idx, + final_idx, + input_register_1, + input_register_2, + input_register_3, + output_register + } = multi_observe_specific; + + builder + .when(multi_observe_row) + .assert_bool(is_first); + builder + .when(multi_observe_row) + .assert_bool(is_last); + builder + .when(multi_observe_row) + .assert_bool(should_permute); + + self.execution_bridge + .execute_and_increment_pc( + AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()), + [ + output_register.into(), + input_register_1.into(), + input_register_2.into(), + self.address_space.into(), + self.address_space.into(), + input_register_3.into(), + ], + ExecutionState::new(pc, very_first_timestamp), + final_timestamp_increment, + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, output_register), + [state_ptr], + very_first_timestamp, + &read_data[0], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_1), + [init_pos], + very_first_timestamp + AB::F::ONE, + &read_data[1], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_2), + [input_ptr], + very_first_timestamp + AB::F::TWO, + &read_data[2], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_3), + [len], + very_first_timestamp + AB::F::from_canonical_usize(3), + &read_data[3], + ) + .eval(builder, multi_observe_row * is_first); + + for i in 0..CHUNK { + let i_var = AB::F::from_canonical_usize(i); + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_ptr + curr_len + i_var - start_idx), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, + &read_data[i] + ) + .eval(builder, aux_after_start[i] * aux_before_end[i]); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + state_ptr + i_var, + ), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + &write_data[i], + ) + .eval(builder, aux_after_start[i] * aux_before_end[i]); + } + + for i in 0..(CHUNK - 1) { + builder + .when(aux_after_start[i]) + .assert_one(aux_after_start[i + 1]); + } + + for i in 1..CHUNK { + builder + .when(aux_before_end[i]) + .assert_one(aux_before_end[i - 1]); + } + + builder + .when(multi_observe_row) + .when(not(is_first)) + .assert_eq( + aux_after_start[0] + + aux_after_start[1] + + aux_after_start[2] + + aux_after_start[3] + + aux_after_start[4] + + aux_after_start[5] + + aux_after_start[6] + + aux_after_start[7], + AB::Expr::from_canonical_usize(CHUNK) - start_idx.into() + ); + + builder + .when(multi_observe_row) + .when(not(is_first)) + .assert_eq( + aux_before_end[0] + + aux_before_end[1] + + aux_before_end[2] + + aux_before_end[3] + + aux_before_end[4] + + aux_before_end[5] + + aux_before_end[6] + + aux_before_end[7], + end_idx + ); + + let full_sponge_input = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.inputs[i]); + let full_sponge_output = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i]); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + state_ptr, + ), + full_sponge_input, + start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO, + &read_sponge_state, + ) + .eval(builder, multi_observe_row * should_permute); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + state_ptr + ), + full_sponge_output, + start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + &write_sponge_state, + ) + .eval(builder, multi_observe_row * should_permute); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + input_register_1, + ), + [final_idx], + start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO, + &write_final_idx + ) + .eval(builder, multi_observe_row * is_last); + + // Field transitions + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(next_multi_observe_specific.curr_len, multi_observe_specific.curr_len + end_idx - start_idx); + + // Boundary conditions + builder + .when(multi_observe_row) + .when(is_first) + .assert_zero(curr_len); + + builder + .when(multi_observe_row) + .when(is_last) + .assert_eq(curr_len + (end_idx - start_idx), len); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_one(multi_observe_row); + + builder + .when(multi_observe_row) + .when(not(is_last)) + .assert_one(next.multi_observe_row); + + // Field consistency + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(state_ptr, next_multi_observe_specific.state_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_ptr, next_multi_observe_specific.input_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(init_pos, next_multi_observe_specific.init_pos); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(len, next_multi_observe_specific.len); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_register_1, next_multi_observe_specific.input_register_1); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_register_2, next_multi_observe_specific.input_register_2); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_register_3, next_multi_observe_specific.input_register_3); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(output_register, next_multi_observe_specific.output_register); + + // Timestamp constraints + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(very_first_timestamp, next.very_first_timestamp); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(next.start_timestamp, start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO); } } diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 426b089a9c..47ffc156d5 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -9,7 +9,7 @@ use openvm_circuit::{ use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; @@ -120,11 +120,50 @@ pub struct SimplePoseidonRecord { pub p2_input: [F; 2 * CHUNK], } +#[repr(C)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(bound = "F: Field")] +pub struct TranscriptObservationRecord { + pub from_state: ExecutionState, + pub instruction: Instruction, + pub start_idx: usize, + pub end_idx: usize, + pub is_first: bool, + pub is_last: bool, + pub curr_timestamp_increment: usize, + pub final_timestamp_increment: usize, + + pub state_ptr: F, + pub input_ptr: F, + pub init_pos: F, + pub len: usize, + pub curr_len: usize, + pub should_permute: bool, + + pub read_input_data: [RecordId; CHUNK], + pub write_input_data: [RecordId; CHUNK], + pub input_data: [F; CHUNK], + + pub read_sponge_state: RecordId, + pub write_sponge_state: RecordId, + pub permutation_input: [F; 2 * CHUNK], + pub permutation_output: [F; 2 * CHUNK], + + pub write_final_idx: RecordId, + pub final_idx: usize, + + pub input_register_1: F, + pub input_register_2: F, + pub input_register_3: F, + pub output_register: F, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] #[serde(bound = "F: Field")] pub struct NativePoseidon2RecordSet { pub verify_batch_records: Vec>, pub simple_permute_records: Vec>, + pub transcript_observation_records: Vec>, } pub struct NativePoseidon2Chip { @@ -259,6 +298,131 @@ impl InstructionExecutor p2_input, }); self.height += 1; + } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { + let mut observation_records: Vec> = vec![]; + + let &Instruction { + a: output_register, + b: input_register_1, + c: input_register_2, + d: data_address_space, + e: register_address_space, + f: input_register_3, + .. + } = instruction; + + let (read_sponge_ptr, sponge_ptr) = memory.read_cell(register_address_space, output_register); + let (read_init_pos, pos) = memory.read_cell(register_address_space, input_register_1); + let (read_arr_ptr, arr_ptr) = memory.read_cell(register_address_space, input_register_2); + let init_pos = pos.clone(); + + let mut pos = pos.as_canonical_u32() as usize; + let (read_len, len) = memory.read_cell(register_address_space, input_register_3); + let init_len = len.as_canonical_u32() as usize; + let mut len = len.as_canonical_u32() as usize; + + let mut header_record: TranscriptObservationRecord = TranscriptObservationRecord { + from_state, + instruction: instruction.clone(), + curr_timestamp_increment: 0, + is_first: true, + state_ptr: sponge_ptr, + input_ptr: arr_ptr, + init_pos, + len: init_len, + input_register_1, + input_register_2, + input_register_3, + output_register, + ..Default::default() + }; + header_record.read_input_data[0] = read_sponge_ptr; + header_record.read_input_data[1] = read_arr_ptr; + header_record.read_input_data[2] = read_init_pos; + header_record.read_input_data[3] = read_len; + + observation_records.push(header_record); + self.height += 1; + + // Observe bytes + let mut observation_chunks: Vec<(usize, usize, bool)> = vec![]; + while len > 0 { + if len >= (CHUNK - pos) { + observation_chunks.push((pos.clone(), CHUNK.clone(), true)); + len -= CHUNK - pos; + pos = 0; + } else { + observation_chunks.push((pos.clone(), pos + len, false)); + len = 0; + pos = pos + len; + } + } + + let mut curr_timestamp = 4usize; + let mut input_idx: usize = 0; + for chunk in observation_chunks { + let mut record: TranscriptObservationRecord = TranscriptObservationRecord { + from_state, + instruction: instruction.clone(), + + start_idx: chunk.0, + end_idx: chunk.1, + + curr_timestamp_increment: curr_timestamp, + state_ptr: sponge_ptr, + input_ptr: arr_ptr, + init_pos, + len: init_len, + curr_len: input_idx, + input_register_1, + input_register_2, + input_register_3, + output_register, + ..Default::default() + }; + + for j in chunk.0..chunk.1 { + let (n_read, n_f) = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(input_idx)); + record.read_input_data[j] = n_read; + record.input_data[j] = n_f; + input_idx += 1; + curr_timestamp += 1; + + let (n_write, _) = memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(j), n_f); + record.write_input_data[j] = n_write; + curr_timestamp += 1; + } + + if record.end_idx >= CHUNK { + let (read_sponge_record, permutation_input) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr); + let output = self.subchip.permute(permutation_input); + let (write_sponge_record, _) = memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i])); + + curr_timestamp += 2; + + record.should_permute = true; + record.read_sponge_state = read_sponge_record; + record.write_sponge_state = write_sponge_record; + record.permutation_input = permutation_input; + record.permutation_output = output; + } + + observation_records.push(record); + self.height += 1; + } + + let last_record = observation_records.last_mut().unwrap(); + let final_idx = last_record.end_idx % CHUNK; + let (write_final, _) = memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(final_idx)); + last_record.is_last = true; + last_record.write_final_idx = write_final; + last_record.final_idx = final_idx; + curr_timestamp += 1; + + for record in &mut observation_records { + record.final_timestamp_increment = curr_timestamp; + } + self.record_set.transcript_observation_records.extend(observation_records); } else if instruction.opcode == VERIFY_BATCH.global_opcode() { let &Instruction { a: dim_register, @@ -501,7 +665,9 @@ impl InstructionExecutor String::from("PERM_POS2") } else if opcode == COMP_POS2.global_opcode().as_usize() { String::from("COMP_POS2") - } else { + } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { + String::from("MULTI_OBSERVE") + }else { unreachable!("unsupported opcode: {}", opcode) } } diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index 6c47c23245..fe0fce881a 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -28,6 +28,8 @@ pub struct NativePoseidon2Cols { pub inside_row: T, /// Indicates that this row is a simple row. pub simple: T, + /// Indicates that this row is a multi_observe row. + pub multi_observe_row: T, /// Indicates the last row in an inside-row block. pub end_inside_row: T, @@ -60,15 +62,16 @@ pub struct NativePoseidon2Cols { /// indicates that cell `i + 1` inside a chunk is exhausted. pub is_exhausted: [T; CHUNK - 1], - pub specific: [T; max3( + pub specific: [T; max4( TopLevelSpecificCols::::width(), InsideRowSpecificCols::::width(), SimplePoseidonSpecificCols::::width(), + MultiObserveCols::::width(), )], } -const fn max3(a: usize, b: usize, c: usize) -> usize { - const_max(a, const_max(b, c)) +const fn max4(a: usize, b: usize, c: usize, d: usize) -> usize { + const_max(a, const_max(b, const_max(c, d))) } #[repr(C)] #[derive(AlignedBorrow)] @@ -200,3 +203,44 @@ pub struct SimplePoseidonSpecificCols { pub write_data_1: MemoryWriteAuxCols, pub write_data_2: MemoryWriteAuxCols, } + +#[repr(C)] +#[derive(AlignedBorrow, Copy, Clone)] +pub struct MultiObserveCols { + // Program states + pub pc: T, + pub final_timestamp_increment: T, + + // Initial reads from registers + pub state_ptr: T, + pub input_ptr: T, + pub init_pos: T, + pub len: T, + + pub is_first: T, + pub is_last: T, + pub curr_len: T, + pub start_idx: T, + pub end_idx: T, + pub aux_after_start: [T; CHUNK], + pub aux_before_end: [T; CHUNK], + + // Transcript observation + pub read_data: [MemoryReadAuxCols; CHUNK], + pub write_data: [MemoryWriteAuxCols; CHUNK], + pub data: [T; CHUNK], + + // Permutation + pub should_permute: T, + pub read_sponge_state: MemoryReadAuxCols, + pub write_sponge_state: MemoryWriteAuxCols, + + // Final write back and registers + pub write_final_idx: MemoryWriteAuxCols, + pub final_idx: T, + + pub input_register_1: T, + pub input_register_2: T, + pub input_register_3: T, + pub output_register: T, +} \ No newline at end of file diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 32a0e483a3..6d703c549b 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -434,7 +434,8 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { tester.write(e, lhs, data); - } + }, + MULTI_OBSERVE => {} } tester.execute(&mut chip, &instruction); @@ -449,6 +450,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester(e, dst); assert_eq!(hash, actual); } + MULTI_OBSERVE => {} } } tester.build().load(chip).finalize() diff --git a/extensions/native/circuit/src/poseidon2/trace.rs b/extensions/native/circuit/src/poseidon2/trace.rs index df8547767f..9a2c2b4cb1 100644 --- a/extensions/native/circuit/src/poseidon2/trace.rs +++ b/extensions/native/circuit/src/poseidon2/trace.rs @@ -15,18 +15,15 @@ use openvm_stark_backend::{ }; use crate::{ - chip::{SimplePoseidonRecord, NUM_INITIAL_READS}, - poseidon2::{ + chip::TranscriptObservationRecord, poseidon2::{ chip::{ - CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, - NativePoseidon2Chip, VerifyBatchRecord, + CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, NativePoseidon2Chip, SimplePoseidonRecord, VerifyBatchRecord, NUM_INITIAL_READS }, columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols }, CHUNK, - }, + } }; impl ChipUsageGetter for NativePoseidon2Chip @@ -432,6 +429,79 @@ impl NativePoseidon2Chip, + aux_cols_factory: &MemoryAuxColsFactory, + slice: &mut [F], + memory: &OfflineMemory, + ) { + self.generate_subair_cols(record.permutation_input, slice); + let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); + cols.very_first_timestamp = F::from_canonical_u32(record.from_state.timestamp); + cols.start_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.curr_timestamp_increment); + cols.multi_observe_row = F::ONE; + + let specific: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + specific.pc = F::from_canonical_u32(record.from_state.pc); + specific.final_timestamp_increment = F::from_canonical_usize(record.final_timestamp_increment); + specific.state_ptr = record.state_ptr; + specific.input_ptr = record.input_ptr; + specific.init_pos = record.init_pos; + specific.len = F::from_canonical_usize(record.len); + specific.curr_len = F::from_canonical_usize(record.curr_len); + + if record.is_first { + specific.is_first = F::ONE; + let read_state_ptr_record = memory.record_by_id(record.read_input_data[0]); + let read_input_ptr_record = memory.record_by_id(record.read_input_data[1]); + let read_init_pos_record = memory.record_by_id(record.read_input_data[2]); + let read_len_record = memory.record_by_id(record.read_input_data[3]); + aux_cols_factory.generate_read_aux(read_state_ptr_record, &mut specific.read_data[0]); + aux_cols_factory.generate_read_aux(read_init_pos_record, &mut specific.read_data[1]); + aux_cols_factory.generate_read_aux(read_input_ptr_record, &mut specific.read_data[2]); + aux_cols_factory.generate_read_aux(read_len_record, &mut specific.read_data[3]); + } else { + specific.start_idx = F::from_canonical_usize(record.start_idx); + specific.end_idx = F::from_canonical_usize(record.end_idx); + + for i in record.start_idx..CHUNK { + specific.aux_after_start[i] = F::ONE; + } + for i in 0..record.end_idx { + specific.aux_before_end[i] = F::ONE; + } + for i in record.start_idx..record.end_idx { + let read_data_record = memory.record_by_id(record.read_input_data[i]); + let write_data_record = memory.record_by_id(record.write_input_data[i]); + aux_cols_factory.generate_read_aux(read_data_record, &mut specific.read_data[i]); + aux_cols_factory.generate_write_aux(write_data_record, &mut specific.write_data[i]); + specific.data[i] = record.input_data[i]; + } + if record.should_permute { + let read_sponge_record = memory.record_by_id(record.read_sponge_state); + let write_sponge_record = memory.record_by_id(record.write_sponge_state); + aux_cols_factory.generate_read_aux(read_sponge_record, &mut specific.read_sponge_state); + aux_cols_factory.generate_write_aux(write_sponge_record, &mut specific.write_sponge_state); + specific.should_permute = F::ONE; + } + } + + if record.is_last { + specific.is_last = F::ONE; + specific.final_idx = F::from_canonical_usize(record.final_idx); + let write_final_idx_record = memory.record_by_id(record.write_final_idx); + aux_cols_factory.generate_write_aux(write_final_idx_record, &mut specific.write_final_idx); + } + + specific.input_register_1 = record.input_register_1; + specific.input_register_2 = record.input_register_2; + specific.input_register_3 = record.input_register_3; + specific.output_register = record.output_register; + } + fn generate_trace(self) -> RowMajorMatrix { let width = self.trace_width(); let height = next_power_of_two_or_zero(self.height); @@ -459,6 +529,15 @@ impl NativePoseidon2Chip + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } + DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + self.push( + AsmInstruction::Poseidon2MultiObserve(dst.fp(), init_pos.fp(), arr_ptr.fp(), len.get_var().fp()), + debug_info, + ); + }, DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) { (Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push( AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()), @@ -617,6 +623,15 @@ impl + TwoAdicField> AsmCo debug_info, ); } + DslIr::ExtFromBaseVec(ext, base_vec) => { + assert_eq!(base_vec.len(), EF::D); + for (i, base) in base_vec.into_iter().enumerate() { + self.push( + AsmInstruction::CopyF(ext.fp() + (i as i32), base.fp()), + debug_info.clone(), + ); + } + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index 1aa5ea8527..cd4990b08b 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -110,6 +110,11 @@ pub enum AsmInstruction { /// Halt. Halt, + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation + /// (sponge_state, init_pos, arr_ptr, len) + /// Returns the final index position of hash sponge + Poseidon2MultiObserve(i32, i32, i32, i32), + /// Perform a Poseidon2 permutation on state starting at address `lhs` /// and store new state at `rhs`. /// (a, b) are pointers to (lhs, rhs). @@ -334,6 +339,9 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), + AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + write!(f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", dst, init_pos, arr, len) + } AsmInstruction::Poseidon2Permute(dst, lhs) => { write!(f, "poseidon2_permute ({})fp, ({})fp", dst, lhs) } diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index ce108addaa..6b831e596b 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -493,11 +493,11 @@ impl Halo2ConstraintCompiler { } DslIr::CycleTrackerStart(_name) => { #[cfg(feature = "bench-metrics")] - cell_tracker.start(_name); + cell_tracker.start(_name, 0); } DslIr::CycleTrackerEnd(_name) => { #[cfg(feature = "bench-metrics")] - cell_tracker.end(_name); + cell_tracker.end(_name, 0); } DslIr::CircuitPublish(val, index) => { public_values[index] = vars[&val.0]; diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 9c3fc8d752..f8da82c30b 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -441,6 +441,18 @@ fn convert_instruction>( AS::Native, AS::Native, )], + AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + Instruction { + opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), + a: i32_f(dst), + b: i32_f(init), + c: i32_f(arr), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(len), + g: F::ZERO, + } + ], AsmInstruction::Poseidon2Compress(dst, src1, src2) => vec![inst( options.opcode_with_offset(Poseidon2Opcode::COMP_POS2), i32_f(dst), diff --git a/extensions/native/compiler/src/ir/builder.rs b/extensions/native/compiler/src/ir/builder.rs index 966c0db21c..64c62e64fa 100644 --- a/extensions/native/compiler/src/ir/builder.rs +++ b/extensions/native/compiler/src/ir/builder.rs @@ -620,6 +620,10 @@ impl Builder { self.witness_space.get(id.value()).unwrap() } + pub fn ext_from_base_vec(&mut self, ext: Ext, base_vec: Vec>) { + self.push(DslIr::ExtFromBaseVec(ext, base_vec)); + } + /// Throws an error. pub fn error(&mut self) { self.operations.trace_push(DslIr::Error()); diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 9159c5775b..24654cb70f 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,6 +208,13 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations (output = p2_multi_observe(array, els)). + Poseidon2MultiObserve( + Ptr, // sponge_state + Var, // initial input_ptr position + Ptr, // input array (base elements) + Usize, // len of els + ), // Miscellaneous instructions. /// Prints a variable. @@ -241,6 +248,9 @@ pub enum DslIr { /// Operation to halt the program. Should be the last instruction in the program. Halt, + /// Packs a vector of felts into an ext. + ExtFromBaseVec(Ext, Vec>), + // Public inputs for circuits. /// Publish a field element as the ith public value. Should only be used when target is a /// circuit. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 12ec526c89..ee9bc0d87d 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -1,6 +1,8 @@ use openvm_native_compiler_derive::iter_zip; use openvm_stark_backend::p3_field::FieldAlgebra; +use crate::ir::Variable; + use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; pub const DIGEST_SIZE: usize = 8; @@ -8,6 +10,47 @@ pub const HASH_RATE: usize = 8; pub const PERMUTATION_WIDTH: usize = 16; impl Builder { + /// Extends native VM ability to observe multiple base elements in one opcode operation + /// Absorbs elements sequentially at the RATE portion of sponge state and performs as many permutations as necessary. + /// Returns the index position of the next input_ptr. + /// + /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) + pub fn poseidon2_multi_observe( + &mut self, + sponge_state: &Array>, + input_ptr: Ptr, + arr: &Array>, + ) -> Usize { + let buffer_size: Var = Var::uninit(self); + self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); + + match sponge_state { + Array::Fixed(_) => { + panic!("Poseidon2 permutation is not allowed on fixed arrays"); + } + Array::Dyn(sponge_ptr, _) => { + match arr { + Array::Fixed(_) => { + panic!("Base elements input must be dynamic"); + } + Array::Dyn(ptr, len) => { + let init_pos: Var = Var::uninit(self); + self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + + self.operations.push(DslIr::Poseidon2MultiObserve( + *sponge_ptr, + init_pos, + *ptr, + len.clone(), + )); + + Usize::Var(init_pos) + } + } + } + } + } + /// Applies the Poseidon2 permutation to the given array. /// /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index ef28b37139..66c786fbd9 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -184,6 +184,7 @@ pub enum NativePhantom { pub enum Poseidon2Opcode { PERM_POS2, COMP_POS2, + MULTI_OBSERVE, } /// Opcodes for FRI opening proofs. diff --git a/extensions/native/compiler/tests/ext.rs b/extensions/native/compiler/tests/ext.rs index 5da70cb53b..70584fd926 100644 --- a/extensions/native/compiler/tests/ext.rs +++ b/extensions/native/compiler/tests/ext.rs @@ -35,7 +35,7 @@ fn test_ext2felt() { } #[test] -fn test_ext_from_base_slice() { +fn test_ext_from_base_vec() { const D: usize = 4; type F = BabyBear; type EF = BinomialExtensionField; @@ -52,8 +52,9 @@ fn test_ext_from_base_slice() { let val = EF::from_base_slice(base_slice); let expected: Ext<_, _> = builder.constant(val); - let felts = base_slice.map(|e| builder.constant::>(e)); - let actual = builder.ext_from_base_slice(&felts); + let felts = base_slice.map(|e| builder.constant::>(e)).to_vec(); + let actual = builder.uninit(); + builder.ext_from_base_vec(actual, felts); builder.assert_ext_eq(actual, expected); builder.halt(); diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 7c0cd4dd88..2d45d896be 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -101,7 +101,9 @@ impl DuplexChallengerVariable { let b = self.sample(builder); let c = self.sample(builder); let d = self.sample(builder); - builder.ext_from_base_slice(&[a, b, c, d]) + let ext = builder.uninit(); + builder.ext_from_base_vec(ext, vec![a, b, c, d]); + ext } fn sample_bits(&self, builder: &mut Builder, nb_bits: RVar) -> Array> diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 8f354f3316..2147a18203 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -1,13 +1,30 @@ -use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmConfig, VmExecutor}; +use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmConfig, VmExecutor, verify_single, VirtualMachine,}; use openvm_native_circuit::{Native, NativeConfig}; -use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; -use openvm_native_recursion::testing_utils::inner::run_recursive_test; +use openvm_native_compiler::{ + prelude::*, + asm::{AsmBuilder, AsmCompiler}, ir::Felt, + conversion::{convert_program, CompilerOptions}, +}; +use openvm_native_recursion::{testing_utils::inner::run_recursive_test, challenger::duplex::DuplexChallengerVariable}; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, p3_commit::PolynomialSpace, p3_field::{extension::BinomialExtensionField, FieldAlgebra}, }; -use openvm_stark_sdk::{config::FriParameters, p3_baby_bear::BabyBear, utils::ProofInputForTest}; +use openvm_stark_sdk::{ + config::FriParameters, + p3_baby_bear::BabyBear, + utils::ProofInputForTest, + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Engine, + fri_params::standard_fri_params_with_100_bits_conjectured_security, + }, + engine::StarkFriEngine, + utils::create_seeded_rng, +}; +use rand::Rng; +pub type F = BabyBear; +pub type E = BinomialExtensionField; fn fibonacci_program(a: u32, b: u32, n: u32) -> Program { type F = BabyBear; @@ -79,3 +96,80 @@ fn test_fibonacci_program_halo2_verify() { let fib_program_stark = fibonacci_program_test_proof_input(0, 1, 32); run_static_verifier_test(fib_program_stark, FriParameters::new_for_testing(3)); } + +#[test] +fn test_multi_observe() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + // Fill in test program logic + builder.halt(); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + // let program = Program::from_instructions(&instructions); + let program: Program<_> = convert_program(asm_code, compilation_options); + + let poseidon2_max_constraint_degree = 3; + + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vm = VirtualMachine::new(engine, config); + + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, vec![]).unwrap(); + let proofs = vm.prove(&pk, result); + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + } +} + +fn build_test_program( + builder: &mut Builder, +) { + let sample_lens: Vec = vec![10, 2, 0, 3, 20]; + + let mut rng = create_seeded_rng(); + let challenger = DuplexChallengerVariable::new(builder); + + for l in sample_lens { + let sample_input: Array> = builder.dyn_array(l); + builder.range(0, l).for_each(|idx_vec, builder| { + let f_u32: u32 = rng.gen_range(1..1 << 30); + builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); + }); + + let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &sample_input); + + builder.assign( + &challenger.input_ptr, + challenger.io_empty_ptr + next_input_ptr.clone(), + ); + builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); + }, + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_full_ptr); + }, + ); + } +}