Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions specs/opcode/1bSHL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SHL opcode

## Procedure

The `SHL` opcode shifts the bits towards the most significant one. The bits moved after the 256th one are discarded, the new bits are set to 0.

### EVM behavior

Pop two EVM words `a` and `shift` from the stack, and push `b` to the stack, where `b` is computed as:

1. If `shift >= 256`,then `b` is set to zero.
2. If `shift < 256`,compute `b = a << shift`.

### Circuit behavior

To prove the `SHL` opcode, we first construct a `ShlGadget` that proves `a << shift == b` where `a, b, shift` are all 256-bit words.
As usual, we use 32 cells to represent word `a` and `b`, where each cell holds a 8-bit value. Then split each word into four 64-bit limbs denoted by `a64s[idx]` and `b64s[idx]` where idx in `(0, 1, 2, 3)`.
We put the lower `64 - n` bits of a limb into the `lo` array, and put the higher `n` bits into the `hi` array, where `n` is `shift % 64`. During the SHL operation, the `lo` array will move to higher bits of the result, and the `hi` array will move to lower bits of the result.

The following figure illustrates how shift left works under the case of `shift < 64`.

```
------+-------------------------------+-------------------------------+------
|a0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| ...
------+-------------------------------+-------------------------------+------
| a64s[0] | a64s[1] | ...
------+------------+------------------+------------+------------------+------
| a64s_lo[0] | a64s_hi[0] | a64s_lo[1] | a64s_hi[1] | ...
------+------------+------------------+------------+------------------+------
b64s[0] | b64s[1] | b64s[2]
------+-------------------------------+------------+-------------------------
```

More formally, the variables are defined as follows:

```
shf0 = bytes_to_fq(shift.le_bytes[:1])
shf_div64 = shift // 64
shf_mod64 = shift % 64
shf_lt256 = is_zero(sum(shift[1:]))
p_lo = 1 << (64 - shf_mod64)
p_hi = 1 << shf_mod64
a64s = word_to_64s(a)
a64s_lo[idx] = a64s[idx] % p_lo
a64s_hi[idx] = a64s[idx] / p_lo
```

If `shift >= 256`, `b64s` are all 0. Otherwise, `b64s` can be calculated by `a << shf0` then split into four 64-bit limbs.

Now putting things together, the constraints can be constructed as follows:

1. `a64s` and `b64s` constraints:

* First calculate `shf_lt256` as `is_zero(sum(shift[1:]))`.
* `a64s[idx]`: It should be equal to `from_bytes(a[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`.
* `b64s[idx] * shf_lt256`: It should be equal to `from_bytes(b[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`.

2. `a64s_lo` and `a64s_hi` constraints:

* `a64s[idx]`: It should be equal to `a64s_lo[idx] + a64s_hi[idx] * p_lo`.
* `a64s_hi[idx]`: It should always be less than `p_hi` (`a64s_hi[idx] < p_hi`).

3. Merge constraints:

* First create three `IsZero` gadgets:
```
shf_div64_eq0 = is_zero(shf_div64)
shf_div64_eq1 = is_zero(shf_div64 - 1)
shf_div64_eq2 = is_zero(shf_div64 - 2)
```

* `b64s[0]` should be equal to:
```
shf_div64_eq0 * a64s_lo[0] * p_hi
```

* `b64s[1]` should be equal to:
```
shf_div64_eq0 * (a64s_hi[0] + a64s_lo[1] * p_hi) +
shf_div64_eq1 * a64s_lo[0] * p_hi
```

* `b64s[2]` should be equal to:
```
shf_div64_eq0 * (a64s_hi[1] + a64s_lo[2] * p_hi) +
shf_div64_eq1 * (a64s_hi[0] + a64s_lo[1] * p_hi) +
shf_div64_eq2 * a64s_lo[0] * p_hi
```

* `b64s[3]` should be equal to:
```
shf_div64_eq0 * (a64s_hi[2] + a64s_lo[3] * p_hi) +
shf_div64_eq1 * (a64s_hi[1] + a64s_lo[2] * p_hi) +
shf_div64_eq2 * (a64s_hi[0] + a64s_lo[1] * p_hi) +
(1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) * a64s_lo[0] * p_hi
```

