diff --git a/src/lib.rs b/src/lib.rs index 9a2a6de..713bce9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,62 +19,3 @@ mod vm_specs; // STARK tables ------------- #[allow(dead_code)] mod stark_program_instructions; - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use crate::{ - preflight_simulator::PreflightSimulation, - vm_specs::{ - Instruction, - MemoryLocation, - Program, - Register, - }, - }; - - #[test] - /// Tests whether two numbers in memory can be added together - /// in the ZKVM - fn test_preflight_add_memory() { - let instructions = vec![ - Instruction::Lb(Register::R0, MemoryLocation(0x40)), - Instruction::Lb(Register::R1, MemoryLocation(0x41)), - Instruction::Add(Register::R0, Register::R1), - Instruction::Sb(Register::R0, MemoryLocation(0x42)), - Instruction::Halt, - ]; - - let code = instructions - .into_iter() - .enumerate() - .map(|(idx, inst)| (idx as u8, inst)) - .collect::>(); - - let memory_init: HashMap = - HashMap::from_iter(vec![(0x40, 0x20), (0x41, 0x45)]); - - let program = Program { - entry_point: 0, - code, - memory_init, - }; - - let expected = (0x42, 0x65); - - let simulation = PreflightSimulation::simulate(&program); - assert!(simulation.is_ok()); - let simulation = simulation.unwrap(); - - assert_eq!( - simulation.trace_rows[simulation - .trace_rows - .len() - - 1] - .get_memory_at(&expected.0) - .unwrap(), - expected.1 - ); - } -} diff --git a/src/preflight_simulator.rs b/src/preflight_simulator.rs index e26308b..edc29b4 100644 --- a/src/preflight_simulator.rs +++ b/src/preflight_simulator.rs @@ -98,28 +98,28 @@ impl SimulationRow { match self.instruction { Instruction::Add(a, b) => { - registers[usize::from(a)] += registers[usize::from(b)] + registers[usize::from(a)] = registers[usize::from(a)] + .wrapping_add(registers[usize::from(b)]); } Instruction::Sub(a, b) => { - registers[usize::from(a)] -= registers[usize::from(b)] + registers[usize::from(a)] = registers[usize::from(a)] + .wrapping_sub(registers[usize::from(b)]); } Instruction::Mul(a, b) => { - registers[usize::from(a)] *= registers[usize::from(b)] + registers[usize::from(a)] = registers[usize::from(a)] + .wrapping_mul(registers[usize::from(b)]); } Instruction::Div(a, b) => { - registers[usize::from(a)] /= registers[usize::from(b)] + registers[usize::from(a)] = registers[usize::from(a)] + .wrapping_div(registers[usize::from(b)]); } - Instruction::Bsl(reg, amount) => { - if registers[usize::from(amount)] >= 8 { - return Err(anyhow!("invalid shift amount")); - } - registers[usize::from(reg)] <<= registers[usize::from(amount)]; + Instruction::Shl(reg, amount) => { + registers[usize::from(reg)] = registers[usize::from(reg)] + .wrapping_shl(registers[usize::from(amount)].into()); } - Instruction::Bsr(reg, amount) => { - if registers[usize::from(amount)] >= 8 { - return Err(anyhow!("invalid shift amount")); - } - registers[usize::from(reg)] >>= registers[usize::from(amount)]; + Instruction::Shr(reg, amount) => { + registers[usize::from(reg)] = registers[usize::from(reg)] + .wrapping_shr(registers[usize::from(amount)].into()); } Instruction::Lb(reg, memloc) => { registers[usize::from(reg)] = self @@ -197,3 +197,60 @@ impl PreflightSimulation { Ok(Self { trace_rows }) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + use crate::vm_specs::{ + Instruction, + MemoryLocation, + Program, + Register, + }; + + #[test] + /// Tests whether two numbers in memory can be added together + /// in the ZKVM + fn test_preflight_add_memory() { + let instructions = vec![ + Instruction::Lb(Register::R0, MemoryLocation(0x40)), + Instruction::Lb(Register::R1, MemoryLocation(0x41)), + Instruction::Add(Register::R0, Register::R1), + Instruction::Sb(Register::R0, MemoryLocation(0x42)), + Instruction::Halt, + ]; + + let code = instructions + .into_iter() + .enumerate() + .map(|(idx, inst)| (idx as u8, inst)) + .collect::>(); + + let memory_init: HashMap = + HashMap::from_iter(vec![(0x40, 0x20), (0x41, 0x45)]); + + let program = Program { + entry_point: 0, + code, + memory_init, + }; + + let expected = (0x42, 0x65); + + let simulation = PreflightSimulation::simulate(&program); + assert!(simulation.is_ok()); + let simulation = simulation.unwrap(); + + assert_eq!( + simulation.trace_rows[simulation + .trace_rows + .len() + - 1] + .get_memory_at(&expected.0) + .unwrap(), + expected.1 + ); + } +} diff --git a/src/vm_specs.rs b/src/vm_specs.rs index 6b31a2b..4d994a9 100644 --- a/src/vm_specs.rs +++ b/src/vm_specs.rs @@ -30,8 +30,8 @@ pub enum Instruction { Sub(Register, Register), Mul(Register, Register), Div(Register, Register), - Bsl(Register, Register), - Bsr(Register, Register), + Shl(Register, Register), + Shr(Register, Register), Lb(Register, MemoryLocation), Sb(Register, MemoryLocation), #[default] @@ -47,8 +47,8 @@ impl Instruction { Instruction::Sub(_, _) => 1, Instruction::Mul(_, _) => 2, Instruction::Div(_, _) => 3, - Instruction::Bsl(_, _) => 4, - Instruction::Bsr(_, _) => 5, + Instruction::Shl(_, _) => 4, + Instruction::Shr(_, _) => 5, Instruction::Lb(_, _) => 6, Instruction::Sb(_, _) => 7, Instruction::Halt => 8,