diff --git a/Makefile b/Makefile index a450b648d..d90c6de50 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ fmt: ## Format the code lint: ## Check whether the code is formated correctly black . --check - mdformat . --number --check + mdformat specs/ --number --check test: ## Run tests pytest diff --git a/specs/opcode/12SLT_13SGT.md b/specs/opcode/12SLT_13SGT.md new file mode 100644 index 000000000..1521d678e --- /dev/null +++ b/specs/opcode/12SLT_13SGT.md @@ -0,0 +1,67 @@ +# SLT & SGT opcode + +## Procedure + +The `SLT` and `SGT` opcodes compare the top two values on the stack, and push the result (0 or 1) back to the stack. + +The stack inputs `a` and `b` are 256-bits signed integers with the most significant bit being the sign (1 for negative and 0 for positive). The 256-bits are in the two's complement form of representing signed integers. + +#### Circuit Behaviour + +The `SignedComparatorGadget` takes arguments `a: [u8; 32]`, `b: [u8; 32]` and `is_sgt: bool`. + +It returns the result of `a < b` where: + +- `Stack = [b, a]` if `is_sgt == false` +- `Stack = [a, b]` if `is_sgt == true` + +We basically swap the stack inputs if `is_sgt == true`, so that we only need to compare `a < b` in our gadget. + +The gadget is constructed with the following logic: + +```python +# a < 0 and b >= 0 +if a[31] >= 128 and b[31] < 128: + result = 1 +# b < 0 and a >= 0 +elif b[31] >= 128 and a[31] < 128: + result = 0 +# (a < 0 and b < 0) or (a >= 0 and b >= 0) +else: + if a_hi < b_hi: + result = 1 + elif a_hi == b_hi and a_lo < b_lo: + result = 1 + else: + result = 0 +``` + +where: + +- `a[31]` and `b[31]` represent the most significant bytes of `a` and `b` respectively. +- `a[31] >= 128` (same for `b`) signifies that `a` (same for `b`) is a negative number. +- `a_hi = a[16..32]` and `a_lo = a[0..16]` (same for `b`) with `a` (same for `b`) being represented in the little-endian form. + +## Constraints + +- `OpcodeId` check: + - opId === OpcodeId(0x12) for `SLT` + - opId === OpcodeId(0x13) for `SGT` +- State Transition: + - gc -> gc + 3 + - stack pointer -> stack pointer + 1 + - pc -> pc + 1 + - gas -> gas + 3 +- Lookups: + - `a` is at the top of the stack + - `b` is at the second position of the stack + - `result` is the new top of the stack + +## Exceptions + +1. Stack underflow: `1023 <= stack pointer <= 1024` +2. Out of gas: gas left \< 3 + +## Code + +See [`slt_sgt.py`](src/zkevm_specs/evm/execution/slt_sgt.py) diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 429364adc..5163c12c2 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -13,6 +13,7 @@ from .push import * from .block_coinbase import * from .caller import * +from .slt_sgt import * EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { @@ -25,4 +26,5 @@ ExecutionState.JUMP: jump, ExecutionState.JUMPI: jumpi, ExecutionState.PUSH: push, + ExecutionState.SCMP: scmp, } diff --git a/src/zkevm_specs/evm/execution/slt_sgt.py b/src/zkevm_specs/evm/execution/slt_sgt.py new file mode 100644 index 000000000..58f28badc --- /dev/null +++ b/src/zkevm_specs/evm/execution/slt_sgt.py @@ -0,0 +1,51 @@ +from typing import Sequence, Tuple + +from ..instruction import Instruction, Transition +from ..opcode import Opcode + + +def scmp(instruction: Instruction): + opcode = instruction.opcode_lookup(True) + + is_sgt, _ = instruction.pair_select(opcode, Opcode.SGT, Opcode.SLT) + + a = instruction.stack_pop() + b = instruction.stack_pop() + c = instruction.stack_push() + + # swap a and b if the opcode is SGT + aa = b if is_sgt else a + bb = a if is_sgt else b + + # decode RLC to bytes for a and b + a8s = instruction.rlc_to_le_bytes(aa) + b8s = instruction.rlc_to_le_bytes(bb) + c8s = instruction.rlc_to_le_bytes(c) + + a_lo = int.from_bytes(a8s[:16], "little") + a_hi = int.from_bytes(a8s[16:], "little") + b_lo = int.from_bytes(b8s[:16], "little") + b_hi = int.from_bytes(b8s[16:], "little") + cc = int.from_bytes(c8s, "little") + + a_lt_b_lo, a_eq_b_lo = instruction.compare(a_lo, b_lo, 16) + a_lt_b_hi, a_eq_b_hi = instruction.compare(a_hi, b_hi, 16) + + a_lt_b = instruction.select(a_lt_b_hi, 1, instruction.select(a_eq_b_hi * a_lt_b_lo, 1, 0)) + + # a < 0 and b >= 0 => a < b == true + if a8s[31] >= 128 and b8s[31] < 128: + instruction.constrain_equal(cc, 1) + # b < 0 and a >= 0 => a < b == false + elif b8s[31] >= 128 and a8s[31] < 128: + instruction.constrain_equal(cc, 0) + # (a < 0 and b < 0) or (a >= 0 and b >= 0) + else: + instruction.constrain_equal(cc, a_lt_b) + + instruction.step_state_transition_in_same_context( + opcode, + rw_counter=Transition.delta(3), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(1), + ) diff --git a/tests/common.py b/tests/common.py index 6e700b676..5fad7994c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -18,4 +18,6 @@ (65536, 65536), ((1 << 256) - 1, (1 << 256) - 2), ((1 << 256) - 2, (1 << 256) - 1), + ((1 << 256) - 1, 0), + (0, (1 << 256) - 1), ) diff --git a/tests/evm/test_slt_sgt.py b/tests/evm/test_slt_sgt.py new file mode 100644 index 000000000..7d26ccb5c --- /dev/null +++ b/tests/evm/test_slt_sgt.py @@ -0,0 +1,238 @@ +import pytest + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + Opcode, + verify_steps, + Tables, + RWTableTag, + RW, + Block, + Bytecode, +) +from zkevm_specs.util import rand_fp, rand_word, RLC + +RAND_1 = rand_word() + +RAND_2 = rand_word() + +TESTING_DATA = ( + # a >= 0 and b >= 0 + ( + Opcode.SLT, + 0x00, + 0x01, + 0x01, + ), + ( + Opcode.SGT, + 0x00, + 0x01, + 0x00, + ), + ( + Opcode.SLT, + 0x01, + 0x00, + 0x00, + ), + ( + Opcode.SGT, + 0x01, + 0x00, + 0x01, + ), + # a < 0 and b >= 0 + ( + Opcode.SLT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0x00, + 0x01, + ), + ( + Opcode.SGT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0x00, + 0x00, + ), + ( + Opcode.SLT, + 0x00, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0x00, + ), + ( + Opcode.SGT, + 0x00, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0x01, + ), + # a < 0 and b < 0 + ( + Opcode.SLT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0x01, + ), + ( + Opcode.SGT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0x00, + ), + ( + Opcode.SLT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE, + 0x00, + ), + ( + Opcode.SGT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE, + 0x01, + ), + # a_hi == b_hi and a_lo < b_lo and a < 0 and b < 0 + ( + Opcode.SLT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF11111111111111111111111111111111, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF22222222222222222222222222222222, + 0x01, + ), + ( + Opcode.SGT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF11111111111111111111111111111111, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF22222222222222222222222222222222, + 0x00, + ), + ( + Opcode.SLT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF22222222222222222222222222222222, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF11111111111111111111111111111111, + 0x00, + ), + ( + Opcode.SGT, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF22222222222222222222222222222222, + 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF11111111111111111111111111111111, + 0x01, + ), + # a_hi == b_hi and a_lo < b_lo and a >= 0 and b >= 0 + ( + Opcode.SLT, + 0x1111111111111111111111111111111144444444444444444444444444444443, + 0x1111111111111111111111111111111144444444444444444444444444444444, + 0x01, + ), + ( + Opcode.SGT, + 0x1111111111111111111111111111111144444444444444444444444444444443, + 0x1111111111111111111111111111111144444444444444444444444444444444, + 0x00, + ), + ( + Opcode.SLT, + 0x1111111111111111111111111111111144444444444444444444444444444444, + 0x1111111111111111111111111111111144444444444444444444444444444443, + 0x00, + ), + ( + Opcode.SGT, + 0x1111111111111111111111111111111144444444444444444444444444444444, + 0x1111111111111111111111111111111144444444444444444444444444444443, + 0x01, + ), + # both equal + ( + Opcode.SLT, + RAND_1, + RAND_1, + 0x00, + ), + ( + Opcode.SGT, + RAND_2, + RAND_2, + 0x00, + ), + # more cases where contiguous bytes are different + ( + Opcode.SLT, + 0x1234567812345678123456781234567812345678123456781234567812345678, + 0x2345678123456781234567812345678123456781234567812345678123456781, + 0x01, + ), + ( + Opcode.SGT, + 0x1234567812345678123456781234567812345678123456781234567812345678, + 0x2345678123456781234567812345678123456781234567812345678123456781, + 0x00, + ), + ( + Opcode.SLT, + 0x2345678123456781234567812345678123456781234567812345678123456781, + 0x1234567812345678123456781234567812345678123456781234567812345678, + 0x00, + ), + ( + Opcode.SGT, + 0x2345678123456781234567812345678123456781234567812345678123456781, + 0x1234567812345678123456781234567812345678123456781234567812345678, + 0x01, + ), +) + + +@pytest.mark.parametrize("opcode, a, b, res", TESTING_DATA) +def test_slt_sgt(opcode: Opcode, a: int, b: int, res: int): + randomness = rand_fp() + + a = RLC(a, randomness) + b = RLC(b, randomness) + res = RLC(res, randomness) + + bytecode = Bytecode().slt(a, b) if opcode == Opcode.SLT else Bytecode().sgt(a, b) + bytecode_hash = RLC(bytecode.hash(), randomness) + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(), + bytecode_table=set(bytecode.table_assignments(randomness)), + rw_table=set( + [ + (9, RW.Read, RWTableTag.Stack, 1, 1022, a, 0, 0), + (10, RW.Read, RWTableTag.Stack, 1, 1023, b, 0, 0), + (11, RW.Write, RWTableTag.Stack, 1, 1023, res, 0, 0), + ] + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.SCMP, + rw_counter=9, + call_id=1, + is_root=True, + is_create=False, + code_source=bytecode_hash, + program_counter=66, + stack_pointer=1022, + gas_left=3, + ), + StepState( + execution_state=ExecutionState.STOP, + rw_counter=12, + call_id=1, + is_root=True, + is_create=False, + code_source=bytecode_hash, + program_counter=67, + stack_pointer=1023, + gas_left=0, + ), + ], + )