4. `shift[0]` constraint:

* `shift[0]`: It should be equal to `shf_mod64 + shf_div64 * 64`.

5. `Pow2` table look up:

* First build `Pow2` table by tuple `(value, value_pow)` which meets `value_pow == pow(2, value)`

* Look up for `(64 - shf_mod64, p_lo)` and `(shf_mod64, p_hi)`

6. Stack pop and push:

* Pop word `a`
* Pop word `shift`
* Push word `shift_lt256 * b`

## Constraints

1. opId = OpcodeId(0x1b)
2. state transition:
- gc + 3 (2 stack reads + 1 stack write)
- stack\_pointer + 1
- pc + 1
- gas + 3
3. lookups: 3 busmapping lookups
- `a` is at the top of the stack
- `shift` is at the second position of the stack
- `b`, the result, is at the new top of the stack

## Exceptions

1. stack underflow: `1023 <= stack_pointer <= 1024`
2. out of gas: remaining gas is not enough

## Code

See `src/zkevm_specs/evm/execution/shl.py`
2 changes: 2 additions & 0 deletions src/zkevm_specs/evm/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .selfbalance import *
from .extcodehash import *
from .log import *
from .shl import *


EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = {
Expand Down Expand Up @@ -65,4 +66,5 @@
ExecutionState.LOG: log,
ExecutionState.CALL: call,
ExecutionState.ISZERO: iszero,
ExecutionState.SHL: shl,
}
148 changes: 148 additions & 0 deletions src/zkevm_specs/evm/execution/shl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from ...util import FQ, N_BYTES_U64, RLC
from ..instruction import Instruction, Transition
from ..typing import Sequence


def shl(instruction: Instruction):
opcode = instruction.opcode_lookup(True)

a = instruction.stack_pop()
shift = instruction.stack_pop()
b = instruction.stack_push()

(
a64s,
b64s,
a64s_lo,
a64s_hi,
shf_div64,
shf_mod64,
p_lo,
p_hi,
) = gen_witness(instruction, a, shift)
check_witness(
instruction,
a,
shift,
b,
a64s,
b64s,
a64s_lo,
a64s_hi,
shf_div64,
shf_mod64,
p_lo,
p_hi,
)

instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(2),
program_counter=Transition.delta(1),
stack_pointer=Transition.delta(1),
)


def check_witness(
instruction: Instruction,
a: RLC,
shift: RLC,
b: RLC,
a64s: Sequence[FQ],
b64s: Sequence[FQ],
a64s_lo: Sequence[FQ],
a64s_hi: Sequence[FQ],
shf_div64,
shf_mod64,
p_lo,
p_hi,
):
shf_lt256 = instruction.is_zero(instruction.sum(shift.le_bytes[1:]))
for idx in range(4):
offset = idx * N_BYTES_U64

# a64s constraint
instruction.constrain_equal(
a64s[idx],
instruction.bytes_to_fq(a.le_bytes[offset : offset + N_BYTES_U64]),
)

# b64s constraint
instruction.constrain_equal(
b64s[idx] * shf_lt256,
instruction.bytes_to_fq(b.le_bytes[offset : offset + N_BYTES_U64]),
)

# `a64s[idx] == a64s_lo[idx] + a64s_hi[idx] * p_lo`
instruction.constrain_equal(a64s[idx], a64s_lo[idx] + a64s_hi[idx] * p_lo)

# `a64s_hi[idx] < p_hi`
#
# TRICKY:
# Since `p_lo` could be equal to `1 << 64` that is greater than `N_BYTES_U64`(8 bytes) if
# `shf_mod64` is zero. Alternative to compare `a64s_hi[idx]` and `p_hi` here.
a64s_hi_lt_p_hi, _ = instruction.compare(a64s_hi[idx], p_hi, N_BYTES_U64)
instruction.constrain_equal(a64s_hi_lt_p_hi, FQ(1))

