diff --git a/specs/opcode/02MUL_04DIV_06MOD.md b/specs/opcode/02MUL_04DIV_06MOD_1bSHL_1cSHR.md similarity index 54% rename from specs/opcode/02MUL_04DIV_06MOD.md rename to specs/opcode/02MUL_04DIV_06MOD_1bSHL_1cSHR.md index 5f119c363..aa9e45afd 100644 --- a/specs/opcode/02MUL_04DIV_06MOD.md +++ b/specs/opcode/02MUL_04DIV_06MOD_1bSHL_1cSHR.md @@ -1,4 +1,4 @@ -# MUL, DIV, and MOD opcodes +# MUL, DIV, MOD, SHL and SHR opcodes ## Procedure @@ -9,10 +9,12 @@ Pop two EVM words `a` and `b` from the stack, and push `c` to the stack, where ` - for opcode `MUL`, compute `c = (a * b) % 2^256` - for opcode `DIV`, compute `c = a // b` when `b != 0` otherwise `c = 0` - for opcode `MOD`, compute `c = a mod b` when `b != 0` otherwise `c = 0` +- for opcode `SHL`, `b` is a number of bits to shift to the left, compute `c = (a * 2^b) % 2^256` when `b < 256` otherwise `c = 0` +- for opcode `SHR`, `b` is a number of bits to shift to the right, compute `c = a // 2^b` when `b < 256` otherwise `c = 0` ### Circuit behavior -To prove the `MUL/DIV/MOD` opcode, we first construct a `MulAddWordsGadget` that proves `a * b + c = d (mod 2^256)` where `a, b, c, d` are all 256-bit words. +To prove the `MUL/DIV/MOD/SHL/SHR` opcode, we first construct a `MulAddWordsGadget` that proves `quotient * divisor + remainder = dividend (mod 2^256)` where `quotient, divisor, remainder, dividend` are all 256-bit words. Consider `quotient, divisor, remander, dividend` as `a, b, c, d` in `MulAddWordsGaget`. As usual, we use 32 cells to represent each word shown as the table below, where each cell holds a 8-bit value. @@ -60,20 +62,20 @@ $$ overflow = carry_{hi} + A_1B_3 + A_2B_2 + A_3B_1 + A_2B_3 + A_3B_2 + A_3B_3 $$ -Now back to the opcode circuit for `MUL`, `DIV`, and `MOD`, we first construct -the `MulAddWordsGadget` with four EVM words `a, b, c, d`. +Now back to the opcode circuit for `MUL`, `DIV`, `MOD`, `SHL` and `SHR`, we first construct the `MulAddWordsGadget` with four EVM words `quotient, divisor, remainder, dividend`. Based on different opcode cases, we constrain the stack pops and pushes as follows -- for `MUL`, two stack pops are `a` and `b`, and the stack push is `d` -- for `DIV`, two stack pops are `d` and `b`, and the stack push is `a` if `b != 0`; otherwise 0. -- for `MOD`, two stack pops are `d` and `b`, and the stack push is `c` if `b != 0`; otherwise 0. +- for `MUL`, two stack pops are `quotient` and `divisor`, and the stack push is `dividend`. +- for `DIV`, two stack pops are `dividend` and `divisor`, and the stack push is `quotient` if `divisor != 0` and 0 otherwise. +- for `MOD`, two stack pops are `dividend` and `divisor`, and the stack push is `remainder` if `divisor != 0` and 0 otherwise. +- for `SHL`, two stack pops are `quotient` and `shift` when `divisor = 2^shift` if `shift < 256` and 0 otherwise. The stack push is `dividend` if `shift < 256` and 0 otherwise. +- for `SHR`, two stack pops are `dividend` and `shift` when `divisor = 2^shift` if `shift < 256` and 0 otherwise. The stack push is `quotient` if `shift < 256` and 0 otherwise. The opcode circuit also adds extra constraints for different opcodes: -- if the opcode is `MUL`, constrain `c == 0`. -- if the opcode is not `MUL`, - - use a `LtWordGadget` to constrain `c < b` when `b != 0` - - constrain `overflow == 0` +- use a `LtWordGadget` to constrain `remainder < divisor` when `divisor != 0`. +- if the opcode is `MUL` or `SHL`, constrain `remainder == 0`. +- if the opcode is `DIV`, `MOD` or `SHR`, constrain `overflow == 0`. ## Constraints @@ -81,21 +83,26 @@ The opcode circuit also adds extra constraints for different opcodes: 1. opId === OpcodeId(0x02) for `MUL` 2. opId === OpcodeId(0x04) for `DIV` 3. opId === OpcodeId(0x06) for `MOD` + 3. opId === OpcodeId(0x1b) for `SHL` + 3. opId === OpcodeId(0x1c) for `SHR` 2. state transition: - gc + 3 - stack_pointer + 1 - pc + 1 - - gas + 5 + - gas + - when opcode is `MUL`, `DIV` or `MOD`, gas + 5. + - when opcode is `SHL` or `SHR`, gas + 3. 3. Lookups: 3 busmapping lookups - - top of the stack : - - when it's `MUL`, `a` is at the top of the stack - - when it's `DIV`, `d` is at the top of the stack. - - when it's `MOD`, `d` is at the top of the stack. - - `b` is at the second position of the stack + - top of the stack + - when opcode is `MUL` or `SHL`, `quotient` is at the top of the stack. + - when opcode is `DIV`, `MOD` or `SHR`, `dividend` is at the top of the stack. + - second position of the stack + - when opcode is `MUL`, `DIV` or `MOD`, `divisor` is at the second position of the stack. + - when opcode is `SHL` or `SHR`, `shift` is at the second position of the stack when `divisor = 2^shift`. - new top of the stack - - when it's `MUL`, `d` is at the new top of the stack - - when it's `DIV`, `a` is at the new top of the stack when `b != 0`, otherwise 0 - - when it's `MOD`, `c` is at the new top of the stack when `b != 0`, otherwise 0 + - when opcode is `MUL` or `SHL`, `dividend` is at the new top of the stack. + - when opcode is `DIV` or `SHR`, `quotient` is at the new top of the stack if `divisor != 0` otherwise 0. + - when opcode is `MOD`, `remainder` is at the new top of the stack if `divisor != 0`, otherwise 0. ## Exceptions @@ -104,4 +111,4 @@ The opcode circuit also adds extra constraints for different opcodes: ## Code -See `src/zkevm_specs/evm/execution/mul_div_mod.py` +See `src/zkevm_specs/evm/execution/mul_div_mod_shl_shr.py` diff --git a/specs/opcode/09MULMOD.md b/specs/opcode/09MULMOD.md new file mode 100644 index 000000000..e0e0d0f9e --- /dev/null +++ b/specs/opcode/09MULMOD.md @@ -0,0 +1,83 @@ +# MULMOD opcode + +## Procedure + +### EVM behavior + + +Pop 3 EVM words `a`, `b` and `N` from the stack. + +If `N` is 0: + push 0 into the stack. +else: + compute `r= (a * b) mod N` and push `r` into the stack. + +*Note* +All intermediate calculations of this operation are not subject to the 2^256 modulo. + +### Circuit behavior + +The MulModGadget takes arguments: + - `a: [u8;32]` + - `b: [u8;32]`, + - `r: [u8;32]`, + - `N: [u8;32]`, +and keeps 5 words for storing: + - `a_reduced: [u8;32]` , + - `k: [u8;32]`, + - `e: [u8;32]`, + - `d: [u8;32]` + - `zero: [u8;32]`. + + + Witness `a_reduced ← a` if `n!=0` else `a_reduced ← 0` + Witness `(e, d) ← ( (a_reduced * b) % 2^256, (a_reduced * b ) // 2^256)` + Witness `(r, k) ← ( (a_reduced * b) % N, (a_reduced * b ) // N)` + + 1. Check `a_reduced = a mod N`. + which uses `ModGadget` that in turn checks: + - Check the equality ` j * N + a_reduced == a ` + - Check (`a_reduced = 0` and `N == 0`) or `a_reduced < N` + + 2. Check `r = a * b mod N` + which uses 2 `MulAddWords512Gadget` to check: + ` a_reduced * b = k * N + r` in 2 steps + - `a_reduced * b + zero == d * 2^256 + e` + - `k * N + r == d * 2^256 + e` + + 1 `IsZeroGadget` and 1 `LtWordsGadget` that check: + `(r == 0 and N == 0)` or `(r < N)` + + +#### Note + +The first step on the computation, reducing `a` mod `N`, is taken in order +to prevent overflow in the factor `k`, which could happen for high values +of `a` and `b` and low `N`. This steps ensures: + +$$ +k \leq \frac{(a \text{ mod } n) \cdot b }{n } \leq \frac{ (n-1) \cdot b}{n} < b \leq MAXU256 +$$ + +## Constraints + +1. opcodeID checks + opId == OpcodeId(0x09) +2. state transition: + - gc + 4 + - stack_pointer +2 + - pc + 1 + - gas + 8 +3. Lookups: 4 busmapping lookups + - `a` is on top of the stack. + - `b` is in the second position of the stack. + - `N` is in the third position of the stack. + - `r`, the result is on top of the new stack. + + +## Exceptions + +1. stack undeflow: `1022 <= stack_pointer <= 1024`. +2. out of gas: Remaining gas is not enough. + +See `src/zkevm_specs/opcode/mulmod.py` diff --git a/specs/opcode/1cSHR.md b/specs/opcode/1cSHR.md deleted file mode 100644 index 6faad0496..000000000 --- a/specs/opcode/1cSHR.md +++ /dev/null @@ -1,134 +0,0 @@ -# SHR opcode - -## Procedure - -The `SHR` opcode shifts the bits towards the least significant one. The bits moved before the first one are discarded, the new bits are set to 0. - -### EVM behavior - -Pop two EVM words `a` and `shift` from the stack, and push `b` to the stack, where `b` is computed as: - -1. If `shift >= 256`,then `b` is set to zero. -2. If `shift < 256`,compute `b = a >> shift`. - -### Circuit behavior - -To prove the `SHR` opcode, we first construct a `ShrGadget` that proves `a >> shift = b` where `a, b, shift` are all 256-bit words. -As usual, we use 32 cells to represent word `a` and `b`, where each cell holds a 8-bit value. Then split each word into four 64-bit limbs denoted by `a64s[idx]` and `b64s[idx]` where idx in `(0, 1, 2, 3)`. -We put the lower `n` bits of a limb into the `lo` array, and put the higher `64 - n` bits into the `hi` array, where `n` is `shift % 64`. During the SHR operation, the `lo` array will move to higher bits of the result, and the `hi` array will move to lower bits of the result. - -The following figure illustrates how shift right works under the case of `shift < 64`. - -``` -+-------------------------------+-------------------------------+----- -|a0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| ... -+-------------------------------+-------------------------------+----- -| a64s[0] | a64s[1] | ... -+------------+------------------+------------+------------------+----- -| a64s_lo[0] | a64s_hi[0] | a64s_lo[1] | a64s_hi[1] | ... -+------------+------------------+------------+------------------+----- - | b64s[0] | b64s[1] - +-------------------------------+------------------------ -``` - -More formally, the variables are defined as follows: - -``` -shf0 = bytes_to_fq(shift.le_bytes[:1]) -shf_div64 = shift // 64 -shf_mod64 = shift % 64 -shf_lt256 = is_zero(sum(shift[1:])) -p_lo = 1 << shf_mod64 -p_hi = 1 << (64 - shf_mod64) -a64s = word_to_64s(a) -a64s_lo[idx] = a64s[idx] % p_lo -a64s_hi[idx] = a64s[idx] / p_lo -``` - -If `shift >= 256`, `b64s` are all 0. Otherwise, `b64s` can be calculated by `a >> shf0` then split into four 64-bit limbs. - -Now putting things together, the constraints can be constructed as follows: - -1. `a64s` and `b64s` constraints: - -* First calculate `shf_lt256` as `is_zero(sum(shift[1:]))`. -* `a64s[idx]`: It should be equal to `from_bytes(a[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. -* `b64s[idx] * shf_lt256`: It should be equal to `from_bytes(b[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. - -2. `a64s_lo` and `a64s_hi` constraints: - -* `a64s[idx]`: It should be equal to `a64s_lo[idx] + a64s_hi[idx] * p_lo`. -* `a64s_lo[idx]`: It should always be less than `p_lo` (`a64s_lo[idx] < p_lo`). - -3. Merge constraints: - -* First create three `IsZero` gadgets: -``` -shf_div64_eq0 = is_zero(shf_div64) -shf_div64_eq1 = is_zero(shf_div64 - 1) -shf_div64_eq2 = is_zero(shf_div64 - 2) -``` - -* `b64s[0]` should be equal to: -``` -(a64s_hi[0] + a64s_lo[1] * p_hi) * shf_div64_eq0 + - (a64s_hi[1] + a64s_lo[2] * p_hi) * shf_div64_eq1 + - (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq2 + - a64s_hi[3] * (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) -``` - -* `b64s[1]` should be equal to: -``` -(a64s_hi[1] + a64s_lo[2] * p_hi) * shf_div64_eq0 + - (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq1 + - a64s_hi[3] * shf_div64_eq2 -``` - -* `b64s[2]` should be equal to: -``` -(a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq0 + - a64s_hi[3] * shf_div64_eq1 -``` - -* `b64s[3]` should be equal to: -``` -a64s_hi[3] * shf_div64_eq0 -``` - -4. `shift[0]` constraint: - -* `shift[0]`: It should be equal to `shf_mod64 + shf_div64 * 64`. - -5. `Pow2` table look up: - -* First build `Pow2` table by tuple $[value, value\_pow]$ which meets $${value\_pow == 2^{value}}$$ - -* Look up for `(shf_mod64, p_lo)` and `(64 - shf_mod64, p_hi)` - -6. Stack pop and push: - -* Pop word `a` -* Pop word `shift` -* Push word `shift_lt256 * b` - -## Constraints - -1. opId = OpcodeId(0x1c) -2. state transition: - - gc + 3 (2 stack reads + 1 stack write) - - stack\_pointer + 1 - - pc + 1 - - gas + 3 -3. lookups: 3 busmapping lookups - - `a` is at the top of the stack - - `shift` is at the second position of the stack - - `b`, the result, is at the new top of the stack - -## Exceptions - -1. stack underflow: `1023 <= stack_pointer <= 1024` -2. out of gas: remaining gas is not enough - -## Code - -See `src/zkevm_specs/evm/execution/shr.py` diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 8039bd3ca..f19612c79 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -12,6 +12,7 @@ # Opcode's successful cases from .add_sub import * from .addmod import * +from .mulmod import * from .block_ctx import * from .call import * from .calldatasize import * @@ -25,7 +26,6 @@ from .iszero import * from .jump import * from .jumpi import * -from .mul_div_mod import * from .origin import * from .push import * from .slt_sgt import * @@ -35,7 +35,7 @@ from .selfbalance import * from .extcodehash import * from .log import * -from .shr import * +from .mul_div_mod_shl_shr import mul_div_mod_shl_shr EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { @@ -46,7 +46,7 @@ ExecutionState.CopyToMemory: copy_to_memory, ExecutionState.ADD: add_sub, ExecutionState.ADDMOD: addmod, - ExecutionState.MUL: mul_div_mod, + ExecutionState.MULMOD: mulmod, ExecutionState.ORIGIN: origin, ExecutionState.CALLER: caller, ExecutionState.CALLVALUE: callvalue, @@ -70,5 +70,5 @@ ExecutionState.LOG: log, ExecutionState.CALL: call, ExecutionState.ISZERO: iszero, - ExecutionState.SHR: shr, + ExecutionState.MUL_DIV_MOD_SHL_SHR: mul_div_mod_shl_shr, } diff --git a/src/zkevm_specs/evm/execution/mul_div_mod.py b/src/zkevm_specs/evm/execution/mul_div_mod.py deleted file mode 100644 index 7f0abee7e..000000000 --- a/src/zkevm_specs/evm/execution/mul_div_mod.py +++ /dev/null @@ -1,71 +0,0 @@ -from ..instruction import Instruction, Transition -from ..opcode import Opcode -from ...util import FQ - - -def mul_div_mod(instruction: Instruction): - opcode = instruction.opcode_lookup(True) - - # The opcode value for MUL, DIV and MOD is 2, 4, 6. When the opcode is MUL, - # (Opcode.DIV - opcode) * (Opcode.MOD - opcode) is 8. To make `is_mul` be - # either 0 or 1, we need to divide the product by 8, which is equivalent to - # multiply it by inversion of 8. Similarly, we also need to multiply the - # inversion of 4 and 8 for `is_div` and `is_mod` respectively. - is_mul = (Opcode.DIV - opcode) * (Opcode.MOD - opcode) * FQ(8).inv() - is_div = (opcode - Opcode.MUL) * (Opcode.MOD - opcode) * FQ(4).inv() - is_mod = (opcode - Opcode.MUL) * (opcode - Opcode.DIV) * FQ(8).inv() - - pop1 = instruction.stack_pop() - pop2 = instruction.stack_pop() - push = instruction.stack_push() - - # this part corresponds to witness assignment in the zkevm circuit - if is_mul == 1: - a = pop1 - b = pop2 - c = instruction.rlc_encode(0, 32) - d = push - elif is_div == 1: - d = pop1 # dividend - b = pop2 # divisor - a = push # quotient - c = instruction.rlc_encode(d.int_value - b.int_value * a.int_value, 32) # remainder - else: # is_mod == 1 - d = pop1 # dividend - b = pop2 # divisor - if b.int_value == 0: - c = d - a = instruction.rlc_encode(0, 32) - else: - c = push - a = instruction.rlc_encode((d.int_value - c.int_value) // b.int_value, 32) - - divisor_is_zero = instruction.word_is_zero(b) - overflow = instruction.mul_add_words(a, b, c, d) - - # constrain the push and pop values - instruction.constrain_equal(pop1, instruction.select(is_mul, a, d)) - instruction.constrain_equal(pop2, b) - instruction.constrain_equal( - push, - is_mul * d.expr() - + is_div * a.expr() * (1 - divisor_is_zero) - + is_mod * c.expr() * (1 - divisor_is_zero), - ) - - # constrain c == 0 for MUL - instruction.constrain_zero(is_mul * instruction.sum(c.le_bytes)) - - # constrain remainder < divisor when divisor != 0 for DIV and MOD - lt, _ = instruction.compare_word(c, b) - instruction.constrain_zero((1 - is_mul) * (1 - divisor_is_zero) * (1 - lt)) - - # constrain overflow == 0 for DIV and MOD - instruction.constrain_zero((1 - is_mul) * overflow) - - instruction.step_state_transition_in_same_context( - opcode, - rw_counter=Transition.delta(3), - program_counter=Transition.delta(1), - stack_pointer=Transition.delta(1), - ) diff --git a/src/zkevm_specs/evm/execution/mul_div_mod_shl_shr.py b/src/zkevm_specs/evm/execution/mul_div_mod_shl_shr.py new file mode 100644 index 000000000..f99ce9120 --- /dev/null +++ b/src/zkevm_specs/evm/execution/mul_div_mod_shl_shr.py @@ -0,0 +1,232 @@ +from ...util import FQ, RLC +from ..instruction import Instruction, Transition +from ..opcode import Opcode +from typing import Tuple + + +def mul_div_mod_shl_shr(instruction: Instruction): + opcode = instruction.opcode_lookup(True) + + pop1 = instruction.stack_pop() + pop2 = instruction.stack_pop() + push = instruction.stack_push() + + ( + is_mul, + is_div, + is_mod, + is_shl, + is_shr, + shf0, + dividend, + divisor, + quotient, + remainder, + shift, + ) = gen_witness(opcode, pop1, pop2, push) + check_witness( + instruction, + is_mul, + is_div, + is_mod, + is_shl, + is_shr, + shf0, + dividend, + divisor, + quotient, + remainder, + shift, + pop1, + pop2, + push, + ) + + instruction.step_state_transition_in_same_context( + opcode, + rw_counter=Transition.delta(3), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(1), + ) + + +def check_witness( + instruction: Instruction, + is_mul: FQ, + is_div: FQ, + is_mod: FQ, + is_shl: FQ, + is_shr: FQ, + shf0: FQ, + dividend: RLC, + divisor: RLC, + quotient: RLC, + remainder: RLC, + shift: RLC, + pop1: RLC, + pop2: RLC, + push: RLC, +): + divisor_is_zero = instruction.word_is_zero(divisor) + + # Based on different opcode cases, constrain stack pops and pushes as: + # - for `MUL`, two pops are quotient and divisor, and push is dividend. + # - for `DIV`, two pops are dividend and divisor, and push is quotient. + # - for `MOD`, two pops are dividend and divisor, and push is remainder. + # - for `SHL`, two pops are shift and quotient, and push is dividend. + # - for `SHR`, two pops are shift and dividend, and push is quotient. + instruction.constrain_equal( + pop1.expr(), + is_mul * quotient.expr() + + (is_div + is_mod) * dividend.expr() + + (is_shl + is_shr) * shift.expr(), + ) + instruction.constrain_equal( + pop2.expr(), + (is_mul + is_div + is_mod) * divisor.expr() + + is_shl * quotient.expr() + + is_shr * dividend.expr(), + ) + instruction.constrain_equal( + push.expr(), + (is_mul + is_shl) * dividend.expr() + + (is_div + is_shr) * quotient.expr() * (1 - divisor_is_zero) + + is_mod * remainder.expr() * (1 - divisor_is_zero), + ) + + # Constrain remainder < divisor when divisor != 0. + divisor_is_zero = instruction.word_is_zero(divisor) + remainder_lt_divisor, _ = instruction.compare_word(remainder, divisor) + instruction.constrain_zero((1 - divisor_is_zero) * (1 - remainder_lt_divisor)) + + # Constrain remainder == 0 for both MUL and SHL. + remainder_is_zero = instruction.word_is_zero(remainder) + instruction.constrain_zero((is_mul + is_shl) * (1 - remainder_is_zero)) + + # Constrain overflow == 0 for DIV, MOD and SHR. + overflow = instruction.mul_add_words(quotient, divisor, remainder, dividend) + instruction.constrain_zero((is_div + is_mod + is_shr) * overflow) + + # Constrain pop1 == pop1.cells[0] when divisor != 0 for opcode SHL and SHR. + instruction.constrain_zero( + (is_shl + is_shr) * (1 - divisor_is_zero) * (pop1.expr() - pop1.le_bytes[0]), + ) + + # For opcode SHL and SHR, constrain `divisor_lo == 2^shf0` when + # `shf0 < 128`, and `divisor_hi == 2^(128 - shf0)` otherwise. + divisor_lo = instruction.bytes_to_fq(divisor.le_bytes[:16]) + divisor_hi = instruction.bytes_to_fq(divisor.le_bytes[16:]) + if (is_shl + is_shr) * (1 - divisor_is_zero) == 1: + instruction.pow2_lookup(shf0, divisor_lo, divisor_hi) + + +def gen_witness(opcode: FQ, pop1: RLC, pop2: RLC, push: RLC): + is_mul = is_op_mul(opcode) + is_div = is_op_div(opcode) + is_mod = is_op_mod(opcode) + is_shl = is_op_shl(opcode) + is_shr = is_op_shr(opcode) + + # Get the first byte of shift value only for opcode SHL and SHR. + shf0 = pop1.le_bytes[0] + + if is_mul.n == 1: + quotient = pop1 + divisor = pop2 + remainder = RLC(0) + dividend = push + shift = RLC(0) + elif is_div.n == 1: + quotient = push + divisor = pop2 + remainder = RLC(pop1.int_value - push.int_value * pop2.int_value) + dividend = pop1 + shift = RLC(0) + elif is_mod.n == 1: + quotient = RLC(0) if pop2.int_value == 0 else RLC(pop1.int_value // pop2.int_value) + divisor = pop2 + remainder = pop1 if pop2.int_value == 0 else push + dividend = pop1 + shift = RLC(0) + elif is_shl.n == 1: + divisor = RLC(1 << shf0) if shf0 == pop1.int_value else RLC(0) + quotient = pop2 + remainder = RLC(0) + dividend = push + shift = pop1 + else: # SHR + divisor = RLC(1 << shf0) if shf0 == pop1.int_value else RLC(0) + quotient = push + remainder = RLC(pop2.int_value - push.int_value * divisor.int_value) + dividend = pop2 + shift = pop1 + + return ( + is_mul, + is_div, + is_mod, + is_shl, + is_shr, + shf0, + dividend, + divisor, + quotient, + remainder, + shift, + ) + + +# The opcode value for MUL, DIV, MOD, SHL and SHR are 2, 4, 6, 0x1b and 0x1c. +# When the opcode is MUL, the result of below formula is 5200: +# (DIV - opcode) * (MOD- opcode) * (SHL - opcode) * (SHR - opcode) +# To make `is_mul` be either 0 or 1, the result needs to be divided by 5200, +# which is equivalent to multiply it by inversion of 5200. +# And calculate `is_div`, `is_mod`, `is_shl` and `is_shr` respectively. +def is_op_mul(opcode: FQ) -> FQ: + return ( + (Opcode.DIV - opcode) + * (Opcode.MOD - opcode) + * (Opcode.SHL - opcode) + * (Opcode.SHR - opcode) + * FQ(5200).inv() + ) + + +def is_op_div(opcode: FQ) -> FQ: + return ( + (opcode - Opcode.MUL) + * (Opcode.MOD - opcode) + * (Opcode.SHL - opcode) + * (Opcode.SHR - opcode) + * FQ(2208).inv() + ) + + +def is_op_mod(opcode: FQ) -> FQ: + return ( + (opcode - Opcode.MUL) + * (opcode - Opcode.DIV) + * (Opcode.SHL - opcode) + * (Opcode.SHR - opcode) + * FQ(3696).inv() + ) + + +def is_op_shl(opcode: FQ) -> FQ: + return ( + (opcode - Opcode.MUL) + * (opcode - Opcode.DIV) + * (opcode - Opcode.MOD) + * (Opcode.SHR - opcode) + * FQ(12075).inv() + ) + + +def is_op_shr(opcode: FQ) -> FQ: + return ( + (opcode - Opcode.MUL) + * (opcode - Opcode.DIV) + * (opcode - Opcode.MOD) + * (opcode - Opcode.SHL) + * FQ(13728).inv() + ) diff --git a/src/zkevm_specs/evm/execution/mulmod.py b/src/zkevm_specs/evm/execution/mulmod.py new file mode 100644 index 000000000..24630c2db --- /dev/null +++ b/src/zkevm_specs/evm/execution/mulmod.py @@ -0,0 +1,73 @@ +from ..instruction import Instruction, Transition +from ..opcode import Opcode +from zkevm_specs.util import FQ, RLC + + +def mod(instruction: Instruction, a: RLC, n: RLC, r: RLC): + """ + The function constraints r = a mod n, where a, n, r a re 256-bit words. + This in turn constraints: + - k * n + r = a if n != 0 + - r = 0 if n == 0 + """ + if n.int_value == 0: + a_or_zero = RLC(0) + k = 0 + else: + a_or_zero = a + k = a.int_value // n.int_value + + instruction.mul_add_words(RLC(k), n, r, a_or_zero) + eq = instruction.is_equal(a, a_or_zero) + cmp = instruction.compare_word(r, n) + n_is_zero = instruction.is_zero(n) + a_or_is_zero = instruction.is_zero(a_or_zero) + # a_or_zero = a if n!=0 else a_or_zero = 0 + instruction.constrain_zero((FQ(1) - eq) * (FQ(1) - n_is_zero * a_or_is_zero)) + # r> shf0.n - b64s = [FQ((bb >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)] - - return ( - a64s, - b64s, - a64s_lo, - a64s_hi, - shf_div64, - shf_mod64, - p_lo, - p_hi, - ) diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index 37ecea387..79ce2044a 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -26,7 +26,7 @@ class ExecutionState(IntEnum): # Opcode's successful cases STOP = auto() ADD = auto() # ADD, SUB - MUL = auto() # MUL, DIV, MOD + MUL_DIV_MOD_SHL_SHR = auto() # MUL, DIV, MOD, SHL, SHR SDIV = auto() SMOD = auto() ADDMOD = auto() @@ -39,8 +39,6 @@ class ExecutionState(IntEnum): BITWISE = auto() # AND, OR, XOR NOT = auto() BYTE = auto() - SHL = auto() - SHR = auto() SAR = auto() SHA3 = auto() ADDRESS = auto() @@ -146,8 +144,8 @@ def responsible_opcode(self) -> Union[Sequence[int], Sequence[Tuple[int, int]]]: Opcode.ADD, Opcode.SUB, ] - elif self == ExecutionState.MUL: - return [Opcode.MUL, Opcode.DIV, Opcode.MOD] + elif self == ExecutionState.MUL_DIV_MOD_SHL_SHR: + return [Opcode.MUL, Opcode.DIV, Opcode.MOD, Opcode.SHL, Opcode.SHR] elif self == ExecutionState.SDIV: return [Opcode.SDIV] elif self == ExecutionState.SMOD: @@ -183,10 +181,6 @@ def responsible_opcode(self) -> Union[Sequence[int], Sequence[Tuple[int, int]]]: return [Opcode.NOT] elif self == ExecutionState.BYTE: return [Opcode.BYTE] - elif self == ExecutionState.SHL: - return [Opcode.SHL] - elif self == ExecutionState.SHR: - return [Opcode.SHR] elif self == ExecutionState.SAR: return [Opcode.SAR] elif self == ExecutionState.SHA3: diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 75b62d839..f13d87ef3 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -321,39 +321,6 @@ def rlc_to_fq(self, word: RLC, n_bytes: int) -> FQ: raise ConstraintUnsatFailure(f"Word {word} has too many bytes to fit {n_bytes} bytes") return self.bytes_to_fq(word.le_bytes[:n_bytes]) - def mul_add_words_512(self, a: RLC, b: RLC, c: RLC, d: RLC, e: RLC): - """ - The function constrains a * b + c == d * 2**256 + e, where a, b, c, d are 256-bit words. - """ - a64s = self.word_to_64s(a) - b64s = self.word_to_64s(b) - c_lo, c_hi = self.word_to_lo_hi(c) - d_lo, d_hi = self.word_to_lo_hi(d) - e_lo, e_hi = self.word_to_lo_hi(e) - - t0 = a64s[0] * b64s[0] - t1 = a64s[0] * b64s[1] + a64s[1] * b64s[0] - t2 = a64s[0] * b64s[2] + a64s[1] * b64s[1] + a64s[2] * b64s[0] - t3 = a64s[0] * b64s[3] + a64s[1] * b64s[2] + a64s[2] * b64s[1] + a64s[3] * b64s[0] - - t4 = a64s[1] * b64s[3] + a64s[2] * b64s[2] + a64s[3] * b64s[1] - t5 = a64s[2] * b64s[3] + a64s[3] * b64s[2] - t6 = a64s[3] * b64s[3] - - carry_0 = (t0 + t1 * (2**64) + c_lo - e_lo) / (2**128) - carry_1 = (t2 + t3 * (2**64) + c_hi + carry_0 - e_hi) / (2**128) - carry_2 = (t4 + t5 * (2**64) + carry_1 - d_lo) / (2**128) - - # range check for carries - self.range_check(carry_0, 9) - self.range_check(carry_1, 9) - self.range_check(carry_2, 9) - - self.constrain_equal(t0 + t1 * (2**64) + c_lo, e_lo + carry_0 * (2**128)) - self.constrain_equal(t2 + t3 * (2**64) + c_hi + carry_0, e_hi + carry_1 * (2**128)) - self.constrain_equal(t4 + t5 * (2**64) + carry_1, d_lo + carry_2 * (2**128)) - self.constrain_equal(t6 + carry_2, d_hi) - def word_is_zero(self, word: RLC) -> FQ: assert len(word.le_bytes) == 32, "Expected word to contain 32 bytes" return self.is_zero(self.sum(word.le_bytes)) @@ -471,6 +438,39 @@ def mul_add_words(self, a: RLC, b: RLC, c: RLC, d: RLC) -> FQ: return overflow + def mul_add_words_512(self, a: RLC, b: RLC, c: RLC, d: RLC, e: RLC): + """ + The function constrains a * b + c == d * 2**256 + e, where a, b, c, d are 256-bit words. + """ + a64s = self.word_to_64s(a) + b64s = self.word_to_64s(b) + c_lo, c_hi = self.word_to_lo_hi(c) + d_lo, d_hi = self.word_to_lo_hi(d) + e_lo, e_hi = self.word_to_lo_hi(e) + + t0 = a64s[0] * b64s[0] + t1 = a64s[0] * b64s[1] + a64s[1] * b64s[0] + t2 = a64s[0] * b64s[2] + a64s[1] * b64s[1] + a64s[2] * b64s[0] + t3 = a64s[0] * b64s[3] + a64s[1] * b64s[2] + a64s[2] * b64s[1] + a64s[3] * b64s[0] + + t4 = a64s[1] * b64s[3] + a64s[2] * b64s[2] + a64s[3] * b64s[1] + t5 = a64s[2] * b64s[3] + a64s[3] * b64s[2] + t6 = a64s[3] * b64s[3] + + carry_0 = (t0 + t1 * (2**64) + c_lo - e_lo) / (2**128) + carry_1 = (t2 + t3 * (2**64) + c_hi + carry_0 - e_hi) / (2**128) + carry_2 = (t4 + t5 * (2**64) + carry_1 - d_lo) / (2**128) + + # range check for carries + self.range_check(carry_0, 9) + self.range_check(carry_1, 9) + self.range_check(carry_2, 9) + + self.constrain_equal(t0 + t1 * (2**64) + c_lo, e_lo + carry_0 * (2**128)) + self.constrain_equal(t2 + t3 * (2**64) + c_hi + carry_0, e_hi + carry_1 * (2**128)) + self.constrain_equal(t4 + t5 * (2**64) + carry_1, d_lo + carry_2 * (2**128)) + self.constrain_equal(t6 + carry_2, d_hi) + def fixed_lookup( self, tag: FixedTableTag, @@ -877,5 +877,5 @@ def memory_copier_gas_cost( self.range_check(gas_cost, N_BYTES_GAS) return gas_cost - def pow2_lookup(self, value: Expression, value_pow: Expression): - self.fixed_lookup(FixedTableTag.Pow2, value, value_pow) + def pow2_lookup(self, value: Expression, pow_lo128: Expression, pow_hi128: Expression): + self.fixed_lookup(FixedTableTag.Pow2, value, pow_lo128, pow_hi128) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 7a5b06732..44ab3f23c 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -70,7 +70,15 @@ def table_assignments(self) -> List[FixedTableRow]: ) ] elif self == FixedTableTag.Pow2: - return [FixedTableRow(FQ(self), FQ(value), FQ(1 << value)) for value in range(65)] + return [ + FixedTableRow( + FQ(self), + FQ(value), + FQ(1 << value) if value < 128 else FQ(0), + FQ(0) if value < 128 else FQ(1 << (value - 128)), + ) + for value in range(256) + ] else: raise ValueError("Unreacheable") diff --git a/tests/evm/test_mul_div_mod.py b/tests/evm/test_mul_div_mod_shl_shr.py similarity index 52% rename from tests/evm/test_mul_div_mod.py rename to tests/evm/test_mul_div_mod_shl_shr.py index 42429be54..b1d508327 100644 --- a/tests/evm/test_mul_div_mod.py +++ b/tests/evm/test_mul_div_mod_shl_shr.py @@ -1,6 +1,5 @@ import pytest -from typing import Optional from zkevm_specs.evm import ( ExecutionState, StepState, @@ -15,55 +14,79 @@ from common import generate_nasty_tests +MAX_WORD = (1 << 256) - 1 + TESTING_DATA = [ - (Opcode.MUL, 0x030201, 0x060504), - ( - Opcode.MUL, - 3402823669209384634633746074317682114560, - 34028236692093846346337460743176821145600, - ), - ( - Opcode.MUL, - 3402823669209384634633746074317682114560, - 34028236692093846346337460743176821145500, - ), - (Opcode.DIV, 0xFFFFFF, 0xABC), - (Opcode.DIV, 0xABC, 0xFFFFFF), - (Opcode.DIV, 0xFFFFFF, 0xFFFFFFF), - (Opcode.DIV, 0xABC, 0), - (Opcode.MOD, 0xFFFFFF, 0xABC), - (Opcode.MOD, 0xABC, 0xFFFFFF), - (Opcode.MOD, 0xFFFFFF, 0xFFFFFFF), - (Opcode.MOD, 0xABC, 0), + (Opcode.MUL, 0xABCD, 0x1234), + (Opcode.MUL, 0xABCD, 0x1234 << 240), + (Opcode.MUL, 0xABCD << 240, 0x1234 << 240), + (Opcode.MUL, MAX_WORD, 0x1234), + (Opcode.MUL, MAX_WORD, 0), + (Opcode.DIV, 0xABCD, 0x1234), + (Opcode.DIV, 0xABCD, 0x1234 << 240), + (Opcode.DIV, 0xABCD << 240, 0x1234 << 240), + (Opcode.DIV, MAX_WORD, 0x1234), + (Opcode.DIV, MAX_WORD, 0), + (Opcode.MOD, 0xABCD, 0x1234), + (Opcode.MOD, 0xABCD, 0x1234 << 240), + (Opcode.MOD, 0xABCD << 240, 0x1234 << 240), + (Opcode.MOD, MAX_WORD, 0x1234), + (Opcode.MOD, MAX_WORD, 0), + (Opcode.SHL, 8, 0xABCD << 240), + (Opcode.SHL, 7, 0x1234 << 240), + (Opcode.SHL, 17, 0x8765 << 240), + (Opcode.SHL, 0, 0x4321 << 240), + (Opcode.SHL, 256, 0xFFFF), + (Opcode.SHL, 256 + 8 + 1, 0x12345), + (Opcode.SHL, 63, MAX_WORD), + (Opcode.SHL, 128, MAX_WORD), + (Opcode.SHL, 129, MAX_WORD), + (Opcode.SHR, 8, 0xABCD), + (Opcode.SHR, 7, 0x1234), + (Opcode.SHR, 17, 0x8765), + (Opcode.SHR, 0, 0x4321), + (Opcode.SHR, 256, 0xFFFF), + (Opcode.SHR, 256 + 8 + 1, 0x12345), + (Opcode.SHR, 63, (1 << 256) - 1), + (Opcode.SHR, 128, (1 << 256) - 1), + (Opcode.SHR, 129, (1 << 256) - 1), (Opcode.MUL, rand_word(), rand_word()), (Opcode.DIV, rand_word(), rand_word()), (Opcode.MOD, rand_word(), rand_word()), + (Opcode.SHL, rand_word(), rand_word()), + (Opcode.SHR, rand_word(), rand_word()), ] -generate_nasty_tests(TESTING_DATA, (Opcode.MUL, Opcode.DIV, Opcode.MOD)) +generate_nasty_tests(TESTING_DATA, (Opcode.MUL, Opcode.DIV, Opcode.MOD, Opcode.SHL, Opcode.SHR)) @pytest.mark.parametrize("opcode, a, b", TESTING_DATA) -def test_mul_div_mod(opcode: Opcode, a: int, b: int): - randomness = rand_fq() - +def test_mul_div_mod_shl_shr(opcode: Opcode, a: int, b: int): if opcode == Opcode.MUL: - c = a * b % 2**256 + c = a * b & MAX_WORD + bytecode = Bytecode().mul(a, b) + used_gas = 5 elif opcode == Opcode.DIV: c = 0 if b == 0 else a // b - else: # Opcode.MOD + bytecode = Bytecode().div(a, b) + used_gas = 5 + elif opcode == Opcode.MOD: c = 0 if b == 0 else a % b + bytecode = Bytecode().mod(a, b) + used_gas = 5 + elif opcode == Opcode.SHL: + c = b << a & MAX_WORD if a <= 255 else 0 + bytecode = Bytecode().shl(a, b) + used_gas = 3 + else: # SHR + c = b >> a if a <= 255 else 0 + bytecode = Bytecode().shr(a, b) + used_gas = 3 + randomness = rand_fq() a = RLC(a, randomness) b = RLC(b, randomness) c = RLC(c, randomness) - - if opcode == Opcode.MUL: - bytecode = Bytecode().mul(a, b) - elif opcode == Opcode.DIV: - bytecode = Bytecode().div(a, b) - else: - bytecode = Bytecode().mod(a, b) bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( @@ -84,7 +107,7 @@ def test_mul_div_mod(opcode: Opcode, a: int, b: int): tables=tables, steps=[ StepState( - execution_state=ExecutionState.MUL, + execution_state=ExecutionState.MUL_DIV_MOD_SHL_SHR, rw_counter=9, call_id=1, is_root=True, @@ -92,7 +115,7 @@ def test_mul_div_mod(opcode: Opcode, a: int, b: int): code_hash=bytecode_hash, program_counter=66, stack_pointer=1022, - gas_left=5, + gas_left=used_gas, ), StepState( execution_state=ExecutionState.STOP, diff --git a/tests/evm/test_shr.py b/tests/evm/test_mulmod.py similarity index 52% rename from tests/evm/test_shr.py rename to tests/evm/test_mulmod.py index 758032dca..ca1c0ec27 100644 --- a/tests/evm/test_shr.py +++ b/tests/evm/test_mulmod.py @@ -1,5 +1,6 @@ import pytest +from typing import Optional from zkevm_specs.evm import ( ExecutionState, StepState, @@ -9,38 +10,37 @@ Bytecode, RWDictionary, ) -from zkevm_specs.util import ( - rand_fq, - rand_range, - rand_word, - RLC, - U256, -) +from zkevm_specs.util import rand_fq, RLC +MAXU256 = (2**256) - 1 -TESTING_DATA = ( - (0xABCD, 8), - (0x1234, 7), - (0x8765, 17), - (0x4321, 0), - (0xFFFF, 256), - (0x12345, 256 + 8 + 1), - ((1 << 256) - 1, 63), - ((1 << 256) - 1, 128), - ((1 << 256) - 1, 129), -) +TESTING_DATA = [ + (1, 1, 2), + (1, 1, 0), + (0, 2, 3), + (MAXU256, MAXU256, MAXU256), + (MAXU256, MAXU256, 1), + (MAXU256, 1, MAXU256), + (MAXU256, 2, 2), + (0, 0, 0), +] -@pytest.mark.parametrize("value, shift", TESTING_DATA) -def test_shr(value: U256, shift: int): - result = value >> shift if shift <= 255 else 0 +@pytest.mark.parametrize("a, b, n", TESTING_DATA) +def test_mulmod(a: int, b: int, n: int): randomness = rand_fq() - value = RLC(value, randomness) - shift = RLC(shift, randomness) - result = RLC(result, randomness) - bytecode = Bytecode().push32(value).push32(shift).shr().stop() + if n == 0: + r = RLC(0, randomness) + else: + r = RLC((a * b) % n, randomness) + + a = RLC(a, randomness) + b = RLC(b, randomness) + n = RLC(n, randomness) + + bytecode = Bytecode().mulmod(a, b, n).stop() bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( @@ -49,9 +49,10 @@ def test_shr(value: U256, shift: int): bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( RWDictionary(9) - .stack_read(1, 1022, value) - .stack_read(1, 1023, shift) - .stack_write(1, 1023, result) + .stack_read(1, 1021, a) + .stack_read(1, 1022, b) + .stack_read(1, 1023, n) + .stack_write(1, 1023, r) .rws ), ) @@ -61,24 +62,24 @@ def test_shr(value: U256, shift: int): tables=tables, steps=[ StepState( - execution_state=ExecutionState.SHR, + execution_state=ExecutionState.MULMOD, rw_counter=9, call_id=1, is_root=True, is_create=False, code_hash=bytecode_hash, - program_counter=66, - stack_pointer=1022, - gas_left=3, + program_counter=99, + stack_pointer=1021, + gas_left=8, ), StepState( execution_state=ExecutionState.STOP, - rw_counter=11, + rw_counter=13, call_id=1, is_root=True, is_create=False, code_hash=bytecode_hash, - program_counter=67, + program_counter=100, stack_pointer=1023, gas_left=0, ),