diff --git a/src/main.etk b/src/main.etk index bec086d..37c9e10 100644 --- a/src/main.etk +++ b/src/main.etk @@ -39,10 +39,9 @@ 0xfffffffffffffffffffffffffffffffffffffffe %end -# get_input loads the first word of input from calldata. %macro get_input() - %push0() # [0] - calldataload # [calldata[0..32]] + push1 4 # [4] + calldataload # [calldata[4:36]] %end # get_timestamp_index calculates the index a timestamp should be stored at. @@ -52,10 +51,13 @@ mod # [timestamp % rootmod] %end -%macro revert() - %push0() # [0] - %push0() # [0, 0] - revert # [] +# revert_if_neq reverts if the top two stack arguments are not equal +# stack: [a, b] (assumed precondition) +%macro revert_if_neq() + eq # [a == b] + iszero # [a != b] + push1 revert # [revert, a != b] + jumpi # [] %end # ----------------------------------------------------------------------------- @@ -64,7 +66,7 @@ # Protect the submit routine by verifying the caller is equal to sysaddr(). caller # [caller] -push20 sysaddr() # [sysaddr, caller] +push20 sysaddr() # [sysaddr, caller] eq # [sysaddr == caller] push1 submit # [submitaddr, sysaddr == caller] jumpi # [] @@ -72,19 +74,18 @@ jumpi # [] # Fallthrough if addresses don't match -- this means the caller intends to read # a root. -# ----------------------------------------------------------------------------- -# READ ROOT ------------------------------------------------------------------- -# ----------------------------------------------------------------------------- - -# Check if calldata is equal to 32 bytes. -push1 32 # [32] -calldatasize # [calldatasize, 32] -eq # [calldatasize == 32] +# Check if calldata is 36 bytes +push1 36 # [36] +calldatasize # [calldatasize, 36] +%revert_if_neq() -# Jump to continue if length-check passed, otherwise revert. -push1 load # [load_addr, calldatasize == 32] -jumpi # [] -%revert() # [] +# Check if calldata has the function signature +push1 0 # [0] +calldataload # [calldataload] +push1 224 # [224, calldataload] +shr # [calldataload >> 224] +push4 selector("get(uint256)") # [function_sig, calldataload >> 224] +%revert_if_neq() # Load stored timestamp. load: @@ -96,11 +97,8 @@ dup1 # [time_index, time_index] sload # [timestamp, time_index] # Verify stored timestamp matches input. -%get_input() # [input, timestamp, time_index] -eq # [input == timestamp, time_index] -iszero # [input != timestamp, time_index] -push1 exit # [exit_addr, got != want, time_index] -jumpi # [time_index] +%get_input() # [got, want, time_index] +%revert_if_neq() # Extend index to get root index. push3 rootmod() # [rootmod, time_index] @@ -142,3 +140,17 @@ stop # [] # ----------------------------------------------------------------------------- # SUBMIT END ------------------------------------------------------------------ # ----------------------------------------------------------------------------- + +########## +# REVERT # +########## + +revert: +jumpdest +%push0() +%push0() +revert + +############## +# REVERT END # +############## \ No newline at end of file diff --git a/test/Contract.t.sol.in b/test/Contract.t.sol.in index 8f62118..96dee48 100644 --- a/test/Contract.t.sol.in +++ b/test/Contract.t.sol.in @@ -33,12 +33,23 @@ contract ContractTest is Test { vm.store(unit, timestamp_idx(), timestamp()); vm.store(unit, root_idx(), root); - (bool ret, bytes memory data) = unit.call(bytes.concat(timestamp())); + (bool ret, bytes memory data) = unit.call(bytes.concat(bytes4(0x9507d39a), bytes32(uint256(block.timestamp)))); assertTrue(ret); assertEq(data, bytes.concat(root)); } + function testReadWrongSig() public { + assertTrue(unit != address(0)); + + vm.store(unit, timestamp_idx(), timestamp()); + vm.store(unit, root_idx(), root); + + (bool ret, /*bytes memory data*/) = unit.call(bytes.concat(bytes4(0x01020304), bytes32(uint256(block.timestamp)))); + + assertFalse(ret); + } + function testUpdate() public { vm.prank(sysaddr); (bool ret, bytes memory data) = unit.call(bytes.concat(root));