@@ -98,28 +98,28 @@ impl SimulationRow {
98
98
99
99
match self . instruction {
100
100
Instruction :: Add ( a, b) => {
101
- registers[ usize:: from ( a) ] += registers[ usize:: from ( b) ]
101
+ registers[ usize:: from ( a) ] = registers[ usize:: from ( a) ]
102
+ . wrapping_add ( registers[ usize:: from ( b) ] ) ;
102
103
}
103
104
Instruction :: Sub ( a, b) => {
104
- registers[ usize:: from ( a) ] -= registers[ usize:: from ( b) ]
105
+ registers[ usize:: from ( a) ] = registers[ usize:: from ( a) ]
106
+ . wrapping_sub ( registers[ usize:: from ( b) ] ) ;
105
107
}
106
108
Instruction :: Mul ( a, b) => {
107
- registers[ usize:: from ( a) ] *= registers[ usize:: from ( b) ]
109
+ registers[ usize:: from ( a) ] = registers[ usize:: from ( a) ]
110
+ . wrapping_mul ( registers[ usize:: from ( b) ] ) ;
108
111
}
109
112
Instruction :: Div ( a, b) => {
110
- registers[ usize:: from ( a) ] /= registers[ usize:: from ( b) ]
113
+ registers[ usize:: from ( a) ] = registers[ usize:: from ( a) ]
114
+ . wrapping_div ( registers[ usize:: from ( b) ] ) ;
111
115
}
112
- Instruction :: Bsl ( reg, amount) => {
113
- if registers[ usize:: from ( amount) ] >= 8 {
114
- return Err ( anyhow ! ( "invalid shift amount" ) ) ;
115
- }
116
- registers[ usize:: from ( reg) ] <<= registers[ usize:: from ( amount) ] ;
116
+ Instruction :: Shl ( reg, amount) => {
117
+ registers[ usize:: from ( reg) ] = registers[ usize:: from ( reg) ]
118
+ . wrapping_shl ( registers[ usize:: from ( amount) ] . into ( ) ) ;
117
119
}
118
- Instruction :: Bsr ( reg, amount) => {
119
- if registers[ usize:: from ( amount) ] >= 8 {
120
- return Err ( anyhow ! ( "invalid shift amount" ) ) ;
121
- }
122
- registers[ usize:: from ( reg) ] >>= registers[ usize:: from ( amount) ] ;
120
+ Instruction :: Shr ( reg, amount) => {
121
+ registers[ usize:: from ( reg) ] = registers[ usize:: from ( reg) ]
122
+ . wrapping_shr ( registers[ usize:: from ( amount) ] . into ( ) ) ;
123
123
}
124
124
Instruction :: Lb ( reg, memloc) => {
125
125
registers[ usize:: from ( reg) ] = self
@@ -197,3 +197,60 @@ impl PreflightSimulation {
197
197
Ok ( Self { trace_rows } )
198
198
}
199
199
}
200
+
201
+ #[ cfg( test) ]
202
+ mod tests {
203
+ use super :: * ;
204
+ use std:: collections:: HashMap ;
205
+
206
+ use crate :: vm_specs:: {
207
+ Instruction ,
208
+ MemoryLocation ,
209
+ Program ,
210
+ Register ,
211
+ } ;
212
+
213
+ #[ test]
214
+ /// Tests whether two numbers in memory can be added together
215
+ /// in the ZKVM
216
+ fn test_preflight_add_memory ( ) {
217
+ let instructions = vec ! [
218
+ Instruction :: Lb ( Register :: R0 , MemoryLocation ( 0x40 ) ) ,
219
+ Instruction :: Lb ( Register :: R1 , MemoryLocation ( 0x41 ) ) ,
220
+ Instruction :: Add ( Register :: R0 , Register :: R1 ) ,
221
+ Instruction :: Sb ( Register :: R0 , MemoryLocation ( 0x42 ) ) ,
222
+ Instruction :: Halt ,
223
+ ] ;
224
+
225
+ let code = instructions
226
+ . into_iter ( )
227
+ . enumerate ( )
228
+ . map ( |( idx, inst) | ( idx as u8 , inst) )
229
+ . collect :: < HashMap < u8 , Instruction > > ( ) ;
230
+
231
+ let memory_init: HashMap < u8 , u8 > =
232
+ HashMap :: from_iter ( vec ! [ ( 0x40 , 0x20 ) , ( 0x41 , 0x45 ) ] ) ;
233
+
234
+ let program = Program {
235
+ entry_point : 0 ,
236
+ code,
237
+ memory_init,
238
+ } ;
239
+
240
+ let expected = ( 0x42 , 0x65 ) ;
241
+
242
+ let simulation = PreflightSimulation :: simulate ( & program) ;
243
+ assert ! ( simulation. is_ok( ) ) ;
244
+ let simulation = simulation. unwrap ( ) ;
245
+
246
+ assert_eq ! (
247
+ simulation. trace_rows[ simulation
248
+ . trace_rows
249
+ . len( )
250
+ - 1 ]
251
+ . get_memory_at( & expected. 0 )
252
+ . unwrap( ) ,
253
+ expected. 1
254
+ ) ;
255
+ }
256
+ }
0 commit comments