diff --git a/devnet-sdk/system/interfaces.go b/devnet-sdk/system/interfaces.go index 84e328f3940ac..61a6d12234985 100644 --- a/devnet-sdk/system/interfaces.go +++ b/devnet-sdk/system/interfaces.go @@ -135,7 +135,6 @@ type InteropSet interface { // Supervisor provides access to the query interface of the supervisor type Supervisor interface { - CheckMessage(context.Context, supervisorTypes.Identifier, common.Hash, supervisorTypes.ExecutingDescriptor) (supervisorTypes.SafetyLevel, error) LocalUnsafe(context.Context, eth.ChainID) (eth.BlockID, error) CrossSafe(context.Context, eth.ChainID) (supervisorTypes.DerivedIDPair, error) Finalized(context.Context, eth.ChainID) (eth.BlockID, error) diff --git a/devnet-sdk/testing/systest/testing_test.go b/devnet-sdk/testing/systest/testing_test.go index be21f3c57e954..86e8418a15366 100644 --- a/devnet-sdk/testing/systest/testing_test.go +++ b/devnet-sdk/testing/systest/testing_test.go @@ -128,10 +128,6 @@ func (m *mockInteropSet) L2s() []system.Chain { return []system.Chain{&mockChain // mockSupervisor implements the system.Supervisor interface for testing type mockSupervisor struct{} -func (m *mockSupervisor) CheckMessage(ctx context.Context, id supervisorTypes.Identifier, hash common.Hash, desc supervisorTypes.ExecutingDescriptor) (supervisorTypes.SafetyLevel, error) { - return supervisorTypes.Invalid, nil -} - func (m *mockSupervisor) LocalUnsafe(ctx context.Context, chainID eth.ChainID) (eth.BlockID, error) { return eth.BlockID{}, nil } diff --git a/op-e2e/interop/interop_test.go b/op-e2e/interop/interop_test.go index 7fd364805399e..5a670fcedae03 100644 --- a/op-e2e/interop/interop_test.go +++ b/op-e2e/interop/interop_test.go @@ -197,8 +197,8 @@ func TestInterop_EmitLogs(t *testing.T) { supervisor := s2.SupervisorClient() - // helper function to turn a log into an identifier and the expected hash of the payload - logToIdentifier := func(chainID string, log gethTypes.Log) (types.Identifier, common.Hash) { + // helper function to turn a log into an access-list object + logToAccess := func(chainID string, log gethTypes.Log) types.Access { client := s2.L2GethClient(chainID, "sequencer") // construct the expected hash of the log's payload // (topics concatenated with data) @@ -207,7 +207,7 @@ func TestInterop_EmitLogs(t *testing.T) { msgPayload = append(msgPayload, topic.Bytes()...) } msgPayload = append(msgPayload, log.Data...) - expectedHash := common.BytesToHash(crypto.Keccak256(msgPayload)) + msgHash := crypto.Keccak256Hash(msgPayload) // get block for the log (for timestamp) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -215,43 +215,36 @@ func TestInterop_EmitLogs(t *testing.T) { block, err := client.BlockByHash(ctx, log.BlockHash) require.NoError(t, err) - // make an identifier out of the sample log - identifier := types.Identifier{ - Origin: log.Address, + args := types.ChecksumArgs{ BlockNumber: log.BlockNumber, - LogIndex: uint32(log.Index), Timestamp: block.Time(), + LogIndex: uint32(log.Index), ChainID: eth.ChainIDFromBig(s2.ChainID(chainID)), + LogHash: types.PayloadHashToLogHash(msgHash, log.Address), } - return identifier, expectedHash + return args.Access() } - // all logs should be cross-safe - for _, log := range logsA { - identifier, expectedHash := logToIdentifier(chainA, log) - safety, err := supervisor.CheckMessage(context.Background(), identifier, expectedHash, types.ExecutingDescriptor{Timestamp: identifier.Timestamp}) - require.NoError(t, err) - // the supervisor could progress the safety level more quickly than we expect, - // which is why we check for a minimum safety level - require.True(t, safety.AtLeastAsSafe(types.CrossSafe), "log: %v should be at least Cross-Safe, but is %s", log, safety.String()) + var accessEntries []types.Access + for _, evLog := range logsA { + accessEntries = append(accessEntries, logToAccess(chainA, evLog)) } - for _, log := range logsB { - identifier, expectedHash := logToIdentifier(chainB, log) - safety, err := supervisor.CheckMessage(context.Background(), identifier, expectedHash, types.ExecutingDescriptor{Timestamp: identifier.Timestamp}) - require.NoError(t, err) - // the supervisor could progress the safety level more quickly than we expect, - // which is why we check for a minimum safety level - require.True(t, safety.AtLeastAsSafe(types.CrossSafe), "log: %v should be at least Cross-Safe, but is %s", log, safety.String()) + for _, evLog := range logsB { + accessEntries = append(accessEntries, logToAccess(chainB, evLog)) } + accessList := types.EncodeAccessList(accessEntries) - // a log should be invalid if the timestamp is incorrect - identifier, expectedHash := logToIdentifier(chainA, logsA[0]) - // make the timestamp incorrect - identifier.Timestamp = 333 - safety, err := supervisor.CheckMessage(context.Background(), identifier, expectedHash, types.ExecutingDescriptor{Timestamp: 333}) - require.NoError(t, err) - require.Equal(t, types.Invalid, safety) + timestamp := uint64(time.Now().Unix()) + ed := types.ExecutingDescriptor{Timestamp: timestamp} + ctx = context.Background() + err = supervisor.CheckAccessList(ctx, accessList, types.CrossSafe, ed) + require.NoError(t, err, "logsA must all be cross-safe") + // a log should be invalid if the timestamp is incorrect + accessEntries[0].Timestamp = 333 + accessList = types.EncodeAccessList(accessEntries) + err = supervisor.CheckAccessList(ctx, accessList, types.CrossSafe, ed) + require.ErrorContains(t, err, "conflict") } config := SuperSystemConfig{ mempoolFiltering: false, @@ -265,6 +258,10 @@ func TestInteropBlockBuilding(t *testing.T) { logger := testlog.Logger(t, log.LevelInfo) oplog.SetGlobalLogHandler(logger.Handler()) + // TODO(14697): re-enable once op-geth block-building uses access-lists. + // When re-enabling, the txs that execute other messages will need access-lists. + t.Skip("blocked by issue #14697") + test := func(t *testing.T, s2 SuperSystem) { ids := s2.L2IDs() chainA := ids[0] diff --git a/op-program/client/interop/consolidate.go b/op-program/client/interop/consolidate.go index e946aae7b37c4..d45628a596f37 100644 --- a/op-program/client/interop/consolidate.go +++ b/op-program/client/interop/consolidate.go @@ -275,9 +275,15 @@ func (d *consolidateCheckDeps) Contains(chain eth.ChainID, query supervisortypes for _, receipt := range receipts { for i, log := range receipt.Logs { if current+uint32(i) == query.LogIdx { - msgHash := logToMessageHash(log) - if msgHash != query.LogHash { - return supervisortypes.BlockSeal{}, fmt.Errorf("payload hash mismatch: %s != %s: %w", msgHash, query.LogHash, supervisortypes.ErrConflict) + checksum := supervisortypes.ChecksumArgs{ + BlockNumber: query.BlockNum, + LogIndex: query.LogIdx, + Timestamp: query.Timestamp, + ChainID: chain, + LogHash: logToLogHash(log), + }.Checksum() + if checksum != query.Checksum { + return supervisortypes.BlockSeal{}, fmt.Errorf("checksum mismatch: %s != %s: %w", checksum, query.Checksum, supervisortypes.ErrConflict) } else if block.Time() != query.Timestamp { return supervisortypes.BlockSeal{}, fmt.Errorf("block timestamp mismatch: %d != %d: %w", block.Time(), query.Timestamp, supervisortypes.ErrConflict) } else { @@ -294,7 +300,7 @@ func (d *consolidateCheckDeps) Contains(chain eth.ChainID, query supervisortypes return supervisortypes.BlockSeal{}, fmt.Errorf("log not found") } -func logToMessageHash(l *ethtypes.Log) common.Hash { +func logToLogHash(l *ethtypes.Log) common.Hash { payloadHash := crypto.Keccak256Hash(supervisortypes.LogToMessagePayload(l)) return supervisortypes.PayloadHashToLogHash(payloadHash, l.Address) } diff --git a/op-service/sources/supervisor_client.go b/op-service/sources/supervisor_client.go index 78d60814a8a8b..5641fca3f6264 100644 --- a/op-service/sources/supervisor_client.go +++ b/op-service/sources/supervisor_client.go @@ -20,9 +20,8 @@ type SupervisorAdminAPI interface { } type SupervisorQueryAPI interface { - CheckMessage(ctx context.Context, identifier types.Identifier, payloadHash common.Hash, executingDescriptor types.ExecutingDescriptor) (types.SafetyLevel, error) - CheckMessages(ctx context.Context, messages []types.Message, minSafety types.SafetyLevel) error - CheckMessagesV2(ctx context.Context, messages []types.Message, minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error + CheckAccessList(ctx context.Context, inboxEntries []common.Hash, + minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error CrossDerivedToSource(ctx context.Context, chainID eth.ChainID, derived eth.BlockID) (derivedFrom eth.BlockRef, err error) LocalUnsafe(ctx context.Context, chainID eth.ChainID) (eth.BlockID, error) CrossSafe(ctx context.Context, chainID eth.ChainID) (types.DerivedIDPair, error) @@ -74,29 +73,9 @@ func (cl *SupervisorClient) AddL2RPC(ctx context.Context, rpc string, auth eth.B return result } -func (cl *SupervisorClient) CheckMessage(ctx context.Context, identifier types.Identifier, logHash common.Hash, - executingDescriptor types.ExecutingDescriptor) (types.SafetyLevel, error) { - - var result types.SafetyLevel - err := cl.client.CallContext(ctx, &result, "supervisor_checkMessage", identifier, logHash, executingDescriptor) - if err != nil { - return types.Invalid, fmt.Errorf("failed to check message (chain %s), (block %v), (index %v), (logHash %s), (executingTimestamp %v): %w", - identifier.ChainID, - identifier.BlockNumber, - identifier.LogIndex, - logHash, - executingDescriptor.Timestamp, - err) - } - return result, nil -} - -func (cl *SupervisorClient) CheckMessages(ctx context.Context, messages []types.Message, minSafety types.SafetyLevel) error { - return cl.client.CallContext(ctx, nil, "supervisor_checkMessages", messages, minSafety) -} - -func (cl *SupervisorClient) CheckMessagesV2(ctx context.Context, messages []types.Message, minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error { - return cl.client.CallContext(ctx, nil, "supervisor_checkMessagesV2", messages, minSafety, executingDescriptor) +func (cl *SupervisorClient) CheckAccessList(ctx context.Context, inboxEntries []common.Hash, + minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error { + return cl.client.CallContext(ctx, nil, "supervisor_checkAccessList", inboxEntries, minSafety, executingDescriptor) } func (cl *SupervisorClient) CrossDerivedToSource(ctx context.Context, chainID eth.ChainID, derived eth.BlockID) (derivedFrom eth.BlockRef, err error) { diff --git a/op-supervisor/supervisor/backend/backend.go b/op-supervisor/supervisor/backend/backend.go index 65a99451eb9b2..8438a0a3a94c4 100644 --- a/op-supervisor/supervisor/backend/backend.go +++ b/op-supervisor/supervisor/backend/backend.go @@ -423,94 +423,82 @@ func (su *SupervisorBackend) DependencySet() depset.DependencySet { // Query methods // ---------------------------- -func (su *SupervisorBackend) CheckMessage(ctx context.Context, identifier types.Identifier, payloadHash common.Hash, executingDescriptor types.ExecutingDescriptor) (types.SafetyLevel, error) { - logHash := types.PayloadHashToLogHash(payloadHash, identifier.Origin) - chainID := identifier.ChainID - blockNum := identifier.BlockNumber - logIdx := identifier.LogIndex - _, err := su.chainDBs.Contains(chainID, - types.ContainsQuery{ - BlockNum: blockNum, - Timestamp: identifier.Timestamp, - LogIdx: logIdx, - LogHash: logHash, - }) - if errors.Is(err, types.ErrFuture) { - su.logger.Debug("Future message", "identifier", identifier, "payloadHash", payloadHash, "err", err) - return types.LocalUnsafe, nil - } - if errors.Is(err, types.ErrConflict) { - su.logger.Debug("Conflicting message", "identifier", identifier, "payloadHash", payloadHash, "err", err) - return types.Invalid, nil +// checkAccess checks message timestamp invariants and inclusion in the chain. +// If the initiating message exists, the block it is included in is returned. +func (su *SupervisorBackend) checkAccess(acc types.Access, execAt types.ExecutingDescriptor) (eth.BlockID, error) { + // Check if message passes time checks + if err := execAt.AccessCheck(su.depSet.MessageExpiryWindow(), acc.Timestamp); err != nil { + return eth.BlockID{}, err } + + // Check if message exists + bl, err := su.chainDBs.Contains(acc.ChainID, types.ContainsQuery{ + Timestamp: acc.Timestamp, + BlockNum: acc.BlockNumber, + LogIdx: acc.LogIndex, + Checksum: acc.Checksum, + }) if err != nil { - return types.Invalid, fmt.Errorf("failed to check log: %w", err) - } - if identifier.Timestamp+su.depSet.MessageExpiryWindow() < executingDescriptor.Timestamp { - su.logger.Debug("Message expired", "identifier", identifier, "payloadHash", payloadHash, "executingTimestamp", executingDescriptor.Timestamp) - return types.Invalid, nil + return eth.BlockID{}, err } - if identifier.Timestamp > executingDescriptor.Timestamp { - su.logger.Debug("Message timestamp is in the future", "identifier", identifier, "payloadHash", payloadHash, "executingTimestamp", executingDescriptor.Timestamp) - return types.Invalid, nil + return bl.ID(), nil +} + +// checkSafety is a helper method to check if a block has the given safety level. +// It is already assumed to exist in the canonical unsafe chain. +func (su *SupervisorBackend) checkSafety(chainID eth.ChainID, blockID eth.BlockID, safetyLevel types.SafetyLevel) error { + switch safetyLevel { + case types.LocalUnsafe: + return nil // msg exists, nothing more to check + case types.CrossUnsafe: + return su.chainDBs.IsCrossUnsafe(chainID, blockID) + case types.LocalSafe: + return su.chainDBs.IsLocalSafe(chainID, blockID) + case types.CrossSafe: + return su.chainDBs.IsCrossSafe(chainID, blockID) + case types.Finalized: + return su.chainDBs.IsFinalized(chainID, blockID) + default: + return types.ErrConflict } - return su.chainDBs.Safest(chainID, blockNum, logIdx) } -func (su *SupervisorBackend) CheckMessagesV2( - ctx context.Context, - messages []types.Message, - minSafety types.SafetyLevel, - executingDescriptor types.ExecutingDescriptor) error { - su.logger.Debug("Checking messages", "count", len(messages), "minSafety", minSafety, "executingTimestamp", executingDescriptor.Timestamp) +func (su *SupervisorBackend) CheckAccessList(ctx context.Context, inboxEntries []common.Hash, + minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error { + switch minSafety { + case types.LocalUnsafe, types.CrossUnsafe, types.LocalSafe, types.CrossSafe, types.Finalized: + // valid safety level + default: + return errors.New("unexpected min-safety level") + } - for _, msg := range messages { - su.logger.Debug("Checking message", - "identifier", msg.Identifier, "payloadHash", msg.PayloadHash.String(), "executingTimestamp", executingDescriptor.Timestamp) - safety, err := su.CheckMessage(ctx, msg.Identifier, msg.PayloadHash, executingDescriptor) - if err != nil { - su.logger.Error("Check message failed", "err", err, - "identifier", msg.Identifier, "payloadHash", msg.PayloadHash.String(), "executingTimestamp", executingDescriptor.Timestamp) - return fmt.Errorf("failed to check message: %w", err) + su.logger.Debug("Checking access-list", + "minSafety", minSafety, "length", len(inboxEntries)) + + // TODO(#14800): acquire a rewind-read-lock, so we can ensure the safety of all entries is consistent + + entries := inboxEntries + for len(entries) > 0 { + if err := ctx.Err(); err != nil { + return fmt.Errorf("stopped acces-list check early: %w", err) } - if !safety.AtLeastAsSafe(minSafety) { - su.logger.Error("Message is not sufficiently safe", - "safety", safety, "minSafety", minSafety, - "identifier", msg.Identifier, "payloadHash", msg.PayloadHash.String(), "executingTimestamp", executingDescriptor.Timestamp) - return fmt.Errorf("message %v (safety level: %v) does not meet the minimum safety %v", - msg.Identifier, - safety, - minSafety) + remaining, acc, err := types.ParseAccess(entries) + if err != nil { + return fmt.Errorf("failed to read data: %w", err) } - } - return nil -} + entries = remaining -func (su *SupervisorBackend) CheckMessages( - ctx context.Context, - messages []types.Message, - minSafety types.SafetyLevel) error { - su.logger.Debug("Checking messages", "count", len(messages), "minSafety", minSafety) - - for _, msg := range messages { - su.logger.Debug("Checking message", - "identifier", msg.Identifier, "payloadHash", msg.PayloadHash.String()) - // Guarantee message expiry checks do not fail by setting the executing timestamp to the message timestamp - // This is intentionally done to avoid breaking checkMessagesV1 which doesn't handle message expiry checks - safety, err := su.CheckMessage(ctx, msg.Identifier, msg.PayloadHash, types.ExecutingDescriptor{Timestamp: msg.Identifier.Timestamp}) + msgBlock, err := su.checkAccess(acc, executingDescriptor) if err != nil { - su.logger.Error("Check message failed", "err", err, - "identifier", msg.Identifier, "payloadHash", msg.PayloadHash.String()) - return fmt.Errorf("failed to check message: %w", err) + su.logger.Debug("Access-list inclusion check failed", "err", err) + return types.ErrConflict } - if !safety.AtLeastAsSafe(minSafety) { - su.logger.Error("Message is not sufficiently safe", - "safety", safety, "minSafety", minSafety, - "identifier", msg.Identifier, "payloadHash", msg.PayloadHash.String()) - return fmt.Errorf("message %v (safety level: %v) does not meet the minimum safety %v", - msg.Identifier, - safety, - minSafety) + // TODO(#14800) add msgBlock to rewind lock + + // TODO(#14800): this can be deferred to only check the latest block of all access entries + if err := su.checkSafety(acc.ChainID, msgBlock, minSafety); err != nil { + su.logger.Debug("Access-list safety check failed", "err", err) + return types.ErrConflict } } return nil diff --git a/op-supervisor/supervisor/backend/cross/hazard_set.go b/op-supervisor/supervisor/backend/cross/hazard_set.go index c13e99b2a4231..4bc74f8c38dd9 100644 --- a/op-supervisor/supervisor/backend/cross/hazard_set.go +++ b/op-supervisor/supervisor/backend/cross/hazard_set.go @@ -147,13 +147,14 @@ func (h *HazardSet) build(deps HazardDeps, logger log.Logger, chainID eth.ChainI if err := h.checkChainCanInitiate(depSet, srcChainID, candidate, msg); err != nil { return err } - includedIn, err := deps.Contains(srcChainID, - types.ContainsQuery{ - Timestamp: msg.Timestamp, - BlockNum: msg.BlockNum, - LogIdx: msg.LogIdx, - LogHash: msg.Hash, - }) + q := types.ChecksumArgs{ + BlockNumber: msg.BlockNum, + LogIndex: msg.LogIdx, + Timestamp: msg.Timestamp, + ChainID: srcChainID, + LogHash: msg.Hash, + }.Query() + includedIn, err := deps.Contains(srcChainID, q) if err != nil { return fmt.Errorf("executing msg %s failed inclusion check: %w", msg, err) } diff --git a/op-supervisor/supervisor/backend/cross/safe_update_test.go b/op-supervisor/supervisor/backend/cross/safe_update_test.go index af783e52c67f8..891610fe297c9 100644 --- a/op-supervisor/supervisor/backend/cross/safe_update_test.go +++ b/op-supervisor/supervisor/backend/cross/safe_update_test.go @@ -8,7 +8,6 @@ import ( "github.com/ethereum-optimism/optimism/op-service/testlog" "github.com/ethereum-optimism/optimism/op-supervisor/supervisor/backend/depset" "github.com/ethereum-optimism/optimism/op-supervisor/supervisor/types" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/stretchr/testify/require" ) @@ -31,7 +30,7 @@ func TestCrossSafeUpdate(t *testing.T) { csd.openBlockFn = func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) { return opened, 10, execs, nil } - csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } csd.deps = mockDependencySet{} @@ -126,7 +125,7 @@ func TestCrossSafeUpdate(t *testing.T) { csd.openBlockFn = func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) { return opened, 10, execs, nil } - csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } invalidated := false @@ -362,7 +361,7 @@ func TestScopedCrossSafeUpdate(t *testing.T) { csd.openBlockFn = func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) { return opened, 10, execs, nil } - csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } count := 0 @@ -401,7 +400,7 @@ func TestScopedCrossSafeUpdate(t *testing.T) { csd.openBlockFn = func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) { return opened, 3, map[uint32]*types.ExecutingMessage{1: em1, 2: em2}, nil } - csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } csd.deps = mockDependencySet{} @@ -429,7 +428,7 @@ func TestScopedCrossSafeUpdate(t *testing.T) { csd.openBlockFn = func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) { return opened, 10, execs, nil } - csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } csd.deps = mockDependencySet{} @@ -497,7 +496,7 @@ func TestScopedCrossSafeUpdate(t *testing.T) { // when no errors occur, the update is carried out // the used candidate and scope are from CandidateCrossSafe // the candidateScope is returned - csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + csd.checkFn = func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } pair, err := scopedCrossSafeUpdate(logger, chainID, csd) @@ -517,7 +516,7 @@ type mockCrossSafeDeps struct { updateCrossSafeFn func(chain eth.ChainID, l1View eth.BlockRef, lastCrossDerived eth.BlockRef) error nextSourceFn func(chain eth.ChainID, source eth.BlockID) (after eth.BlockRef, err error) previousDerivedFn func(chain eth.ChainID, derived eth.BlockID) (prevDerived types.BlockSeal, err error) - checkFn func(chainID eth.ChainID, blockNum uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) + checkFn func(chainID eth.ChainID, blockNum uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) invalidateLocalSafeFn func(chainID eth.ChainID, candidate types.DerivedBlockRefPair) error } @@ -547,7 +546,7 @@ func (m *mockCrossSafeDeps) CrossDerivedToSource(chainID eth.ChainID, derived et func (m *mockCrossSafeDeps) Contains(chainID eth.ChainID, q types.ContainsQuery) (types.BlockSeal, error) { if m.checkFn != nil { - return m.checkFn(chainID, q.BlockNum, q.LogIdx, q.LogHash) + return m.checkFn(chainID, q.BlockNum, q.LogIdx, q.Checksum) } return types.BlockSeal{}, nil } diff --git a/op-supervisor/supervisor/backend/cross/unsafe_update_test.go b/op-supervisor/supervisor/backend/cross/unsafe_update_test.go index 3f291a27cd651..c1f6fba990496 100644 --- a/op-supervisor/supervisor/backend/cross/unsafe_update_test.go +++ b/op-supervisor/supervisor/backend/cross/unsafe_update_test.go @@ -138,7 +138,7 @@ func TestCrossUnsafeUpdate(t *testing.T) { usd.openBlockFn = func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) { return bl, 3, map[uint32]*types.ExecutingMessage{1: em1, 2: em2}, nil } - usd.checkFn = func(chainID eth.ChainID, blockNum uint64, timestamp uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + usd.checkFn = func(chainID eth.ChainID, blockNum uint64, timestamp uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return types.BlockSeal{Number: 1, Timestamp: 1}, nil } usd.deps = mockDependencySet{} @@ -167,7 +167,7 @@ func TestCrossUnsafeUpdate(t *testing.T) { Hash: crossUnsafe.Hash, }, 0, nil, nil } - usd.checkFn = func(chainID eth.ChainID, blockNum uint64, timestamp uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) { + usd.checkFn = func(chainID eth.ChainID, blockNum uint64, timestamp uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) { return crossUnsafe, nil } usd.deps = mockDependencySet{} @@ -193,7 +193,7 @@ type mockCrossUnsafeDeps struct { crossUnsafeFn func(chainID eth.ChainID) (types.BlockSeal, error) openBlockFn func(chainID eth.ChainID, blockNum uint64) (ref eth.BlockRef, logCount uint32, execMsgs map[uint32]*types.ExecutingMessage, err error) updateCrossUnsafeFn func(chain eth.ChainID, crossUnsafe types.BlockSeal) error - checkFn func(chainID eth.ChainID, blockNum uint64, timestamp uint64, logIdx uint32, logHash common.Hash) (types.BlockSeal, error) + checkFn func(chainID eth.ChainID, blockNum uint64, timestamp uint64, logIdx uint32, checksum types.MessageChecksum) (types.BlockSeal, error) } func (m *mockCrossUnsafeDeps) CrossUnsafe(chainID eth.ChainID) (derived types.BlockSeal, err error) { @@ -213,7 +213,7 @@ func (m *mockCrossUnsafeDeps) MessageExpiryWindow() uint64 { func (m *mockCrossUnsafeDeps) Contains(chainID eth.ChainID, q types.ContainsQuery) (types.BlockSeal, error) { if m.checkFn != nil { - return m.checkFn(chainID, q.BlockNum, q.Timestamp, q.LogIdx, q.LogHash) + return m.checkFn(chainID, q.BlockNum, q.Timestamp, q.LogIdx, q.Checksum) } return types.BlockSeal{}, nil } diff --git a/op-supervisor/supervisor/backend/db/logs/db.go b/op-supervisor/supervisor/backend/db/logs/db.go index 51bd8cd330a29..55f61fe1d8ec6 100644 --- a/op-supervisor/supervisor/backend/db/logs/db.go +++ b/op-supervisor/supervisor/backend/db/logs/db.go @@ -38,22 +38,25 @@ type DB struct { store entrydb.EntryStore[EntryType, Entry] rwLock sync.RWMutex + chainID eth.ChainID + lastEntryContext logContext } -func NewFromFile(logger log.Logger, m Metrics, path string, trimToLastSealed bool) (*DB, error) { +func NewFromFile(logger log.Logger, m Metrics, chainID eth.ChainID, path string, trimToLastSealed bool) (*DB, error) { store, err := entrydb.NewEntryDB[EntryType, Entry, EntryBinary](logger, path) if err != nil { return nil, fmt.Errorf("failed to open DB: %w", err) } - return NewFromEntryStore(logger, m, store, trimToLastSealed) + return NewFromEntryStore(logger, m, chainID, store, trimToLastSealed) } -func NewFromEntryStore(logger log.Logger, m Metrics, store entrydb.EntryStore[EntryType, Entry], trimToLastSealed bool) (*DB, error) { +func NewFromEntryStore(logger log.Logger, m Metrics, chainID eth.ChainID, store entrydb.EntryStore[EntryType, Entry], trimToLastSealed bool) (*DB, error) { db := &DB{ - log: logger, - m: m, - store: store, + log: logger, + m: m, + store: store, + chainID: chainID, } if err := db.init(trimToLastSealed); err != nil { return nil, fmt.Errorf("failed to init database: %w", err) @@ -266,15 +269,6 @@ func (db *DB) LatestSealedBlock() (id eth.BlockID, ok bool) { }, true } -// Get returns the hash of the log at the specified blockNum (of the sealed block) -// and logIdx (of the log after the block), or an error if the log is not found. -func (db *DB) Get(blockNum uint64, logIdx uint32) (common.Hash, error) { - db.rwLock.RLock() - defer db.rwLock.RUnlock() - hash, _, err := db.findLogInfo(blockNum, logIdx) - return hash, err -} - // Contains returns no error iff the specified logHash is recorded in the specified blockNum and logIdx. // If the log is out of reach, then ErrFuture is returned. // If the log is determined to conflict with the canonical chain, then ErrConflict is returned. @@ -283,10 +277,10 @@ func (db *DB) Get(blockNum uint64, logIdx uint32) (common.Hash, error) { // The block-seal of the blockNum block, that the log was included in, is returned. // This seal may be fully zeroed, without error, if the block isn't fully known yet. func (db *DB) Contains(query types.ContainsQuery) (types.BlockSeal, error) { - blockNum, logIdx, logHash, timestamp := query.BlockNum, query.LogIdx, query.LogHash, query.Timestamp + blockNum, logIdx, timestamp := query.BlockNum, query.LogIdx, query.Timestamp db.rwLock.RLock() defer db.rwLock.RUnlock() - db.log.Trace("Checking for log", "blockNum", blockNum, "logIdx", logIdx, "hash", logHash) + db.log.Trace("Checking for log", "blockNum", blockNum, "logIdx", logIdx) // Hot-path: check if we have the block if db.lastEntryContext.hasCompleteBlock() && db.lastEntryContext.blockNum < blockNum { @@ -299,15 +293,11 @@ func (db *DB) Contains(query types.ContainsQuery) (types.BlockSeal, error) { return types.BlockSeal{}, types.ErrFuture } - evtHash, iter, err := db.findLogInfo(blockNum, logIdx) + entryLogHash, iter, err := db.findLogInfo(blockNum, logIdx) if err != nil { return types.BlockSeal{}, err // may be ErrConflict if the block does not have as many logs } - db.log.Trace("Found initiatingEvent", "blockNum", blockNum, "logIdx", logIdx, "hash", evtHash) - // Found the requested block and log index, check if the hash matches - if evtHash != logHash { - return types.BlockSeal{}, fmt.Errorf("payload hash mismatch: expected %s, got %s %w", logHash, evtHash, types.ErrConflict) - } + db.log.Trace("Found initiatingEvent", "blockNum", blockNum, "logIdx", logIdx, "hash", entryLogHash) // Now find the block seal after the log, to identify where the log was included in. err = iter.TraverseConditional(func(state IteratorState) error { _, n, ok := state.SealedBlock() @@ -336,6 +326,17 @@ func (db *DB) Contains(query types.ContainsQuery) (types.BlockSeal, error) { if t != timestamp { return types.BlockSeal{}, fmt.Errorf("timestamp mismatch: expected %d, got %d %w", timestamp, t, types.ErrConflict) } + entryChecksum := types.ChecksumArgs{ + BlockNumber: n, + LogIndex: logIdx, + Timestamp: t, + ChainID: db.chainID, + LogHash: entryLogHash, + }.Checksum() + // Found the requested block and log index, check if the hash matches + if entryChecksum != query.Checksum { + return types.BlockSeal{}, fmt.Errorf("payload hash mismatch: expected %s, got %s %w", query.Checksum, entryChecksum, types.ErrConflict) + } // construct a block seal with the found data now that we know it's correct return types.BlockSeal{ Hash: h, diff --git a/op-supervisor/supervisor/backend/db/logs/db_test.go b/op-supervisor/supervisor/backend/db/logs/db_test.go index 299680421104d..89f603fb158db 100644 --- a/op-supervisor/supervisor/backend/db/logs/db_test.go +++ b/op-supervisor/supervisor/backend/db/logs/db_test.go @@ -38,7 +38,8 @@ func createHash(i int) common.Hash { func TestErrorOpeningDatabase(t *testing.T) { dir := t.TempDir() - _, err := NewFromFile(testlog.Logger(t, log.LvlInfo), &stubMetrics{}, filepath.Join(dir, "missing-dir", "file.db"), false) + chainID := eth.ChainIDFromUInt64(123) + _, err := NewFromFile(testlog.Logger(t, log.LvlInfo), &stubMetrics{}, chainID, filepath.Join(dir, "missing-dir", "file.db"), false) require.ErrorIs(t, err, os.ErrNotExist) } @@ -47,7 +48,8 @@ func runDBTest(t *testing.T, setup func(t *testing.T, db *DB, m *stubMetrics), a logger := testlog.Logger(t, log.LvlTrace) path := filepath.Join(dir, "test.db") m := &stubMetrics{} - db, err := NewFromFile(logger, m, path, false) + chainID := eth.ChainIDFromUInt64(123) + db, err := NewFromFile(logger, m, chainID, path, false) require.NoError(t, err, "Failed to create database") t.Cleanup(func() { err := db.Close() @@ -852,12 +854,14 @@ func requireContains(t *testing.T, db *DB, blockNum uint64, logIdx uint32, times require.LessOrEqual(t, len(execMsg), 1, "cannot have multiple executing messages for a single log") m, ok := db.m.(*stubMetrics) require.True(t, ok, "Did not get the expected metrics type") - _, err := db.Contains(types.ContainsQuery{ - Timestamp: timestamp, - BlockNum: blockNum, - LogIdx: logIdx, - LogHash: logHash, - }) + q := types.ChecksumArgs{ + BlockNumber: blockNum, + LogIndex: logIdx, + Timestamp: timestamp, + ChainID: db.chainID, + LogHash: logHash, + }.Query() + _, err := db.Contains(q) require.NoErrorf(t, err, "Error searching for log %v in block %v", logIdx, blockNum) require.LessOrEqual(t, m.entriesReadForSearch, int64(searchCheckpointFrequency*2), "Should not need to read more than between two checkpoints") require.NotZero(t, m.entriesReadForSearch, "Must read at least some entries to find the log") @@ -872,12 +876,14 @@ func requireContains(t *testing.T, db *DB, blockNum uint64, logIdx uint32, times func requireConflicts(t *testing.T, db *DB, blockNum uint64, logIdx uint32, timestamp uint64, logHash common.Hash) { m, ok := db.m.(*stubMetrics) require.True(t, ok, "Did not get the expected metrics type") - _, err := db.Contains(types.ContainsQuery{ - Timestamp: timestamp, - BlockNum: blockNum, - LogIdx: logIdx, - LogHash: logHash, - }) + q := types.ChecksumArgs{ + BlockNumber: blockNum, + LogIndex: logIdx, + Timestamp: timestamp, + ChainID: db.chainID, + LogHash: logHash, + }.Query() + _, err := db.Contains(q) require.ErrorIs(t, err, types.ErrConflict, "canonical chain must not include this log") require.LessOrEqual(t, m.entriesReadForSearch, int64(searchCheckpointFrequency*2), "Should not need to read more than between two checkpoints") } @@ -885,12 +891,14 @@ func requireConflicts(t *testing.T, db *DB, blockNum uint64, logIdx uint32, time func requireFuture(t *testing.T, db *DB, blockNum uint64, logIdx uint32, timestamp uint64, logHash common.Hash) { m, ok := db.m.(*stubMetrics) require.True(t, ok, "Did not get the expected metrics type") - _, err := db.Contains(types.ContainsQuery{ - Timestamp: timestamp, - BlockNum: blockNum, - LogIdx: logIdx, - LogHash: logHash, - }) + q := types.ChecksumArgs{ + BlockNumber: blockNum, + LogIndex: logIdx, + Timestamp: timestamp, + ChainID: db.chainID, + LogHash: logHash, + }.Query() + _, err := db.Contains(q) require.ErrorIs(t, err, types.ErrFuture, "canonical chain does not yet include this log") require.LessOrEqual(t, m.entriesReadForSearch, int64(searchCheckpointFrequency*2), "Should not need to read more than between two checkpoints") } @@ -912,10 +920,11 @@ func requireExecutingMessage(t *testing.T, db *DB, blockNum uint64, logIdx uint3 } func TestRecoverOnCreate(t *testing.T) { + chainID := eth.ChainIDFromUInt64(123) createDb := func(t *testing.T, store *entrydb.MemEntryStore[EntryType, Entry]) (*DB, *stubMetrics, error) { logger := testlog.Logger(t, log.LvlInfo) m := &stubMetrics{} - db, err := NewFromEntryStore(logger, m, store, true) + db, err := NewFromEntryStore(logger, m, chainID, store, true) return db, m, err } diff --git a/op-supervisor/supervisor/backend/db/open.go b/op-supervisor/supervisor/backend/db/open.go index 8e6b44edea080..4129175202e3d 100644 --- a/op-supervisor/supervisor/backend/db/open.go +++ b/op-supervisor/supervisor/backend/db/open.go @@ -15,7 +15,7 @@ func OpenLogDB(logger log.Logger, chainID eth.ChainID, dataDir string, m logs.Me if err != nil { return nil, fmt.Errorf("failed to create datadir for chain %s: %w", chainID, err) } - logDB, err := logs.NewFromFile(logger, m, path, true) + logDB, err := logs.NewFromFile(logger, m, chainID, path, true) if err != nil { return nil, fmt.Errorf("failed to create logdb for chain %s at %v: %w", chainID, path, err) } diff --git a/op-supervisor/supervisor/backend/db/query.go b/op-supervisor/supervisor/backend/db/query.go index 8d4e7a95e9a51..400d95ec4da71 100644 --- a/op-supervisor/supervisor/backend/db/query.go +++ b/op-supervisor/supervisor/backend/db/query.go @@ -98,6 +98,21 @@ func (db *ChainsDB) IsLocalSafe(chainID eth.ChainID, block eth.BlockID) error { return ldb.ContainsDerived(block) } +func (db *ChainsDB) IsFinalized(chainID eth.ChainID, block eth.BlockID) error { + finL1 := db.FinalizedL1() + if finL1 == (eth.BlockRef{}) { + return types.ErrUninitialized + } + source, err := db.CrossDerivedToSource(chainID, block) + if err != nil { + return fmt.Errorf("failed to get cross-safe source: %w", err) + } + if finL1.Number >= source.Number { + return nil + } + return fmt.Errorf("cross-safe source block is not finalized: %w", types.ErrFuture) +} + func (db *ChainsDB) SafeDerivedAt(chainID eth.ChainID, source eth.BlockID) (types.BlockSeal, error) { lDB, ok := db.localDBs.Get(chainID) if !ok { @@ -403,41 +418,6 @@ func (db *ChainsDB) NextSource(chain eth.ChainID, source eth.BlockID) (after eth return v.MustWithParent(source), nil } -// Safest returns the strongest safety level that can be guaranteed for the given log entry. -// it assumes the log entry has already been checked and is valid, this function only checks safety levels. -// Safety levels are assumed to graduate from LocalUnsafe to LocalSafe to CrossUnsafe to CrossSafe, with Finalized as the strongest. -func (db *ChainsDB) Safest(chainID eth.ChainID, blockNum uint64, index uint32) (safest types.SafetyLevel, err error) { - if finalized, err := db.Finalized(chainID); err == nil { - if finalized.Number >= blockNum { - return types.Finalized, nil - } - } - crossSafe, err := db.CrossSafe(chainID) - if err != nil { - return types.Invalid, err - } - if crossSafe.Derived.Number >= blockNum { - return types.CrossSafe, nil - } - crossUnsafe, err := db.CrossUnsafe(chainID) - if err != nil { - return types.Invalid, err - } - // TODO(#12425): API: "index" for in-progress block building shouldn't be exposed from DB. - // For now we're not counting anything cross-safe until the block is sealed. - if blockNum <= crossUnsafe.Number { - return types.CrossUnsafe, nil - } - localSafe, err := db.LocalSafe(chainID) - if err != nil { - return types.Invalid, err - } - if blockNum <= localSafe.Derived.Number { - return types.LocalSafe, nil - } - return types.LocalUnsafe, nil -} - func (db *ChainsDB) IteratorStartingAt(chain eth.ChainID, sealedNum uint64, logIndex uint32) (logs.Iterator, error) { logDB, ok := db.logDBs.Get(chain) if !ok { diff --git a/op-supervisor/supervisor/backend/mock.go b/op-supervisor/supervisor/backend/mock.go index 94a6f4682ae0e..98157ddab821f 100644 --- a/op-supervisor/supervisor/backend/mock.go +++ b/op-supervisor/supervisor/backend/mock.go @@ -47,15 +47,8 @@ func (m *MockBackend) AddL2RPC(ctx context.Context, rpc string, jwtSecret eth.By return nil } -func (m *MockBackend) CheckMessage(ctx context.Context, identifier types.Identifier, payloadHash common.Hash, executingDescriptor types.ExecutingDescriptor) (types.SafetyLevel, error) { - return types.CrossUnsafe, nil -} - -func (m *MockBackend) CheckMessages(ctx context.Context, messages []types.Message, minSafety types.SafetyLevel) error { - return nil -} - -func (m *MockBackend) CheckMessagesV2(ctx context.Context, messages []types.Message, minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error { +func (m *MockBackend) CheckAccessList(ctx context.Context, inboxEntries []common.Hash, + minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error { return nil } diff --git a/op-supervisor/supervisor/backend/rewinder/rewinder_test.go b/op-supervisor/supervisor/backend/rewinder/rewinder_test.go index 19353fe0ca783..ab358d13ea450 100644 --- a/op-supervisor/supervisor/backend/rewinder/rewinder_test.go +++ b/op-supervisor/supervisor/backend/rewinder/rewinder_test.go @@ -1455,7 +1455,7 @@ func setupTestChains(t *testing.T, chainIDs ...eth.ChainID) *testSetup { require.NoError(t, err) // Create and open the log DB - logDB, err := logs.NewFromFile(logger, &stubMetrics{}, filepath.Join(chainDir, "log.db"), true) + logDB, err := logs.NewFromFile(logger, &stubMetrics{}, chainID, filepath.Join(chainDir, "log.db"), true) require.NoError(t, err) chainsDB.AddLogDB(chainID, logDB) diff --git a/op-supervisor/supervisor/frontend/frontend.go b/op-supervisor/supervisor/frontend/frontend.go index 5b4b239cd12cd..fce86030488fa 100644 --- a/op-supervisor/supervisor/frontend/frontend.go +++ b/op-supervisor/supervisor/frontend/frontend.go @@ -22,30 +22,9 @@ type QueryFrontend struct { var _ sources.SupervisorQueryAPI = (*QueryFrontend)(nil) -// CheckMessage checks the safety-level of an individual message. -// The payloadHash references the hash of the message-payload of the message. -func (q *QueryFrontend) CheckMessage(ctx context.Context, identifier types.Identifier, payloadHash common.Hash, executingDescriptor types.ExecutingDescriptor) (types.SafetyLevel, error) { - return q.Supervisor.CheckMessage(ctx, identifier, payloadHash, executingDescriptor) -} - -// CheckMessagesV2 checks the safety-level of a collection of messages, -// and returns if the minimum safety-level is met for all messages. -func (q *QueryFrontend) CheckMessagesV2( - ctx context.Context, - messages []types.Message, - minSafety types.SafetyLevel, - executingDescriptor types.ExecutingDescriptor) error { - return q.Supervisor.CheckMessagesV2(ctx, messages, minSafety, executingDescriptor) -} - -// CheckMessages checks the safety-level of a collection of messages, -// and returns if the minimum safety-level is met for all messages. -// Deprecated: This method does not check for message expiry. -func (q *QueryFrontend) CheckMessages( - ctx context.Context, - messages []types.Message, - minSafety types.SafetyLevel) error { - return q.Supervisor.CheckMessages(ctx, messages, minSafety) +func (q *QueryFrontend) CheckAccessList(ctx context.Context, inboxEntries []common.Hash, + minSafety types.SafetyLevel, executingDescriptor types.ExecutingDescriptor) error { + return q.Supervisor.CheckAccessList(ctx, inboxEntries, minSafety, executingDescriptor) } func (q *QueryFrontend) LocalUnsafe(ctx context.Context, chainID eth.ChainID) (eth.BlockID, error) { diff --git a/op-supervisor/supervisor/service_test.go b/op-supervisor/supervisor/service_test.go index 01003e17f6e75..b7508ce909b8f 100644 --- a/op-supervisor/supervisor/service_test.go +++ b/op-supervisor/supervisor/service_test.go @@ -65,18 +65,10 @@ func TestSupervisorService(t *testing.T) { cl, err := dial.DialRPCClientWithTimeout(context.Background(), time.Second*5, logger, endpoint) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - var dest types.SafetyLevel - err = cl.CallContext(ctx, &dest, "supervisor_checkMessage", - types.Identifier{ - Origin: common.Address{0xaa}, - BlockNumber: 123, - LogIndex: 42, - Timestamp: 1234567, - ChainID: eth.ChainID{0xbb}, - }, common.Hash{0xcc}, types.ExecutingDescriptor{Timestamp: 1234568}) + err = cl.CallContext(ctx, nil, "supervisor_checkAccessList", + []common.Hash{}, types.CrossUnsafe, types.ExecutingDescriptor{Timestamp: 1234568}) cancel() require.NoError(t, err) - require.Equal(t, types.CrossUnsafe, dest, "expecting mock to return cross-unsafe") cl.Close() } require.NoError(t, supervisor.Stop(context.Background()), "stop service") diff --git a/op-supervisor/supervisor/types/types.go b/op-supervisor/supervisor/types/types.go index 6302a7086bf6e..46908ecce2e35 100644 --- a/op-supervisor/supervisor/types/types.go +++ b/op-supervisor/supervisor/types/types.go @@ -1,22 +1,25 @@ package types import ( + "encoding/binary" "encoding/json" "errors" "fmt" "math" "strconv" - ethTypes "github.com/ethereum/go-ethereum/core/types" + "github.com/holiman/uint256" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + ethTypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum-optimism/optimism/op-service/eth" ) // ChainIndex represents the lifetime of a chain in a dependency set. +// Warning: JSON-encoded as string, in base-10. type ChainIndex uint32 func (ci ChainIndex) String() string { @@ -42,7 +45,7 @@ type ContainsQuery struct { Timestamp uint64 BlockNum uint64 LogIdx uint32 - LogHash common.Hash // LogHash commits to the origin-address and the message payload-hash + Checksum MessageChecksum } type ExecutingMessage struct { @@ -50,7 +53,7 @@ type ExecutingMessage struct { BlockNum uint64 LogIdx uint32 Timestamp uint64 - Hash common.Hash + Hash common.Hash // LogHash (hash of msgHash and origin address) } func (s *ExecutingMessage) String() string { @@ -63,6 +66,56 @@ type Message struct { PayloadHash common.Hash `json:"payloadHash"` } +func (m *Message) Checksum() MessageChecksum { + args := ChecksumArgs{ + BlockNumber: m.Identifier.BlockNumber, + LogIndex: m.Identifier.LogIndex, + Timestamp: m.Identifier.Timestamp, + ChainID: m.Identifier.ChainID, + LogHash: PayloadHashToLogHash(m.PayloadHash, m.Identifier.Origin), + } + return args.Checksum() +} + +type ChecksumArgs struct { + BlockNumber uint64 + LogIndex uint32 + Timestamp uint64 + ChainID eth.ChainID + LogHash common.Hash +} + +func (args ChecksumArgs) Checksum() MessageChecksum { + idPacked := make([]byte, 12, 32) // 12 zero bytes, as padding to 32 bytes + idPacked = binary.BigEndian.AppendUint64(idPacked, args.BlockNumber) + idPacked = binary.BigEndian.AppendUint64(idPacked, args.Timestamp) + idPacked = binary.BigEndian.AppendUint32(idPacked, args.LogIndex) + idLogHash := crypto.Keccak256Hash(args.LogHash[:], idPacked) + chainID := args.ChainID.Bytes32() + out := crypto.Keccak256Hash(idLogHash[:], chainID[:]) + out[0] = 0x03 // type/version byte + return MessageChecksum(out) +} + +func (args ChecksumArgs) Access() Access { + return Access{ + BlockNumber: args.BlockNumber, + Timestamp: args.Timestamp, + LogIndex: args.LogIndex, + ChainID: args.ChainID, + Checksum: args.Checksum(), + } +} + +func (args ChecksumArgs) Query() ContainsQuery { + return ContainsQuery{ + BlockNum: args.BlockNumber, + Timestamp: args.Timestamp, + LogIdx: args.LogIndex, + Checksum: args.Checksum(), + } +} + type Identifier struct { Origin common.Address BlockNumber uint64 @@ -137,30 +190,6 @@ func (lvl *SafetyLevel) UnmarshalText(text []byte) error { return nil } -// AtLeastAsSafe returns true if the receiver is at least as safe as the other SafetyLevel. -// Safety levels are assumed to graduate from LocalUnsafe to LocalSafe to CrossUnsafe to CrossSafe, with Finalized as the strongest. -func (lvl *SafetyLevel) AtLeastAsSafe(min SafetyLevel) bool { - relativeSafety := map[SafetyLevel]int{ - Invalid: 0, - LocalUnsafe: 1, - LocalSafe: 2, - CrossUnsafe: 3, - CrossSafe: 4, - Finalized: 5, - } - // if either level is not recognized, return false - _, ok := relativeSafety[*lvl] - if !ok { - return false - } - _, ok = relativeSafety[min] - if !ok { - return false - } - // compare the relative safety levels to determine if the receiver is at least as safe as the other - return relativeSafety[*lvl] >= relativeSafety[min] -} - const ( // Finalized is CrossSafe, with the additional constraint that every // dependency is derived only from finalized L1 input data. @@ -187,15 +216,59 @@ const ( type ExecutingDescriptor struct { // Timestamp is the timestamp of the executing message Timestamp uint64 + + // Timeout, requests verification to still hold at Timestamp+Timeout (incl.). Defaults to 0. + // I.e. Timestamp is used as lower-bound validity, and Timeout defines the span to the upper-bound. + Timeout uint64 +} + +func (ed *ExecutingDescriptor) AccessCheck(expiryWindow uint64, initMsgTimestamp uint64) error { + // Check upper-bound invariant, strictly + // (for access-lists we don't afford to check intra-timestamp dependencies) + if ed.Timestamp < initMsgTimestamp { + return fmt.Errorf("message broke timestamp invariant: exec: %d, init: %d, %w", + ed.Timestamp, initMsgTimestamp, ErrConflict) + } + if ed.Timestamp == initMsgTimestamp { + return fmt.Errorf("access-list check does not allow intra-timestamp (%d): %w", ed.Timestamp, ErrConflict) + } + + // Check message expiry + expiryAt := initMsgTimestamp + expiryWindow + if expiryAt < initMsgTimestamp { + return fmt.Errorf("message timestamp too high, overflows: %d, %w", + initMsgTimestamp, ErrConflict) + } + if ed.Timestamp > expiryAt { + return fmt.Errorf("cannot message execute at %d, message expired at %d: %w", + ed.Timestamp, expiryAt, ErrConflict) + } + if ed.Timeout == 0 { + // If no timeout, then just checking the exact execution time was sufficient + return nil + } + + // If a timeout is set, check if executing late is still within the expiry window + if ed.Timestamp+ed.Timeout < ed.Timestamp { + return fmt.Errorf("message timeout too high, overflows: %d, %w", + ed.Timestamp, ErrConflict) + } + if v := ed.Timestamp + ed.Timeout; v > expiryAt { + return fmt.Errorf("cannot execute message at timeout %d, expired at %d: %w", + v, expiryAt, ErrConflict) + } + return nil } type executingDescriptorMarshaling struct { Timestamp hexutil.Uint64 `json:"timestamp"` + Timeout hexutil.Uint64 `json:"timeout,omitempty"` } func (ed ExecutingDescriptor) MarshalJSON() ([]byte, error) { var enc executingDescriptorMarshaling enc.Timestamp = hexutil.Uint64(ed.Timestamp) + enc.Timeout = hexutil.Uint64(ed.Timeout) return json.Marshal(&enc) } @@ -205,6 +278,7 @@ func (ed *ExecutingDescriptor) UnmarshalJSON(input []byte) error { return err } ed.Timestamp = uint64(dec.Timestamp) + ed.Timeout = uint64(dec.Timeout) return nil } @@ -348,3 +422,157 @@ type ManagedEvent struct { ReplaceBlock *BlockReplacement `json:"replaceBlock,omitempty"` DerivationOriginUpdate *eth.BlockRef `json:"derivationOriginUpdate,omitempty"` } + +// MessageChecksum represents a message checksum, as used for access-list checks. +type MessageChecksum common.Hash + +func (mc MessageChecksum) MarshalText() ([]byte, error) { + return common.Hash(mc).MarshalText() +} + +func (mc *MessageChecksum) UnmarshalText(data []byte) error { + return (*common.Hash)(mc).UnmarshalText(data) +} + +func (mc MessageChecksum) String() string { + return common.Hash(mc).String() +} + +// Access represents access to a message, parsed from an access-list +type Access struct { + BlockNumber uint64 + Timestamp uint64 + LogIndex uint32 + ChainID eth.ChainID + Checksum MessageChecksum +} + +// lookupEntry encodes a lookup entry for an access-list +func (acc Access) lookupEntry() common.Hash { + var out common.Hash + out[0] = PrefixLookup + binary.BigEndian.PutUint64(out[4:12], (*uint256.Int)(&acc.ChainID).Uint64()) + binary.BigEndian.PutUint64(out[12:20], acc.BlockNumber) + binary.BigEndian.PutUint64(out[20:28], acc.Timestamp) + binary.BigEndian.PutUint32(out[28:32], acc.LogIndex) + return out +} + +// chainIDExtensionEntry encodes a chainID-extension entry for an access-list +func (acc Access) chainIDExtensionEntry() common.Hash { + var out common.Hash + dat := (*uint256.Int)(&acc.ChainID).Bytes32() + out[0] = PrefixChainIDExtension + copy(out[8:32], dat[0:24]) + return out +} + +type accessMarshaling struct { + BlockNumber hexutil.Uint64 `json:"blockNumber"` + Timestamp hexutil.Uint64 `json:"timestamp"` + LogIndex uint32 `json:"logIndex"` + ChainID eth.ChainID `json:"chainID"` + Checksum MessageChecksum `json:"checksum"` +} + +func (a Access) MarshalJSON() ([]byte, error) { + enc := accessMarshaling{ + BlockNumber: hexutil.Uint64(a.BlockNumber), + Timestamp: hexutil.Uint64(a.Timestamp), + LogIndex: a.LogIndex, + ChainID: a.ChainID, + Checksum: a.Checksum, + } + return json.Marshal(&enc) +} + +func (a *Access) UnmarshalJSON(input []byte) error { + var dec accessMarshaling + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + a.BlockNumber = uint64(dec.BlockNumber) + a.Timestamp = uint64(dec.Timestamp) + a.LogIndex = dec.LogIndex + a.ChainID = dec.ChainID + a.Checksum = dec.Checksum + return nil +} + +const ( + PrefixLookup = 1 + PrefixChainIDExtension = 2 + PrefixChecksum = 3 +) + +var ( + errExpectedEntry = errors.New("expected entry") + errMalformedEntry = errors.New("malformed entry") + errUnexpectedEntryType = errors.New("unexpected entry type") +) + +// ParseAccess parses some access-list entries into an Access, and returns the remaining entries. +// This process can be repeated until no entries are left, to parse an access-list. +func ParseAccess(entries []common.Hash) ([]common.Hash, Access, error) { + if len(entries) == 0 { + return nil, Access{}, errExpectedEntry + } + entry := entries[0] + entries = entries[1:] + if typeByte := entry[0]; typeByte != PrefixLookup { + return nil, Access{}, fmt.Errorf("expected lookup, got entry type %d: %w", + typeByte, errUnexpectedEntryType) + } + if ([3]byte)(entry[1:4]) != ([3]byte{}) { + return nil, Access{}, fmt.Errorf("expected zero bytes: %w", errMalformedEntry) + } + var access Access + access.ChainID = eth.ChainIDFromUInt64(binary.BigEndian.Uint64(entry[4:12])) + access.BlockNumber = binary.BigEndian.Uint64(entry[12:20]) + access.Timestamp = binary.BigEndian.Uint64(entry[20:28]) + access.LogIndex = binary.BigEndian.Uint32(entry[28:32]) + + if len(entries) == 0 { + return nil, Access{}, errExpectedEntry + } + entry = entries[0] + entries = entries[1:] + if typeByte := entry[0]; typeByte == PrefixChainIDExtension { + if ([7]byte)(entry[1:8]) != ([7]byte{}) { + return nil, Access{}, fmt.Errorf("expected zero bytes") + } + // The lower 8 bytes is set to the uint64 in the first entry. + // The upper 24 bytes are set with this extension entry. + chIDBytes32 := access.ChainID.Bytes32() + copy(chIDBytes32[0:24], entry[8:32]) + access.ChainID = eth.ChainIDFromBytes32(chIDBytes32) + if len(entries) == 0 { + return nil, Access{}, errExpectedEntry + } + entry = entries[0] + entries = entries[1:] + } + if typeByte := entry[0]; typeByte != PrefixChecksum { + return nil, Access{}, fmt.Errorf("expected checksum, got entry type %d: %w", + typeByte, errUnexpectedEntryType) + } + access.Checksum = MessageChecksum(entry) + return entries, access, nil +} + +func EncodeAccessList(accesses []Access) []common.Hash { + out := make([]common.Hash, 0, len(accesses)*2) + for _, acc := range accesses { + out = append(out, acc.lookupEntry()) + + if !(*uint256.Int)(&acc.ChainID).IsUint64() { + out = append(out, acc.chainIDExtensionEntry()) + } + + if acc.Checksum[0] != PrefixChecksum { + panic("invalid checksum entry") + } + out = append(out, common.Hash(acc.Checksum)) + } + return out +} diff --git a/op-supervisor/supervisor/types/types_test.go b/op-supervisor/supervisor/types/types_test.go index bc01fa63efa5f..3c9d018dbc461 100644 --- a/op-supervisor/supervisor/types/types_test.go +++ b/op-supervisor/supervisor/types/types_test.go @@ -1,14 +1,22 @@ package types import ( + "encoding/binary" "encoding/json" + "fmt" "math/big" + "math/rand" + "strings" "testing" - "github.com/ethereum-optimism/optimism/op-service/eth" "github.com/stretchr/testify/require" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + ethTypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + + "github.com/ethereum-optimism/optimism/op-service/eth" ) func FuzzRoundtripIdentifierJSONMarshal(f *testing.F) { @@ -38,3 +46,506 @@ func FuzzRoundtripIdentifierJSONMarshal(f *testing.F) { require.Equal(t, id.ChainID, dec.ChainID) }) } + +func TestChainIndex(t *testing.T) { + var x ChainIndex + require.NoError(t, json.Unmarshal([]byte(`"1"`), &x)) + require.Equal(t, ChainIndex(1), x) + data, err := json.Marshal(x) + require.NoError(t, err) + require.Equal(t, `"1"`, string(data)) + + require.NoError(t, json.Unmarshal([]byte(`"4294967295"`), &x)) + require.Equal(t, ChainIndex(0xff_ff_ff_ff), x) + data, err = json.Marshal(x) + require.NoError(t, err) + require.Equal(t, `"4294967295"`, string(data)) + + require.ErrorContains(t, json.Unmarshal([]byte(`"-1"`), &x), "invalid") + require.ErrorContains(t, json.Unmarshal([]byte(`"4294967296"`), &x), "out of range") +} + +func TestHashing(t *testing.T) { + keccak256 := func(name string, parts ...[]byte) (h common.Hash) { + t.Logf("%s = H(", name) + for _, p := range parts { + t.Logf(" %x,", p) + } + t.Logf(")") + h = crypto.Keccak256Hash(parts...) + t.Logf("%s = %s", name, h) + return h + } + id := Identifier{ + Origin: common.HexToAddress("0xe0e1e2e3e4e5e6e7e8e9f0f1f2f3f4f5f6f7f8f9"), + BlockNumber: 0xa1a2_a3a4_a5a6_a7a8, + LogIndex: 0xb1b2_b3b4, + Timestamp: 0xc1c2_c3c4_c5c6_c7c8, + ChainID: eth.ChainIDFromUInt64(0xd1d2_d3d4_d5d6_d7d8), + } + payloadHash := keccak256("payloadHash", []byte("example payload")) // aka msgHash + logHash := keccak256("logHash", id.Origin[:], payloadHash[:]) + x := PayloadHashToLogHash(payloadHash, id.Origin) + require.Equal(t, logHash, x, "check op-supervisor version of log-hashing matches intermediate value") + + var idPacked []byte + idPacked = append(idPacked, make([]byte, 12)...) + idPacked = binary.BigEndian.AppendUint64(idPacked, id.BlockNumber) + idPacked = binary.BigEndian.AppendUint64(idPacked, id.Timestamp) + idPacked = binary.BigEndian.AppendUint32(idPacked, id.LogIndex) + t.Logf("idPacked: %x", idPacked) + + idLogHash := keccak256("idLogHash", logHash[:], idPacked) + chainID := id.ChainID.Bytes32() + bareChecksum := keccak256("bareChecksum", idLogHash[:], chainID[:]) + + checksum := bareChecksum + checksum[0] = 0x03 + t.Logf("Checksum: %s", checksum) +} + +var ( + testOrigin = common.HexToAddress("0xe0e1e2e3e4e5e6e7e8e9f0f1f2f3f4f5f6f7f8f9") + testBlockNumber = uint64(0xa1a2_a3a4_a5a6_a7a8) + testLogIndex = uint32(0xb1b2_b3b4) + testTimestamp = uint64(0xc1c2_c3c4_c5c6_c7c8) + testChainID = eth.ChainIDFromUInt64(0xd1d2_d3d4_d5d6_d7d8) + testPayload = []byte("example payload") + testMsgHash = common.HexToHash("0x8017559a85b12c04b14a1a425d53486d1015f833714a09bd62f04152a7e2ae9b") + testLogHash = common.HexToHash("0xf9ed05990c887d3f86718aabd7e940faaa75d6a5cd44602e89642586ce85f2aa") + testChecksum = MessageChecksum(common.HexToHash("0x03749e87fd7789575de9906569deb05aaf220dc4cfab3d8abbfd34a2e1d7d357")) + testLookupEntry = common.HexToHash("0x01000000d1d2d3d4d5d6d7d8a1a2a3a4a5a6a7a8c1c2c3c4c5c6c7c8b1b2b3b4") +) + +func TestMessage(t *testing.T) { + msg := Message{ + Identifier: Identifier{ + Origin: testOrigin, + BlockNumber: testBlockNumber, + LogIndex: testLogIndex, + Timestamp: testTimestamp, + ChainID: testChainID, + }, + PayloadHash: testMsgHash, + } + t.Run("checksum", func(t *testing.T) { + require.Equal(t, testChecksum, msg.Checksum()) + }) + t.Run("json roundtrip", func(t *testing.T) { + data, err := json.Marshal(msg) + require.NoError(t, err) + var out Message + require.NoError(t, json.Unmarshal(data, &out)) + require.Equal(t, msg, out) + }) +} + +func TestChecksumArgs(t *testing.T) { + args := ChecksumArgs{ + BlockNumber: testBlockNumber, + LogIndex: testLogIndex, + Timestamp: testTimestamp, + ChainID: testChainID, + LogHash: testLogHash, + } + t.Run("checksum", func(t *testing.T) { + require.Equal(t, testChecksum, args.Checksum()) + }) + t.Run("as query", func(t *testing.T) { + q := args.Query() + require.Equal(t, testBlockNumber, q.BlockNum) + require.Equal(t, testTimestamp, q.Timestamp) + require.Equal(t, testLogIndex, q.LogIdx) + require.Equal(t, testChecksum, q.Checksum) + }) + t.Run("as access", func(t *testing.T) { + acc := args.Access() + require.Equal(t, testBlockNumber, acc.BlockNumber) + require.Equal(t, testTimestamp, acc.Timestamp) + require.Equal(t, testLogIndex, acc.LogIndex) + require.Equal(t, testChainID, acc.ChainID) + require.Equal(t, testChecksum, acc.Checksum) + }) +} + +func TestIdentifier(t *testing.T) { + id := Identifier{ + Origin: testOrigin, + BlockNumber: testBlockNumber, + LogIndex: testLogIndex, + Timestamp: testTimestamp, + ChainID: testChainID, + } + t.Run("json roundtrip", func(t *testing.T) { + data, err := json.Marshal(id) + require.NoError(t, err) + var out Identifier + require.NoError(t, json.Unmarshal(data, &out)) + require.Equal(t, id, out) + }) +} + +func TestSafetyLevel(t *testing.T) { + for _, lvl := range []SafetyLevel{ + Finalized, + CrossSafe, + LocalSafe, + CrossUnsafe, + LocalUnsafe, + Invalid, + } { + upper := strings.ToUpper(lvl.String()) + var x SafetyLevel + require.ErrorContains(t, json.Unmarshal([]byte(fmt.Sprintf("%q", upper)), &x), "unrecognized", "case sensitive") + require.NoError(t, json.Unmarshal([]byte(fmt.Sprintf("%q", lvl.String())), &x)) + dat, err := json.Marshal(x) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%q", lvl.String()), string(dat)) + } + var x SafetyLevel + require.ErrorContains(t, json.Unmarshal([]byte(`""`), &x), "unrecognized", "empty") + require.ErrorContains(t, json.Unmarshal([]byte(`"foobar"`), &x), "unrecognized", "other") +} + +type execDescrTestCase struct { + name string + ed ExecutingDescriptor + expiryWindow uint64 + initMsgTimestamp uint64 + errStr string // empty if no error +} + +func TestExecutingDescriptorAccessCheck(t *testing.T) { + testCases := []execDescrTestCase{ + { + name: "success", + ed: ExecutingDescriptor{ + Timestamp: 3, + Timeout: 0, + }, + expiryWindow: 10, + initMsgTimestamp: 2, + }, + { + name: "future exec", + ed: ExecutingDescriptor{ + Timestamp: 3, + Timeout: 0, + }, + expiryWindow: 10, + initMsgTimestamp: 4, + errStr: "broke timestamp invariant", + }, + { + name: "access-list checks are extra strict, don't allow intra-timestamp", + ed: ExecutingDescriptor{ + Timestamp: 3, + Timeout: 0, + }, + expiryWindow: 10, + initMsgTimestamp: 3, + errStr: "not allow intra-timestamp", + }, + { + name: "attempt init-msg timestamp overflow", + ed: ExecutingDescriptor{ + Timestamp: (^uint64(0)) - 2, + Timeout: 0, + }, + expiryWindow: 10, + initMsgTimestamp: (^uint64(0)) - 3, + errStr: "overflow", + }, + { + name: "expired", + ed: ExecutingDescriptor{ + Timestamp: 100, + Timeout: 0, + }, + expiryWindow: 10, + initMsgTimestamp: 89, + errStr: "expired", + }, + { + name: "timeout overflow", + ed: ExecutingDescriptor{ + Timestamp: 100, + Timeout: (^uint64(0)) - 3, + }, + expiryWindow: 10, + initMsgTimestamp: 99, + errStr: "overflow", + }, + { + name: "timeout, valid at exec timestamp, but not shortly after", + ed: ExecutingDescriptor{ + Timestamp: 100, + Timeout: 10, //timeout asks for 100+10=110 + }, + expiryWindow: 10, // valid till 95+10 = 105 + initMsgTimestamp: 95, + errStr: "timeout", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.ed.AccessCheck(tc.expiryWindow, tc.initMsgTimestamp) + if tc.errStr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.errStr) + } + }) + } +} + +func TestPayloadHashToLogHash(t *testing.T) { + logHash := PayloadHashToLogHash(testMsgHash, testOrigin) + require.Equal(t, testLogHash, logHash) +} + +func TestLogToMessagePayload(t *testing.T) { + payload := LogToMessagePayload(ðTypes.Log{ + Data: testPayload, + }) + require.Equal(t, hexutil.Encode(testPayload), hexutil.Encode(payload)) + + t.Run("1 topic", func(t *testing.T) { + v := LogToMessagePayload(ðTypes.Log{ + Data: []byte(`foobar`), + Topics: []common.Hash{ + crypto.Keccak256Hash([]byte(`topic0`)), + }, + }) + expected := make([]byte, 0) + expected = append(expected, crypto.Keccak256([]byte(`topic0`))...) + expected = append(expected, []byte(`foobar`)...) + require.Equal(t, expected, v) + }) + + t.Run("4 topics", func(t *testing.T) { + v := LogToMessagePayload(ðTypes.Log{ + Data: []byte(`foobar`), + Topics: []common.Hash{ + crypto.Keccak256Hash([]byte(`topic0`)), + crypto.Keccak256Hash([]byte(`topic1`)), + crypto.Keccak256Hash([]byte(`topic2`)), + crypto.Keccak256Hash([]byte(`topic3`)), + }, + }) + expected := make([]byte, 0) + expected = append(expected, crypto.Keccak256([]byte(`topic0`))...) + expected = append(expected, crypto.Keccak256([]byte(`topic1`))...) + expected = append(expected, crypto.Keccak256([]byte(`topic2`))...) + expected = append(expected, crypto.Keccak256([]byte(`topic3`))...) + expected = append(expected, []byte(`foobar`)...) + require.Equal(t, expected, v) + }) +} + +func TestAccess(t *testing.T) { + acc := Access{ + BlockNumber: testBlockNumber, + Timestamp: testTimestamp, + LogIndex: testLogIndex, + ChainID: testChainID, + Checksum: MessageChecksum(testChecksum), + } + t.Run("json roundtrip", func(t *testing.T) { + data, err := json.Marshal(acc) + require.NoError(t, err) + var out Access + require.NoError(t, json.Unmarshal(data, &out)) + require.Equal(t, acc, out) + }) +} + +func TestParseAccess(t *testing.T) { + t.Run("empty", func(t *testing.T) { + _, _, err := ParseAccess(nil) + require.ErrorIs(t, err, errExpectedEntry) + }) + t.Run("unexpected 0 type", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: 0x00}, + }) + require.ErrorIs(t, err, errUnexpectedEntryType) + require.ErrorContains(t, err, "expected lookup") + }) + t.Run("unexpected arbitrary type", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: 10}, + }) + require.ErrorIs(t, err, errUnexpectedEntryType) + require.ErrorContains(t, err, "expected lookup") + }) + t.Run("unexpected non-zero padding", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: PrefixLookup, 1: 0x01}, // valid lookup prefix byte, but non-zero value in padding area + }) + require.ErrorIs(t, err, errMalformedEntry) + require.ErrorContains(t, err, "expected zero bytes") + }) + t.Run("incomplete", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: PrefixLookup}, // valid lookup, but no checksum after + }) + require.ErrorIs(t, err, errExpectedEntry) + }) + t.Run("unexpected 0 type after checksum", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: PrefixLookup}, + {0: 0}, + }) + require.ErrorIs(t, err, errUnexpectedEntryType) + }) + t.Run("unexpected lookup repeat", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: PrefixLookup}, + {0: PrefixLookup}, + }) + require.ErrorIs(t, err, errUnexpectedEntryType) + }) + t.Run("unexpected arbitrary type after checksum", func(t *testing.T) { + _, _, err := ParseAccess([]common.Hash{ + {0: PrefixLookup}, + {0: 10}, // unexpected type byte + }) + require.ErrorIs(t, err, errUnexpectedEntryType) + }) + t.Run("valid but zero", func(t *testing.T) { + remaining, acc, err := ParseAccess([]common.Hash{ + {0: PrefixLookup}, // valid lookup entry + {0: PrefixChecksum}, // valid checksum entry + }) + require.NoError(t, err) + require.Equal(t, Access{ + BlockNumber: 0, + Timestamp: 0, + LogIndex: 0, + ChainID: eth.ChainID{}, + Checksum: MessageChecksum{0: PrefixChecksum}, + }, acc) + require.Empty(t, remaining) + }) + t.Run("valid", func(t *testing.T) { + acc := Access{ + BlockNumber: testBlockNumber, + Timestamp: testTimestamp, + LogIndex: testLogIndex, + ChainID: testChainID, + Checksum: MessageChecksum(testChecksum), + } + remaining, parsed, err := ParseAccess([]common.Hash{ + testLookupEntry, + common.Hash(acc.Checksum), + }) + require.NoError(t, err) + require.Equal(t, acc, parsed) + require.Empty(t, remaining) + }) + t.Run("repeat", func(t *testing.T) { + acc := Access{ + BlockNumber: testBlockNumber, + Timestamp: testTimestamp, + LogIndex: testLogIndex, + ChainID: testChainID, + Checksum: MessageChecksum(testChecksum), + } + remaining, parsed, err := ParseAccess([]common.Hash{ + testLookupEntry, + common.Hash(acc.Checksum), + testLookupEntry, + common.Hash(acc.Checksum), + }) + require.NoError(t, err) + require.Equal(t, acc, parsed) + require.Len(t, remaining, 2) + remaining2, parsed2, err := ParseAccess(remaining) + require.NoError(t, err) + require.Equal(t, acc, parsed2) + require.Empty(t, remaining2) + }) + t.Run("with chainID extension", func(t *testing.T) { + acc := Access{ + BlockNumber: testBlockNumber, + Timestamp: testTimestamp, + LogIndex: testLogIndex, + ChainID: eth.ChainIDFromBytes32([32]byte{0: 7, 31: 10}), + Checksum: MessageChecksum(testChecksum), + } + remaining, parsed, err := ParseAccess([]common.Hash{ + acc.lookupEntry(), + acc.chainIDExtensionEntry(), + common.Hash(acc.Checksum), + }) + require.NoError(t, err) + require.Equal(t, acc, parsed) + require.Empty(t, remaining) + }) +} + +func TestEncodeAccessList(t *testing.T) { + acc := Access{ + BlockNumber: testBlockNumber, + Timestamp: testTimestamp, + LogIndex: testLogIndex, + ChainID: testChainID, + Checksum: MessageChecksum(testChecksum), + } + t.Run("valid single", func(t *testing.T) { + accList := EncodeAccessList([]Access{acc}) + require.Len(t, accList, 2) + require.Equal(t, testLookupEntry, accList[0]) + require.Equal(t, common.Hash(testChecksum), accList[1]) + _, result, err := ParseAccess(accList) + require.NoError(t, err) + require.Equal(t, acc, result, "roundtrip") + }) + t.Run("valid repeat", func(t *testing.T) { + accList := EncodeAccessList([]Access{ + acc, + acc, + }) + require.Len(t, accList, 4) + require.Equal(t, testLookupEntry, accList[0]) + require.Equal(t, common.Hash(testChecksum), accList[1]) + require.Equal(t, testLookupEntry, accList[2]) + require.Equal(t, common.Hash(testChecksum), accList[3]) + }) + t.Run("roundtrip", func(t *testing.T) { + accObjects := make([]Access, 0) + rng := rand.New(rand.NewSource(1234)) + randB32 := func() (out [32]byte) { + rng.Read(out[:]) + return + } + // test a big random access-list + count := 200 + for i := 0; i < count; i++ { + chainID := eth.ChainIDFromBytes32(randB32()) + if rng.Intn(5) < 2 { // don't make them all full random bytes32 + chainID = eth.ChainIDFromUInt64(rng.Uint64()) + } + checksum := randB32() + checksum[0] = PrefixChecksum + accObjects = append(accObjects, Access{ + BlockNumber: rng.Uint64(), + Timestamp: rng.Uint64(), + LogIndex: rng.Uint32(), + ChainID: chainID, + Checksum: checksum, + }) + } + list := EncodeAccessList(accObjects) + var result []Access + for i := 0; i < count && len(list) > 0; i++ { + remaining, v, err := ParseAccess(list) + require.NoError(t, err) + result = append(result, v) + list = remaining + } + require.Empty(t, list, "need to exhaust entries, expecting to be done") + require.Equal(t, accObjects, result, "roundtrip of random entries should work") + }) +}