# merge contraints
shf_div64_eq0 = instruction.is_zero(shf_div64)
shf_div64_eq1 = instruction.is_zero(shf_div64 - 1)
shf_div64_eq2 = instruction.is_zero(shf_div64 - 2)
instruction.constrain_equal(b64s[0], shf_div64_eq0 * a64s_lo[0] * p_hi)
instruction.constrain_equal(
b64s[1],
shf_div64_eq0 * (a64s_hi[0] + a64s_lo[1] * p_hi) + shf_div64_eq1 * a64s_lo[0] * p_hi,
)
instruction.constrain_equal(
b64s[2],
shf_div64_eq0 * (a64s_hi[1] + a64s_lo[2] * p_hi)
+ shf_div64_eq1 * (a64s_hi[0] + a64s_lo[1] * p_hi)
+ shf_div64_eq2 * a64s_lo[0] * p_hi,
)
instruction.constrain_equal(
b64s[3],
shf_div64_eq0 * (a64s_hi[2] + a64s_lo[3] * p_hi)
+ shf_div64_eq1 * (a64s_hi[1] + a64s_lo[2] * p_hi)
+ shf_div64_eq2 * (a64s_hi[0] + a64s_lo[1] * p_hi)
+ (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) * a64s_lo[0] * p_hi,
)

# shift constraint
instruction.constrain_equal(
instruction.bytes_to_fq(shift.le_bytes[:1]),
shf_mod64 + shf_div64 * 64,
)

# `p_lo == pow(2, 64 - shf_mod64)` and `p_hi == pow(2, shf_mod64)`.
instruction.pow2_lookup(64 - shf_mod64, p_lo)
instruction.pow2_lookup(shf_mod64, p_hi)


def gen_witness(instruction: Instruction, a: RLC, shift: RLC):
shf0 = instruction.bytes_to_fq(shift.le_bytes[:1])
shf_div64 = FQ(shf0.n // 64)
shf_mod64 = FQ(shf0.n % 64)
# Remain lower bits of `64 - shf_mod64` for SHL (reverse to SHR).
p_lo = FQ(1 << (64 - shf_mod64.n))
p_hi = FQ(1 << shf_mod64.n)

a64s = instruction.word_to_64s(a)
a64s_lo = [FQ(0)] * 4
a64s_hi = [FQ(0)] * 4
for idx in range(4):
a64s_lo[idx] = FQ(a64s[idx].n % p_lo.n)
a64s_hi[idx] = FQ(a64s[idx].n // p_lo.n)

bb = a.int_value << shf0.n
b64s = [FQ((bb >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)]

return (
a64s,
b64s,
a64s_lo,
a64s_hi,
shf_div64,
shf_mod64,
p_lo,
p_hi,
)
3 changes: 3 additions & 0 deletions src/zkevm_specs/evm/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,3 +817,6 @@ def memory_copier_gas_cost(
gas_cost = word_size * GAS_COST_COPY + memory_expansion_gas_cost
self.range_check(gas_cost, N_BYTES_GAS)
return gas_cost

def pow2_lookup(self, value: Expression, value_pow: Expression):
self.fixed_lookup(FixedTableTag.Pow2, value, value_pow)
3 changes: 3 additions & 0 deletions src/zkevm_specs/evm/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class FixedTableTag(IntEnum):
BitwiseOr = auto() # lhs, rhs, lhs | rhs, 0
BitwiseXor = auto() # lhs, rhs, lhs ^ rhs, 0
ResponsibleOpcode = auto() # execution_state, opcode, aux
Pow2 = auto() # value, value_pow

def table_assignments(self) -> List[FixedTableRow]:
if self == FixedTableTag.Range5:
Expand Down Expand Up @@ -68,6 +69,8 @@ def table_assignments(self) -> List[FixedTableRow]:
execution_state.responsible_opcode(),
)
]
elif self == FixedTableTag.Pow2:
return [FixedTableRow(FQ(self), FQ(value), FQ(1 << value)) for value in range(65)]
else:
raise ValueError("Unreacheable")

Expand Down
Loading