Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 99 additions & 90 deletions rollup/internal/controller/relayer/l2_relayer_sanity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package relayer

import (
"fmt"
"math/big"

"github.com/scroll-tech/da-codec/encoding"
"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto/kzg4844"

"scroll-tech/rollup/internal/orm"
Expand All @@ -14,36 +16,24 @@ import (
// transaction data (calldata and blobs) by parsing them and comparing against database records.
// This ensures the constructed transaction data is correct and consistent with the database state.
func (r *Layer2Relayer) sanityChecksCommitBatchCodecV7CalldataAndBlobs(calldata []byte, blobs []*kzg4844.Blob) error {
if len(blobs) == 0 {
return fmt.Errorf("no blobs provided")
}

calldataInfo, err := r.parseCommitBatchesCalldata(calldata)
if err != nil {
return fmt.Errorf("failed to parse calldata: %w", err)
}

batchesToValidate, firstBatch, lastBatch, err := r.getBatchesFromCalldata(calldataInfo)
batchesToValidate, err := r.getBatchesFromCalldata(calldataInfo)
if err != nil {
return fmt.Errorf("failed to get batches from database: %w", err)
}

if len(blobs) != len(batchesToValidate) {
return fmt.Errorf("blob count mismatch: got %d blobs, expected %d batches", len(blobs), len(batchesToValidate))
}

if err := r.validateCalldataAgainstDatabase(calldataInfo, firstBatch, lastBatch); err != nil {
return fmt.Errorf("calldata validation failed: %w", err)
if err := r.validateCalldataAndBlobsAgainstDatabase(calldataInfo, blobs, batchesToValidate); err != nil {
return fmt.Errorf("calldata and blobs validation failed: %w", err)
}

if err := r.validateDatabaseConsistency(batchesToValidate); err != nil {
return fmt.Errorf("database consistency validation failed: %w", err)
}

if err := r.validateBlobsAgainstDatabase(blobs, batchesToValidate); err != nil {
return fmt.Errorf("blob validation failed: %w", err)
}

return nil
}

Expand Down Expand Up @@ -91,17 +81,17 @@ func (r *Layer2Relayer) parseCommitBatchesCalldata(calldata []byte) (*CalldataIn
}

// getBatchesFromCalldata retrieves the relevant batches from database based on calldata information
func (r *Layer2Relayer) getBatchesFromCalldata(info *CalldataInfo) ([]*dbBatchWithChunks, *orm.Batch, *orm.Batch, error) {
func (r *Layer2Relayer) getBatchesFromCalldata(info *CalldataInfo) ([]*dbBatchWithChunks, error) {
// Get the parent batch to determine the starting point
parentBatch, err := r.batchOrm.GetBatchByHash(r.ctx, info.ParentBatchHash.Hex())
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get parent batch by hash %s: %w", info.ParentBatchHash.Hex(), err)
return nil, fmt.Errorf("failed to get parent batch by hash %s: %w", info.ParentBatchHash.Hex(), err)
}

// Get the last batch to determine the ending point
lastBatch, err := r.batchOrm.GetBatchByHash(r.ctx, info.LastBatchHash.Hex())
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get last batch by hash %s: %w", info.LastBatchHash.Hex(), err)
return nil, fmt.Errorf("failed to get last batch by hash %s: %w", info.LastBatchHash.Hex(), err)
}

// Get all batches in the range (parent+1 to last)
Expand All @@ -110,20 +100,20 @@ func (r *Layer2Relayer) getBatchesFromCalldata(info *CalldataInfo) ([]*dbBatchWi

// Check if the range is valid
if firstBatchIndex > lastBatchIndex {
return nil, nil, nil, fmt.Errorf("no batches found in range: first index %d, last index %d", firstBatchIndex, lastBatchIndex)
return nil, fmt.Errorf("no batches found in range: first index %d, last index %d", firstBatchIndex, lastBatchIndex)
}

var batchesToValidate []*dbBatchWithChunks
for batchIndex := firstBatchIndex; batchIndex <= lastBatchIndex; batchIndex++ {
dbBatch, err := r.batchOrm.GetBatchByIndex(r.ctx, batchIndex)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get batch by index %d: %w", batchIndex, err)
return nil, fmt.Errorf("failed to get batch by index %d: %w", batchIndex, err)
}

// Get chunks for this batch
dbChunks, err := r.chunkOrm.GetChunksInRange(r.ctx, dbBatch.StartChunkIndex, dbBatch.EndChunkIndex)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get chunks for batch %d: %w", batchIndex, err)
return nil, fmt.Errorf("failed to get chunks for batch %d: %w", batchIndex, err)
}

batchesToValidate = append(batchesToValidate, &dbBatchWithChunks{
Expand All @@ -132,30 +122,7 @@ func (r *Layer2Relayer) getBatchesFromCalldata(info *CalldataInfo) ([]*dbBatchWi
})
}

// Get first batch for return
firstBatch := batchesToValidate[0].Batch

return batchesToValidate, firstBatch, lastBatch, nil
}

// validateCalldataAgainstDatabase validates calldata parameters against database records
func (r *Layer2Relayer) validateCalldataAgainstDatabase(info *CalldataInfo, firstBatch, lastBatch *orm.Batch) error {
// Validate codec version
if info.Version != uint8(firstBatch.CodecVersion) {
return fmt.Errorf("version mismatch: calldata=%d, db=%d", info.Version, firstBatch.CodecVersion)
}

// Validate parent batch hash
if info.ParentBatchHash != common.HexToHash(firstBatch.ParentBatchHash) {
return fmt.Errorf("parentBatchHash mismatch: calldata=%s, db=%s", info.ParentBatchHash.Hex(), firstBatch.ParentBatchHash)
}

// Validate last batch hash
if info.LastBatchHash != common.HexToHash(lastBatch.Hash) {
return fmt.Errorf("lastBatchHash mismatch: calldata=%s, db=%s", info.LastBatchHash.Hex(), lastBatch.Hash)
}

return nil
return batchesToValidate, nil
}

// validateDatabaseConsistency performs comprehensive validation of database records
Expand Down Expand Up @@ -328,10 +295,38 @@ func (r *Layer2Relayer) validateSingleChunkConsistency(chunk *orm.Chunk, prevChu
return nil
}

// validateBlobsAgainstDatabase validates blobs against database records
func (r *Layer2Relayer) validateBlobsAgainstDatabase(blobs []*kzg4844.Blob, batchesToValidate []*dbBatchWithChunks) error {
// Get codec for blob decoding
// validateCalldataAndBlobsAgainstDatabase validates calldata and blobs against database records
func (r *Layer2Relayer) validateCalldataAndBlobsAgainstDatabase(calldataInfo *CalldataInfo, blobs []*kzg4844.Blob, batchesToValidate []*dbBatchWithChunks) error {
// Validate blobs
if len(blobs) == 0 {
return fmt.Errorf("no blobs provided")
}

// Validate blob count
if len(blobs) != len(batchesToValidate) {
return fmt.Errorf("blob count mismatch: got %d blobs, expected %d batches", len(blobs), len(batchesToValidate))
}

// Get first and last batches for validation, length check is already done above
firstBatch := batchesToValidate[0].Batch
lastBatch := batchesToValidate[len(batchesToValidate)-1].Batch

// Validate codec version
if calldataInfo.Version != uint8(firstBatch.CodecVersion) {
return fmt.Errorf("version mismatch: calldata=%d, db=%d", calldataInfo.Version, firstBatch.CodecVersion)
}

// Validate parent batch hash
if calldataInfo.ParentBatchHash != common.HexToHash(firstBatch.ParentBatchHash) {
return fmt.Errorf("parentBatchHash mismatch: calldata=%s, db=%s", calldataInfo.ParentBatchHash.Hex(), firstBatch.ParentBatchHash)
}

// Validate last batch hash
if calldataInfo.LastBatchHash != common.HexToHash(lastBatch.Hash) {
return fmt.Errorf("lastBatchHash mismatch: calldata=%s, db=%s", calldataInfo.LastBatchHash.Hex(), lastBatch.Hash)
}

// Get codec for blob decoding
codec, err := encoding.CodecFromVersion(encoding.CodecVersion(firstBatch.CodecVersion))
if err != nil {
return fmt.Errorf("failed to get codec: %w", err)
Expand All @@ -340,9 +335,7 @@ func (r *Layer2Relayer) validateBlobsAgainstDatabase(blobs []*kzg4844.Blob, batc
// Validate each blob against its corresponding batch
for i, blob := range blobs {
dbBatch := batchesToValidate[i].Batch
dbChunks := batchesToValidate[i].Chunks

if err := r.validateSingleBlobAgainstBatch(blob, dbBatch, dbChunks, codec); err != nil {
if err := r.validateSingleBlobAgainstBatch(calldataInfo, blob, dbBatch, codec); err != nil {
return fmt.Errorf("blob validation failed for batch %d: %w", dbBatch.Index, err)
}
}
Expand All @@ -351,53 +344,21 @@ func (r *Layer2Relayer) validateBlobsAgainstDatabase(blobs []*kzg4844.Blob, batc
}

// validateSingleBlobAgainstBatch validates a single blob against its batch data
func (r *Layer2Relayer) validateSingleBlobAgainstBatch(blob *kzg4844.Blob, dbBatch *orm.Batch, dbChunks []*orm.Chunk, codec encoding.Codec) error {
// Collect all blocks for the batch
var batchBlocks []*encoding.Block
for _, c := range dbChunks {
blocks, err := r.l2BlockOrm.GetL2BlocksInRange(r.ctx, c.StartBlockNumber, c.EndBlockNumber)
if err != nil {
return fmt.Errorf("failed to get blocks for chunk %d: %w", c.Index, err)
}

if len(blocks) == 0 {
return fmt.Errorf("chunk %d has no blocks in range [%d, %d]", c.Index, c.StartBlockNumber, c.EndBlockNumber)
}

// Verify block count matches expected range
expectedBlockCount := c.EndBlockNumber - c.StartBlockNumber + 1
if len(blocks) != int(expectedBlockCount) {
return fmt.Errorf("chunk %d expected %d blocks but got %d", c.Index, expectedBlockCount, len(blocks))
}

batchBlocks = append(batchBlocks, blocks...)
}

func (r *Layer2Relayer) validateSingleBlobAgainstBatch(calldataInfo *CalldataInfo, blob *kzg4844.Blob, dbBatch *orm.Batch, codec encoding.Codec) error {
// Decode blob payload
payload, err := codec.DecodeBlob(blob)
if err != nil {
return fmt.Errorf("failed to decode blob: %w", err)
}

// Validate L1 message queue hashes
if payload.PrevL1MessageQueueHash() != common.HexToHash(dbBatch.PrevL1MessageQueueHash) {
return fmt.Errorf("prevL1MessageQueueHash mismatch: decoded=%s, db=%s", payload.PrevL1MessageQueueHash().Hex(), dbBatch.PrevL1MessageQueueHash)
}

if payload.PostL1MessageQueueHash() != common.HexToHash(dbBatch.PostL1MessageQueueHash) {
return fmt.Errorf("postL1MessageQueueHash mismatch: decoded=%s, db=%s", payload.PostL1MessageQueueHash().Hex(), dbBatch.PostL1MessageQueueHash)
}

// Validate block data
decodedBlocks := payload.Blocks()
if len(decodedBlocks) != len(batchBlocks) {
return fmt.Errorf("block count mismatch: decoded=%d, db=%d", len(decodedBlocks), len(batchBlocks))
// Validate batch hash
daBatch, err := assembleDABatchFromPayload(calldataInfo, payload, dbBatch, codec)
if err != nil {
return fmt.Errorf("failed to assemble batch from payload: %w", err)
}

for j, dbBlock := range batchBlocks {
if decodedBlocks[j].Number() != dbBlock.Header.Number.Uint64() {
return fmt.Errorf("block number mismatch at index %d: decoded=%d, db=%d", j, decodedBlocks[j].Number(), dbBlock.Header.Number.Uint64())
}
if daBatch.Hash() != common.HexToHash(dbBatch.Hash) {
return fmt.Errorf("batch hash mismatch: decoded from blob=%s, db=%s", daBatch.Hash().Hex(), dbBatch.Hash)
}

return nil
Expand Down Expand Up @@ -436,3 +397,51 @@ func (r *Layer2Relayer) validateMessageQueueConsistency(batchIndex uint64, chunk

return nil
}

func assembleDABatchFromPayload(calldataInfo *CalldataInfo, payload encoding.DABlobPayload, dbBatch *orm.Batch, codec encoding.Codec) (encoding.DABatch, error) {
blocks, err := assembleBlocksFromPayload(payload)
if err != nil {
return nil, fmt.Errorf("failed to assemble blocks from payload batch_index=%d codec_version=%d parent_batch_hash=%s: %w", dbBatch.Index, dbBatch.CodecVersion, calldataInfo.ParentBatchHash.Hex(), err)
}
parentBatchHash := calldataInfo.ParentBatchHash
batch := &encoding.Batch{
Index: dbBatch.Index, // The database provides only batch index, other fields are derived from blob payload
ParentBatchHash: parentBatchHash,
PrevL1MessageQueueHash: payload.PrevL1MessageQueueHash(),
PostL1MessageQueueHash: payload.PostL1MessageQueueHash(),
Blocks: blocks,
Chunks: []*encoding.Chunk{ // One chunk for this batch to pass sanity checks when building DABatch
{
Blocks: blocks,
PrevL1MessageQueueHash: payload.PrevL1MessageQueueHash(),
PostL1MessageQueueHash: payload.PostL1MessageQueueHash(),
},
},
}
daBatch, err := codec.NewDABatch(batch)
if err != nil {
return nil, fmt.Errorf("failed to build DABatch batch_index=%d codec_version=%d parent_batch_hash=%s: %w", dbBatch.Index, dbBatch.CodecVersion, calldataInfo.ParentBatchHash.Hex(), err)
}
return daBatch, nil
}

func assembleBlocksFromPayload(payload encoding.DABlobPayload) ([]*encoding.Block, error) {
daBlocks := payload.Blocks()
txss := payload.Transactions()
if len(daBlocks) != len(txss) {
return nil, fmt.Errorf("mismatched number of blocks and transactions: %d blocks, %d transactions", len(daBlocks), len(txss))
}
blocks := make([]*encoding.Block, len(daBlocks))
for i := range daBlocks {
blocks[i] = &encoding.Block{
Header: &types.Header{
Number: new(big.Int).SetUint64(daBlocks[i].Number()),
Time: daBlocks[i].Timestamp(),
BaseFee: daBlocks[i].BaseFee(),
GasLimit: daBlocks[i].GasLimit(),
},
Transactions: encoding.TxsToTxsData(txss[i]),
}
}
return blocks, nil
}