diff --git a/specs/bytecode-proof.md b/specs/bytecode-proof.md index f952a624d..091b9b25f 100644 --- a/specs/bytecode-proof.md +++ b/specs/bytecode-proof.md @@ -4,24 +4,29 @@ The bytecode proof helps the EVM proof by making the bytecode (identified by its ## Circuit Layout -| Column | Description | -| --------------------- | ------------------------------------------------------------------ | -| `q_first` | `1` on the first row, else `0` | -| `q_last` | `1` on the last row, else `0` | -| `hash` | The keccak hash of the bytecode | -| `tag` | Tag indicates whether `value` is a byte or length of the bytecode | -| `index` | The position of the byte in the bytecode | -| `value` | The byte data for the current position, or length of the bytecode | -| `is_code` | `1` if the byte is code, `0` if the byte is PUSH data | -| `push_data_left` | The number of PUSH data bytes that still follow the current row | -| `hash_rlc` | The accumulator containing the current and previous bytes | -| `hash_length` | The bytecode length | -| `byte_push_size` | The number of bytes pushed for the current byte | -| `is_final` | `1` if the current byte is the last byte of the bytecode, else `0` | -| `padding` | `1` if the current row is padding, else `0` | -| `push_data_left_inv` | The inverse of `push_data_left` (`IsZeroChip` helper) | -| `push_table.byte` | Push Table: A byte value | -| `push_table.push_size`| Push Table: The number of bytes pushed for this byte as opcode | +The column `tag` (advice) makes the circuit behave as a state machine, selecting different constraints depending on the current and next row value. The `tag` column can have two different values: Header, Byte. A row of `tag==Header` precedes a series of `tag==Byte` rows that contain a complete bytecode sequence. The row `tag==Header` contains the length of the bytecode and the hash of the bytecode, each `tag==Byte` contains one byte of the bytecode, bytecode hash, its length and other for the push data. + + +| Column | Description | +| --------------------- | --------------------------------------------------------------------| +| `q_first` (fixed) | `1` on the first row, else `0` | +| `q_last` (fixed) | `1` on the last row, else `0` | +| `hash` | The keccak hash of the bytecode | +| `index` | The position of the byte in the bytecode starting from 0 | +| `value` | Value for this row bytecode byte, and the length in Header rows. | +| `is_code` | `1` if the byte is code, `0` if the byte is PUSH data | +| `push_data_left` | The number of PUSH data bytes that still follow the current row | +| `value_rlc` | The accumulator containing the current and previous bytes RLC | +| `length` | The bytecode length, that could be 0 for empty bytecodes and padding| +| `push_data_size` | The number of bytes pushed for the current byte | +| `push_table.byte` | Push Table: A byte value | +| `push_table.push_size`| Push Table: The number of bytes pushed for this byte as opcode | + + +After all the bytecodes have been added, the rest of the rows are filled with padding in the form of `tag == Header && length == 0 && value == 0 && hash == EMPTY_HASH` rows. + +Additionally we will need two columns for IsZeroChip for `length` and `push_data_left` + ## Push table @@ -35,30 +40,117 @@ Because we do this lookup for each byte, this table is also indirectly used to r | \[OpcodeId::PUSH1, OpcodeId::PUSH32\] | `[1..32]` | | \[OpcodeId::PUSH32, 256\] | `0` | -### Circuit behavior +## Witness generation -The circuit starts by adding a row that contains the bytecode length using `tag = Length`. Then it -runs over all the bytes of the bytecode starting at the byte at position `0`. -Each following row unrolls a single byte of the bytecode while also storing its position +The circuit starts by adding a row that contains the bytecode length using `tag = Header`. + +Then it runs over all the bytes of the bytecode in order starting at the byte at position `0`. +Each following row unrolls a single byte (using `tag = Byte` and `value = the actual byte value`) of the bytecode while also storing its position (`index`), the code hash it's part of (`hash`), and if it is code or not -(`is_code`); using `tag = Byte` +(`is_code`). Also `push_data_size` is filled to match the push table, and `push_data_left` is computed. + +All byte data is accumulated per byte (with one byte per row) into `value_rlc` as follows, where r is a challenge: + +``` +first_bytecode.value_rlc := first_bytecode.value + +next.value_rlc := cur.value_rlc * r + next.value +``` + +For detecting which byte is code and which byte is push data the [Push table](#push-table) is used. This table allows finding out how many bytes an opcode pushes. This is used to set `next.push_data_left` if and only if the current byte is code (the first byte in any bytecode is code). + +If a row contains a zero value for `push_data_left` we know the current byte is an opcode: + +``` +first_bytecode.is_code := 1 +cur.is_code := cur.push_data_left == 0 +next.push_data_left := cur.byte_push_size if cur.is_code else cur.push_data_left - 1 +``` + +The fixed columns `q_first` and `q_last` should be zero for all rows, except the first one where `q_first := 1` and the last one where `q_last := 1`. + +## Circuit constrains + +All circuit constraints are based on the current row (`cur`) and the `next` row. + +First of all if `cur.q_first` or `cur.q_last` are `1`, then `cur.tag == Header`. + +We should have the following constraint based on `cur.tag` and `next.tag` (state transition), for all rows except the last one (`cur.q_last == 1`). + +To enable lookup all `cur.tag == Header` rows should have: + +``` +assert cur.index == 0 +assert cur.value == cur.length +``` + +Also, each `cur.tag == Byte` should have: + +``` +assert push_data_size_table_lookup(cur.value, cur.push_data_size) +assert cur.is_code == (cur.push_data_left == 0) +if cur.is_code: + assert next.push_data_left == cur.push_data_size +else: + assert next.push_data_left == cur.push_data_left - 1 +``` + +This way we make sure is_code and next.push_data_left have the right values. + +### cur.tag == Header and next.tag == Header + +We are in a transition from a empty bytecode to the begining of another bytecode that could be empty or not. + +Hence: +``` +assert cur.length == 0 +assert cur.hash == EMPTY_HASH +``` -All byte data is accumulated per byte (with one byte per row) into `hash_rlc` as follows: +### cur.tag == Header and next.tag == Byte + +We are at the begining of a non-empty bytecode. + +Hence: + +``` +assert next.length == cur.length +assert next.index == 0 +assert next.is_code == 1 +assert next.hash == cur.hash +assert next.value_rlc == next.value +``` + +### cur.tag == Byte and next.tag == Byte + +We are working on an actual bytecode byte that is not the last one. + +Hence: ``` -hash_rlc := hash_rlc_prev * r + byte +assert next.length == cur.length +assert next.index == cur.index + 1 +assert next.hash == cur.hash +assert next.value_rlc == cur.value_rlc * randomness + next.value ``` -For detecting which byte is code and which byte is push data the [Push table](#push-table) is used. This table allows finding out how many bytes an opcode pushes. This is used to set `push_data_left` if and only if the current byte is code (the first byte in any bytecode is code). If a row contains a non-zero value for `push_data_left` on its previous row we know the current byte is an opcode: +We make sure that `index` is incremented and `value_rlc` is accumulated. + +### cur.tag == Bytecode and next.tag == Header + +We are at the last byte of a bytecode. + +Hence: ``` -is_code := prev_push_data_left == 0 -push_data_left := byte_push_size if is_code else prev_push_data_left - 1 +assert cur.index + 1 == cur.length +assert keccak256_table_lookup(cur.hash, cur.length, cur.value_rlc) ``` -At the last byte the prover can set `is_final` to `1`, which will enable the keccak lookup on `(hash_rlc, hash_length, hash)`. This will ensure that the byte data passed into the circuit matches the data the prover gave as input (all the byte data is accumulated into `hash_rlc`). This has the consequence that the circuit _requires_ the full bytecode to be a part of its state, otherwise the prover could pass in invalid byte data for the specified hash. This is enforced by the circuit by requiring the last row in the circuit (when `q_last == 1`, note that `q_first` of the next row _cannot_ be used because of unusable rows) to either have `is_final == 1` or `padding == 1`, and padding itself can only be enabled after a `is_final` was set to `1`. +First, we make sure that the bytecode has `cur.length` bytes in the table. + +Second, we ensure that the byte data passed into the circuit matches the data the prover gave as input (all the byte data is accumulated into `value_rlc`). This has the consequence that the circuit _requires_ the full bytecode to be a part of its state, otherwise the prover could pass in invalid byte data for the specified hash. -Explicit padding is added to this circuit to be able to fully fill the circuit with valid data without depending on the keccak circuit to either hash this padding data or support looking up e.g. all zero data. ## Code diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 83ced381d..881a1e6f2 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -1,13 +1,13 @@ -from typing import Sequence, Union, Tuple, Set, NamedTuple +from typing import Sequence, Tuple, Set, NamedTuple from collections import namedtuple from .util import keccak256, EMPTY_HASH, FQ, RLC from .evm import get_push_size, BytecodeFieldTag, BytecodeTableRow -from .encoding import U8, U256, is_circuit_code +from .encoding import is_circuit_code # Row in the circuit Row = namedtuple( "Row", - "q_first q_last hash tag index is_code value push_data_left hash_rlc hash_length byte_push_size is_final padding", + "q_first q_last hash tag index value is_code push_data_left value_rlc length push_data_size", ) # Unrolled bytecode class UnrolledBytecode(NamedTuple): @@ -16,104 +16,73 @@ class UnrolledBytecode(NamedTuple): @is_circuit_code -def assert_bool(value: Union[int, bool]): - assert value in [0, 1] +def check_bytecode_row( + cur: Row, + next: Row, + push_table: Set[Tuple[int, int]], + keccak_table: Set[Tuple[int, int, int]], + randomness: int, +): + cur = Row(*[v if isinstance(v, RLC) else FQ(v) for v in cur]) + next = Row(*[v if isinstance(v, RLC) else FQ(v) for v in next]) + + if cur.q_first == 1: + assert cur.tag == BytecodeFieldTag.Header + + if cur.q_last == 0: + if cur.tag == BytecodeFieldTag.Header: + assert cur.value == cur.length + assert cur.index == 0 + if next.tag == BytecodeFieldTag.Byte: + check_bytecode_row_header_to_byte(cur, next) + if next.tag == BytecodeFieldTag.Header: + check_bytecode_row_header_to_header(cur, randomness) + + if cur.tag == BytecodeFieldTag.Byte: + assert (cur.value, cur.push_data_size) in push_table + assert cur.is_code == (cur.push_data_left == 0) + + if next.tag == BytecodeFieldTag.Byte: + check_bytecode_row_byte_to_byte(cur, next, randomness) + if next.tag == BytecodeFieldTag.Header: + check_bytecode_row_byte_to_header(cur, keccak_table) + + if cur.q_last == 1: + assert cur.tag == BytecodeFieldTag.Header + check_bytecode_row_header_to_header(cur, randomness) @is_circuit_code -def select( - selector: U8, - when_true: U256, - when_false: U256, -) -> U256: - return U256(selector * when_true + (1 - selector) * when_false) +def check_bytecode_row_header_to_byte(cur: Row, next: Row): + assert next.length == cur.length + assert next.index == 0 + assert next.is_code == 1 + assert next.hash == cur.hash + assert next.value_rlc == next.value @is_circuit_code -def check_bytecode_row( - row: Row, - prev_row: Row, - next_row: Row, - push_table: Set[Tuple[int, int]], - keccak_table: Set[Tuple[int, int, int]], - r: int, -): - row = Row(*[v if isinstance(v, RLC) else FQ(v) for v in row]) - prev_row = Row(*[v if isinstance(v, RLC) else FQ(v) for v in prev_row]) - next_row = Row(*[v if isinstance(v, RLC) else FQ(v) for v in next_row]) - if row.q_first == 0 and prev_row.is_final == 0: - # Continue - if prev_row.tag == BytecodeFieldTag.Length: - # index starts from 0 - assert row.index == 0 - # is_code := 1, since this is the first byte of the bytecode - assert row.is_code == 1 - else: - # index is 1 more than previous row's index - assert row.index == prev_row.index + 1 - # is_code := push_data_left_prev == 0 - assert row.is_code == (prev_row.push_data_left == 0) - # hash_rlc := hash_rlc_prev * r + byte - assert row.hash_rlc == prev_row.hash_rlc * r + row.value - # padding needs to remain the same - assert row.padding == prev_row.padding - # hash needs to remain the same - assert row.hash == prev_row.hash - # hash_length needs to remain the same - assert row.hash_length == prev_row.hash_length +def check_bytecode_row_header_to_header(cur: Row, randomness: int): + assert cur.length == 0 + assert cur.hash == RLC(EMPTY_HASH, FQ(randomness)).expr() + + +@is_circuit_code +def check_bytecode_row_byte_to_byte(cur: Row, next: Row, r: int): + assert next.length == cur.length + assert next.index == cur.index + 1 + assert next.hash == cur.hash + assert next.value_rlc == cur.value_rlc * r + next.value + if cur.is_code == 1: + assert next.push_data_left == cur.push_data_size else: - # Start - # the row following an `is_final` previous row is either tagged Length - if row.tag == BytecodeFieldTag.Length: - # value matches hash length - assert row.value == row.hash_length - # if bytecode length is zero - if row.value == 0: - # bytecode hash should be EMPTY_HASH - assert row.hash == RLC(EMPTY_HASH, FQ(r)).expr() - # the next row should be a tag Length or padding - assert (next_row.tag == BytecodeFieldTag.Length) or (next_row.tag == 0) - else: - # the next row should be tag Byte - assert next_row.tag == BytecodeFieldTag.Byte - # or is the start of padding rows - else: - assert row.padding == 1 - - # is_final needs to be boolean - assert_bool(row.is_final) - # padding needs to be boolean - assert_bool(row.padding) - - if row.tag == BytecodeFieldTag.Byte: - # push_data_left := is_code ? byte_push_size : push_data_left_prev - 1 - assert row.push_data_left == select( - row.is_code, row.byte_push_size, prev_row.push_data_left - 1 - ) + assert next.push_data_left == cur.push_data_left - 1 - # Padding - if row.q_first == 0: - # padding can only go 0 -> 1 once - assert_bool(row.padding - prev_row.padding) - - # Last row - # The hash is checked on the latest row because only then have - # we accumulated all the bytes. We also have to go through the bytes - # in a forward manner because that's the only way we can know which - # bytes are op codes and which are push data. - if row.q_last == 1: - # padding needs to be enabled OR - # the last row needs to be the last byte - assert row.padding == 1 or row.is_final == 1 - - if row.tag == BytecodeFieldTag.Byte: - # Lookup how many bytes the current opcode pushes - # (also indirectly range checks `byte` to be in [0, 255]) - assert (row.value, row.byte_push_size) in push_table - - # keccak lookup when on the last byte - if row.is_final == 1 and row.padding == 0: - assert (row.hash_rlc, row.hash_length, row.hash) in keccak_table + +@is_circuit_code +def check_bytecode_row_byte_to_header(cur: Row, keccak_table: Set[Tuple[int, int, int]]): + assert cur.index + 1 == cur.length + assert (cur.value_rlc, cur.length, cur.hash) in keccak_table # Populate the circuit matrix @@ -124,36 +93,34 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rando rows = [] offset = 0 for bytecode in bytecodes: - push_data_left = 0 - hash_rlc = FQ(0) + next_push_data_left = 0 + value_rlc = FQ(0) for idx, row in enumerate(bytecode.rows): # Subsequent rows represent the bytecode bytes # Track which byte is an opcode and which is push data + push_data_left = next_push_data_left is_code = push_data_left == 0 - byte_push_size = 0 + push_data_size = 0 if idx > 0: - byte_push_size = get_push_size(row.value) - push_data_left = byte_push_size if is_code else push_data_left - 1 + push_data_size = get_push_size(row.value) + next_push_data_left = push_data_size if is_code else push_data_left - 1 # Add the byte to the accumulator - hash_rlc = hash_rlc * randomness + row.value + value_rlc = value_rlc * randomness + row.value # Set the data for this row rows.append( Row( - offset == 0, - offset == last_row_offset, - row.bytecode_hash, - row.field_tag, - row.index, - row.is_code, - row.value, - push_data_left, - hash_rlc, - len(bytecode.bytes), - byte_push_size, - # Since 1 row is taken up by the Length tag - idx == len(bytecode.bytes), - False, + q_first=offset == 0, + q_last=offset == last_row_offset, + hash=row.bytecode_hash, + tag=row.field_tag, + index=row.index, + value=row.value, + is_code=row.is_code, + push_data_left=push_data_left, + value_rlc=value_rlc, + length=len(bytecode.bytes), + push_data_size=push_data_size, ) ) @@ -166,19 +133,17 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rando for idx in range(offset, 2**k): rows.append( Row( - idx == 0, - idx == last_row_offset, - 0, - 0, - 0, - 0, - True, - 0, - 0, - 0, - 0, - True, - True, + q_first=idx == 0, + q_last=idx == last_row_offset, + hash=RLC(EMPTY_HASH, FQ(randomness)).expr(), + tag=BytecodeFieldTag.Header, + index=0, + value=0, + is_code=False, + push_data_left=0, + value_rlc=0, + length=0, + push_data_size=0, ) ) diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 42b38eea4..8a14323cb 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -686,7 +686,7 @@ def bytecode_lookup_pair( def bytecode_length(self, bytecode_hash: Expression) -> Expression: return self.tables.bytecode_lookup( - bytecode_hash, FQ(BytecodeFieldTag.Length), FQ(0), FQ(0) + bytecode_hash, FQ(BytecodeFieldTag.Header), FQ(0), FQ(0) ).value def tx_gas_price(self, tx_id: Expression) -> RLC: diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 5d2627b2f..4a7d91347 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -164,7 +164,7 @@ class BytecodeFieldTag(IntEnum): Tag for BytecodeTable lookup. """ - Length = 1 + Header = 1 Byte = 2 diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 1ad2dd7a2..5504a0226 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -321,7 +321,7 @@ def __next__(self): if self.idx == 0: self.idx += 1 return BytecodeTableRow( - self.hash, FQ(BytecodeFieldTag.Length), FQ(0), FQ(0), FQ(len(self.code)) + self.hash, FQ(BytecodeFieldTag.Header), FQ(0), FQ(0), FQ(len(self.code)) ) if self.idx > len(self.code): diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index 91eea29fc..d2510f24d 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -3,7 +3,7 @@ from zkevm_specs.bytecode import * from zkevm_specs.evm import Opcode, Bytecode, BytecodeFieldTag, BytecodeTableRow, is_push -from zkevm_specs.util import RLC, rand_fq +from zkevm_specs.util import RLC, rand_fq, U256 # Unroll the bytecode def unroll(bytecode, randomness): @@ -12,18 +12,24 @@ def unroll(bytecode, randomness): # Verify the bytecode circuit with the given data def verify(k, bytecodes, randomness, success): + rows = assign_bytecode_circuit(k, bytecodes, randomness) + verify_rows(bytecodes, rows, success) + + +def verify_rows(bytecodes, rows, success): push_table = assign_push_table() keccak_table = assign_keccak_table(map(lambda v: v.bytes, bytecodes), randomness) - rows = assign_bytecode_circuit(k, bytecodes, randomness) try: for (idx, row) in enumerate(rows): - prev_row = rows[(idx - 1) % len(rows)] next_row = rows[(idx + 1) % len(rows)] - check_bytecode_row(row, prev_row, next_row, push_table, keccak_table, randomness) + check_bytecode_row(row, next_row, push_table, keccak_table, randomness) ok = True except AssertionError as e: if success: traceback.print_exc() + print(idx) + print(row) + print(next_row) ok = False assert ok == success @@ -53,7 +59,7 @@ def test_bytecode_unrolling(): for i in range(len(rows)): rows[i] = BytecodeTableRow(hash.expr(), rows[i][1], rows[i][2], rows[i][3], rows[i][4]) # Prepend the length of bytecode to rows - rows.insert(0, BytecodeTableRow(hash.expr(), BytecodeFieldTag.Length, 0, 0, len(bytecode))) + rows.insert(0, BytecodeTableRow(hash.expr(), BytecodeFieldTag.Header, 0, 0, len(bytecode))) # Unroll the bytecode unrolled = unroll(bytes(bytecode), randomness) # Check if the bytecode was unrolled correctly @@ -68,7 +74,10 @@ def test_bytecode_empty(): def test_bytecode_full(): - bytecodes = [unroll(bytes([7] * (2**k - 1)), randomness)] + bytecodes = [ + unroll(bytes([7] * (2**k - 2)), randomness), + unroll(bytes([]), randomness), # Last row must be tag=Header + ] verify(k, bytecodes, randomness, True) @@ -203,3 +212,69 @@ def test_bytecode_invalid_is_code(): row = unrolled.rows[7] invalid.rows[7] = BytecodeTableRow(row.bytecode_hash, row.field_tag, row.index, 1, row.value) verify(k, [invalid], randomness, False) + + +def test_last_row(): + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) + + # last row has length != 0 + rows = assign_bytecode_circuit(k, [unrolled], randomness) + rows[-1] = Row( + q_first=0, + q_last=1, + hash=RLC(EMPTY_HASH, FQ(randomness)).expr(), + tag=BytecodeFieldTag.Header, + index=0, + value=0, + is_code=False, + push_data_left=0, + value_rlc=0, + length=1000, + push_data_size=0, + ) + verify_rows([unrolled], rows, False) + + # last row has hash != EMPTY_HASH + NOT_EMPTY_HASH = U256( + int.from_bytes( + keccak256(bytes("why is there something instead of nothing?", "utf-8")), "big" + ) + ) + rows = assign_bytecode_circuit(k, [unrolled], randomness) + rows[-1] = Row( + q_first=0, + q_last=1, + hash=RLC(NOT_EMPTY_HASH, FQ(randomness)).expr(), + tag=BytecodeFieldTag.Header, + index=0, + value=0, + is_code=False, + push_data_left=0, + value_rlc=0, + length=0, + push_data_size=0, + ) + verify_rows([unrolled], rows, False) + + # last row is not Header + NOT_EMPTY_HASH = U256( + int.from_bytes( + keccak256(bytes("why is there something instead of nothing?", "utf-8")), "big" + ) + ) + rows = assign_bytecode_circuit(k, [unrolled], randomness) + rows[-1] = Row( + q_first=0, + q_last=1, + hash=RLC(NOT_EMPTY_HASH, FQ(randomness)).expr(), + tag=BytecodeFieldTag.Byte, + index=0, + value=0, + is_code=False, + push_data_left=0, + value_rlc=0, + length=0, + push_data_size=0, + ) + verify_rows([unrolled], rows, False)