From da894473a4fa47ee9a4efb89a73bca11a7b71a85 Mon Sep 17 00:00:00 2001 From: hexoscott <70711990+hexoscott@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:21:57 +0000 Subject: [PATCH] moving zk witness generation into it's own package (#150) --- cmd/rpcdaemon/commands/zkevm_api.go | 264 ++------------------------- zk/witness/witness.go | 265 ++++++++++++++++++++++++++++ 2 files changed, 277 insertions(+), 252 deletions(-) create mode 100644 zk/witness/witness.go diff --git a/cmd/rpcdaemon/commands/zkevm_api.go b/cmd/rpcdaemon/commands/zkevm_api.go index 471a185150e..553a0c1f48e 100644 --- a/cmd/rpcdaemon/commands/zkevm_api.go +++ b/cmd/rpcdaemon/commands/zkevm_api.go @@ -1,7 +1,6 @@ package commands import ( - "bytes" "context" "encoding/json" "errors" @@ -9,32 +8,20 @@ import ( "math/big" "github.com/ledgerwatch/erigon-lib/common" + libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/common/hexutility" "github.com/ledgerwatch/erigon-lib/kv" - "github.com/ledgerwatch/erigon-lib/kv/memdb" - "github.com/ledgerwatch/log/v3" - - libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/holiman/uint256" "github.com/ledgerwatch/erigon/common/hexutil" - "github.com/ledgerwatch/erigon/consensus" "github.com/ledgerwatch/erigon/core" - "github.com/ledgerwatch/erigon/core/state" eritypes "github.com/ledgerwatch/erigon/core/types" - "github.com/ledgerwatch/erigon/core/vm" - "github.com/ledgerwatch/erigon/eth/stagedsync" "github.com/ledgerwatch/erigon/eth/stagedsync/stages" "github.com/ledgerwatch/erigon/rpc" - db2 "github.com/ledgerwatch/erigon/smt/pkg/db" - "github.com/ledgerwatch/erigon/smt/pkg/smt" "github.com/ledgerwatch/erigon/turbo/rpchelper" - "github.com/ledgerwatch/erigon/turbo/trie" - dstypes "github.com/ledgerwatch/erigon/zk/datastream/types" "github.com/ledgerwatch/erigon/zk/hermez_db" types "github.com/ledgerwatch/erigon/zk/rpcdaemon" - zkStages "github.com/ledgerwatch/erigon/zk/stages" - zkUtils "github.com/ledgerwatch/erigon/zk/utils" + "github.com/ledgerwatch/erigon/zk/witness" "github.com/ledgerwatch/erigon/zkevm/jsonrpc/client" ) @@ -338,8 +325,7 @@ func (api *ZkEvmAPIImpl) getBlockRangeWitness(ctx context.Context, db kv.RoDB, s return nil, fmt.Errorf("not supported by Erigon3") } - blockNr, hash, _, err := rpchelper.GetCanonicalBlockNumber(startBlockNrOrHash, tx, api.ethApi.filters) // DoCall cannot be executed on non-canonical blocks - + blockNr, _, _, err := rpchelper.GetCanonicalBlockNumber(startBlockNrOrHash, tx, api.ethApi.filters) // DoCall cannot be executed on non-canonical blocks if err != nil { return nil, err } @@ -354,247 +340,21 @@ func (api *ZkEvmAPIImpl) getBlockRangeWitness(ctx context.Context, db kv.RoDB, s return nil, fmt.Errorf("start block number must be less than or equal to end block number, start=%d end=%d", blockNr, endBlockNr) } - // Witness for genesis block is empty - if endBlockNr == 0 { - w := trie.NewWitness(make([]trie.WitnessOperator, 0)) - - var buf bytes.Buffer - _, err = w.WriteInto(&buf, debug) - if err != nil { - return nil, err - } - - return buf.Bytes(), nil - } - - block, err := api.ethApi.blockWithSenders(tx, hash, blockNr) - if err != nil { - return nil, err - } - if block == nil { - return nil, nil - } - - latestBlock, err := rpchelper.GetLatestBlockNumber(tx) - if err != nil { - return nil, err - } - - if latestBlock < endBlockNr { - // shouldn't happen, but check anyway - return nil, fmt.Errorf("block number is in the future latest=%d requested=%d", latestBlock, endBlockNr) - } - - batch := memdb.NewMemoryBatch(tx, api.ethApi.dirs.Tmp) - defer batch.Rollback() - - // Hack for now for the new tables not defined in erigon-lib - err = batch.CreateBucket(db2.TableSmt) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(db2.TableAccountValues) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(db2.TableLastRoot) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(db2.TableMetadata) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(db2.TableHashKey) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(hermez_db.TX_PRICE_PERCENTAGE) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(hermez_db.BLOCKBATCHES) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(hermez_db.BLOCK_GLOBAL_EXIT_ROOTS) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(hermez_db.GLOBAL_EXIT_ROOTS_BATCHES) - if err != nil { - return nil, err - } - - err = batch.CreateBucket(hermez_db.STATE_ROOTS) - if err != nil { - return nil, err - } - - if blockNr-1 < latestBlock { - if latestBlock-blockNr > maxGetProofRewindBlockCount { - return nil, fmt.Errorf("requested block is too old, block must be within %d blocks of the head block number (currently %d)", maxGetProofRewindBlockCount, latestBlock) - } - - unwindState := &stagedsync.UnwindState{UnwindPoint: blockNr - 1} - stageState := &stagedsync.StageState{BlockNumber: latestBlock} - - hashStageCfg := stagedsync.StageHashStateCfg(nil, api.ethApi.dirs, api.ethApi.historyV3(batch), api.ethApi._agg) - if err := stagedsync.UnwindHashStateStage(unwindState, stageState, batch, hashStageCfg, ctx); err != nil { - return nil, err - } - - interHashStageCfg := zkStages.StageZkInterHashesCfg(nil, true, true, false, api.ethApi.dirs.Tmp, api.ethApi._blockReader, nil, api.ethApi.historyV3(batch), api.ethApi._agg, nil) - - err = zkStages.UnwindZkIntermediateHashesStage(unwindState, stageState, batch, interHashStageCfg, ctx) - if err != nil { - return nil, err - } - - tx = batch - } - chainConfig, err := api.ethApi.chainConfig(tx) if err != nil { return nil, err } - prevHeader, err := api.ethApi._blockReader.HeaderByNumber(ctx, tx, blockNr-1) - if err != nil { - return nil, err - } - - tds := state.NewTrieDbState(prevHeader.Root, tx, blockNr-1, nil) - - tds.SetResolveReads(true) - - tds.StartNewBuffer() - trieStateWriter := tds.TrieStateWriter() - - getHeader := func(hash libcommon.Hash, number uint64) *eritypes.Header { - h, e := api.ethApi._blockReader.Header(ctx, tx, hash, number) - if e != nil { - log.Error("getHeader error", "number", number, "hash", hash, "err", e) - } - return h - } - - for i := blockNr; i <= endBlockNr; i++ { - curBlockNum := rpc.BlockNumberOrHashWithNumber(rpc.BlockNumber(i)) - blockNr, curHash, _, err := rpchelper.GetCanonicalBlockNumber(curBlockNum, tx, api.ethApi.filters) // DoCall cannot be executed on non-canonical blocks - - if err != nil { - return nil, err - } - - block, err := api.ethApi.blockWithSenders(tx, curHash, i) - - if err != nil { - return nil, err - } - - reader, err := rpchelper.CreateHistoryStateReader(tx, blockNr, 0, false, chainConfig.ChainName) - if err != nil { - return nil, err - } - - tds.SetStateReader(reader) - - gers := []*dstypes.GerUpdate{} - - hermezDb := hermez_db.NewHermezDbReader(tx) - - //[zkevm] get batches between last block and this one - // plus this blocks ger - lastBatchInserted, err := hermezDb.GetBatchNoByL2Block(i - 1) - if err != nil { - return nil, fmt.Errorf("failed to get batch for block %d: %v", i-1, err) - } - - currentBatch, err := hermezDb.GetBatchNoByL2Block(i) - if err != nil { - return nil, fmt.Errorf("failed to get batch for block %d: %v", i, err) - } - - gersInBetween, err := hermezDb.GetBatchGlobalExitRoots(lastBatchInserted, currentBatch) - if err != nil { - return nil, err - } - - if gersInBetween != nil { - gers = append(gers, gersInBetween...) - } - - blockGer, _, err := hermezDb.GetBlockGlobalExitRoot(i) - if err != nil { - return nil, err - } - emptyHash := libcommon.Hash{} - - if blockGer != emptyHash { - blockGerUpdate := dstypes.GerUpdate{ - GlobalExitRoot: blockGer, - Timestamp: block.Header().Time, - } - gers = append(gers, &blockGerUpdate) - } - - for _, ger := range gers { - // [zkevm] - add GER if there is one for this batch - if err := zkUtils.WriteGlobalExitRoot(tds, trieStateWriter, ger.GlobalExitRoot, ger.Timestamp); err != nil { - return nil, err - } - } - - engine, ok := api.ethApi.engine().(consensus.Engine) - - if !ok { - return nil, fmt.Errorf("engine is not consensus.Engine") - } - - vmConfig := vm.Config{} - - getHashFn := core.GetHashFn(block.Header(), getHeader) - - chainReader := stagedsync.NewChainReaderImpl(chainConfig, tx, nil) - - _, err = core.ExecuteBlockEphemerally(chainConfig, &vmConfig, getHashFn, engine, block, tds, trieStateWriter, chainReader, nil, nil, hermezDb) - - if err != nil { - return nil, err - } - } - - rl, err := tds.ResolveSMTRetainList() - - if err != nil { - return nil, err - } - - eridb := db2.NewEriDb(batch) - smtTrie := smt.NewSMT(eridb) - - witness, err := smt.BuildWitness(smtTrie, rl, ctx) - - if err != nil { - return nil, err - } - - var buf bytes.Buffer - _, err = witness.WriteInto(&buf, debug) - if err != nil { - return nil, err - } + generator := witness.NewGenerator( + api.ethApi.dirs, + api.ethApi.historyV3(tx), + api.ethApi._agg, + api.ethApi._blockReader, + chainConfig, + api.ethApi._engine, + ) - return buf.Bytes(), nil + return generator.GenerateWitness(tx, ctx, blockNr, endBlockNr, debug) } func (api *ZkEvmAPIImpl) GetBatchWitness(ctx context.Context, batchNumber uint64) (hexutility.Bytes, error) { diff --git a/zk/witness/witness.go b/zk/witness/witness.go new file mode 100644 index 00000000000..6e9c4b018fa --- /dev/null +++ b/zk/witness/witness.go @@ -0,0 +1,265 @@ +package witness + +import ( + "bytes" + "context" + "errors" + "fmt" + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/common/datadir" + "github.com/ledgerwatch/erigon-lib/kv" + "github.com/ledgerwatch/erigon-lib/kv/memdb" + libstate "github.com/ledgerwatch/erigon-lib/state" + "github.com/ledgerwatch/erigon/chain" + "github.com/ledgerwatch/erigon/consensus" + "github.com/ledgerwatch/erigon/core" + "github.com/ledgerwatch/erigon/core/rawdb" + "github.com/ledgerwatch/erigon/core/state" + "github.com/ledgerwatch/erigon/core/systemcontracts" + eritypes "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/core/vm" + "github.com/ledgerwatch/erigon/eth/stagedsync" + "github.com/ledgerwatch/erigon/eth/stagedsync/stages" + db2 "github.com/ledgerwatch/erigon/smt/pkg/db" + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/turbo/services" + "github.com/ledgerwatch/erigon/turbo/trie" + dstypes "github.com/ledgerwatch/erigon/zk/datastream/types" + "github.com/ledgerwatch/erigon/zk/hermez_db" + zkStages "github.com/ledgerwatch/erigon/zk/stages" + zkUtils "github.com/ledgerwatch/erigon/zk/utils" + "github.com/ledgerwatch/log/v3" +) + +var ( + maxGetProofRewindBlockCount uint64 = 1_000 + + ErrEndBeforeStart = errors.New("end block must be higher than start block") +) + +type Generator struct { + tx kv.Tx + dirs datadir.Dirs + historyV3 bool + agg *libstate.AggregatorV3 + blockReader services.FullBlockReader + chainCfg *chain.Config + engine consensus.EngineReader +} + +func NewGenerator( + dirs datadir.Dirs, + historyV3 bool, + agg *libstate.AggregatorV3, + blockReader services.FullBlockReader, + chainCfg *chain.Config, + engine consensus.EngineReader, +) *Generator { + return &Generator{ + dirs: dirs, + historyV3: historyV3, + agg: agg, + blockReader: blockReader, + chainCfg: chainCfg, + engine: engine, + } +} + +func (g *Generator) GenerateWitness(tx kv.Tx, ctx context.Context, startBlock, endBlock uint64, debug bool) ([]byte, error) { + if startBlock > endBlock { + return nil, ErrEndBeforeStart + } + + if endBlock == 0 { + witness := trie.NewWitness([]trie.WitnessOperator{}) + return getWitnessBytes(witness, debug) + } + + latestBlock, err := stages.GetStageProgress(tx, stages.Execution) + if err != nil { + return nil, err + } + + if latestBlock < endBlock { + return nil, fmt.Errorf("block number is in the future latest=%d requested=%d", latestBlock, endBlock) + } + + batch := memdb.NewMemoryBatch(tx, g.dirs.Tmp) + defer batch.Rollback() + if err = populateDbTables(batch); err != nil { + return nil, err + } + + sBlock, err := rawdb.ReadBlockByNumber(tx, startBlock) + if err != nil { + return nil, err + } + if sBlock == nil { + return nil, nil + } + + if startBlock-1 < latestBlock { + if latestBlock-startBlock > maxGetProofRewindBlockCount { + return nil, fmt.Errorf("requested block is too old, block must be within %d blocks of the head block number (currently %d)", maxGetProofRewindBlockCount, latestBlock) + } + + unwindState := &stagedsync.UnwindState{UnwindPoint: startBlock - 1} + stageState := &stagedsync.StageState{BlockNumber: latestBlock} + + hashStageCfg := stagedsync.StageHashStateCfg(nil, g.dirs, g.historyV3, g.agg) + if err := stagedsync.UnwindHashStateStage(unwindState, stageState, batch, hashStageCfg, ctx); err != nil { + return nil, err + } + + interHashStageCfg := zkStages.StageZkInterHashesCfg(nil, true, true, false, g.dirs.Tmp, g.blockReader, nil, g.historyV3, g.agg, nil) + + err = zkStages.UnwindZkIntermediateHashesStage(unwindState, stageState, batch, interHashStageCfg, ctx) + if err != nil { + return nil, err + } + + tx = batch + } + + prevHeader, err := g.blockReader.HeaderByNumber(ctx, tx, startBlock-1) + if err != nil { + return nil, err + } + + tds := state.NewTrieDbState(prevHeader.Root, tx, startBlock-1, nil) + tds.SetResolveReads(true) + tds.StartNewBuffer() + trieStateWriter := tds.TrieStateWriter() + + getHeader := func(hash libcommon.Hash, number uint64) *eritypes.Header { + h, e := g.blockReader.Header(ctx, tx, hash, number) + if e != nil { + log.Error("getHeader error", "number", number, "hash", hash, "err", e) + } + return h + } + + for blockNum := startBlock; blockNum <= endBlock; blockNum++ { + block, err := rawdb.ReadBlockByNumber(tx, blockNum) + + if err != nil { + return nil, err + } + + reader := state.NewPlainState(tx, blockNum, systemcontracts.SystemContractCodeLookup[g.chainCfg.ChainName]) + + tds.SetStateReader(reader) + + hermezDb := hermez_db.NewHermezDbReader(tx) + + //[zkevm] get batches between last block and this one + // plus this blocks ger + lastBatchInserted, err := hermezDb.GetBatchNoByL2Block(blockNum - 1) + if err != nil { + return nil, fmt.Errorf("failed to get batch for block %d: %v", blockNum-1, err) + } + + currentBatch, err := hermezDb.GetBatchNoByL2Block(blockNum) + if err != nil { + return nil, fmt.Errorf("failed to get batch for block %d: %v", blockNum, err) + } + + gersInBetween, err := hermezDb.GetBatchGlobalExitRoots(lastBatchInserted, currentBatch) + if err != nil { + return nil, err + } + + var globalExitRoots []*dstypes.GerUpdate + + if gersInBetween != nil { + globalExitRoots = append(globalExitRoots, gersInBetween...) + } + + blockGer, _, err := hermezDb.GetBlockGlobalExitRoot(blockNum) + if err != nil { + return nil, err + } + emptyHash := libcommon.Hash{} + + if blockGer != emptyHash { + blockGerUpdate := dstypes.GerUpdate{ + GlobalExitRoot: blockGer, + Timestamp: block.Header().Time, + } + globalExitRoots = append(globalExitRoots, &blockGerUpdate) + } + + for _, ger := range globalExitRoots { + // [zkevm] - add GER if there is one for this batch + if err := zkUtils.WriteGlobalExitRoot(tds, trieStateWriter, ger.GlobalExitRoot, ger.Timestamp); err != nil { + return nil, err + } + } + + engine, ok := g.engine.(consensus.Engine) + + if !ok { + return nil, fmt.Errorf("engine is not consensus.Engine") + } + + vmConfig := vm.Config{} + + getHashFn := core.GetHashFn(block.Header(), getHeader) + + chainReader := stagedsync.NewChainReaderImpl(g.chainCfg, tx, nil) + + _, err = core.ExecuteBlockEphemerally(g.chainCfg, &vmConfig, getHashFn, engine, block, tds, trieStateWriter, chainReader, nil, nil, hermezDb) + + if err != nil { + return nil, err + } + } + + rl, err := tds.ResolveSMTRetainList() + + if err != nil { + return nil, err + } + + eridb := db2.NewEriDb(batch) + smtTrie := smt.NewSMT(eridb) + + witness, err := smt.BuildWitness(smtTrie, rl, ctx) + if err != nil { + return nil, err + } + + return getWitnessBytes(witness, debug) +} + +func getWitnessBytes(witness *trie.Witness, debug bool) ([]byte, error) { + var buf bytes.Buffer + _, err := witness.WriteInto(&buf, debug) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func populateDbTables(batch *memdb.MemoryMutation) error { + tables := []string{ + db2.TableSmt, + db2.TableAccountValues, + db2.TableMetadata, + db2.TableHashKey, + db2.TableLastRoot, + hermez_db.TX_PRICE_PERCENTAGE, + hermez_db.BLOCKBATCHES, + hermez_db.BLOCK_GLOBAL_EXIT_ROOTS, + hermez_db.GLOBAL_EXIT_ROOTS_BATCHES, + hermez_db.STATE_ROOTS, + } + + for _, t := range tables { + if err := batch.CreateBucket(t); err != nil { + return err + } + } + + return nil +}