diff --git a/rvgo/fast/instrumented.go b/rvgo/fast/instrumented.go index 6ac8b11a..da8796cd 100644 --- a/rvgo/fast/instrumented.go +++ b/rvgo/fast/instrumented.go @@ -55,9 +55,6 @@ func (m *InstrumentedState) Step(proof bool) (wit *StepWitness, err error) { } err = m.riscvStep() - if err != nil { - return nil, err - } if proof { wit.MemProof = make([]byte, 0, len(m.memProofs)*memProofSize) @@ -70,6 +67,7 @@ func (m *InstrumentedState) Step(proof bool) (wit *StepWitness, err error) { wit.PreimageValue = m.lastPreimage } } + return } diff --git a/rvgo/fast/vm.go b/rvgo/fast/vm.go index d9d77100..8a5f0e8f 100644 --- a/rvgo/fast/vm.go +++ b/rvgo/fast/vm.go @@ -16,6 +16,22 @@ func (e *UnsupportedSyscallErr) Error() string { return fmt.Sprintf("unsupported system call: %d", e.SyscallNum) } +type UnrecognizedSyscallErr struct { + SyscallNum U64 +} + +func (e *UnrecognizedSyscallErr) Error() string { + return fmt.Sprintf("unrecognized system call: %d", e.SyscallNum) +} + +type UnrecognizedResourceErr struct { + Resource U64 +} + +func (e *UnrecognizedResourceErr) Error() string { + return fmt.Sprintf("unrecognized resource limit lookup: %d", e.Resource) +} + // riscvStep runs a single instruction // Note: errors are only returned in debugging/tooling modes, not in production use. func (inst *InstrumentedState) riscvStep() (outErr error) { @@ -123,7 +139,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { getRegister := func(reg U64) U64 { if reg > 31 { - revertWithCode(0xbad4e9, fmt.Errorf("cannot load invalid register: %d", reg)) + revertWithCode(riscv.ErrInvalidRegister, fmt.Errorf("cannot load invalid register: %d", reg)) } //fmt.Printf("load reg %2d: %016x\n", reg, state.Registers[reg]) return s.Registers[reg] @@ -150,7 +166,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { getMemoryB32 := func(addr U64, proofIndex uint8) (out [32]byte) { if addr&31 != 0 { // quick addr alignment check - revertWithCode(0xbad10ad0, fmt.Errorf("addr %d not aligned with 32 bytes", addr)) + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("addr %d not aligned with 32 bytes", addr)) } inst.trackMemAccess(addr, proofIndex) s.Memory.GetUnaligned(addr, out[:]) @@ -168,12 +184,12 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { // load unaligned, optionally signed, little-endian, integer of 1 ... 8 bytes from memory loadMem := func(addr U64, size U64, signed bool, proofIndexL uint8, proofIndexR uint8) (out U64) { if size > 8 { - revertWithCode(0xbad512e0, fmt.Errorf("cannot load more than 8 bytes: %d", size)) + revertWithCode(riscv.ErrLoadExceeds8Bytes, fmt.Errorf("cannot load more than 8 bytes: %d", size)) } inst.trackMemAccess(addr&^31, proofIndexL) if (addr+size-1)&^31 != addr&^31 { if proofIndexR == 0xff { - revertWithCode(0xbad22220, fmt.Errorf("unexpected need for right-side proof %d in loadMem", proofIndexR)) + revertWithCode(riscv.ErrUnexpectedRProofLoad, fmt.Errorf("unexpected need for right-side proof %d in loadMem", proofIndexR)) } inst.trackMemAccess((addr+size-1)&^31, proofIndexR) } @@ -190,7 +206,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { storeMemUnaligned := func(addr U64, size U64, value U256, proofIndexL uint8, proofIndexR uint8, verifyL bool, verifyR bool) { if size > 32 { - revertWithCode(0xbad512e1, fmt.Errorf("cannot store more than 32 bytes: %d", size)) + revertWithCode(riscv.ErrStoreExceeds32Bytes, fmt.Errorf("cannot store more than 32 bytes: %d", size)) } var bytez [32]byte binary.LittleEndian.PutUint64(bytez[:8], value[0]) @@ -208,7 +224,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { return } if proofIndexR == 0xff { - revertWithCode(0xbad22221, fmt.Errorf("unexpected need for right-side proof %d in storeMemUnaligned", proofIndexR)) + revertWithCode(riscv.ErrUnexpectedRProofStoreUnaligned, fmt.Errorf("unexpected need for right-side proof %d in storeMemUnaligned", proofIndexR)) } // if not aligned rightAddr := leftAddr + 32 @@ -223,7 +239,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { storeMem := func(addr U64, size U64, value U64, proofIndexL uint8, proofIndexR uint8, verifyL bool, verifyR bool) { if size > 8 { - revertWithCode(0xbad512e8, fmt.Errorf("cannot store more than 8 bytes: %d", size)) + revertWithCode(riscv.ErrStoreExceeds8Bytes, fmt.Errorf("cannot store more than 8 bytes: %d", size)) } var bytez [8]byte binary.LittleEndian.PutUint64(bytez[:], value) @@ -238,7 +254,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { } // if not aligned if proofIndexR == 0xff { - revertWithCode(0xbad2222f, fmt.Errorf("unexpected need for right-side proof %d in storeMem", proofIndexR)) + revertWithCode(riscv.ErrUnexpectedRProofStore, fmt.Errorf("unexpected need for right-side proof %d in storeMem", proofIndexR)) } rightAddr := leftAddr + 32 leftSize := rightAddr - addr @@ -271,7 +287,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { case 3: // ?11 = CSRRC(I) v = and64(out, not64(v)) default: - revertWithCode(0xbadc0de0, fmt.Errorf("unkwown CSR mode: %d", mode)) + revertWithCode(riscv.ErrUnknownCSRMode, fmt.Errorf("unkwown CSR mode: %d", mode)) } writeCSR(num, v) return @@ -315,7 +331,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { pdatB32, pdatlen, err := inst.readPreimage(preImageKey, offset) // pdat is left-aligned if err != nil { - revertWithCode(0xbadf00d0, err) + revertWithCode(riscv.ErrFailToReadPreimage, err) } if iszero64(pdatlen) { // EOF return toU64(0) @@ -538,7 +554,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { setRegister(toU64(10), toU64(0)) setRegister(toU64(11), toU64(0)) default: - revertWithCode(0xf0012, fmt.Errorf("unrecognized resource limit lookup: %d", res)) + revertWithCode(riscv.ErrUnrecognizedResource, &UnrecognizedResourceErr{Resource: res}) } case riscv.SysMadvise: // madvise - ignored setRegister(toU64(10), toU64(0)) @@ -568,13 +584,13 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { setRegister(toU64(10), toU64(0)) setRegister(toU64(11), toU64(0)) case riscv.SysPrlimit64: // prlimit64 -- unsupported, we have getrlimit, is prlimit64 even called? - revertWithCode(0xf001ca11, &UnsupportedSyscallErr{SyscallNum: a7}) + revertWithCode(riscv.ErrInvalidSyscall, &UnsupportedSyscallErr{SyscallNum: a7}) case riscv.SysFutex: // futex - not supported, for now - revertWithCode(0xf001ca11, &UnsupportedSyscallErr{SyscallNum: a7}) + revertWithCode(riscv.ErrInvalidSyscall, &UnsupportedSyscallErr{SyscallNum: a7}) case riscv.SysNanosleep: // nanosleep - not supported, for now - revertWithCode(0xf001ca11, &UnsupportedSyscallErr{SyscallNum: a7}) + revertWithCode(riscv.ErrInvalidSyscall, &UnsupportedSyscallErr{SyscallNum: a7}) default: - revertWithCode(0xf001ca11, fmt.Errorf("unrecognized system call: %d", a7)) + revertWithCode(riscv.ErrInvalidSyscall, &UnrecognizedSyscallErr{SyscallNum: a7}) } } @@ -889,7 +905,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { // 0b011 == RV64A D variants size := shl64(funct3, toU64(1)) if lt64(size, toU64(4)) != 0 { - revertWithCode(0xbada70, fmt.Errorf("bad AMO size: %d", size)) + revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size)) } addr := getRegister(rs1) // TODO check if addr is aligned @@ -945,7 +961,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { v = value } default: - revertWithCode(0xf001a70, fmt.Errorf("unknown atomic operation %d", op)) + revertWithCode(riscv.ErrUnknownAtomicOperation, fmt.Errorf("unknown atomic operation %d", op)) } storeMem(addr, size, v, 1, 3, false, true) // after overwriting 1, proof 2 is no longer valid setRegister(rd, rdValue) @@ -963,7 +979,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { case 0x53: // FADD etc. no-op is enough to pass Go runtime check setPC(add64(pc, toU64(4))) // no-op this. default: - revertWithCode(0xf001c0de, fmt.Errorf("unknown instruction opcode: %d", opcode)) + revertWithCode(riscv.ErrUnknownOpCode, fmt.Errorf("unknown instruction opcode: %d", opcode)) } return nil } diff --git a/rvgo/riscv/constants.go b/rvgo/riscv/constants.go index 142f2557..9360e868 100644 --- a/rvgo/riscv/constants.go +++ b/rvgo/riscv/constants.go @@ -38,4 +38,21 @@ const ( FdHintWrite = 4 FdPreimageRead = 5 FdPreimageWrite = 6 + + ErrUnrecognizedResource = uint64(0xf0012) + ErrUnknownAtomicOperation = uint64(0xf001a70) + ErrUnknownOpCode = uint64(0xf001c0de) + ErrInvalidSyscall = uint64(0xf001ca11) + ErrInvalidRegister = uint64(0xbad4e9) + ErrNotAlignedAddr = uint64(0xbad10ad0) + ErrLoadExceeds8Bytes = uint64(0xbad512e0) + ErrStoreExceeds8Bytes = uint64(0xbad512e8) + ErrStoreExceeds32Bytes = uint64(0xbad512e1) + ErrUnexpectedRProofLoad = uint64(0xbad22220) + ErrUnexpectedRProofStoreUnaligned = uint64(0xbad22221) + ErrUnexpectedRProofStore = uint64(0xbad2222f) + ErrUnknownCSRMode = uint64(0xbadc0de0) + ErrBadAMOSize = uint64(0xbada70) + ErrFailToReadPreimage = uint64(0xbadf00d0) + ErrBadMemoryProof = uint64(0xbadf00d1) ) diff --git a/rvgo/slow/vm.go b/rvgo/slow/vm.go index adfbdd77..da4b43e9 100644 --- a/rvgo/slow/vm.go +++ b/rvgo/slow/vm.go @@ -60,6 +60,22 @@ func (e *UnsupportedSyscallErr) Error() string { return fmt.Sprintf("unsupported system call: %d", e.SyscallNum) } +type UnrecognizedSyscallErr struct { + SyscallNum U64 +} + +func (e *UnrecognizedSyscallErr) Error() string { + return fmt.Sprintf("unrecognized system call: %d", e.SyscallNum) +} + +type UnrecognizedResourceErr struct { + Resource U64 +} + +func (e *UnrecognizedResourceErr) Error() string { + return fmt.Sprintf("unrecognized resource limit lookup: %d", e.Resource) +} + type PreimageOracle interface { ReadPreimagePart(key [32]byte, offset uint64) (dat [32]byte, datlen uint8, err error) } @@ -67,8 +83,13 @@ type PreimageOracle interface { func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr error) { var revertCode uint64 defer func() { - if err := recover(); err != nil { - outErr = fmt.Errorf("revert: %v", err) + if errInterface := recover(); errInterface != nil { + if err, ok := errInterface.(error); ok { + outErr = fmt.Errorf("revert: %w", err) + } else { + outErr = fmt.Errorf("revert: %v", err) + } + } if revertCode != 0 { outErr = fmt.Errorf("revert %x: %w", revertCode, outErr) @@ -196,7 +217,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err getRegister := func(reg U64) U64 { if gt64(reg, toU64(31)) != (U64{}) { - revertWithCode(0xbad4e9, fmt.Errorf("cannot load invalid register: %d", reg.val())) + revertWithCode(riscv.ErrInvalidRegister, fmt.Errorf("cannot load invalid register: %d", reg.val())) } //fmt.Printf("load reg %2d: %016x\n", reg, state.Registers[reg]) offset := add64(toU64(stateOffsetRegisters), mul64(reg, toU64(8))) @@ -209,7 +230,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err return } if gt64(reg, toU64(31)) != (U64{}) { - revertWithCode(0xbad4e9, fmt.Errorf("unknown register %d, cannot write %x", reg.val(), v.val())) + revertWithCode(riscv.ErrInvalidRegister, fmt.Errorf("unknown register %d, cannot write %x", reg.val(), v.val())) } offset := add64(toU64(stateOffsetRegisters), mul64(reg, toU64(8))) writeState(offset.val(), 8, encodeU64BE(v)) @@ -261,7 +282,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err getMemoryB32 := func(addr U64, proofIndex uint8) (out [32]byte) { if and64(addr, toU64(31)) != (U64{}) { // quick addr alignment check - revertWithCode(0xbad10ad0, fmt.Errorf("addr %d not aligned with 32 bytes", addr)) + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("addr %d not aligned with 32 bytes", addr)) } offset := proofOffset(proofIndex) leaf := calldataload(offset) @@ -281,7 +302,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err } memRoot := getMemRoot() if iszero(eq(b32asBEWord(node), b32asBEWord(memRoot))) { // verify the root matches - revertWithCode(0xbadf00d1, fmt.Errorf("bad memory proof, got mem root: %x, expected %x", node, memRoot)) + revertWithCode(riscv.ErrBadMemoryProof, fmt.Errorf("bad memory proof, got mem root: %x, expected %x", node, memRoot)) } out = leaf return @@ -291,7 +312,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err // it assumes the same memory proof has been verified with getMemoryB32 setMemoryB32 := func(addr U64, v [32]byte, proofIndex uint8) { if and64(addr, toU64(31)) != (U64{}) { - revertWithCode(0xbad10ad0, fmt.Errorf("addr %d not aligned with 32 bytes", addr)) + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("addr %d not aligned with 32 bytes", addr)) } offset := proofOffset(proofIndex) leaf := v @@ -315,7 +336,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err // load unaligned, optionally signed, little-endian, integer of 1 ... 8 bytes from memory loadMem := func(addr U64, size U64, signed bool, proofIndexL uint8, proofIndexR uint8) (out U64) { if size.val() > 8 { - revertWithCode(0xbad512e0, fmt.Errorf("cannot load more than 8 bytes: %d", size)) + revertWithCode(riscv.ErrLoadExceeds8Bytes, fmt.Errorf("cannot load more than 8 bytes: %d", size)) } // load/verify left part leftAddr := and64(addr, not64(toU64(31))) @@ -329,7 +350,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err if iszero64(eq64(leftAddr, rightAddr)) { // if unaligned, use second proof for the right part if proofIndexR == 0xff { - revertWithCode(0xbad22220, fmt.Errorf("unexpected need for right-side proof %d in loadMem", proofIndexR)) + revertWithCode(riscv.ErrUnexpectedRProofLoad, fmt.Errorf("unexpected need for right-side proof %d in loadMem", proofIndexR)) } // load/verify right part right = b32asBEWord(getMemoryB32(rightAddr, proofIndexR)) @@ -400,7 +421,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err storeMemUnaligned := func(addr U64, size U64, value U256, proofIndexL uint8, proofIndexR uint8) { if size.val() > 32 { - revertWithCode(0xbad512e1, fmt.Errorf("cannot store more than 32 bytes: %d", size)) + revertWithCode(riscv.ErrStoreExceeds32Bytes, fmt.Errorf("cannot store more than 32 bytes: %d", size)) } leftAddr := and64(addr, not64(toU64(31))) @@ -420,7 +441,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err return } if proofIndexR == 0xff { - revertWithCode(0xbad22221, fmt.Errorf("unexpected need for right-side proof %d in storeMemUnaligned", proofIndexR)) + revertWithCode(riscv.ErrUnexpectedRProofStoreUnaligned, fmt.Errorf("unexpected need for right-side proof %d in storeMemUnaligned", proofIndexR)) } // load the right base (with updated mem root) right := b32asBEWord(getMemoryB32(rightAddr, proofIndexR)) @@ -454,7 +475,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err case 3: // ?11 = CSRRC(I) v = and64(out, not64(v)) default: - revertWithCode(0xbadc0de0, fmt.Errorf("unkwown CSR mode: %d", mode.val())) + revertWithCode(riscv.ErrUnknownCSRMode, fmt.Errorf("unkwown CSR mode: %d", mode.val())) } writeCSR(num, v) return @@ -500,7 +521,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err datlen = toU64(l) return } - revertWithCode(0xbadf00d0, err) + revertWithCode(riscv.ErrFailToReadPreimage, err) return } @@ -713,7 +734,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err setRegister(toU64(10), toU64(0)) setRegister(toU64(11), toU64(0)) default: - revertWithCode(0xf0012, fmt.Errorf("unrecognized resource limit lookup: %d", res)) + revertWithCode(riscv.ErrUnrecognizedResource, &UnrecognizedResourceErr{Resource: res}) } case riscv.SysMadvise: // madvise - ignored setRegister(toU64(10), toU64(0)) @@ -743,13 +764,13 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err setRegister(toU64(10), toU64(0)) setRegister(toU64(11), toU64(0)) case riscv.SysPrlimit64: // prlimit64 -- unsupported, we have getrlimit, is prlimit64 even called? - revertWithCode(0xf001ca11, &UnsupportedSyscallErr{SyscallNum: a7}) + revertWithCode(riscv.ErrInvalidSyscall, &UnsupportedSyscallErr{SyscallNum: a7}) case riscv.SysFutex: // futex - not supported, for now - revertWithCode(0xf001ca11, &UnsupportedSyscallErr{SyscallNum: a7}) + revertWithCode(riscv.ErrInvalidSyscall, &UnsupportedSyscallErr{SyscallNum: a7}) case riscv.SysNanosleep: // nanosleep - not supported, for now - revertWithCode(0xf001ca11, &UnsupportedSyscallErr{SyscallNum: a7}) + revertWithCode(riscv.ErrInvalidSyscall, &UnsupportedSyscallErr{SyscallNum: a7}) default: - revertWithCode(0xf001ca11, fmt.Errorf("unrecognized system call: %d", a7)) + revertWithCode(riscv.ErrInvalidSyscall, &UnrecognizedSyscallErr{SyscallNum: a7}) } } @@ -1064,7 +1085,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err // 0b011 == RV64A D variants size := shl64(funct3, toU64(1)) if lt64(size, toU64(4)) != (U64{}) { - revertWithCode(0xbada70, fmt.Errorf("bad AMO size: %d", size)) + revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size)) } addr := getRegister(rs1) // TODO check if addr is aligned @@ -1120,7 +1141,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err v = value } default: - revertWithCode(0xf001a70, fmt.Errorf("unknown atomic operation %d", op)) + revertWithCode(riscv.ErrUnknownAtomicOperation, fmt.Errorf("unknown atomic operation %d", op)) } storeMem(addr, size, v, 1, 3) // after overwriting 1, proof 2 is no longer valid setRegister(rd, rdValue) @@ -1138,7 +1159,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err case 0x53: // FADD etc. no-op is enough to pass Go runtime check setPC(add64(pc, toU64(4))) // no-op this. default: - revertWithCode(0xf001c0de, fmt.Errorf("unknown instruction opcode: %d", opcode)) + revertWithCode(riscv.ErrUnknownOpCode, fmt.Errorf("unknown instruction opcode: %d", opcode)) } return computeStateHash(), nil } diff --git a/rvgo/test/evm_test.go b/rvgo/test/evm_test.go index af9af170..ec464e4a 100644 --- a/rvgo/test/evm_test.go +++ b/rvgo/test/evm_test.go @@ -143,7 +143,7 @@ func addTracer(t *testing.T, env *vm.EVM, addrs *Addresses, contracts *Contracts }, os.Stdout) } -func stepEVM(t *testing.T, env *vm.EVM, wit *fast.StepWitness, addrs *Addresses, step uint64) (postState []byte, postHash common.Hash, gasUsed uint64) { +func stepEVM(t *testing.T, env *vm.EVM, wit *fast.StepWitness, addrs *Addresses, step uint64, revertCode []byte) (postState []byte, postHash common.Hash, gasUsed uint64) { startingGas := uint64(30_000_000) snap := env.StateDB.Snapshot() @@ -158,6 +158,11 @@ func stepEVM(t *testing.T, env *vm.EVM, wit *fast.StepWitness, addrs *Addresses, input := wit.EncodeStepInput(fast.LocalContext{}) ret, leftOverGas, err := env.Call(vm.AccountRef(addrs.Sender), addrs.RISCV, input, startingGas, big.NewInt(0)) + if revertCode != nil { + require.ErrorIs(t, err, vm.ErrExecutionReverted) + require.Equal(t, ret, revertCode) + return + } require.NoError(t, err, "evm must not fail (ret: %x), at step %d", ret, step) gasUsed = startingGas - leftOverGas diff --git a/rvgo/test/syscall_test.go b/rvgo/test/syscall_test.go index 539e9663..39298352 100644 --- a/rvgo/test/syscall_test.go +++ b/rvgo/test/syscall_test.go @@ -30,22 +30,33 @@ func staticOracle(t *testing.T, preimageData []byte) *testOracle { } } -func runEVM(t *testing.T, contracts *Contracts, addrs *Addresses, stepWitness *fast.StepWitness, fastPost fast.StateWitness) { +func runEVM(t *testing.T, contracts *Contracts, addrs *Addresses, stepWitness *fast.StepWitness, fastPost fast.StateWitness, revertCode []byte) { env := newEVMEnv(t, contracts, addrs) - evmPost, _, _ := stepEVM(t, env, stepWitness, addrs, 0) + evmPost, _, _ := stepEVM(t, env, stepWitness, addrs, 0, revertCode) require.Equal(t, hexutil.Bytes(fastPost).String(), hexutil.Bytes(evmPost).String(), "fast VM produced different state than EVM") } -func runSlow(t *testing.T, stepWitness *fast.StepWitness, fastPost fast.StateWitness, po slow.PreimageOracle) { +func runSlow(t *testing.T, stepWitness *fast.StepWitness, fastPost fast.StateWitness, po slow.PreimageOracle, expectedErr interface{}) { slowPostHash, err := slow.Step(stepWitness.EncodeStepInput(fast.LocalContext{}), po) - require.NoError(t, err) - fastPostHash, err := fastPost.StateHash() - require.NoError(t, err) - require.Equal(t, fastPostHash, slowPostHash, "fast VM produced different state than slow VM") + if expectedErr != nil { + require.ErrorAs(t, err, expectedErr) + } else { + require.NoError(t, err) + fastPostHash, err := fastPost.StateHash() + require.NoError(t, err) + require.Equal(t, fastPostHash, slowPostHash, "fast VM produced different state than slow VM") + } + +} + +func errCodeToByte32(errCode uint64) []byte { + return binary.BigEndian.AppendUint64(make([]byte, 24), errCode) } func TestStateSyscallUnsupported(t *testing.T) { + contracts := testContracts(t) + addrs := testAddrs syscalls := []int{ riscv.SysPrlimit64, riscv.SysFutex, @@ -68,11 +79,14 @@ func TestStateSyscallUnsupported(t *testing.T) { state.Memory.SetUnaligned(pc, syscallInsn) fastState := fast.NewInstrumentedState(state, nil, os.Stdout, os.Stderr) - _, err := fastState.Step(true) - var syscallErr *fast.UnsupportedSyscallErr - require.ErrorAs(t, err, &syscallErr) + stepWitness, err := fastState.Step(true) + var fastSyscallErr *fast.UnsupportedSyscallErr + require.ErrorAs(t, err, &fastSyscallErr) + + runEVM(t, contracts, addrs, stepWitness, nil, errCodeToByte32(riscv.ErrInvalidSyscall)) - // TODO: Test EVM & slow VM + var slowSyscallErr *slow.UnsupportedSyscallErr + runSlow(t, stepWitness, nil, nil, &slowSyscallErr) }) } } @@ -114,8 +128,8 @@ func FuzzStateSyscallExit(f *testing.F) { require.Equal(t, preStateRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) } f.Fuzz(func(t *testing.T, exitCode uint8, pc uint64, step uint64) { @@ -162,8 +176,8 @@ func FuzzStateSyscallBrk(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) }) } @@ -216,8 +230,8 @@ func FuzzStateSyscallMmap(f *testing.F) { require.Equal(t, newHeap, state.Heap) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) }) } @@ -258,8 +272,8 @@ func FuzzStateSyscallFcntl(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) } f.Fuzz(func(t *testing.T, fd uint64, cmd uint64, pc uint64, step uint64) { @@ -321,8 +335,8 @@ func FuzzStateSyscallOpenat(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) }) } @@ -369,8 +383,8 @@ func FuzzStateSyscallClockGettime(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) }) } @@ -411,8 +425,8 @@ func FuzzStateSyscallClone(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) }) } @@ -460,8 +474,8 @@ func FuzzStateSyscallGetrlimit(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) } testGetrlimitErr := func(t *testing.T, res, addr, pc, step uint64) { @@ -480,9 +494,14 @@ func FuzzStateSyscallGetrlimit(f *testing.F) { state.Memory.SetUnaligned(pc, syscallInsn) fastState := fast.NewInstrumentedState(state, nil, os.Stdout, os.Stderr) - _, err := fastState.Step(true) - require.Contains(t, err.Error(), "f0012") - // TODO: Test EVM & slow VM + stepWitness, err := fastState.Step(true) + var fastSyscallErr *fast.UnrecognizedResourceErr + require.ErrorAs(t, err, &fastSyscallErr) + + runEVM(t, contracts, addrs, stepWitness, nil, errCodeToByte32(riscv.ErrUnrecognizedResource)) + + var slowSyscallErr *slow.UnrecognizedResourceErr + runSlow(t, stepWitness, nil, nil, &slowSyscallErr) } f.Fuzz(func(t *testing.T, res, addr, pc, step uint64) { @@ -553,8 +572,8 @@ func FuzzStateSyscallNoop(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) } f.Fuzz(func(t *testing.T, arg uint64, pc uint64, step uint64) { @@ -601,8 +620,8 @@ func FuzzStateSyscallRead(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) } f.Fuzz(func(t *testing.T, fd, addr, count, pc, step uint64) { @@ -665,8 +684,8 @@ func FuzzStateHintRead(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, oracle) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, oracle, nil) }) } @@ -732,8 +751,8 @@ func FuzzStatePreimageRead(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, oracle) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, oracle, nil) }) } @@ -774,8 +793,8 @@ func FuzzStateSyscallWrite(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, nil) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, nil, nil) } f.Fuzz(func(t *testing.T, fd, addr, count, pc, step uint64) { @@ -845,8 +864,8 @@ func FuzzStateHintWrite(f *testing.F) { require.Equal(t, expectedRegisters, state.Registers) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, oracle) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, oracle, nil) }) } @@ -917,7 +936,7 @@ func FuzzStatePreimageWrite(f *testing.F) { require.Equal(t, expectedKey, state.PreimageKey) fastPost := state.EncodeWitness() - runEVM(t, contracts, addrs, stepWitness, fastPost) - runSlow(t, stepWitness, fastPost, oracle) + runEVM(t, contracts, addrs, stepWitness, fastPost, nil) + runSlow(t, stepWitness, fastPost, oracle, nil) }) } diff --git a/rvgo/test/vm_go_test.go b/rvgo/test/vm_go_test.go index 0b60a4a8..d06cd6b2 100644 --- a/rvgo/test/vm_go_test.go +++ b/rvgo/test/vm_go_test.go @@ -114,7 +114,7 @@ func fullTest(t *testing.T, vmState *fast.VMState, po *testOracle, symbols fast. } if runEVM { - evmPost, evmPostHash, gasUsed := stepEVM(t, env, wit, addrs, i) + evmPost, evmPostHash, gasUsed := stepEVM(t, env, wit, addrs, i, nil) if gasUsed > maxGasUsed { maxGasUsed = gasUsed } diff --git a/rvgo/test/vm_test.go b/rvgo/test/vm_test.go index 20834dbb..71e9956d 100644 --- a/rvgo/test/vm_test.go +++ b/rvgo/test/vm_test.go @@ -121,7 +121,7 @@ func runEVMTestSuite(t *testing.T, path string) { wit, err := instState.Step(true) require.NoError(t, err) - evmPost, evmPostHash, gasUsed := stepEVM(t, env, wit, addrs, i) + evmPost, evmPostHash, gasUsed := stepEVM(t, env, wit, addrs, i, nil) if gasUsed > maxGasUsed { maxGasUsed = gasUsed }