diff --git a/.github/workflows/docker_release.yml b/.github/workflows/docker_release.yml index 54c8e8abc..fb3b062be 100644 --- a/.github/workflows/docker_release.yml +++ b/.github/workflows/docker_release.yml @@ -4,9 +4,14 @@ on: push: tags: - morph-v* + workflow_dispatch: + inputs: + tag: + description: 'Tag name to build (e.g. morph-v2.2.0)' + required: true env: - IMAGE_NAME: go-ethereum + IMAGE_NAME: ghcr.io/${{ github.repository }} jobs: push: @@ -14,28 +19,43 @@ jobs: steps: - uses: actions/checkout@v4 + with: + ref: ${{ inputs.tag || github.ref }} - - name: Log into registry - run: echo "${{ secrets.PACKAGE_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 - - name: Build the Docker image - run: | - docker build . --file Dockerfile \ - --build-arg COMMIT="${{ github.sha }}" \ - --build-arg VERSION="${{ github.ref_name }}" \ - -t "${IMAGE_NAME}" + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - - name: Push image + - name: Log into registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.PACKAGE_TOKEN }} + + - name: Extract version and commit + id: meta run: | - IMAGE_ID="ghcr.io/${{ github.repository }}" - - # Change all uppercase to lowercase - IMAGE_ID=$(echo "$IMAGE_ID" | tr '[A-Z]' '[a-z]') - # Strip "morph-v" prefix from tag name - VERSION=$(echo "${{ github.ref_name }}" | sed -e 's/^morph-v//') - echo IMAGE_ID="$IMAGE_ID" - echo VERSION="$VERSION" - docker tag "$IMAGE_NAME" "$IMAGE_ID:$VERSION" - docker tag "$IMAGE_NAME" "$IMAGE_ID:latest" - docker push "$IMAGE_ID:$VERSION" - docker push "$IMAGE_ID:latest" + TAG="${{ inputs.tag || github.ref_name }}" + VERSION=$(echo "$TAG" | sed -e 's/^morph-v//') + COMMIT=$(git rev-parse HEAD) + echo "version=${VERSION}" >> "$GITHUB_OUTPUT" + echo "commit=${COMMIT}" >> "$GITHUB_OUTPUT" + + - name: Build and push + uses: docker/build-push-action@v6 + with: + context: . + file: Dockerfile + platforms: linux/amd64,linux/arm64 + push: true + tags: | + ${{ env.IMAGE_NAME }}:${{ steps.meta.outputs.version }} + ${{ env.IMAGE_NAME }}:latest + build-args: | + COMMIT=${{ steps.meta.outputs.commit }} + VERSION=${{ steps.meta.outputs.version }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 2ce5935c5..e67fa4caf 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -63,6 +63,26 @@ var ( emptyKeccakCodeHash = codehash.EmptyKeccakCodeHash.Bytes() ) +// teeWriter writes to both the real database and the state bloom filter. +// This ensures GenerateTrie persists trie nodes to disk (filling gaps from +// unclean shutdowns) while also recording their hashes in the bloom filter +// for pruning protection. +type teeWriter struct { + db ethdb.KeyValueWriter + bloom *stateBloom +} + +func (w *teeWriter) Put(key, value []byte) error { + if err := w.db.Put(key, value); err != nil { + return err + } + return w.bloom.Put(key, value) +} + +func (w *teeWriter) Delete(key []byte) error { + return w.db.Delete(key) +} + // Pruner is an offline tool to prune the stale state with the // help of the snapshot. The workflow of pruner is very simple: // @@ -230,7 +250,8 @@ func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, sta // Prune deletes all historical state nodes except the nodes belong to the // specified state version. If user doesn't specify the state version, use -// the bottom-most snapshot diff layer as the target. +// the HEAD state as the target so the node restarts at exactly the same +// height it stopped at (zero block rewind). func (p *Pruner) Prune(root common.Hash) error { // If the state bloom filter is already committed previously, // reuse it for pruning instead of generating a new one. It's @@ -243,61 +264,16 @@ func (p *Pruner) Prune(root common.Hash) error { if stateBloomRoot != (common.Hash{}) { return RecoverPruning(p.datadir, p.db, p.trieCachePath) } - // If the target state root is not specified, use the HEAD-127 as the - // target. The reason for picking it is: - // - in most of the normal cases, the related state is available - // - the probability of this layer being reorg is very low - var layers []snapshot.Snapshot if root == (common.Hash{}) { - // Retrieve all snapshot layers from the current HEAD. - // In theory there are 128 difflayers + 1 disk layer present, - // so 128 diff layers are expected to be returned. - layers = p.snaptree.Snapshots(p.headHeader.Root, 128, true) - if len(layers) != 128 { - // Reject if the accumulated diff layers are less than 128. It - // means in most of normal cases, there is no associated state - // with bottom-most diff layer. - return fmt.Errorf("snapshot not old enough yet: need %d more blocks", 128-len(layers)) - } - // Use the bottom-most diff layer as the target - root = layers[len(layers)-1].Root() - } - // Ensure the root is really present. The weak assumption - // is the presence of root can indicate the presence of the - // entire trie. - if blob := rawdb.ReadTrieNode(p.db, root); len(blob) == 0 { - // The special case is for clique based networks(rinkeby, goerli - // and some other private networks), it's possible that two - // consecutive blocks will have same root. In this case snapshot - // difflayer won't be created. So HEAD-127 may not paired with - // head-127 layer. Instead the paired layer is higher than the - // bottom-most diff layer. Try to find the bottom-most snapshot - // layer with state available. - // - // Note HEAD and HEAD-1 is ignored. Usually there is the associated - // state available, but we don't want to use the topmost state - // as the pruning target. - var found bool - for i := len(layers) - 2; i >= 2; i-- { - if blob := rawdb.ReadTrieNode(p.db, layers[i].Root()); len(blob) != 0 { - root = layers[i].Root() - found = true - log.Info("Selecting middle-layer as the pruning target", "root", root, "depth", i) - break - } - } - if !found { - if len(layers) > 0 { - return errors.New("no snapshot paired state") - } - return fmt.Errorf("associated state[%x] is not present", root) - } + // Use HEAD as the pruning target. On L2 chains reorgs are + // essentially non-existent, so the HEAD-127 safety margin + // from L1 is unnecessary. Combined with the teeWriter that + // persists the trie from snapshot, this guarantees zero + // height drop after pruning regardless of shutdown method. + root = p.headHeader.Root + log.Info("Selecting HEAD as the pruning target", "root", root, "height", p.headHeader.Number.Uint64()) } else { - if len(layers) > 0 { - log.Info("Selecting bottom-most difflayer as the pruning target", "root", root, "height", p.headHeader.Number.Uint64()-127) - } else { - log.Info("Selecting user-specified state as the pruning target", "root", root) - } + log.Info("Selecting user-specified state as the pruning target", "root", root) } // Before start the pruning, delete the clean trie cache first. // It's necessary otherwise in the next restart we will hit the @@ -305,19 +281,12 @@ func (p *Pruner) Prune(root common.Hash) error { // state is picked for usage. deleteCleanTrieCache(p.trieCachePath) - // All the state roots of the middle layer should be forcibly pruned, - // otherwise the dangling state will be left. - middleRoots := make(map[common.Hash]struct{}) - for _, layer := range layers { - if layer.Root() == root { - break - } - middleRoots[layer.Root()] = struct{}{} - } // Traverse the target state, re-construct the whole state trie and - // commit to the given bloom filter. + // commit to the given bloom filter. The teeWriter ensures trie nodes + // are also persisted to disk, filling any gaps from unclean shutdowns. start := time.Now() - if err := snapshot.GenerateTrie(p.snaptree, root, p.db, p.stateBloom); err != nil { + writer := &teeWriter{db: p.db, bloom: p.stateBloom} + if err := snapshot.GenerateTrie(p.snaptree, root, p.db, writer); err != nil { return err } // Traverse the genesis, put all genesis state entries into the @@ -332,7 +301,7 @@ func (p *Pruner) Prune(root common.Hash) error { return err } log.Info("State bloom filter committed", "name", filterName) - return prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start) + return prune(p.snaptree, root, p.db, p.stateBloom, filterName, nil, start) } // RecoverPruning will resume the pruning procedure during the system restart. @@ -410,7 +379,19 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { if genesis == nil { return errors.New("missing genesis block") } - t, err := trie.NewSecure(genesis.Root(), trie.NewDatabase(db)) + // genesis.Root() may be a zkTrie root (overridden via GenesisStateRoot + // for block hash compatibility). Resolve it to the actual MPT disk root + // so trie.NewSecure can open the trie. + genesisRoot := genesis.Root() + mptRoot, err := rawdb.ReadDiskStateRoot(db, genesisRoot) + if err != nil { + return fmt.Errorf("failed to read disk state root mapping for genesis root %x: %w", genesisRoot, err) + } + if mptRoot == (common.Hash{}) { + return fmt.Errorf("empty disk state root mapping for genesis root %x", genesisRoot) + } + genesisRoot = mptRoot + t, err := trie.NewSecure(genesisRoot, trie.NewDatabase(db)) if err != nil { return err } diff --git a/core/token_gas.go b/core/token_gas.go index 25c8cb48f..a2e814373 100644 --- a/core/token_gas.go +++ b/core/token_gas.go @@ -19,6 +19,25 @@ var ( maxGas uint64 = 200000 ) +func startSystemCallTrace(evm *vm.EVM) func() { + if evm == nil || evm.Config.Tracer == nil { + return nil + } + tracer := evm.Config.Tracer + if tracer.OnSystemCallEnd == nil { + return nil + } + if tracer.OnSystemCallStartV2 != nil { + tracer.OnSystemCallStartV2(evm.GetVMContext()) + return tracer.OnSystemCallEnd + } + if tracer.OnSystemCallStart != nil { + tracer.OnSystemCallStart() + return tracer.OnSystemCallEnd + } + return nil +} + // GetAltTokenBalanceHybrid returns the balance of an alt token using either storage slot or call method // If balanceSlot is zero hash, uses call method; otherwise uses storage slot method func (st *StateTransition) GetAltTokenBalanceHybrid(tokenID uint16, user common.Address) (*fees.TokenInfo, *big.Int, error) { @@ -26,7 +45,7 @@ func (st *StateTransition) GetAltTokenBalanceHybrid(tokenID uint16, user common. if err != nil { return nil, nil, err } - balance := new(big.Int) + var balance *big.Int if !info.HasSlot { balance, err = GetAltTokenBalanceByEVM(st.evm, info.TokenAddress, user) if err != nil { @@ -61,7 +80,7 @@ func GetAltTokenBalance(evm *vm.EVM, tokenID uint16, user common.Address) (*big. if err != nil { return nil, fmt.Errorf("failed to get token address for token ID %d: %v", tokenID, err) } - balance := new(big.Int) + var balance *big.Int if info.HasSlot { // balance slot exist balance, _, err = fees.GetAltTokenBalanceFromSlot(evm.StateDB, info.TokenAddress, user, info.BalanceSlot) @@ -89,9 +108,8 @@ func GetAltTokenBalanceByEVM(evm *vm.EVM, tokenAddress, userAddress common.Addre // Create a message call context sender := vm.AccountRef(userAddress) - if evm.Config.Tracer != nil && evm.Config.Tracer.OnSystemCallStartV2 != nil && evm.Config.Tracer.OnSystemCallEnd != nil { - evm.Config.Tracer.OnSystemCallStartV2(evm.GetVMContext()) - defer evm.Config.Tracer.OnSystemCallEnd() + if endTrace := startSystemCallTrace(evm); endTrace != nil { + defer endTrace() } // Execute the call (using StaticCall since we're only reading state) @@ -142,13 +160,14 @@ func transferAltTokenByEVM(evm *vm.EVM, tokenAddress, from, to common.Address, a // Create a message call context sender := vm.AccountRef(from) - if evm.Config.Tracer != nil && evm.Config.Tracer.OnSystemCallStartV2 != nil && evm.Config.Tracer.OnSystemCallEnd != nil { - evm.Config.Tracer.OnSystemCallStartV2(evm.GetVMContext()) - defer evm.Config.Tracer.OnSystemCallEnd() - } - // Execute the call - ret, _, err := evm.Call(sender, tokenAddress, data, maxGas, big.NewInt(0)) + var ret []byte + func() { + if endTrace := startSystemCallTrace(evm); endTrace != nil { + defer endTrace() + } + ret, _, err = evm.Call(sender, tokenAddress, data, maxGas, big.NewInt(0)) + }() if err != nil { return fmt.Errorf("alt token transfer call failed: %v", err) } diff --git a/core/token_gas_test.go b/core/token_gas_test.go new file mode 100644 index 000000000..e2ace0607 --- /dev/null +++ b/core/token_gas_test.go @@ -0,0 +1,88 @@ +package core + +import ( + "math/big" + "testing" + + "github.com/morph-l2/go-ethereum/core/tracing" + "github.com/morph-l2/go-ethereum/core/vm" +) + +func TestStartSystemCallTraceFallsBackToLegacyHook(t *testing.T) { + t.Parallel() + + var starts, ends int + evm := &vm.EVM{ + Config: vm.Config{ + Tracer: &tracing.Hooks{ + OnSystemCallStart: func() { starts++ }, + OnSystemCallEnd: func() { ends++ }, + }, + }, + } + + end := startSystemCallTrace(evm) + if starts != 1 { + t.Fatalf("unexpected legacy start count: %d", starts) + } + if end == nil { + t.Fatal("expected end hook") + } + end() + if ends != 1 { + t.Fatalf("unexpected end count: %d", ends) + } +} + +func TestStartSystemCallTracePrefersV2Hook(t *testing.T) { + t.Parallel() + + var legacyStarts, v2Starts int + evm := &vm.EVM{ + Context: vm.BlockContext{ + Time: big.NewInt(0), + BlockNumber: big.NewInt(0), + }, + Config: vm.Config{ + Tracer: &tracing.Hooks{ + OnSystemCallStart: func() { legacyStarts++ }, + OnSystemCallStartV2: func(*tracing.VMContext) { v2Starts++ }, + OnSystemCallEnd: func() {}, + }, + }, + } + + end := startSystemCallTrace(evm) + if end == nil { + t.Fatal("expected end hook") + } + if legacyStarts != 0 || v2Starts != 1 { + t.Fatalf("unexpected hook counts: legacy=%d v2=%d", legacyStarts, v2Starts) + } +} + +func TestStartSystemCallTraceRequiresEndHook(t *testing.T) { + t.Parallel() + + var legacyStarts, v2Starts int + evm := &vm.EVM{ + Context: vm.BlockContext{ + Time: big.NewInt(0), + BlockNumber: big.NewInt(0), + }, + Config: vm.Config{ + Tracer: &tracing.Hooks{ + OnSystemCallStart: func() { legacyStarts++ }, + OnSystemCallStartV2: func(*tracing.VMContext) { v2Starts++ }, + }, + }, + } + + end := startSystemCallTrace(evm) + if end != nil { + t.Fatal("expected nil end hook") + } + if legacyStarts != 0 || v2Starts != 0 { + t.Fatalf("unexpected hook counts without end hook: legacy=%d v2=%d", legacyStarts, v2Starts) + } +} diff --git a/core/types/morph_tx.go b/core/types/morph_tx.go index d0e7a906b..26795e8fc 100644 --- a/core/types/morph_tx.go +++ b/core/types/morph_tx.go @@ -190,6 +190,39 @@ func (tx *MorphTx) EncodeRLP(w io.Writer) error { return err } +// DecodeRLP implements rlp.Decoder so that direct rlp.Decode calls use the +// version-aware decode logic instead of reflection-based struct decoding. +// Without this, the field order mismatch between MorphTx (which has Version +// before FeeTokenID) and the v0 wire format (which lacks Version) causes +// decode failures. +func (tx *MorphTx) DecodeRLP(s *rlp.Stream) error { + kind, _, err := s.Kind() + if err != nil { + return err + } + if kind == rlp.List { + // V0 format: data is a single RLP list + raw, err := s.Raw() + if err != nil { + return err + } + return decodeV0MorphTxRLP(tx, raw) + } + // V1+ format: version byte followed by RLP list + versionByte, err := s.Uint8() + if err != nil { + return err + } + if versionByte != MorphTxVersion1 { + return errors.New("unsupported morph tx version: " + strconv.Itoa(int(versionByte))) + } + raw, err := s.Raw() + if err != nil { + return err + } + return decodeV1MorphTxRLP(tx, raw) +} + func (tx *MorphTx) encode(b *bytes.Buffer) error { switch tx.Version { case MorphTxVersion0: diff --git a/core/types/morph_tx_compat_test.go b/core/types/morph_tx_compat_test.go deleted file mode 100644 index 464fcc96b..000000000 --- a/core/types/morph_tx_compat_test.go +++ /dev/null @@ -1,593 +0,0 @@ -package types - -import ( - "bytes" - "encoding/hex" - "math/big" - "testing" - - "github.com/morph-l2/go-ethereum/common" - "github.com/morph-l2/go-ethereum/crypto" - "github.com/morph-l2/go-ethereum/rlp" -) - -// TestMorphTxV0BackwardCompatibility tests that old AltFeeTx encoded data -// can be correctly decoded by the new MorphTx decoder. -// These hex values were generated from the original AltFeeTx implementation. -func TestMorphTxV0BackwardCompatibility(t *testing.T) { - // Expected values from the original encoding - expectedTo := common.HexToAddress("0x1234567890123456789012345678901234567890") - expectedChainID := big.NewInt(2818) - expectedNonce := uint64(1) - expectedGasTipCap := big.NewInt(1000000000) - expectedGasFeeCap := big.NewInt(2000000000) - expectedGas := uint64(21000) - expectedValue := big.NewInt(1000000000000000000) // 1 ETH - expectedR, _ := new(big.Int).SetString("abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890", 16) - expectedS, _ := new(big.Int).SetString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", 16) - - testCases := []struct { - name string - fullHex string // Full hex including 0x7F prefix - feeTokenID uint16 - feeLimit *big.Int - }{ - { - // Case 1: FeeLimit has value (0.5 ETH = 500000000000000000) - name: "V0 with FeeLimit value", - fullHex: "7ff87e820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018806f05b59d3b2000001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", - feeTokenID: 1, - feeLimit: big.NewInt(500000000000000000), - }, - { - // Case 2: FeeLimit is nil (encoded as 0x80) - name: "V0 with nil FeeLimit", - fullHex: "7ff876820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", - feeTokenID: 1, - feeLimit: nil, - }, - { - // Case 3: FeeLimit is 0 (also encoded as 0x80) - name: "V0 with zero FeeLimit", - fullHex: "7ff876820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", - feeTokenID: 1, - feeLimit: nil, // 0 is encoded as empty, decoded as nil - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - data, err := hex.DecodeString(tc.fullHex) - if err != nil { - t.Fatalf("failed to decode hex: %v", err) - } - - // Verify first byte is MorphTxType (0x7F) - if data[0] != MorphTxType { - t.Fatalf("expected first byte 0x7F, got 0x%x", data[0]) - } - - // Skip txType byte, decode the rest - innerData := data[1:] - t.Logf("First inner byte: 0x%x (should be RLP list prefix >= 0xC0)", innerData[0]) - - // Verify it's RLP list prefix (V0 format) - if innerData[0] < 0xC0 { - t.Errorf("V0 data should start with RLP list prefix, got 0x%x", innerData[0]) - } - - // Decode using MorphTx.decode - var decoded MorphTx - if err := decoded.decode(innerData); err != nil { - t.Fatalf("failed to decode MorphTx: %v", err) - } - - // Verify version is 0 (V0 format) - if decoded.Version != MorphTxVersion0 { - t.Errorf("expected Version 0, got %d", decoded.Version) - } - - // Verify FeeTokenID - if decoded.FeeTokenID != tc.feeTokenID { - t.Errorf("expected FeeTokenID %d, got %d", tc.feeTokenID, decoded.FeeTokenID) - } - - // Verify FeeLimit - if tc.feeLimit == nil { - if decoded.FeeLimit != nil && decoded.FeeLimit.Sign() != 0 { - t.Errorf("expected nil/zero FeeLimit, got %v", decoded.FeeLimit) - } - } else { - if decoded.FeeLimit == nil || decoded.FeeLimit.Cmp(tc.feeLimit) != 0 { - t.Errorf("expected FeeLimit %v, got %v", tc.feeLimit, decoded.FeeLimit) - } - } - - // Verify other common fields - if decoded.ChainID.Cmp(expectedChainID) != 0 { - t.Errorf("ChainID mismatch: expected %v, got %v", expectedChainID, decoded.ChainID) - } - if decoded.Nonce != expectedNonce { - t.Errorf("Nonce mismatch: expected %d, got %d", expectedNonce, decoded.Nonce) - } - if decoded.GasTipCap.Cmp(expectedGasTipCap) != 0 { - t.Errorf("GasTipCap mismatch: expected %v, got %v", expectedGasTipCap, decoded.GasTipCap) - } - if decoded.GasFeeCap.Cmp(expectedGasFeeCap) != 0 { - t.Errorf("GasFeeCap mismatch: expected %v, got %v", expectedGasFeeCap, decoded.GasFeeCap) - } - if decoded.Gas != expectedGas { - t.Errorf("Gas mismatch: expected %d, got %d", expectedGas, decoded.Gas) - } - if decoded.To == nil || *decoded.To != expectedTo { - t.Errorf("To mismatch: expected %v, got %v", expectedTo, decoded.To) - } - if decoded.Value.Cmp(expectedValue) != 0 { - t.Errorf("Value mismatch: expected %v, got %v", expectedValue, decoded.Value) - } - if decoded.R.Cmp(expectedR) != 0 { - t.Errorf("R mismatch: expected %v, got %v", expectedR, decoded.R) - } - if decoded.S.Cmp(expectedS) != 0 { - t.Errorf("S mismatch: expected %v, got %v", expectedS, decoded.S) - } - - t.Logf("Successfully decoded V0 MorphTx: ChainID=%v, Nonce=%d, FeeTokenID=%d, FeeLimit=%v, Version=%d", - decoded.ChainID, decoded.Nonce, decoded.FeeTokenID, decoded.FeeLimit, decoded.Version) - }) - } -} - -// encodeMorphTx encodes a MorphTx using its encode method -func encodeMorphTx(tx *MorphTx) ([]byte, error) { - buf := new(bytes.Buffer) - buf.WriteByte(MorphTxType) // Write txType prefix - if err := tx.encode(buf); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -// TestMorphTxV1Encoding tests the new V1 encoding format -// where version is a prefix byte before RLP data. -func TestMorphTxV1Encoding(t *testing.T) { - reference := common.HexToReference("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") - memo := []byte("test memo") - to := common.HexToAddress("0x1234567890123456789012345678901234567890") - - tx := &MorphTx{ - ChainID: big.NewInt(2818), - Nonce: 1, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(0), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 0, // ETH - FeeLimit: nil, - Version: MorphTxVersion1, - Reference: &reference, - Memo: &memo, - V: big.NewInt(0), - R: big.NewInt(0), - S: big.NewInt(0), - } - - // Encode - encoded, err := encodeMorphTx(tx) - if err != nil { - t.Fatalf("failed to encode: %v", err) - } - - t.Logf("V1 encoded hex: %s", hex.EncodeToString(encoded)) - t.Logf("First byte (type): 0x%x", encoded[0]) - t.Logf("Second byte (version): 0x%x", encoded[1]) - - // Verify first byte is MorphTxType - if encoded[0] != MorphTxType { - t.Errorf("expected first byte 0x%x, got 0x%x", MorphTxType, encoded[0]) - } - - // Verify second byte is version - if encoded[1] != MorphTxVersion1 { - t.Errorf("expected second byte 0x%x (version 1), got 0x%x", MorphTxVersion1, encoded[1]) - } - - // Decode back - var decoded MorphTx - if err := decoded.decode(encoded[1:]); err != nil { // Skip txType byte - t.Fatalf("failed to decode: %v", err) - } - - // Verify fields - if decoded.Version != MorphTxVersion1 { - t.Errorf("expected Version 1, got %d", decoded.Version) - } - if decoded.Reference == nil || *decoded.Reference != reference { - t.Errorf("Reference mismatch") - } - if decoded.Memo == nil || string(*decoded.Memo) != string(memo) { - t.Errorf("Memo mismatch") - } - - t.Logf("Successfully encoded and decoded V1 MorphTx") -} - -// TestMorphTxV0V1RoundTrip tests encoding/decoding round trip for both versions -func TestMorphTxV0V1RoundTrip(t *testing.T) { - to := common.HexToAddress("0x1234567890123456789012345678901234567890") - reference := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") - memo := []byte("hello") - - testCases := []struct { - name string - tx *MorphTx - }{ - { - name: "V0 with FeeTokenID", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 1, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(1000000000000000000), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 1, // Non-zero required for V0 - FeeLimit: big.NewInt(100000000000000000), - Version: MorphTxVersion0, - V: big.NewInt(1), - R: big.NewInt(123456), - S: big.NewInt(654321), - }, - }, - { - name: "V1 with Reference and Memo", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 2, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(0), - Data: []byte{0x01, 0x02, 0x03}, - AccessList: AccessList{}, - FeeTokenID: 0, - FeeLimit: nil, - Version: MorphTxVersion1, - Reference: &reference, - Memo: &memo, - V: big.NewInt(0), - R: big.NewInt(111), - S: big.NewInt(222), - }, - }, - { - name: "V1 with FeeTokenID and Reference", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 3, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 50000, - To: &to, - Value: big.NewInt(500000000000000000), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 2, - FeeLimit: big.NewInt(200000000000000000), - Version: MorphTxVersion1, - Reference: &reference, - Memo: nil, - V: big.NewInt(1), - R: big.NewInt(333), - S: big.NewInt(444), - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Encode - encoded, err := encodeMorphTx(tc.tx) - if err != nil { - t.Fatalf("failed to encode: %v", err) - } - - t.Logf("Encoded hex: %s", hex.EncodeToString(encoded)) - t.Logf("Length: %d bytes", len(encoded)) - - // Decode - var decoded MorphTx - if err := decoded.decode(encoded[1:]); err != nil { // Skip txType byte - t.Fatalf("failed to decode: %v", err) - } - - // Verify key fields - if decoded.Version != tc.tx.Version { - t.Errorf("Version mismatch: expected %d, got %d", tc.tx.Version, decoded.Version) - } - if decoded.FeeTokenID != tc.tx.FeeTokenID { - t.Errorf("FeeTokenID mismatch: expected %d, got %d", tc.tx.FeeTokenID, decoded.FeeTokenID) - } - if decoded.Nonce != tc.tx.Nonce { - t.Errorf("Nonce mismatch: expected %d, got %d", tc.tx.Nonce, decoded.Nonce) - } - if decoded.Gas != tc.tx.Gas { - t.Errorf("Gas mismatch: expected %d, got %d", tc.tx.Gas, decoded.Gas) - } - - t.Logf("Round-trip successful for %s", tc.name) - }) - } -} - -// TestMorphTxVersionDetection tests the version detection logic in decode -func TestMorphTxVersionDetection(t *testing.T) { - // Create a V0 transaction (legacy format) - to := common.HexToAddress("0x1234567890123456789012345678901234567890") - v0Tx := &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 1, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(0), - FeeTokenID: 1, - FeeLimit: big.NewInt(100), - Version: MorphTxVersion0, - V: big.NewInt(0), - R: big.NewInt(0), - S: big.NewInt(0), - } - - encoded, err := encodeMorphTx(v0Tx) - if err != nil { - t.Fatalf("failed to encode V0: %v", err) - } - innerData := encoded[1:] // Skip txType - - // V0 should start with RLP list prefix (0xC0-0xFF) - if innerData[0] < 0xC0 { - t.Errorf("V0 encoded data should start with RLP list prefix, got 0x%x", innerData[0]) - } - t.Logf("V0 first inner byte: 0x%x (RLP list prefix)", innerData[0]) - - // Create a V1 transaction - reference := common.HexToReference("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") - v1Tx := &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 1, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(0), - FeeTokenID: 0, - Version: MorphTxVersion1, - Reference: &reference, - V: big.NewInt(0), - R: big.NewInt(0), - S: big.NewInt(0), - } - - encoded, err = encodeMorphTx(v1Tx) - if err != nil { - t.Fatalf("failed to encode V1: %v", err) - } - innerData = encoded[1:] // Skip txType - - // V1 should start with version byte (0x01) - if innerData[0] != MorphTxVersion1 { - t.Errorf("V1 encoded data should start with version byte 0x01, got 0x%x", innerData[0]) - } - t.Logf("V1 first inner byte: 0x%x (version prefix)", innerData[0]) - - // Second byte should be RLP list prefix - if innerData[1] < 0xC0 { - t.Errorf("V1 second byte should be RLP list prefix, got 0x%x", innerData[1]) - } - t.Logf("V1 second inner byte: 0x%x (RLP list prefix)", innerData[1]) -} - -// TestMorphTxEncodeRLPConsistency verifies that rlp.Encode(morphTx) produces -// the same output as the custom encode() method. This ensures Hash() (which -// uses rlp.Encode internally) is consistent with the wire format. -func TestMorphTxEncodeRLPConsistency(t *testing.T) { - to := common.HexToAddress("0x1234567890123456789012345678901234567890") - reference := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") - memo := []byte("hello") - - testCases := []struct { - name string - tx *MorphTx - }{ - { - name: "V0", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 1, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(1000000000000000000), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 1, - FeeLimit: big.NewInt(100000000000000000), - Version: MorphTxVersion0, - V: big.NewInt(1), - R: big.NewInt(123456), - S: big.NewInt(654321), - }, - }, - { - name: "V1 with Reference and Memo", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 2, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(0), - Data: []byte{0x01, 0x02, 0x03}, - AccessList: AccessList{}, - FeeTokenID: 0, - FeeLimit: nil, - Version: MorphTxVersion1, - Reference: &reference, - Memo: &memo, - V: big.NewInt(0), - R: big.NewInt(111), - S: big.NewInt(222), - }, - }, - { - name: "V1 minimal", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 3, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 50000, - To: &to, - Value: big.NewInt(0), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 0, - FeeLimit: nil, - Version: MorphTxVersion1, - Reference: nil, - Memo: nil, - V: big.NewInt(0), - R: big.NewInt(0), - S: big.NewInt(0), - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Path 1: rlp.Encode (used by Hash via prefixedRlpHash) - var rlpBuf bytes.Buffer - if err := rlp.Encode(&rlpBuf, tc.tx); err != nil { - t.Fatalf("rlp.Encode failed: %v", err) - } - - // Path 2: custom encode() (used by wire format via encodeTyped) - var encodeBuf bytes.Buffer - if err := tc.tx.encode(&encodeBuf); err != nil { - t.Fatalf("encode() failed: %v", err) - } - - if !bytes.Equal(rlpBuf.Bytes(), encodeBuf.Bytes()) { - t.Errorf("rlp.Encode and encode() produce different output:\n rlp.Encode = %s\n encode() = %s", - hex.EncodeToString(rlpBuf.Bytes()), hex.EncodeToString(encodeBuf.Bytes())) - } - }) - } -} - -// TestMorphTxHashMatchesWireFormat verifies that tx.Hash() equals -// keccak256(wire_bytes) for both V0 and V1 MorphTx transactions. -func TestMorphTxHashMatchesWireFormat(t *testing.T) { - to := common.HexToAddress("0x1234567890123456789012345678901234567890") - reference := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") - memo := []byte("hello") - - testCases := []struct { - name string - tx *MorphTx - }{ - { - name: "V0", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 1, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(1000000000000000000), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 1, - FeeLimit: big.NewInt(100000000000000000), - Version: MorphTxVersion0, - V: big.NewInt(1), - R: big.NewInt(123456), - S: big.NewInt(654321), - }, - }, - { - name: "V1 with Reference and Memo", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 2, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 21000, - To: &to, - Value: big.NewInt(0), - Data: []byte{0x01, 0x02, 0x03}, - AccessList: AccessList{}, - FeeTokenID: 0, - FeeLimit: nil, - Version: MorphTxVersion1, - Reference: &reference, - Memo: &memo, - V: big.NewInt(0), - R: big.NewInt(111), - S: big.NewInt(222), - }, - }, - { - name: "V1 minimal", - tx: &MorphTx{ - ChainID: big.NewInt(1), - Nonce: 3, - GasTipCap: big.NewInt(1000000000), - GasFeeCap: big.NewInt(2000000000), - Gas: 50000, - To: &to, - Value: big.NewInt(0), - Data: []byte{}, - AccessList: AccessList{}, - FeeTokenID: 0, - FeeLimit: nil, - Version: MorphTxVersion1, - Reference: nil, - Memo: nil, - V: big.NewInt(0), - R: big.NewInt(0), - S: big.NewInt(0), - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tx := NewTx(tc.tx) - - wireBytes, err := tx.MarshalBinary() - if err != nil { - t.Fatalf("MarshalBinary failed: %v", err) - } - - expectedHash := crypto.Keccak256Hash(wireBytes) - - if tx.Hash() != expectedHash { - t.Errorf("Hash mismatch:\n tx.Hash() = %s\n keccak256(wire) = %s\n wireBytes = %s", - tx.Hash().Hex(), expectedHash.Hex(), hex.EncodeToString(wireBytes)) - } - }) - } -} diff --git a/core/types/morph_tx_test.go b/core/types/morph_tx_test.go new file mode 100644 index 000000000..f451e9466 --- /dev/null +++ b/core/types/morph_tx_test.go @@ -0,0 +1,1465 @@ +package types + +import ( + "bytes" + "encoding/hex" + "math/big" + "testing" + + "github.com/morph-l2/go-ethereum/common" + "github.com/morph-l2/go-ethereum/crypto" + "github.com/morph-l2/go-ethereum/rlp" +) + +// --------------------------------------------------------------------------- +// Encoding / Decoding compatibility tests (from morph_tx_compat_test.go) +// --------------------------------------------------------------------------- + +// TestMorphTxV0BackwardCompatibility tests that old AltFeeTx encoded data +// can be correctly decoded by the new MorphTx decoder. +// These hex values were generated from the original AltFeeTx implementation. +func TestMorphTxV0BackwardCompatibility(t *testing.T) { + // Expected values from the original encoding + expectedTo := common.HexToAddress("0x1234567890123456789012345678901234567890") + expectedChainID := big.NewInt(2818) + expectedNonce := uint64(1) + expectedGasTipCap := big.NewInt(1000000000) + expectedGasFeeCap := big.NewInt(2000000000) + expectedGas := uint64(21000) + expectedValue := big.NewInt(1000000000000000000) // 1 ETH + expectedR, _ := new(big.Int).SetString("abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890", 16) + expectedS, _ := new(big.Int).SetString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", 16) + + testCases := []struct { + name string + fullHex string // Full hex including 0x7F prefix + feeTokenID uint16 + feeLimit *big.Int + }{ + { + // Case 1: FeeLimit has value (0.5 ETH = 500000000000000000) + name: "V0 with FeeLimit value", + fullHex: "7ff87e820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018806f05b59d3b2000001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + feeTokenID: 1, + feeLimit: big.NewInt(500000000000000000), + }, + { + // Case 2: FeeLimit is nil (encoded as 0x80) + name: "V0 with nil FeeLimit", + fullHex: "7ff876820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + feeTokenID: 1, + feeLimit: nil, + }, + { + // Case 3: FeeLimit is 0 (also encoded as 0x80) + name: "V0 with zero FeeLimit", + fullHex: "7ff876820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + feeTokenID: 1, + feeLimit: nil, // 0 is encoded as empty, decoded as nil + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data, err := hex.DecodeString(tc.fullHex) + if err != nil { + t.Fatalf("failed to decode hex: %v", err) + } + + // Verify first byte is MorphTxType (0x7F) + if data[0] != MorphTxType { + t.Fatalf("expected first byte 0x7F, got 0x%x", data[0]) + } + + // Skip txType byte, decode the rest + innerData := data[1:] + t.Logf("First inner byte: 0x%x (should be RLP list prefix >= 0xC0)", innerData[0]) + + // Verify it's RLP list prefix (V0 format) + if innerData[0] < 0xC0 { + t.Errorf("V0 data should start with RLP list prefix, got 0x%x", innerData[0]) + } + + // Decode using MorphTx.decode + var decoded MorphTx + if err := decoded.decode(innerData); err != nil { + t.Fatalf("failed to decode MorphTx: %v", err) + } + + // Verify version is 0 (V0 format) + if decoded.Version != MorphTxVersion0 { + t.Errorf("expected Version 0, got %d", decoded.Version) + } + + // Verify FeeTokenID + if decoded.FeeTokenID != tc.feeTokenID { + t.Errorf("expected FeeTokenID %d, got %d", tc.feeTokenID, decoded.FeeTokenID) + } + + // Verify FeeLimit + if tc.feeLimit == nil { + if decoded.FeeLimit != nil && decoded.FeeLimit.Sign() != 0 { + t.Errorf("expected nil/zero FeeLimit, got %v", decoded.FeeLimit) + } + } else { + if decoded.FeeLimit == nil || decoded.FeeLimit.Cmp(tc.feeLimit) != 0 { + t.Errorf("expected FeeLimit %v, got %v", tc.feeLimit, decoded.FeeLimit) + } + } + + // Verify other common fields + if decoded.ChainID.Cmp(expectedChainID) != 0 { + t.Errorf("ChainID mismatch: expected %v, got %v", expectedChainID, decoded.ChainID) + } + if decoded.Nonce != expectedNonce { + t.Errorf("Nonce mismatch: expected %d, got %d", expectedNonce, decoded.Nonce) + } + if decoded.GasTipCap.Cmp(expectedGasTipCap) != 0 { + t.Errorf("GasTipCap mismatch: expected %v, got %v", expectedGasTipCap, decoded.GasTipCap) + } + if decoded.GasFeeCap.Cmp(expectedGasFeeCap) != 0 { + t.Errorf("GasFeeCap mismatch: expected %v, got %v", expectedGasFeeCap, decoded.GasFeeCap) + } + if decoded.Gas != expectedGas { + t.Errorf("Gas mismatch: expected %d, got %d", expectedGas, decoded.Gas) + } + if decoded.To == nil || *decoded.To != expectedTo { + t.Errorf("To mismatch: expected %v, got %v", expectedTo, decoded.To) + } + if decoded.Value.Cmp(expectedValue) != 0 { + t.Errorf("Value mismatch: expected %v, got %v", expectedValue, decoded.Value) + } + if decoded.R.Cmp(expectedR) != 0 { + t.Errorf("R mismatch: expected %v, got %v", expectedR, decoded.R) + } + if decoded.S.Cmp(expectedS) != 0 { + t.Errorf("S mismatch: expected %v, got %v", expectedS, decoded.S) + } + + t.Logf("Successfully decoded V0 MorphTx: ChainID=%v, Nonce=%d, FeeTokenID=%d, FeeLimit=%v, Version=%d", + decoded.ChainID, decoded.Nonce, decoded.FeeTokenID, decoded.FeeLimit, decoded.Version) + }) + } +} + +// TestMorphTxV1Encoding tests the new V1 encoding format +// where version is a prefix byte before RLP data. +func TestMorphTxV1Encoding(t *testing.T) { + reference := common.HexToReference("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + memo := []byte("test memo") + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + + tx := &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, // ETH + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &reference, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + } + + // Encode + encoded, err := encodeMorphTx(tx) + if err != nil { + t.Fatalf("failed to encode: %v", err) + } + + t.Logf("V1 encoded hex: %s", hex.EncodeToString(encoded)) + t.Logf("First byte (type): 0x%x", encoded[0]) + t.Logf("Second byte (version): 0x%x", encoded[1]) + + // Verify first byte is MorphTxType + if encoded[0] != MorphTxType { + t.Errorf("expected first byte 0x%x, got 0x%x", MorphTxType, encoded[0]) + } + + // Verify second byte is version + if encoded[1] != MorphTxVersion1 { + t.Errorf("expected second byte 0x%x (version 1), got 0x%x", MorphTxVersion1, encoded[1]) + } + + // Decode back + var decoded MorphTx + if err := decoded.decode(encoded[1:]); err != nil { // Skip txType byte + t.Fatalf("failed to decode: %v", err) + } + + // Verify fields + if decoded.Version != MorphTxVersion1 { + t.Errorf("expected Version 1, got %d", decoded.Version) + } + if decoded.Reference == nil || *decoded.Reference != reference { + t.Errorf("Reference mismatch") + } + if decoded.Memo == nil || string(*decoded.Memo) != string(memo) { + t.Errorf("Memo mismatch") + } + + t.Logf("Successfully encoded and decoded V1 MorphTx") +} + +// TestMorphTxV0V1RoundTrip tests encoding/decoding round trip for both versions +func TestMorphTxV0V1RoundTrip(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + reference := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("hello") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V0 with FeeTokenID", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(1000000000000000000), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, // Non-zero required for V0 + FeeLimit: big.NewInt(100000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(123456), + S: big.NewInt(654321), + }, + }, + { + name: "V1 with Reference and Memo", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 2, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{0x01, 0x02, 0x03}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &reference, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(111), + S: big.NewInt(222), + }, + }, + { + name: "V1 with FeeTokenID and Reference", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 3, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 50000, + To: &to, + Value: big.NewInt(500000000000000000), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 2, + FeeLimit: big.NewInt(200000000000000000), + Version: MorphTxVersion1, + Reference: &reference, + Memo: nil, + V: big.NewInt(1), + R: big.NewInt(333), + S: big.NewInt(444), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Encode + encoded, err := encodeMorphTx(tc.tx) + if err != nil { + t.Fatalf("failed to encode: %v", err) + } + + t.Logf("Encoded hex: %s", hex.EncodeToString(encoded)) + t.Logf("Length: %d bytes", len(encoded)) + + // Decode + var decoded MorphTx + if err := decoded.decode(encoded[1:]); err != nil { // Skip txType byte + t.Fatalf("failed to decode: %v", err) + } + + // Verify key fields + if decoded.Version != tc.tx.Version { + t.Errorf("Version mismatch: expected %d, got %d", tc.tx.Version, decoded.Version) + } + if decoded.FeeTokenID != tc.tx.FeeTokenID { + t.Errorf("FeeTokenID mismatch: expected %d, got %d", tc.tx.FeeTokenID, decoded.FeeTokenID) + } + if decoded.Nonce != tc.tx.Nonce { + t.Errorf("Nonce mismatch: expected %d, got %d", tc.tx.Nonce, decoded.Nonce) + } + if decoded.Gas != tc.tx.Gas { + t.Errorf("Gas mismatch: expected %d, got %d", tc.tx.Gas, decoded.Gas) + } + + t.Logf("Round-trip successful for %s", tc.name) + }) + } +} + +// TestMorphTxVersionDetection tests the version detection logic in decode +func TestMorphTxVersionDetection(t *testing.T) { + // Create a V0 transaction (legacy format) + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + v0Tx := &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + FeeTokenID: 1, + FeeLimit: big.NewInt(100), + Version: MorphTxVersion0, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + } + + encoded, err := encodeMorphTx(v0Tx) + if err != nil { + t.Fatalf("failed to encode V0: %v", err) + } + innerData := encoded[1:] // Skip txType + + // V0 should start with RLP list prefix (0xC0-0xFF) + if innerData[0] < 0xC0 { + t.Errorf("V0 encoded data should start with RLP list prefix, got 0x%x", innerData[0]) + } + t.Logf("V0 first inner byte: 0x%x (RLP list prefix)", innerData[0]) + + // Create a V1 transaction + reference := common.HexToReference("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + v1Tx := &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + FeeTokenID: 0, + Version: MorphTxVersion1, + Reference: &reference, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + } + + encoded, err = encodeMorphTx(v1Tx) + if err != nil { + t.Fatalf("failed to encode V1: %v", err) + } + innerData = encoded[1:] // Skip txType + + // V1 should start with version byte (0x01) + if innerData[0] != MorphTxVersion1 { + t.Errorf("V1 encoded data should start with version byte 0x01, got 0x%x", innerData[0]) + } + t.Logf("V1 first inner byte: 0x%x (version prefix)", innerData[0]) + + // Second byte should be RLP list prefix + if innerData[1] < 0xC0 { + t.Errorf("V1 second byte should be RLP list prefix, got 0x%x", innerData[1]) + } + t.Logf("V1 second inner byte: 0x%x (RLP list prefix)", innerData[1]) +} + +// TestMorphTxEncodeRLPConsistency verifies that rlp.Encode(morphTx) produces +// the same output as the custom encode() method. This ensures Hash() (which +// uses rlp.Encode internally) is consistent with the wire format. +func TestMorphTxEncodeRLPConsistency(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + reference := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("hello") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V0", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(1000000000000000000), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(100000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(123456), + S: big.NewInt(654321), + }, + }, + { + name: "V1 with Reference and Memo", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 2, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{0x01, 0x02, 0x03}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &reference, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(111), + S: big.NewInt(222), + }, + }, + { + name: "V1 minimal", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 3, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 50000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: nil, + Memo: nil, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Path 1: rlp.Encode (used by Hash via prefixedRlpHash) + var rlpBuf bytes.Buffer + if err := rlp.Encode(&rlpBuf, tc.tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + + // Path 2: custom encode() (used by wire format via encodeTyped) + var encodeBuf bytes.Buffer + if err := tc.tx.encode(&encodeBuf); err != nil { + t.Fatalf("encode() failed: %v", err) + } + + if !bytes.Equal(rlpBuf.Bytes(), encodeBuf.Bytes()) { + t.Errorf("rlp.Encode and encode() produce different output:\n rlp.Encode = %s\n encode() = %s", + hex.EncodeToString(rlpBuf.Bytes()), hex.EncodeToString(encodeBuf.Bytes())) + } + }) + } +} + +// TestMorphTxHashMatchesWireFormat verifies that tx.Hash() equals +// keccak256(wire_bytes) for both V0 and V1 MorphTx transactions. +func TestMorphTxHashMatchesWireFormat(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + reference := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("hello") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V0", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(1000000000000000000), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(100000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(123456), + S: big.NewInt(654321), + }, + }, + { + name: "V1 with Reference and Memo", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 2, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{0x01, 0x02, 0x03}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &reference, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(111), + S: big.NewInt(222), + }, + }, + { + name: "V1 minimal", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 3, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 50000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: nil, + Memo: nil, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := NewTx(tc.tx) + + wireBytes, err := tx.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed: %v", err) + } + + expectedHash := crypto.Keccak256Hash(wireBytes) + + if tx.Hash() != expectedHash { + t.Errorf("Hash mismatch:\n tx.Hash() = %s\n keccak256(wire) = %s\n wireBytes = %s", + tx.Hash().Hex(), expectedHash.Hex(), hex.EncodeToString(wireBytes)) + } + }) + } +} + +// --------------------------------------------------------------------------- +// DecodeRLP tests (from morph_tx_decode_rlp_test.go) +// --------------------------------------------------------------------------- + +// TestDecodeRLP_V0RoundTrip tests that rlp.Encode → rlp.DecodeBytes round-trip +// works correctly for V0 MorphTx via DecodeRLP. This is the scenario that +// triggered the original bug where derivation module called rlp.DecodeBytes +// directly on MorphTx, causing reflection-based field misalignment. +func TestDecodeRLP_V0RoundTrip(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V0 basic", + tx: &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(1000000000000000000), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(500000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(123456), + S: big.NewInt(654321), + }, + }, + { + name: "V0 large FeeLimit (>65535, the bug trigger)", + tx: &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 42, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(999999999999999999), + Version: MorphTxVersion0, + V: big.NewInt(0), + R: big.NewInt(111), + S: big.NewInt(222), + }, + }, + { + name: "V0 nil FeeLimit", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 0, + GasTipCap: big.NewInt(100), + GasFeeCap: big.NewInt(200), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 5, + FeeLimit: nil, + Version: MorphTxVersion0, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + { + name: "V0 max FeeTokenID", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 100000, + To: &to, + Value: big.NewInt(0), + Data: []byte{0xde, 0xad, 0xbe, 0xef}, + AccessList: AccessList{}, + FeeTokenID: 65535, + FeeLimit: big.NewInt(1), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(999), + S: big.NewInt(888), + }, + }, + { + name: "V0 with AccessList", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 7, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 50000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{ + {Address: common.HexToAddress("0xaaaa"), StorageKeys: []common.Hash{common.HexToHash("0x01")}}, + }, + FeeTokenID: 3, + FeeLimit: big.NewInt(100), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(100), + S: big.NewInt(200), + }, + }, + { + name: "V0 contract creation (To=nil)", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 0, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 3000000, + To: nil, + Value: big.NewInt(0), + Data: []byte{0x60, 0x60, 0x60, 0x40}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(500000), + Version: MorphTxVersion0, + V: big.NewInt(0), + R: big.NewInt(12345), + S: big.NewInt(67890), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, tc.tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + + var decoded MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + assertMorphTxEqual(t, tc.tx, &decoded) + }) + } +} + +// TestDecodeRLP_V1RoundTrip tests that rlp.Encode → rlp.DecodeBytes round-trip +// works correctly for V1 MorphTx. V1 uses a version byte prefix before the RLP list. +func TestDecodeRLP_V1RoundTrip(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + ref := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("test memo data") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V1 with Reference and Memo", + tx: &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(100000), + Version: MorphTxVersion1, + Reference: &ref, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(111), + S: big.NewInt(222), + }, + }, + { + name: "V1 Reference only (no Memo, no FeeTokenID)", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 10, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &ref, + Memo: nil, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + { + name: "V1 Memo only (no Reference, no FeeTokenID)", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 20, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: nil, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(333), + S: big.NewInt(444), + }, + }, + { + name: "V1 minimal (no Reference, no Memo, no FeeTokenID)", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 0, + GasTipCap: big.NewInt(100), + GasFeeCap: big.NewInt(200), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: nil, + Memo: nil, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + }, + }, + { + name: "V1 with FeeTokenID + Reference + Memo + large FeeLimit", + tx: &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 999, + GasTipCap: big.NewInt(5000000000), + GasFeeCap: big.NewInt(10000000000), + Gas: 500000, + To: &to, + Value: big.NewInt(1000000000000000000), + Data: []byte{0x01, 0x02, 0x03, 0x04}, + AccessList: AccessList{}, + FeeTokenID: 2, + FeeLimit: big.NewInt(999999999999999999), + Version: MorphTxVersion1, + Reference: &ref, + Memo: &memo, + V: big.NewInt(1), + R: big.NewInt(12345), + S: big.NewInt(67890), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, tc.tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + + var decoded MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + assertMorphTxEqual(t, tc.tx, &decoded) + }) + } +} + +// TestDecodeRLP_V0LargeFeeLimit_BugRepro reproduces the exact bug from the error: +// +// rlp: input string too long for uint16, decoding into (types.MorphTx).FeeTokenID +// +// Without DecodeRLP, rlp.DecodeBytes uses reflection which misaligns fields: +// V0 wire format has [.., AccessList, FeeTokenID(uint16), FeeLimit(*big.Int), ..] +// but MorphTx struct has [.., AccessList, Version(uint8), FeeTokenID(uint16), FeeLimit, ..] +// so FeeLimit's big.Int bytes are decoded into FeeTokenID (uint16), causing the error. +func TestDecodeRLP_V0LargeFeeLimit_BugRepro(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + feeLimits := []*big.Int{ + big.NewInt(65536), // just above uint16 max + big.NewInt(500000000000000000), // 0.5 ETH + big.NewInt(999999999999999999), // ~1 ETH + new(big.Int).Exp(big.NewInt(10), big.NewInt(30), nil), // 10^30 + } + + for _, feeLimit := range feeLimits { + t.Run("FeeLimit="+feeLimit.String(), func(t *testing.T) { + tx := &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: feeLimit, + Version: MorphTxVersion0, + V: big.NewInt(0), + R: big.NewInt(111), + S: big.NewInt(222), + } + + // Encode via rlp.Encode (uses EncodeRLP → encode → v0MorphTxRLP) + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + + // Decode via rlp.DecodeBytes (uses DecodeRLP). + // Without DecodeRLP, this would fail with: + // "rlp: input string too long for uint16, decoding into (types.MorphTx).FeeTokenID" + var decoded MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes failed (this is the bug!): %v", err) + } + + if decoded.FeeTokenID != tx.FeeTokenID { + t.Errorf("FeeTokenID mismatch: want %d, got %d", tx.FeeTokenID, decoded.FeeTokenID) + } + if decoded.FeeLimit == nil || decoded.FeeLimit.Cmp(feeLimit) != 0 { + t.Errorf("FeeLimit mismatch: want %v, got %v", feeLimit, decoded.FeeLimit) + } + if decoded.Version != MorphTxVersion0 { + t.Errorf("Version mismatch: want %d, got %d", MorphTxVersion0, decoded.Version) + } + }) + } +} + +// TestDecodeRLP_MatchesDecode verifies that DecodeRLP (via rlp.DecodeBytes) +// produces the same result as the custom decode() method for both V0 and V1. +func TestDecodeRLP_MatchesDecode(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + ref := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("hello world") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V0", + tx: &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(1000000000000000000), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(500000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(123456), + S: big.NewInt(654321), + }, + }, + { + name: "V1", + tx: &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 2, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{0x01}, + AccessList: AccessList{}, + FeeTokenID: 2, + FeeLimit: big.NewInt(100000), + Version: MorphTxVersion1, + Reference: &ref, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(999), + S: big.NewInt(888), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Encode via encode() to get raw wire bytes + var buf bytes.Buffer + if err := tc.tx.encode(&buf); err != nil { + t.Fatalf("encode failed: %v", err) + } + wireBytes := buf.Bytes() + + // Path 1: decode via custom decode() + var fromDecode MorphTx + if err := fromDecode.decode(wireBytes); err != nil { + t.Fatalf("decode() failed: %v", err) + } + + // Path 2: decode via rlp.DecodeBytes → DecodeRLP + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, tc.tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + var fromDecodeRLP MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &fromDecodeRLP); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + // Both paths should produce identical results + assertMorphTxEqual(t, &fromDecode, &fromDecodeRLP) + }) + } +} + +// TestDecodeRLP_InRLPList tests DecodeRLP when MorphTx is embedded within an +// RLP list, simulating the batch parsing scenario in the derivation module. +func TestDecodeRLP_InRLPList(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + ref := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("memo") + + txV0 := &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(500000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(111), + S: big.NewInt(222), + } + + txV1 := &MorphTx{ + ChainID: big.NewInt(2818), + Nonce: 2, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &ref, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(333), + S: big.NewInt(444), + } + + // Encode a list of MorphTx (simulating a batch) + batch := []*MorphTx{txV0, txV1} + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, batch); err != nil { + t.Fatalf("rlp.Encode batch failed: %v", err) + } + + // Decode back + var decoded []*MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes batch failed: %v", err) + } + + if len(decoded) != 2 { + t.Fatalf("expected 2 transactions, got %d", len(decoded)) + } + + t.Run("batch[0] V0", func(t *testing.T) { + assertMorphTxEqual(t, txV0, decoded[0]) + }) + t.Run("batch[1] V1", func(t *testing.T) { + assertMorphTxEqual(t, txV1, decoded[1]) + }) +} + +// TestDecodeRLP_V0BackwardCompat_HardcodedHex tests that DecodeRLP correctly +// handles hardcoded V0 data from the original AltFeeTx encoding (the same test +// vectors from TestMorphTxV0BackwardCompatibility, but decoded via rlp.DecodeBytes). +func TestDecodeRLP_V0BackwardCompat_HardcodedHex(t *testing.T) { + // This hex was generated from the original AltFeeTx implementation. + // Inner data (after stripping MorphTxType 0x7F prefix) is a V0 RLP list. + // + // Fields: ChainID=2818, Nonce=1, GasTipCap=1e9, GasFeeCap=2e9, Gas=21000, + // To=0x1234..., Value=1e18, Data=[], AccessList=[], FeeTokenID=1, + // FeeLimit=5e17, V=1, R=..., S=... + innerHex := "f87e820b0201843b9aca008477359400825208941234567890123456789012345678901234567890880de0b6b3a764000080c0018806f05b59d3b2000001a0abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890a01234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + + data, err := hex.DecodeString(innerHex) + if err != nil { + t.Fatalf("hex decode failed: %v", err) + } + + // Decode using rlp.DecodeBytes → DecodeRLP + var decoded MorphTx + if err := rlp.DecodeBytes(data, &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + if decoded.Version != MorphTxVersion0 { + t.Errorf("Version: want %d, got %d", MorphTxVersion0, decoded.Version) + } + if decoded.FeeTokenID != 1 { + t.Errorf("FeeTokenID: want 1, got %d", decoded.FeeTokenID) + } + if decoded.ChainID.Cmp(big.NewInt(2818)) != 0 { + t.Errorf("ChainID: want 2818, got %v", decoded.ChainID) + } + if decoded.Nonce != 1 { + t.Errorf("Nonce: want 1, got %d", decoded.Nonce) + } + expectedFeeLimit := big.NewInt(500000000000000000) + if decoded.FeeLimit == nil || decoded.FeeLimit.Cmp(expectedFeeLimit) != 0 { + t.Errorf("FeeLimit: want %v, got %v", expectedFeeLimit, decoded.FeeLimit) + } +} + +// TestDecodeRLP_ErrorCases tests that DecodeRLP correctly rejects invalid data. +func TestDecodeRLP_ErrorCases(t *testing.T) { + testCases := []struct { + name string + input []byte + }{ + { + name: "empty input", + input: []byte{}, + }, + { + name: "unsupported version byte 0x02", + input: []byte{0x02, 0xc0}, + }, + { + name: "unsupported version byte 0xFF handled as V0 but invalid RLP list content", + input: []byte{0xc1, 0xff}, + }, + { + name: "truncated V0 RLP list", + input: []byte{0xc5, 0x01, 0x02}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var decoded MorphTx + err := rlp.DecodeBytes(tc.input, &decoded) + if err == nil { + t.Error("expected error, got nil") + } + }) + } +} + +// TestDecodeRLP_V0FeeTokenIDValues tests various FeeTokenID values to ensure +// they survive the rlp.Encode → rlp.DecodeBytes round-trip without corruption. +func TestDecodeRLP_V0FeeTokenIDValues(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + feeTokenIDs := []uint16{1, 2, 127, 128, 255, 256, 1000, 65535} + + for _, fid := range feeTokenIDs { + t.Run("FeeTokenID="+big.NewInt(int64(fid)).String(), func(t *testing.T) { + tx := &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: fid, + FeeLimit: big.NewInt(100000), + Version: MorphTxVersion0, + V: big.NewInt(0), + R: big.NewInt(0), + S: big.NewInt(0), + } + + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + + var decoded MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + if decoded.FeeTokenID != fid { + t.Errorf("FeeTokenID: want %d, got %d", fid, decoded.FeeTokenID) + } + }) + } +} + +// TestDecodeRLP_TransactionWrapperConsistency verifies that decoding via +// Transaction.UnmarshalBinary (the normal path) and decoding via direct +// rlp.DecodeBytes on MorphTx produce semantically equivalent results. +func TestDecodeRLP_TransactionWrapperConsistency(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + ref := common.HexToReference("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890") + memo := []byte("memo") + + testCases := []struct { + name string + tx *MorphTx + }{ + { + name: "V0", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 1, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 1, + FeeLimit: big.NewInt(500000000000000000), + Version: MorphTxVersion0, + V: big.NewInt(1), + R: big.NewInt(123), + S: big.NewInt(456), + }, + }, + { + name: "V1", + tx: &MorphTx{ + ChainID: big.NewInt(1), + Nonce: 2, + GasTipCap: big.NewInt(1000000000), + GasFeeCap: big.NewInt(2000000000), + Gas: 21000, + To: &to, + Value: big.NewInt(0), + Data: []byte{}, + AccessList: AccessList{}, + FeeTokenID: 0, + FeeLimit: nil, + Version: MorphTxVersion1, + Reference: &ref, + Memo: &memo, + V: big.NewInt(0), + R: big.NewInt(789), + S: big.NewInt(101), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Path 1: Transaction.MarshalBinary → Transaction.UnmarshalBinary + wrappedTx := NewTx(tc.tx) + wireBytes, err := wrappedTx.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed: %v", err) + } + var parsedTx Transaction + if err := parsedTx.UnmarshalBinary(wireBytes); err != nil { + t.Fatalf("UnmarshalBinary failed: %v", err) + } + fromWrapper := parsedTx.inner.(*MorphTx) + + // Path 2: rlp.Encode → rlp.DecodeBytes (direct MorphTx) + var encoded bytes.Buffer + if err := rlp.Encode(&encoded, tc.tx); err != nil { + t.Fatalf("rlp.Encode failed: %v", err) + } + var fromDirect MorphTx + if err := rlp.DecodeBytes(encoded.Bytes(), &fromDirect); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + assertMorphTxEqual(t, fromWrapper, &fromDirect) + }) + } +} + +// TestDecodeRLP_EncodeDecodeSymmetry verifies that rlp.Encode output can be +// fed back into rlp.DecodeBytes and produce an identical MorphTx for both versions. +func TestDecodeRLP_EncodeDecodeSymmetry(t *testing.T) { + to := common.HexToAddress("0x1234567890123456789012345678901234567890") + ref := common.HexToReference("0x1111111111111111111111111111111111111111111111111111111111111111") + memo := []byte("symmetric test") + + txV0 := &MorphTx{ + ChainID: big.NewInt(1), Nonce: 1, + GasTipCap: big.NewInt(1e9), GasFeeCap: big.NewInt(2e9), Gas: 21000, + To: &to, Value: big.NewInt(1e18), Data: []byte{}, + AccessList: AccessList{}, FeeTokenID: 1, FeeLimit: big.NewInt(1e17), + Version: MorphTxVersion0, + V: big.NewInt(1), R: big.NewInt(100), S: big.NewInt(200), + } + txV1 := &MorphTx{ + ChainID: big.NewInt(1), Nonce: 2, + GasTipCap: big.NewInt(1e9), GasFeeCap: big.NewInt(2e9), Gas: 21000, + To: &to, Value: big.NewInt(0), Data: []byte{0xab}, + AccessList: AccessList{}, FeeTokenID: 3, FeeLimit: big.NewInt(5e17), + Version: MorphTxVersion1, Reference: &ref, Memo: &memo, + V: big.NewInt(0), R: big.NewInt(300), S: big.NewInt(400), + } + + for _, tc := range []struct { + name string + tx *MorphTx + }{ + {"V0", txV0}, + {"V1", txV1}, + } { + t.Run(tc.name, func(t *testing.T) { + // Encode → Decode → Re-encode and verify byte-for-byte equality + var buf1 bytes.Buffer + if err := rlp.Encode(&buf1, tc.tx); err != nil { + t.Fatalf("first rlp.Encode failed: %v", err) + } + + var decoded MorphTx + if err := rlp.DecodeBytes(buf1.Bytes(), &decoded); err != nil { + t.Fatalf("rlp.DecodeBytes failed: %v", err) + } + + var buf2 bytes.Buffer + if err := rlp.Encode(&buf2, &decoded); err != nil { + t.Fatalf("second rlp.Encode failed: %v", err) + } + + if !bytes.Equal(buf1.Bytes(), buf2.Bytes()) { + t.Errorf("encode→decode→encode not stable:\n first: %x\n second: %x", + buf1.Bytes(), buf2.Bytes()) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +// encodeMorphTx encodes a MorphTx using its encode method with txType prefix. +func encodeMorphTx(tx *MorphTx) ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte(MorphTxType) + if err := tx.encode(buf); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// assertMorphTxEqual compares two MorphTx structs field by field. +func assertMorphTxEqual(t *testing.T, want, got *MorphTx) { + t.Helper() + + if want.Version != got.Version { + t.Errorf("Version: want %d, got %d", want.Version, got.Version) + } + if want.FeeTokenID != got.FeeTokenID { + t.Errorf("FeeTokenID: want %d, got %d", want.FeeTokenID, got.FeeTokenID) + } + if want.Nonce != got.Nonce { + t.Errorf("Nonce: want %d, got %d", want.Nonce, got.Nonce) + } + if want.Gas != got.Gas { + t.Errorf("Gas: want %d, got %d", want.Gas, got.Gas) + } + assertBigIntEqual(t, "ChainID", want.ChainID, got.ChainID) + assertBigIntEqual(t, "GasTipCap", want.GasTipCap, got.GasTipCap) + assertBigIntEqual(t, "GasFeeCap", want.GasFeeCap, got.GasFeeCap) + assertBigIntEqual(t, "Value", want.Value, got.Value) + assertBigIntEqual(t, "V", want.V, got.V) + assertBigIntEqual(t, "R", want.R, got.R) + assertBigIntEqual(t, "S", want.S, got.S) + + // FeeLimit: nil and zero are treated as equivalent in RLP + wantFeeLimit := want.FeeLimit + gotFeeLimit := got.FeeLimit + if wantFeeLimit == nil { + wantFeeLimit = new(big.Int) + } + if gotFeeLimit == nil { + gotFeeLimit = new(big.Int) + } + if wantFeeLimit.Cmp(gotFeeLimit) != 0 { + t.Errorf("FeeLimit: want %v, got %v", want.FeeLimit, got.FeeLimit) + } + + if !bytes.Equal(want.Data, got.Data) { + t.Errorf("Data: want %x, got %x", want.Data, got.Data) + } + + // To + if want.To == nil && got.To != nil { + t.Errorf("To: want nil, got %v", got.To) + } else if want.To != nil && got.To == nil { + t.Errorf("To: want %v, got nil", want.To) + } else if want.To != nil && got.To != nil && *want.To != *got.To { + t.Errorf("To: want %v, got %v", want.To, got.To) + } + + // Reference + if want.Reference == nil && got.Reference != nil { + t.Errorf("Reference: want nil, got %v", got.Reference) + } else if want.Reference != nil && got.Reference == nil { + t.Errorf("Reference: want %v, got nil", want.Reference) + } else if want.Reference != nil && got.Reference != nil && *want.Reference != *got.Reference { + t.Errorf("Reference: want %v, got %v", want.Reference, got.Reference) + } + + // Memo + var wantMemo, gotMemo []byte + if want.Memo != nil { + wantMemo = *want.Memo + } + if got.Memo != nil { + gotMemo = *got.Memo + } + if !bytes.Equal(wantMemo, gotMemo) { + t.Errorf("Memo: want %x, got %x", wantMemo, gotMemo) + } +} + +func assertBigIntEqual(t *testing.T, name string, want, got *big.Int) { + t.Helper() + if want == nil && got == nil { + return + } + if want == nil { + want = new(big.Int) + } + if got == nil { + got = new(big.Int) + } + if want.Cmp(got) != 0 { + t.Errorf("%s: want %v, got %v", name, want, got) + } +} diff --git a/eth/api.go b/eth/api.go index 72ac6ce39..22b6f3ed9 100644 --- a/eth/api.go +++ b/eth/api.go @@ -688,6 +688,10 @@ type DiskAndHeaderRoot struct { // This is useful for debugging cross-format state access (zkTrie ↔ MPT). // If no disk root mapping exists, returns the block's root for both fields. func (api *MorphAPI) DiskRoot(ctx context.Context, blockNrOrHash *rpc.BlockNumberOrHash) (DiskAndHeaderRoot, error) { + if blockNrOrHash == nil { + latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) + blockNrOrHash = &latest + } block, err := api.eth.APIBackend.BlockByNumberOrHash(ctx, *blockNrOrHash) if err != nil { return DiskAndHeaderRoot{}, fmt.Errorf("failed to retrieve block: %w", err) diff --git a/eth/tracers/api.go b/eth/tracers/api.go index 586768e45..ff35be457 100644 --- a/eth/tracers/api.go +++ b/eth/tracers/api.go @@ -39,6 +39,7 @@ import ( "github.com/morph-l2/go-ethereum/core/tracing" "github.com/morph-l2/go-ethereum/core/types" "github.com/morph-l2/go-ethereum/core/vm" + "github.com/morph-l2/go-ethereum/crypto" "github.com/morph-l2/go-ethereum/eth/tracers/logger" "github.com/morph-l2/go-ethereum/ethdb" "github.com/morph-l2/go-ethereum/internal/ethapi" @@ -91,17 +92,284 @@ type Backend interface { type API struct { backend Backend morphTracerWrapper morphTracerWrapper - addERC20Balance func(evm *vm.EVM, tokenID uint16, addr common.Address, amount *big.Int) error + addERC20Balance func(statedb *state.StateDB, evm *vm.EVM, tokenID uint16, addr common.Address, amount *big.Int, mask *syntheticPrecreditMask) error } // NewAPI creates a new API definition for the tracing methods of the Ethereum service. func NewAPI(backend Backend, morphTracerWrapper morphTracerWrapper) *API { api := &API{backend: backend, morphTracerWrapper: morphTracerWrapper} - // TODO - api.addERC20Balance = func(evm *vm.EVM, tokenID uint16, addr common.Address, amount *big.Int) error { + api.addERC20Balance = addERC20Balance + return api +} + +type balanceSlotTracer struct { + token common.Address + slots map[common.Hash]struct{} +} + +type syntheticPrecreditMask struct { + active bool + balances map[common.Address]*big.Int + storage map[common.Address]map[common.Hash]common.Hash +} + +type maskedTracingStateDB struct { + inner tracing.StateDB + mask *syntheticPrecreditMask +} + +func (t *balanceSlotTracer) OnOpcode(_ uint64, opcode byte, _ uint64, _ uint64, scope tracing.OpContext, _ []byte, _ int, err error) { + if err != nil || vm.OpCode(opcode) != vm.SLOAD || scope.Address() != t.token { + return + } + stack := scope.StackData() + if len(stack) == 0 { + return + } + t.slots[common.Hash(stack[len(stack)-1].Bytes32())] = struct{}{} +} + +func addERC20Balance(statedb *state.StateDB, evm *vm.EVM, tokenID uint16, addr common.Address, amount *big.Int, mask *syntheticPrecreditMask) error { + if amount == nil || amount.Sign() == 0 { return nil } - return api + info, err := fees.GetTokenInfo(statedb, tokenID) + if err != nil { + return err + } + var slot common.Hash + if info.HasSlot { + _, slot, err = fees.GetAltTokenBalanceFromSlot(statedb, info.TokenAddress, addr, info.BalanceSlot) + } else { + slot, err = discoverERC20BalanceSlot(statedb, evm, info.TokenAddress, addr) + } + if err != nil { + return err + } + current := statedb.GetState(info.TokenAddress, slot) + if mask != nil { + mask.addStorage(info.TokenAddress, slot, current) + } + statedb.SetState(info.TokenAddress, slot, common.BigToHash(new(big.Int).Add(new(big.Int).SetBytes(current.Bytes()), amount))) + return nil +} + +func newSyntheticPrecreditMask() *syntheticPrecreditMask { + return &syntheticPrecreditMask{ + active: true, + balances: make(map[common.Address]*big.Int), + storage: make(map[common.Address]map[common.Hash]common.Hash), + } +} + +func (m *syntheticPrecreditMask) hasEntries() bool { + return m != nil && (len(m.balances) > 0 || len(m.storage) > 0) +} + +func (m *syntheticPrecreditMask) addBalance(addr common.Address, balance *big.Int) { + if m == nil { + return + } + m.balances[addr] = new(big.Int).Set(balance) +} + +func (m *syntheticPrecreditMask) addStorage(addr common.Address, slot common.Hash, value common.Hash) { + if m == nil { + return + } + if m.storage[addr] == nil { + m.storage[addr] = make(map[common.Hash]common.Hash) + } + m.storage[addr][slot] = value +} + +func (m *syntheticPrecreditMask) originalBalance(addr common.Address) (*big.Int, bool) { + if m == nil || !m.active { + return nil, false + } + balance, ok := m.balances[addr] + if !ok { + return nil, false + } + return new(big.Int).Set(balance), true +} + +func (m *syntheticPrecreditMask) originalStorage(addr common.Address, slot common.Hash) (common.Hash, bool) { + if m == nil || !m.active { + return common.Hash{}, false + } + value, ok := m.storage[addr][slot] + return value, ok +} + +func (db *maskedTracingStateDB) GetBalance(addr common.Address) *big.Int { + if balance, ok := db.mask.originalBalance(addr); ok { + return balance + } + return db.inner.GetBalance(addr) +} + +func (db *maskedTracingStateDB) GetNonce(addr common.Address) uint64 { + return db.inner.GetNonce(addr) +} + +func (db *maskedTracingStateDB) GetCode(addr common.Address) []byte { + return db.inner.GetCode(addr) +} + +func (db *maskedTracingStateDB) GetKeccakCodeHash(addr common.Address) common.Hash { + return db.inner.GetKeccakCodeHash(addr) +} + +func (db *maskedTracingStateDB) GetPoseidonCodeHash(addr common.Address) common.Hash { + return db.inner.GetPoseidonCodeHash(addr) +} + +func (db *maskedTracingStateDB) GetState(addr common.Address, key common.Hash) common.Hash { + if value, ok := db.mask.originalStorage(addr, key); ok { + return value + } + return db.inner.GetState(addr, key) +} + +func (db *maskedTracingStateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash { + return db.inner.GetTransientState(addr, key) +} + +func (db *maskedTracingStateDB) Exist(addr common.Address) bool { + return db.inner.Exist(addr) +} + +func (db *maskedTracingStateDB) GetRefund() uint64 { + return db.inner.GetRefund() +} + +func (db *maskedTracingStateDB) GetCodeSize(addr common.Address) uint64 { + return db.inner.GetCodeSize(addr) +} + +func shouldMaskSyntheticPrecredits(config *TraceConfig) bool { + if config == nil || config.Tracer == nil { + return false + } + if *config.Tracer == "prestateTracer" { + return true + } + if *config.Tracer != "muxTracer" || len(config.TracerConfig) == 0 { + return false + } + var subtracers map[string]json.RawMessage + if err := json.Unmarshal(config.TracerConfig, &subtracers); err != nil { + return false + } + _, ok := subtracers["prestateTracer"] + return ok +} + +func wrapSyntheticPrecreditHooks(hooks *tracing.Hooks, mask *syntheticPrecreditMask) *tracing.Hooks { + if hooks == nil || !mask.hasEntries() { + return hooks + } + base := hooks + wrapped := *base + wrapped.OnTxStart = func(env *tracing.VMContext, tx *types.Transaction, from common.Address) { + if base.OnTxStart == nil { + return + } + envCopy := *env + envCopy.StateDB = &maskedTracingStateDB{inner: env.StateDB, mask: mask} + base.OnTxStart(&envCopy, tx, from) + } + wrapped.OnTxEnd = func(receipt *types.Receipt, err error) { + mask.active = false + if base.OnTxEnd != nil { + base.OnTxEnd(receipt, err) + } + } + wrapped.OnBalanceChange = func(addr common.Address, prev, new *big.Int, reason tracing.BalanceChangeReason) { + if original, ok := mask.originalBalance(addr); ok { + prev = original + } + if base.OnBalanceChange != nil { + base.OnBalanceChange(addr, prev, new, reason) + } + } + wrapped.OnStorageChange = func(addr common.Address, slot common.Hash, prev, new common.Hash) { + if original, ok := mask.originalStorage(addr, slot); ok { + prev = original + } + if base.OnStorageChange != nil { + base.OnStorageChange(addr, slot, prev, new) + } + } + return &wrapped +} + +func discoverERC20BalanceSlot(statedb *state.StateDB, evm *vm.EVM, tokenAddress, user common.Address) (common.Hash, error) { + currentBalance, userSlots, err := traceERC20BalanceSlots(statedb, evm, tokenAddress, user) + if err != nil { + return common.Hash{}, err + } + probe := deriveBalanceProbeAddress(tokenAddress, user) + _, probeSlots, err := traceERC20BalanceSlots(statedb, evm, tokenAddress, probe) + if err != nil { + return common.Hash{}, err + } + for slot := range userSlots { + if _, seen := probeSlots[slot]; seen { + continue + } + ok, err := validateERC20BalanceSlot(statedb, evm, tokenAddress, user, currentBalance, slot) + if err != nil { + return common.Hash{}, err + } + if ok { + return slot, nil + } + } + return common.Hash{}, fmt.Errorf("failed to discover ERC20 balance slot for token %s", tokenAddress.Hex()) +} + +func traceERC20BalanceSlots(statedb *state.StateDB, evm *vm.EVM, tokenAddress, user common.Address) (*big.Int, map[common.Hash]struct{}, error) { + probeTracer := &balanceSlotTracer{ + token: tokenAddress, + slots: make(map[common.Hash]struct{}), + } + cfg := evm.Config + cfg.Tracer = &tracing.Hooks{OnOpcode: probeTracer.OnOpcode} + probeEVM := vm.NewEVM(evm.Context, evm.TxContext, statedb, evm.ChainConfig(), cfg) + balance, err := core.GetAltTokenBalanceByEVM(probeEVM, tokenAddress, user) + if err != nil { + return nil, nil, err + } + return balance, probeTracer.slots, nil +} + +func validateERC20BalanceSlot(statedb *state.StateDB, evm *vm.EVM, tokenAddress, user common.Address, currentBalance *big.Int, slot common.Hash) (bool, error) { + snapshot := statedb.Snapshot() + defer statedb.RevertToSnapshot(snapshot) + + current := new(big.Int).SetBytes(statedb.GetState(tokenAddress, slot).Bytes()) + statedb.SetState(tokenAddress, slot, common.BigToHash(new(big.Int).Add(current, big.NewInt(1)))) + + balance, _, err := traceERC20BalanceSlots(statedb, evm, tokenAddress, user) + if err != nil { + return false, err + } + expected := new(big.Int).Add(new(big.Int).Set(currentBalance), big.NewInt(1)) + return balance.Cmp(expected) == 0, nil +} + +func deriveBalanceProbeAddress(tokenAddress, user common.Address) common.Address { + probe := common.BytesToAddress(crypto.Keccak256(tokenAddress.Bytes(), user.Bytes(), []byte("tracecall-balance-probe"))[12:]) + var zero common.Address + if probe == zero || probe == user { + probe[19] ^= 0x01 + if probe == user { + probe[18] ^= 0x01 + } + } + return probe } type chainContext struct { @@ -938,6 +1206,9 @@ func (api *API) traceTx(ctx context.Context, tx *types.Transaction, message core var ( tracer *Tracer err error + execState = statedb + hooks *tracing.Hooks + mask *syntheticPrecreditMask structLogger *logger.StructLogger timeout = defaultTraceTimeout usedGas uint64 @@ -960,9 +1231,34 @@ func (api *API) traceTx(ctx context.Context, tx *types.Transaction, message core return nil, err } } - tracingStateDB := state.NewHookedState(statedb, tracer.Hooks) + hooks = tracer.Hooks + if l1DataFee != nil && l1DataFee.Sign() > 0 && message.GasPrice().Cmp(big.NewInt(0)) == 0 { + if shouldMaskSyntheticPrecredits(config) { + execState = statedb.Copy() + mask = newSyntheticPrecreditMask() + } + if message.FeeTokenID() == 0 { + if mask != nil { + mask.addBalance(message.From(), execState.GetBalance(message.From())) + } + execState.AddBalance(message.From(), l1DataFee, tracing.BalanceChangeUnspecified) + } else { + probeEVM := vm.NewEVM(vmctx, txContext, execState, api.backend.ChainConfig(), vm.Config{NoBaseFee: true}) + erc20Amount, err := fees.EthToAlt(execState, message.FeeTokenID(), l1DataFee) + if err != nil { + return nil, err + } + if err := api.addERC20Balance(execState, probeEVM, message.FeeTokenID(), message.From(), erc20Amount, mask); err != nil { + return nil, err + } + } + } + if mask != nil { + hooks = wrapSyntheticPrecreditHooks(hooks, mask) + } + tracingStateDB := state.NewHookedState(execState, hooks) // Run the transaction with tracing enabled. - vmenv := vm.NewEVM(vmctx, txContext, tracingStateDB, api.backend.ChainConfig(), vm.Config{Tracer: tracer.Hooks, NoBaseFee: true}) + vmenv := vm.NewEVM(vmctx, txContext, tracingStateDB, api.backend.ChainConfig(), vm.Config{Tracer: hooks, NoBaseFee: true}) // Define a meaningful timeout of a single transaction trace if config.Timeout != nil { @@ -981,23 +1277,9 @@ func (api *API) traceTx(ctx context.Context, tx *types.Transaction, message core }() defer cancel() - // If gasPrice is 0, make sure that the account has sufficient balance to cover `l1DataFee`. - if message.GasPrice().Cmp(big.NewInt(0)) == 0 { - if message.FeeTokenID() == 0 { - statedb.AddBalance(message.From(), l1DataFee, tracing.BalanceChangeUnspecified) - } else { - erc20Amount, err := fees.EthToAlt(statedb, message.FeeTokenID(), l1DataFee) - if err != nil { - return nil, err - } - if err := api.addERC20Balance(vmenv, message.FeeTokenID(), message.From(), erc20Amount); err != nil { - return nil, err - } - } - } // Call Prepare to clear out the statedb access list - statedb.SetTxContext(txctx.TxHash, txctx.TxIndex) - _, err = core.ApplyTransactionWithEVM(message, api.backend.ChainConfig(), new(core.GasPool).AddGas(message.Gas()), statedb, vmctx.BlockNumber, txctx.BlockHash, tx, &usedGas, vmenv) + execState.SetTxContext(txctx.TxHash, txctx.TxIndex) + _, err = core.ApplyTransactionWithEVM(message, api.backend.ChainConfig(), new(core.GasPool).AddGas(message.Gas()), execState, vmctx.BlockNumber, txctx.BlockHash, tx, &usedGas, vmenv) if err != nil { return nil, fmt.Errorf("tracing failed: %w", err) } diff --git a/eth/tracers/api_test.go b/eth/tracers/api_test.go index afd8049cb..7bac0f4f9 100644 --- a/eth/tracers/api_test.go +++ b/eth/tracers/api_test.go @@ -36,6 +36,7 @@ import ( "github.com/morph-l2/go-ethereum/core" "github.com/morph-l2/go-ethereum/core/rawdb" "github.com/morph-l2/go-ethereum/core/state" + "github.com/morph-l2/go-ethereum/core/tracing" "github.com/morph-l2/go-ethereum/core/types" "github.com/morph-l2/go-ethereum/core/vm" "github.com/morph-l2/go-ethereum/crypto" @@ -45,6 +46,7 @@ import ( "github.com/morph-l2/go-ethereum/params" "github.com/morph-l2/go-ethereum/rollup/fees" "github.com/morph-l2/go-ethereum/rpc" + "github.com/stretchr/testify/require" ) var ( @@ -252,7 +254,7 @@ func TestTraceCall(t *testing.T) { }, config: nil, expectErr: nil, - expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[]}`, + expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[],"l1DataFee":"0x0"}`, }, // Standard JSON trace upon the head, plain transfer. { @@ -264,7 +266,7 @@ func TestTraceCall(t *testing.T) { }, config: nil, expectErr: nil, - expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[]}`, + expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[],"l1DataFee":"0x0"}`, }, // Upon the last state, default to the post block's state { @@ -275,7 +277,7 @@ func TestTraceCall(t *testing.T) { Value: (*hexutil.Big)(new(big.Int).Add(big.NewInt(params.Ether), big.NewInt(100))), }, config: nil, - expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[]}`, + expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[],"l1DataFee":"0x0"}`, }, // Before the first transaction, should be failed { @@ -310,7 +312,7 @@ func TestTraceCall(t *testing.T) { }, config: nil, expectErr: nil, - expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[]}`, + expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[],"l1DataFee":"0x0"}`, }, // Standard JSON trace upon the pending block { @@ -322,7 +324,7 @@ func TestTraceCall(t *testing.T) { }, config: nil, expectErr: nil, - expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[]}`, + expect: `{"gas":21000,"failed":false,"returnValue":"0x","structLogs":[],"l1DataFee":"0x0"}`, }, } for i, testspec := range testSuite { @@ -332,7 +334,7 @@ func TestTraceCall(t *testing.T) { t.Errorf("Expect error %v, get nothing", testspec.expectErr) continue } - if !reflect.DeepEqual(err.Error(), testspec.expectErr.Error()) { + if err.Error() != testspec.expectErr.Error() { t.Errorf("Error mismatch, want %v, get %v", testspec.expectErr, err) } } else { @@ -453,7 +455,7 @@ func TestTraceBlock(t *testing.T) { t.Errorf("test %d, want error %v", i, tc.expectErr) continue } - if !reflect.DeepEqual(err, tc.expectErr) { + if err.Error() != tc.expectErr.Error() { t.Errorf("test %d: error mismatch, want %v, get %v", i, tc.expectErr, err) } continue @@ -592,6 +594,106 @@ func TestTracingWithOverrides(t *testing.T) { } } +func TestAddERC20BalanceDynamicSlotDiscovery(t *testing.T) { + t.Parallel() + + tokenID := uint16(1) + user := common.HexToAddress("0x100") + token := common.HexToAddress("0x200") + statedb := newTestStateDB(t, nil) + + // Runtime bytecode: + // balanceOf(address) => return sload(keccak256(abi.encode(address, 0))) + statedb.SetCode(token, common.FromHex("0x600435600052600060205260406000205460005260206000f3")) + setTestTokenInfo(statedb, tokenID, token, nil) + + userSlot := fees.CalculateAltTokenBalanceSlot(user, common.Hash{}) + statedb.SetState(token, userSlot, common.BigToHash(big.NewInt(7))) + + evm := vm.NewEVM(vm.BlockContext{ + CanTransfer: core.CanTransfer, + Transfer: core.Transfer, + GetHash: func(uint64) common.Hash { return common.Hash{} }, + Coinbase: common.Address{}, + GasLimit: 30_000_000, + BlockNumber: big.NewInt(1), + Time: big.NewInt(0), + Difficulty: big.NewInt(0), + BaseFee: big.NewInt(0), + }, vm.TxContext{ + Origin: user, + GasPrice: big.NewInt(0), + }, statedb, params.TestChainConfig, vm.Config{NoBaseFee: true}) + + if err := addERC20Balance(statedb, evm, tokenID, user, big.NewInt(5), nil); err != nil { + t.Fatalf("failed to add ERC20 balance: %v", err) + } + balance, err := core.GetAltTokenBalanceByEVM(evm, token, user) + if err != nil { + t.Fatalf("failed to read ERC20 balance: %v", err) + } + if balance.Cmp(big.NewInt(12)) != 0 { + t.Fatalf("unexpected ERC20 balance, have %v want %v", balance, 12) + } +} + +func TestSyntheticPrecreditMaskPreservesPrestateTracerView(t *testing.T) { + t.Parallel() + + user := common.HexToAddress("0x100") + recipient := common.HexToAddress("0x101") + token := common.HexToAddress("0x200") + slot := fees.CalculateAltTokenBalanceSlot(user, common.Hash{}) + original := common.BigToHash(big.NewInt(7)) + synthetic := common.BigToHash(big.NewInt(12)) + + statedb := newTestStateDB(t, nil) + statedb.SetState(token, slot, synthetic) + + var ( + seenStart common.Hash + seenPrev common.Hash + seenFinal common.Hash + seenState tracing.StateDB + ) + baseHooks := &tracing.Hooks{ + OnTxStart: func(env *tracing.VMContext, _ *types.Transaction, _ common.Address) { + seenState = env.StateDB + seenStart = env.StateDB.GetState(token, slot) + }, + OnStorageChange: func(addr common.Address, key common.Hash, prev, _ common.Hash) { + if addr == token && key == slot { + seenPrev = prev + } + }, + OnTxEnd: func(_ *types.Receipt, _ error) { + seenFinal = seenState.GetState(token, slot) + }, + } + mask := newSyntheticPrecreditMask() + mask.addStorage(token, slot, original) + hooks := wrapSyntheticPrecreditHooks(baseHooks, mask) + + tx := types.NewTx(&types.LegacyTx{ + Nonce: 0, + To: &recipient, + Value: big.NewInt(0), + Gas: params.TxGas, + GasPrice: big.NewInt(0), + }) + hooks.OnTxStart(&tracing.VMContext{ + BlockNumber: big.NewInt(1), + StateDB: statedb, + }, tx, user) + hooks.OnStorageChange(token, slot, synthetic, original) + statedb.SetState(token, slot, original) + hooks.OnTxEnd(&types.Receipt{GasUsed: params.TxGas}, nil) + + require.Equal(t, original, seenStart) + require.Equal(t, original, seenPrev) + require.Equal(t, original, seenFinal) +} + type Account struct { key *ecdsa.PrivateKey addr common.Address @@ -633,3 +735,29 @@ func newStates(keys []common.Hash, vals []common.Hash) *map[common.Hash]common.H } return &m } + +func newTestStateDB(t *testing.T, alloc core.GenesisAlloc) *state.StateDB { + t.Helper() + + db := rawdb.NewMemoryDatabase() + block := (&core.Genesis{Config: params.TestChainConfig, Alloc: alloc}).MustCommit(db) + statedb, err := state.New(block.Root(), state.NewDatabase(db), nil) + if err != nil { + t.Fatalf("failed to create state db: %v", err) + } + return statedb +} + +func setTestTokenInfo(statedb *state.StateDB, tokenID uint16, token common.Address, balanceSlot *common.Hash) { + baseSlot := fees.GetTokenInfoStructBaseSlot(tokenID) + statedb.SetState(fees.TokenRegistryAddress, fees.CalculateStructFieldSlot(baseSlot, 0), common.BytesToHash(common.LeftPadBytes(token.Bytes(), 32))) + if balanceSlot != nil { + slotPlusOne := new(big.Int).Add(new(big.Int).SetBytes(balanceSlot.Bytes()), big.NewInt(1)) + statedb.SetState(fees.TokenRegistryAddress, fees.CalculateStructFieldSlot(baseSlot, 1), common.BigToHash(slotPlusOne)) + } + var status common.Hash + status[30] = 18 + status[31] = 1 + statedb.SetState(fees.TokenRegistryAddress, fees.CalculateStructFieldSlot(baseSlot, 2), status) + statedb.SetState(fees.TokenRegistryAddress, fees.CalculateStructFieldSlot(baseSlot, 3), common.BigToHash(big.NewInt(1))) +} diff --git a/eth/tracers/js/goja.go b/eth/tracers/js/goja.go index ed33035d6..9c24072b7 100644 --- a/eth/tracers/js/goja.go +++ b/eth/tracers/js/goja.go @@ -120,8 +120,9 @@ type jsTracer struct { activePrecompiles []common.Address // List of active precompiles at current block traceStep bool // True if tracer object exposes a `step()` method traceFrame bool // True if tracer object exposes the `enter()` and `exit()` methods - err error // Any error that should stop tracing - obj *goja.Object // Trace object + systemCallDepth int + err error // Any error that should stop tracing + obj *goja.Object // Trace object // Methods exposed by tracer result goja.Callable @@ -236,12 +237,14 @@ func newJsTracer(code string, ctx *tracers.Context, cfg json.RawMessage, chainCo return &tracers.Tracer{ Hooks: &tracing.Hooks{ - OnTxStart: t.OnTxStart, - OnTxEnd: t.OnTxEnd, - OnEnter: t.OnEnter, - OnExit: t.OnExit, - OnOpcode: t.OnOpcode, - OnFault: t.OnFault, + OnTxStart: t.OnTxStart, + OnTxEnd: t.OnTxEnd, + OnEnter: t.OnEnter, + OnExit: t.OnExit, + OnOpcode: t.OnOpcode, + OnFault: t.OnFault, + OnSystemCallStartV2: t.OnSystemCallStart, + OnSystemCallEnd: t.OnSystemCallEnd, }, GetResult: t.GetResult, Stop: t.Stop, @@ -331,6 +334,9 @@ func (t *jsTracer) onStart(from common.Address, to common.Address, create bool, // OnOpcode implements the Tracer interface to trace a single step of VM execution. func (t *jsTracer) OnOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, rData []byte, depth int, err error) { + if t.systemCallDepth > 0 { + return + } if !t.traceStep { return } @@ -356,6 +362,9 @@ func (t *jsTracer) OnOpcode(pc uint64, op byte, gas, cost uint64, scope tracing. // OnFault implements the Tracer interface to trace an execution fault func (t *jsTracer) OnFault(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, depth int, err error) { + if t.systemCallDepth > 0 { + return + } if t.err != nil { return } @@ -384,6 +393,9 @@ func (t *jsTracer) onEnd(output []byte, gasUsed uint64, err error, reverted bool // OnEnter is called when EVM enters a new scope (via call, create or selfdestruct). func (t *jsTracer) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { + if t.systemCallDepth > 0 { + return + } if t.err != nil { return } @@ -413,6 +425,9 @@ func (t *jsTracer) OnEnter(depth int, typ byte, from common.Address, to common.A // OnExit is called when EVM exits a scope, even if the scope didn't // execute any code. func (t *jsTracer) OnExit(depth int, output []byte, gasUsed uint64, err error, reverted bool) { + if t.systemCallDepth > 0 { + return + } if t.err != nil { return } @@ -455,6 +470,16 @@ func (t *jsTracer) Stop(err error) { t.vm.Interrupt(err) } +func (t *jsTracer) OnSystemCallStart(*tracing.VMContext) { + t.systemCallDepth++ +} + +func (t *jsTracer) OnSystemCallEnd() { + if t.systemCallDepth > 0 { + t.systemCallDepth-- + } +} + // onError is called anytime the running JS code is interrupted // and returns an error. It in turn pings the EVM to cancel its // execution. diff --git a/eth/tracers/js/goja_systemcall_test.go b/eth/tracers/js/goja_systemcall_test.go new file mode 100644 index 000000000..1907a29ec --- /dev/null +++ b/eth/tracers/js/goja_systemcall_test.go @@ -0,0 +1,101 @@ +package js + +import ( + "encoding/json" + "math/big" + "testing" + + "github.com/holiman/uint256" + "github.com/morph-l2/go-ethereum/common" + "github.com/morph-l2/go-ethereum/core/tracing" + "github.com/morph-l2/go-ethereum/core/types" + "github.com/morph-l2/go-ethereum/core/vm" + "github.com/morph-l2/go-ethereum/eth/tracers" + "github.com/morph-l2/go-ethereum/params" +) + +type jsTestStateDB struct{} + +func (jsTestStateDB) GetBalance(common.Address) *big.Int { return new(big.Int) } +func (jsTestStateDB) GetNonce(common.Address) uint64 { return 0 } +func (jsTestStateDB) GetCode(common.Address) []byte { return nil } +func (jsTestStateDB) GetKeccakCodeHash(common.Address) common.Hash { return common.Hash{} } +func (jsTestStateDB) GetPoseidonCodeHash(common.Address) common.Hash { return common.Hash{} } +func (jsTestStateDB) GetState(common.Address, common.Hash) common.Hash { return common.Hash{} } +func (jsTestStateDB) GetTransientState(common.Address, common.Hash) common.Hash { + return common.Hash{} +} +func (jsTestStateDB) Exist(common.Address) bool { return false } +func (jsTestStateDB) GetRefund() uint64 { return 0 } +func (jsTestStateDB) GetCodeSize(common.Address) uint64 { return 0 } + +type jsTestScope struct{} + +func (jsTestScope) MemoryData() []byte { return nil } +func (jsTestScope) StackData() []uint256.Int { return nil } +func (jsTestScope) Caller() common.Address { return common.Address{} } +func (jsTestScope) Address() common.Address { return common.Address{} } +func (jsTestScope) CallValue() *big.Int { return new(big.Int) } +func (jsTestScope) CallInput() []byte { return nil } +func (jsTestScope) ContractCode() []byte { return nil } + +func TestJSTracerSkipsSystemCalls(t *testing.T) { + t.Parallel() + + const code = `{ + steps: 0, + enters: 0, + exits: 0, + step: function(log, db) { this.steps++; }, + enter: function(frame) { this.enters++; }, + exit: function(frame) { this.exits++; }, + fault: function(log, db) {}, + result: function(ctx, db) { return {steps: this.steps, enters: this.enters, exits: this.exits}; } + }` + + tracer, err := newJsTracer(code, &tracers.Context{}, nil, params.TestChainConfig) + if err != nil { + t.Fatalf("failed to create js tracer: %v", err) + } + tx := types.NewTx(&types.LegacyTx{ + To: &common.Address{}, + Gas: 21000, + GasPrice: big.NewInt(0), + Value: big.NewInt(0), + }) + env := &tracing.VMContext{ + BlockNumber: big.NewInt(0), + Time: 0, + BaseFee: big.NewInt(0), + Coinbase: common.Address{}, + StateDB: jsTestStateDB{}, + } + scope := jsTestScope{} + + tracer.OnTxStart(env, tx, common.Address{}) + tracer.OnSystemCallStartV2(env) + tracer.OnEnter(1, byte(vm.CALL), common.Address{}, common.Address{}, []byte{1, 2, 3, 4}, 21000, big.NewInt(0)) + tracer.OnOpcode(0, 0x00, 1, 1, scope, nil, 1, nil) + tracer.OnExit(1, nil, 0, nil, false) + tracer.OnSystemCallEnd() + + tracer.OnEnter(1, byte(vm.CALL), common.Address{}, common.Address{}, []byte{1, 2, 3, 4}, 21000, big.NewInt(0)) + tracer.OnOpcode(1, 0x00, 1, 1, scope, nil, 1, nil) + tracer.OnExit(1, nil, 0, nil, false) + + res, err := tracer.GetResult() + if err != nil { + t.Fatalf("failed to get trace result: %v", err) + } + var got struct { + Steps int `json:"steps"` + Enters int `json:"enters"` + Exits int `json:"exits"` + } + if err := json.Unmarshal(res, &got); err != nil { + t.Fatalf("failed to decode trace result: %v", err) + } + if got.Steps != 1 || got.Enters != 1 || got.Exits != 1 { + t.Fatalf("unexpected result: %+v", got) + } +} diff --git a/eth/tracers/logger/logger.go b/eth/tracers/logger/logger.go index 8f31cdaca..d707fd21a 100644 --- a/eth/tracers/logger/logger.go +++ b/eth/tracers/logger/logger.go @@ -236,9 +236,9 @@ type StructLogger struct { structLogs []*StructLog resultSize int - interrupt atomic.Bool // Atomic flag to signal execution interruption - reason error // Textual reason for the interruption - skip bool // skip processing hooks. + interrupt atomic.Bool // Atomic flag to signal execution interruption + reason error // Textual reason for the interruption + systemCallDepth int // skip visible processing hooks while inside system calls. } // NewStreamingStructLogger returns a new streaming logger. @@ -307,7 +307,7 @@ func (l *StructLogger) OnOpcode(pc uint64, opcode byte, gas, cost uint64, scope return } // Processing a system call. - if l.skip { + if l.systemCallDepth > 0 { return } // check if already accumulated the size of the response. @@ -406,7 +406,7 @@ func (l *StructLogger) OnExit(depth int, output []byte, gasUsed uint64, err erro if depth != 0 { return } - if l.skip { + if l.systemCallDepth > 0 { return } l.output = output @@ -499,11 +499,13 @@ func (l *StructLogger) OnTxStart(env *tracing.VMContext, tx *types.Transaction, } func (l *StructLogger) OnSystemCallStart(env *tracing.VMContext) { - l.skip = true + l.systemCallDepth++ } func (l *StructLogger) OnSystemCallEnd() { - l.skip = false + if l.systemCallDepth > 0 { + l.systemCallDepth-- + } } func (l *StructLogger) OnTxEnd(receipt *types.Receipt, err error) { @@ -521,6 +523,8 @@ func (l *StructLogger) OnTxEnd(receipt *types.Receipt, err error) { } func (l *StructLogger) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { + // Keep tracking affected accounts during system calls because UpdatedAccounts() + // feeds proof collection and fee-token helper calls can touch required accounts. l.statesAffected[to] = struct{}{} target, ok := types.ParseDelegation(l.env.StateDB.GetCode(to)) // if the target is a delegation, we need to trace the target @@ -558,10 +562,10 @@ func WriteLogs(writer io.Writer, logs []*types.Log) { } type mdLogger struct { - out io.Writer - cfg *Config - env *tracing.VMContext - skip bool + out io.Writer + cfg *Config + env *tracing.VMContext + systemCallDepth int } // NewMarkdownLogger creates a logger which outputs information in a format adapted @@ -591,15 +595,17 @@ func (t *mdLogger) OnTxStart(env *tracing.VMContext, tx *types.Transaction, from } func (t *mdLogger) OnSystemCallStart(env *tracing.VMContext) { - t.skip = true + t.systemCallDepth++ } func (t *mdLogger) OnSystemCallEnd() { - t.skip = false + if t.systemCallDepth > 0 { + t.systemCallDepth-- + } } func (t *mdLogger) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { - if t.skip { + if t.systemCallDepth > 0 { return } if depth != 0 { @@ -629,7 +635,7 @@ func (t *mdLogger) OnEnter(depth int, typ byte, from common.Address, to common.A } func (t *mdLogger) OnExit(depth int, output []byte, gasUsed uint64, err error, reverted bool) { - if t.skip { + if t.systemCallDepth > 0 { return } if depth == 0 { @@ -643,7 +649,7 @@ func (t *mdLogger) OnExit(depth int, output []byte, gasUsed uint64, err error, r // OnOpcode also tracks SLOAD/SSTORE ops to track storage change. func (t *mdLogger) OnOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, rData []byte, depth int, err error) { - if t.skip { + if t.systemCallDepth > 0 { return } stack := scope.StackData() @@ -666,7 +672,7 @@ func (t *mdLogger) OnOpcode(pc uint64, op byte, gas, cost uint64, scope tracing. } func (t *mdLogger) OnFault(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, depth int, err error) { - if t.skip { + if t.systemCallDepth > 0 { return } fmt.Fprintf(t.out, "\nError: at pc=%d, op=%v: %v\n", pc, op, err) diff --git a/eth/tracers/logger/logger_json.go b/eth/tracers/logger/logger_json.go index e32076db6..eed48032c 100644 --- a/eth/tracers/logger/logger_json.go +++ b/eth/tracers/logger/logger_json.go @@ -55,10 +55,11 @@ func (c *callFrame) Type() string { } type jsonLogger struct { - encoder *json.Encoder - cfg *Config - env *tracing.VMContext - hooks *tracing.Hooks + encoder *json.Encoder + cfg *Config + env *tracing.VMContext + hooks *tracing.Hooks + systemCallDepth int } // NewJSONLogger creates a new EVM tracer that prints execution steps as JSON objects @@ -69,11 +70,13 @@ func NewJSONLogger(cfg *Config, writer io.Writer) *tracing.Hooks { l.cfg = &Config{} } l.hooks = &tracing.Hooks{ - OnTxStart: l.OnTxStart, - OnSystemCallStart: l.onSystemCallStart, - OnExit: l.OnExit, - OnOpcode: l.OnOpcode, - OnFault: l.OnFault, + OnTxStart: l.OnTxStart, + OnSystemCallStart: l.onSystemCallStartLegacy, + OnSystemCallStartV2: l.OnSystemCallStart, + OnSystemCallEnd: l.OnSystemCallEnd, + OnExit: l.OnExit, + OnOpcode: l.OnOpcode, + OnFault: l.OnFault, } return l.hooks } @@ -86,22 +89,30 @@ func NewJSONLoggerWithCallFrames(cfg *Config, writer io.Writer) *tracing.Hooks { l.cfg = &Config{} } l.hooks = &tracing.Hooks{ - OnTxStart: l.OnTxStart, - OnSystemCallStart: l.onSystemCallStart, - OnEnter: l.OnEnter, - OnExit: l.OnExit, - OnOpcode: l.OnOpcode, - OnFault: l.OnFault, + OnTxStart: l.OnTxStart, + OnSystemCallStart: l.onSystemCallStartLegacy, + OnSystemCallStartV2: l.OnSystemCallStart, + OnSystemCallEnd: l.OnSystemCallEnd, + OnEnter: l.OnEnter, + OnExit: l.OnExit, + OnOpcode: l.OnOpcode, + OnFault: l.OnFault, } return l.hooks } func (l *jsonLogger) OnFault(pc uint64, op byte, gas uint64, cost uint64, scope tracing.OpContext, depth int, err error) { + if l.systemCallDepth > 0 { + return + } // TODO: Add rData to this interface as well l.OnOpcode(pc, op, gas, cost, scope, nil, depth, err) } func (l *jsonLogger) OnOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, rData []byte, depth int, err error) { + if l.systemCallDepth > 0 { + return + } memory := scope.MemoryData() stack := scope.StackData() @@ -127,18 +138,25 @@ func (l *jsonLogger) OnOpcode(pc uint64, op byte, gas, cost uint64, scope tracin l.encoder.Encode(log) } -func (l *jsonLogger) onSystemCallStart() { - // Process no events while in system call. - hooks := *l.hooks - *l.hooks = tracing.Hooks{ - OnSystemCallEnd: func() { - *l.hooks = hooks - }, +func (l *jsonLogger) OnSystemCallStart(*tracing.VMContext) { + l.systemCallDepth++ +} + +func (l *jsonLogger) onSystemCallStartLegacy() { + l.systemCallDepth++ +} + +func (l *jsonLogger) OnSystemCallEnd() { + if l.systemCallDepth > 0 { + l.systemCallDepth-- } } // OnEnter is not enabled by default. func (l *jsonLogger) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { + if l.systemCallDepth > 0 { + return + } frame := callFrame{ op: vm.OpCode(typ), From: from, @@ -153,6 +171,9 @@ func (l *jsonLogger) OnEnter(depth int, typ byte, from common.Address, to common } func (l *jsonLogger) OnExit(depth int, output []byte, gasUsed uint64, err error, reverted bool) { + if l.systemCallDepth > 0 { + return + } type endLog struct { Output string `json:"output"` GasUsed math.HexOrDecimal64 `json:"gasUsed"` diff --git a/eth/tracers/logger/logger_json_test.go b/eth/tracers/logger/logger_json_test.go new file mode 100644 index 000000000..777644d9f --- /dev/null +++ b/eth/tracers/logger/logger_json_test.go @@ -0,0 +1,73 @@ +package logger + +import ( + "bytes" + "encoding/json" + "math/big" + "strings" + "testing" + + "github.com/holiman/uint256" + "github.com/morph-l2/go-ethereum/common" + "github.com/morph-l2/go-ethereum/core/tracing" + "github.com/morph-l2/go-ethereum/core/types" + "github.com/morph-l2/go-ethereum/core/vm" +) + +type loggerTestStateDB struct{} + +func (loggerTestStateDB) GetBalance(common.Address) *big.Int { return new(big.Int) } +func (loggerTestStateDB) GetNonce(common.Address) uint64 { return 0 } +func (loggerTestStateDB) GetCode(common.Address) []byte { return nil } +func (loggerTestStateDB) GetKeccakCodeHash(common.Address) common.Hash { return common.Hash{} } +func (loggerTestStateDB) GetPoseidonCodeHash(common.Address) common.Hash { return common.Hash{} } +func (loggerTestStateDB) GetState(common.Address, common.Hash) common.Hash { return common.Hash{} } +func (loggerTestStateDB) GetTransientState(common.Address, common.Hash) common.Hash { + return common.Hash{} +} +func (loggerTestStateDB) Exist(common.Address) bool { return false } +func (loggerTestStateDB) GetRefund() uint64 { return 0 } +func (loggerTestStateDB) GetCodeSize(common.Address) uint64 { return 0 } + +type loggerTestScope struct{} + +func (loggerTestScope) MemoryData() []byte { return nil } +func (loggerTestScope) StackData() []uint256.Int { return nil } +func (loggerTestScope) Caller() common.Address { return common.Address{} } +func (loggerTestScope) Address() common.Address { return common.Address{} } +func (loggerTestScope) CallValue() *big.Int { return new(big.Int) } +func (loggerTestScope) CallInput() []byte { return nil } +func (loggerTestScope) ContractCode() []byte { return nil } + +func TestJSONLoggerSkipsSystemCallV2(t *testing.T) { + t.Parallel() + + var out bytes.Buffer + hooks := NewJSONLogger(nil, &out) + tx := types.NewTx(&types.LegacyTx{ + To: &common.Address{}, + Gas: 21000, + GasPrice: big.NewInt(0), + Value: big.NewInt(0), + }) + env := &tracing.VMContext{StateDB: loggerTestStateDB{}} + scope := loggerTestScope{} + + hooks.OnTxStart(env, tx, common.Address{}) + hooks.OnSystemCallStartV2(env) + hooks.OnOpcode(0, byte(vm.SLOAD), 1, 1, scope, nil, 0, nil) + hooks.OnSystemCallEnd() + hooks.OnOpcode(1, byte(vm.SLOAD), 1, 1, scope, nil, 0, nil) + + lines := strings.Split(strings.TrimSpace(out.String()), "\n") + if len(lines) != 1 { + t.Fatalf("unexpected number of log lines: %q", out.String()) + } + var entry map[string]any + if err := json.Unmarshal([]byte(lines[0]), &entry); err != nil { + t.Fatalf("failed to decode log line: %v", err) + } + if entry["pc"] != float64(1) { + t.Fatalf("unexpected logged pc: %v", entry["pc"]) + } +} diff --git a/eth/tracers/native/4byte.go b/eth/tracers/native/4byte.go index 86112e952..e78391040 100644 --- a/eth/tracers/native/4byte.go +++ b/eth/tracers/native/4byte.go @@ -54,6 +54,7 @@ type fourByteTracer struct { reason error // Textual reason for the interruption chainConfig *params.ChainConfig activePrecompiles []common.Address // Updated on CaptureStart based on given rules + systemCallDepth int } // newFourByteTracer returns a native go tracer which collects @@ -65,8 +66,10 @@ func newFourByteTracer(ctx *tracers.Context, cfg json.RawMessage, chainConfig *p } return &tracers.Tracer{ Hooks: &tracing.Hooks{ - OnTxStart: t.OnTxStart, - OnEnter: t.OnEnter, + OnTxStart: t.OnTxStart, + OnEnter: t.OnEnter, + OnSystemCallStartV2: t.OnSystemCallStart, + OnSystemCallEnd: t.OnSystemCallEnd, }, GetResult: t.GetResult, Stop: t.Stop, @@ -98,7 +101,7 @@ func (t *fourByteTracer) OnTxStart(env *tracing.VMContext, tx *types.Transaction // OnEnter is called when EVM enters a new scope (via call, create or selfdestruct). func (t *fourByteTracer) OnEnter(depth int, opcode byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { // Skip if tracing was interrupted - if t.interrupt.Load() { + if t.interrupt.Load() || t.systemCallDepth > 0 { return } if len(input) < 4 { @@ -117,6 +120,16 @@ func (t *fourByteTracer) OnEnter(depth int, opcode byte, from common.Address, to t.store(input[0:4], len(input)-4) } +func (t *fourByteTracer) OnSystemCallStart(*tracing.VMContext) { + t.systemCallDepth++ +} + +func (t *fourByteTracer) OnSystemCallEnd() { + if t.systemCallDepth > 0 { + t.systemCallDepth-- + } +} + // GetResult returns the json-encoded nested list of call traces, and any // error arising from the encoding or forceful termination (via `Stop`). func (t *fourByteTracer) GetResult() (json.RawMessage, error) { diff --git a/eth/tracers/native/call.go b/eth/tracers/native/call.go index d7fb3f9f4..17128402a 100644 --- a/eth/tracers/native/call.go +++ b/eth/tracers/native/call.go @@ -113,13 +113,13 @@ type callFrameMarshaling struct { } type callTracer struct { - callstack []callFrame - config callTracerConfig - gasLimit uint64 - depth int - interrupt atomic.Bool // Atomic flag to signal execution interruption - reason error // Textual reason for the interruption - skip bool + callstack []callFrame + config callTracerConfig + gasLimit uint64 + depth int + interrupt atomic.Bool // Atomic flag to signal execution interruption + reason error // Textual reason for the interruption + systemCallDepth int } type callTracerConfig struct { @@ -161,7 +161,7 @@ func newCallTracerObject(ctx *tracers.Context, cfg json.RawMessage) (*callTracer // OnEnter is called when EVM enters a new scope (via call, create or selfdestruct). func (t *callTracer) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { - if t.skip { + if t.systemCallDepth > 0 { return } t.depth = depth @@ -191,7 +191,7 @@ func (t *callTracer) OnEnter(depth int, typ byte, from common.Address, to common // OnExit is called when EVM exits a scope, even if the scope didn't // execute any code. func (t *callTracer) OnExit(depth int, output []byte, gasUsed uint64, err error, reverted bool) { - if t.skip { + if t.systemCallDepth > 0 { return } if depth == 0 { @@ -227,16 +227,10 @@ func (t *callTracer) captureEnd(output []byte, gasUsed uint64, err error, revert } func (t *callTracer) OnTxStart(env *tracing.VMContext, tx *types.Transaction, from common.Address) { - if t.skip { - return - } t.gasLimit = tx.Gas() } func (t *callTracer) OnTxEnd(receipt *types.Receipt, err error) { - if t.skip { - return - } // Error happened during tx validation. if err != nil { return @@ -251,15 +245,17 @@ func (t *callTracer) OnTxEnd(receipt *types.Receipt, err error) { } func (t *callTracer) OnSystemCall(env *tracing.VMContext) { - t.skip = true + t.systemCallDepth++ } func (t *callTracer) OnSystemCallEnd() { - t.skip = false + if t.systemCallDepth > 0 { + t.systemCallDepth-- + } } func (t *callTracer) OnLog(log *types.Log) { - if t.skip { + if t.systemCallDepth > 0 { return } // Only logs need to be captured via opcode processing diff --git a/eth/tracers/native/call_flat.go b/eth/tracers/native/call_flat.go index dbaa57580..5e7075ec0 100644 --- a/eth/tracers/native/call_flat.go +++ b/eth/tracers/native/call_flat.go @@ -120,6 +120,7 @@ type flatCallTracer struct { ctx *tracers.Context // Holds tracer context data interrupt atomic.Bool // Atomic flag to signal execution interruption activePrecompiles []common.Address // Updated on tx start based on given rules + systemCallDepth int } type flatCallTracerConfig struct { @@ -144,10 +145,12 @@ func newFlatCallTracer(ctx *tracers.Context, cfg json.RawMessage, chainConfig *p ft := &flatCallTracer{tracer: t, ctx: ctx, config: config, chainConfig: chainConfig} return &tracers.Tracer{ Hooks: &tracing.Hooks{ - OnTxStart: ft.OnTxStart, - OnTxEnd: ft.OnTxEnd, - OnEnter: ft.OnEnter, - OnExit: ft.OnExit, + OnTxStart: ft.OnTxStart, + OnTxEnd: ft.OnTxEnd, + OnEnter: ft.OnEnter, + OnExit: ft.OnExit, + OnSystemCallStartV2: ft.OnSystemCallStart, + OnSystemCallEnd: ft.OnSystemCallEnd, }, Stop: ft.Stop, GetResult: ft.GetResult, @@ -156,7 +159,7 @@ func newFlatCallTracer(ctx *tracers.Context, cfg json.RawMessage, chainConfig *p // OnEnter is called when EVM enters a new scope (via call, create or selfdestruct). func (t *flatCallTracer) OnEnter(depth int, typ byte, from common.Address, to common.Address, input []byte, gas uint64, value *big.Int) { - if t.interrupt.Load() { + if t.interrupt.Load() || t.systemCallDepth > 0 { return } t.tracer.OnEnter(depth, typ, from, to, input, gas, value) @@ -174,7 +177,7 @@ func (t *flatCallTracer) OnEnter(depth int, typ byte, from common.Address, to co // OnExit is called when EVM exits a scope, even if the scope didn't // execute any code. func (t *flatCallTracer) OnExit(depth int, output []byte, gasUsed uint64, err error, reverted bool) { - if t.interrupt.Load() { + if t.interrupt.Load() || t.systemCallDepth > 0 { return } t.tracer.OnExit(depth, output, gasUsed, err, reverted) @@ -218,6 +221,24 @@ func (t *flatCallTracer) OnTxEnd(receipt *types.Receipt, err error) { t.tracer.OnTxEnd(receipt, err) } +func (t *flatCallTracer) OnSystemCallStart(env *tracing.VMContext) { + if t.interrupt.Load() { + return + } + t.systemCallDepth++ + t.tracer.OnSystemCall(env) +} + +func (t *flatCallTracer) OnSystemCallEnd() { + if t.interrupt.Load() { + return + } + if t.systemCallDepth > 0 { + t.systemCallDepth-- + } + t.tracer.OnSystemCallEnd() +} + // GetResult returns an empty json object. func (t *flatCallTracer) GetResult() (json.RawMessage, error) { if len(t.tracer.callstack) < 1 { diff --git a/eth/tracers/native/call_flat_test.go b/eth/tracers/native/call_flat_test.go index eb9ec00f5..a56aa26de 100644 --- a/eth/tracers/native/call_flat_test.go +++ b/eth/tracers/native/call_flat_test.go @@ -17,6 +17,7 @@ package native_test import ( + "encoding/json" "errors" "math/big" "testing" @@ -60,3 +61,36 @@ func TestCallFlatStop(t *testing.T) { _, tracerError := tracer.GetResult() require.Equal(t, stopError, tracerError) } + +func TestCallFlatIgnoresHiddenSystemCallFrames(t *testing.T) { + tracer, err := tracers.DefaultDirectory.New("flatCallTracer", &tracers.Context{}, nil, params.MainnetChainConfig) + require.NoError(t, err) + + tx := types.NewTx(&types.LegacyTx{ + Nonce: 0, + To: &common.Address{}, + Value: big.NewInt(0), + Gas: 21000, + GasPrice: big.NewInt(0), + Data: nil, + }) + + tracer.OnTxStart(&tracing.VMContext{BlockNumber: big.NewInt(0)}, tx, common.Address{}) + tracer.OnEnter(0, byte(vm.CALL), common.Address{}, common.Address{}, nil, 21000, big.NewInt(0)) + + tracer.OnSystemCallStartV2(&tracing.VMContext{}) + tracer.OnEnter(1, byte(vm.STATICCALL), common.Address{}, common.Address{}, nil, 1000, nil) + tracer.OnExit(1, nil, 0, nil, false) + tracer.OnSystemCallEnd() + + tracer.OnExit(0, nil, 21000, nil, false) + tracer.OnTxEnd(&types.Receipt{GasUsed: 21000}, nil) + + res, err := tracer.GetResult() + require.NoError(t, err) + + var frames []map[string]any + require.NoError(t, json.Unmarshal(res, &frames)) + require.Len(t, frames, 1) + require.Equal(t, float64(0), frames[0]["subtraces"]) +} diff --git a/eth/tracers/native/mux.go b/eth/tracers/native/mux.go index 95ee0f51c..9c267eae8 100644 --- a/eth/tracers/native/mux.go +++ b/eth/tracers/native/mux.go @@ -58,18 +58,22 @@ func newMuxTracer(ctx *tracers.Context, cfg json.RawMessage, chainConfig *params t := &muxTracer{names: names, tracers: objects} return &tracers.Tracer{ Hooks: &tracing.Hooks{ - OnTxStart: t.OnTxStart, - OnTxEnd: t.OnTxEnd, - OnEnter: t.OnEnter, - OnExit: t.OnExit, - OnOpcode: t.OnOpcode, - OnFault: t.OnFault, - OnGasChange: t.OnGasChange, - OnBalanceChange: t.OnBalanceChange, - OnNonceChange: t.OnNonceChange, - OnCodeChange: t.OnCodeChange, - OnStorageChange: t.OnStorageChange, - OnLog: t.OnLog, + OnTxStart: t.OnTxStart, + OnTxEnd: t.OnTxEnd, + OnEnter: t.OnEnter, + OnExit: t.OnExit, + OnOpcode: t.OnOpcode, + OnFault: t.OnFault, + OnGasChange: t.OnGasChange, + OnBalanceChange: t.OnBalanceChange, + OnNonceChange: t.OnNonceChange, + OnNonceChangeV2: t.OnNonceChangeV2, + OnCodeChange: t.OnCodeChange, + OnStorageChange: t.OnStorageChange, + OnLog: t.OnLog, + OnSystemCallStart: t.OnSystemCallStart, + OnSystemCallStartV2: t.OnSystemCallStartV2, + OnSystemCallEnd: t.OnSystemCallEnd, }, GetResult: t.GetResult, Stop: t.Stop, @@ -148,6 +152,16 @@ func (t *muxTracer) OnNonceChange(a common.Address, prev, new uint64) { } } +func (t *muxTracer) OnNonceChangeV2(a common.Address, prev, new uint64, reason tracing.NonceChangeReason) { + for _, t := range t.tracers { + if t.OnNonceChangeV2 != nil { + t.OnNonceChangeV2(a, prev, new, reason) + } else if t.OnNonceChange != nil { + t.OnNonceChange(a, prev, new) + } + } +} + func (t *muxTracer) OnCodeChange(a common.Address, prevCodeHash common.Hash, prev []byte, codeHash common.Hash, code []byte) { for _, t := range t.tracers { if t.OnCodeChange != nil { @@ -172,6 +186,32 @@ func (t *muxTracer) OnLog(log *types.Log) { } } +func (t *muxTracer) OnSystemCallStart() { + for _, t := range t.tracers { + if t.OnSystemCallStart != nil { + t.OnSystemCallStart() + } + } +} + +func (t *muxTracer) OnSystemCallStartV2(env *tracing.VMContext) { + for _, t := range t.tracers { + if t.OnSystemCallStartV2 != nil { + t.OnSystemCallStartV2(env) + } else if t.OnSystemCallStart != nil { + t.OnSystemCallStart() + } + } +} + +func (t *muxTracer) OnSystemCallEnd() { + for _, t := range t.tracers { + if t.OnSystemCallEnd != nil { + t.OnSystemCallEnd() + } + } +} + // GetResult returns an empty json object. func (t *muxTracer) GetResult() (json.RawMessage, error) { resObject := make(map[string]json.RawMessage) diff --git a/eth/tracers/native/prestate.go b/eth/tracers/native/prestate.go index a85ab8e83..e2fc9d2bf 100644 --- a/eth/tracers/native/prestate.go +++ b/eth/tracers/native/prestate.go @@ -50,7 +50,6 @@ type prestateTracer struct { env *tracing.VMContext pre stateMap post stateMap - create bool to common.Address config prestateTracerConfig chainConfig *params.ChainConfig @@ -87,9 +86,13 @@ func newPrestateTracer(ctx *tracers.Context, cfg json.RawMessage, chainConfig *p } return &tracers.Tracer{ Hooks: &tracing.Hooks{ - OnTxStart: t.OnTxStart, - OnTxEnd: t.OnTxEnd, - OnOpcode: t.OnOpcode, + OnTxStart: t.OnTxStart, + OnTxEnd: t.OnTxEnd, + OnOpcode: t.OnOpcode, + OnBalanceChange: t.OnBalanceChange, + OnNonceChangeV2: t.OnNonceChange, + OnCodeChange: t.OnCodeChange, + OnStorageChange: t.OnStorageChange, }, GetResult: t.GetResult, Stop: t.Stop, @@ -294,27 +297,107 @@ func (t *prestateTracer) Stop(err error) { t.interrupt.Store(true) } -// lookupAccount fetches details of an account and adds it to the prestate -// if it doesn't exist there. -func (t *prestateTracer) lookupAccount(addr common.Address) { - if _, ok := t.pre[addr]; ok { +func (t *prestateTracer) OnBalanceChange(addr common.Address, prev, _ *big.Int, reason tracing.BalanceChangeReason) { + if t.interrupt.Load() || t.env == nil { + return + } + acc, existed := t.ensureAccount(addr) + if !existed { + acc.Balance = new(big.Int).Set(prev) + acc.empty = isEmptyAccount(acc.Balance, acc.Nonce, acc.Code, acc.CodeHash) + } +} + +func (t *prestateTracer) OnNonceChange(addr common.Address, prev, _ uint64, reason tracing.NonceChangeReason) { + if t.interrupt.Load() || t.env == nil { return } + acc, existed := t.ensureAccount(addr) + if !existed { + acc.Nonce = prev + acc.empty = isEmptyAccount(acc.Balance, acc.Nonce, acc.Code, acc.CodeHash) + } +} - t.pre[addr] = &account{ - Balance: t.env.StateDB.GetBalance(addr), - Nonce: t.env.StateDB.GetNonce(addr), - Code: t.env.StateDB.GetCode(addr), - Storage: make(map[common.Hash]common.Hash), +func (t *prestateTracer) OnCodeChange(addr common.Address, prevCodeHash common.Hash, prevCode []byte, codeHash common.Hash, code []byte) { + if t.interrupt.Load() || t.env == nil { + return + } + acc, existed := t.ensureAccount(addr) + if !existed { + acc.Code = common.CopyBytes(prevCode) + acc.CodeHash = normalizeCodeHash(prevCodeHash) + if t.config.DisableCode { + acc.Code = nil + } + acc.empty = isEmptyAccount(acc.Balance, acc.Nonce, prevCode, normalizeCodeHash(prevCodeHash)) } } +func (t *prestateTracer) OnStorageChange(addr common.Address, slot common.Hash, prev, _ common.Hash) { + if t.interrupt.Load() || t.env == nil { + return + } + acc, _ := t.ensureAccount(addr) + if t.config.DisableStorage { + return + } + if _, ok := acc.Storage[slot]; ok { + return + } + acc.Storage[slot] = prev +} + +// lookupAccount fetches details of an account and adds it to the prestate +// if it doesn't exist there. +func (t *prestateTracer) lookupAccount(addr common.Address) { + t.ensureAccount(addr) +} + // lookupStorage fetches the requested storage slot and adds -// it to the prestate of the given contract. It assumes `lookupAccount` -// has been performed on the contract before. +// it to the prestate of the given contract. It ensures the account +// exists in the prestate before accessing its storage. func (t *prestateTracer) lookupStorage(addr common.Address, key common.Hash) { - if _, ok := t.pre[addr].Storage[key]; ok { + acc, _ := t.ensureAccount(addr) + if t.config.DisableStorage { + return + } + if _, ok := acc.Storage[key]; ok { return } - t.pre[addr].Storage[key] = t.env.StateDB.GetState(addr, key) + acc.Storage[key] = t.env.StateDB.GetState(addr, key) +} + +func (t *prestateTracer) ensureAccount(addr common.Address) (*account, bool) { + if acc, ok := t.pre[addr]; ok { + return acc, true + } + code := t.env.StateDB.GetCode(addr) + acc := &account{ + Balance: t.env.StateDB.GetBalance(addr), + Nonce: t.env.StateDB.GetNonce(addr), + Code: code, + CodeHash: normalizeCodeHash(t.env.StateDB.GetKeccakCodeHash(addr)), + } + acc.empty = isEmptyAccount(acc.Balance, acc.Nonce, code, acc.CodeHash) + if t.config.DisableCode { + acc.Code = nil + } + if !t.config.DisableStorage { + acc.Storage = make(map[common.Hash]common.Hash) + } + t.pre[addr] = acc + return acc, false +} + +func normalizeCodeHash(codeHash common.Hash) *common.Hash { + if codeHash == (common.Hash{}) || codeHash == codehash.EmptyKeccakCodeHash { + return nil + } + h := codeHash + return &h +} + +func isEmptyAccount(balance *big.Int, nonce uint64, code []byte, codeHash *common.Hash) bool { + return nonce == 0 && len(code) == 0 && codeHash == nil && (balance == nil || balance.Sign() == 0) } diff --git a/eth/tracers/native/systemcall_prestate_test.go b/eth/tracers/native/systemcall_prestate_test.go new file mode 100644 index 000000000..cceaa7b58 --- /dev/null +++ b/eth/tracers/native/systemcall_prestate_test.go @@ -0,0 +1,225 @@ +package native + +import ( + "encoding/json" + "math/big" + "testing" + + "github.com/morph-l2/go-ethereum/common" + "github.com/morph-l2/go-ethereum/core/tracing" + "github.com/morph-l2/go-ethereum/core/types" + "github.com/morph-l2/go-ethereum/core/vm" + "github.com/morph-l2/go-ethereum/crypto" + "github.com/morph-l2/go-ethereum/crypto/codehash" + etracers "github.com/morph-l2/go-ethereum/eth/tracers" + "github.com/morph-l2/go-ethereum/params" +) + +type tracerTestStateDB struct { + balances map[common.Address]*big.Int + nonces map[common.Address]uint64 + codes map[common.Address][]byte + storage map[common.Address]map[common.Hash]common.Hash +} + +func (db *tracerTestStateDB) GetBalance(addr common.Address) *big.Int { + if bal, ok := db.balances[addr]; ok { + return new(big.Int).Set(bal) + } + return new(big.Int) +} + +func (db *tracerTestStateDB) GetNonce(addr common.Address) uint64 { + return db.nonces[addr] +} + +func (db *tracerTestStateDB) GetCode(addr common.Address) []byte { + return common.CopyBytes(db.codes[addr]) +} + +func (db *tracerTestStateDB) GetKeccakCodeHash(addr common.Address) common.Hash { + code := db.codes[addr] + if len(code) == 0 { + return codehash.EmptyKeccakCodeHash + } + return crypto.Keccak256Hash(code) +} + +func (db *tracerTestStateDB) GetPoseidonCodeHash(common.Address) common.Hash { + return common.Hash{} +} + +func (db *tracerTestStateDB) GetState(addr common.Address, key common.Hash) common.Hash { + if slots, ok := db.storage[addr]; ok { + return slots[key] + } + return common.Hash{} +} + +func (db *tracerTestStateDB) GetTransientState(common.Address, common.Hash) common.Hash { + return common.Hash{} +} + +func (db *tracerTestStateDB) Exist(addr common.Address) bool { + if bal := db.GetBalance(addr); bal.Sign() != 0 { + return true + } + if db.GetNonce(addr) != 0 || len(db.codes[addr]) != 0 { + return true + } + return len(db.storage[addr]) != 0 +} + +func (db *tracerTestStateDB) GetRefund() uint64 { + return 0 +} + +func (db *tracerTestStateDB) GetCodeSize(addr common.Address) uint64 { + return uint64(len(db.codes[addr])) +} + +func newTracerTestStateDB() *tracerTestStateDB { + return &tracerTestStateDB{ + balances: make(map[common.Address]*big.Int), + nonces: make(map[common.Address]uint64), + codes: make(map[common.Address][]byte), + storage: make(map[common.Address]map[common.Hash]common.Hash), + } +} + +func newPrestateTestTracer(cfg prestateTracerConfig, db *tracerTestStateDB) *prestateTracer { + return &prestateTracer{ + env: &tracing.VMContext{ + BlockNumber: big.NewInt(0), + Time: 0, + StateDB: db, + }, + pre: stateMap{}, + post: stateMap{}, + config: cfg, + chainConfig: params.TestChainConfig, + created: make(map[common.Address]bool), + deleted: make(map[common.Address]bool), + } +} + +func TestPrestateTracerLookupStorageCreatesAccount(t *testing.T) { + t.Parallel() + + db := newTracerTestStateDB() + addr := common.HexToAddress("0x100") + slot := common.HexToHash("0x1") + value := common.HexToHash("0x2") + db.codes[addr] = []byte{0x60, 0x00} + db.storage[addr] = map[common.Hash]common.Hash{slot: value} + + tracer := newPrestateTestTracer(prestateTracerConfig{}, db) + tracer.lookupStorage(addr, slot) + + if tracer.pre[addr] == nil { + t.Fatalf("expected account to be added to prestate") + } + if got := tracer.pre[addr].Storage[slot]; got != value { + t.Fatalf("unexpected storage value, have %s want %s", got, value) + } +} + +func TestPrestateTracerOnStorageChangeUsesPreviousValue(t *testing.T) { + t.Parallel() + + db := newTracerTestStateDB() + addr := common.HexToAddress("0x200") + slot := common.HexToHash("0x3") + prev := common.HexToHash("0x4") + next := common.HexToHash("0x5") + db.codes[addr] = []byte{0x60, 0x00} + db.storage[addr] = map[common.Hash]common.Hash{slot: next} + + tracer := newPrestateTestTracer(prestateTracerConfig{}, db) + tracer.OnStorageChange(addr, slot, prev, next) + + if got := tracer.pre[addr].Storage[slot]; got != prev { + t.Fatalf("unexpected prestate slot value, have %s want %s", got, prev) + } +} + +func TestPrestateTracerOnBalanceChangeUsesPreviousBalance(t *testing.T) { + t.Parallel() + + db := newTracerTestStateDB() + addr := common.HexToAddress("0x300") + db.balances[addr] = big.NewInt(9) + + tracer := newPrestateTestTracer(prestateTracerConfig{}, db) + tracer.OnBalanceChange(addr, big.NewInt(4), big.NewInt(9), tracing.BalanceChangeTransfer) + + if got := tracer.pre[addr].Balance; got.Cmp(big.NewInt(4)) != 0 { + t.Fatalf("unexpected prestate balance, have %v want %v", got, 4) + } +} + +func TestFourByteTracerSkipsSystemCalls(t *testing.T) { + t.Parallel() + + tracer, err := newFourByteTracer(nil, nil, params.TestChainConfig) + if err != nil { + t.Fatalf("failed to create tracer: %v", err) + } + tx := types.NewTx(&types.LegacyTx{ + To: &common.Address{}, + Gas: 21000, + GasPrice: big.NewInt(0), + Value: big.NewInt(0), + }) + tracer.OnTxStart(&tracing.VMContext{BlockNumber: big.NewInt(0)}, tx, common.Address{}) + + systemInput := append(common.FromHex("0x70a08231"), make([]byte, 32)...) + userInput := append(common.FromHex("0xa9059cbb"), make([]byte, 64)...) + + tracer.OnSystemCallStartV2(&tracing.VMContext{}) + tracer.OnEnter(1, byte(vm.CALL), common.Address{}, common.HexToAddress("0x1111"), systemInput, 0, big.NewInt(0)) + tracer.OnSystemCallEnd() + tracer.OnEnter(1, byte(vm.CALL), common.Address{}, common.HexToAddress("0x2222"), userInput, 0, big.NewInt(0)) + + res, err := tracer.GetResult() + if err != nil { + t.Fatalf("failed to get trace result: %v", err) + } + var ids map[string]int + if err := json.Unmarshal(res, &ids); err != nil { + t.Fatalf("failed to decode trace result: %v", err) + } + if len(ids) != 1 || ids["0xa9059cbb-64"] != 1 { + t.Fatalf("unexpected 4byte result: %v", ids) + } +} + +func TestMuxTracerForwardsSystemCallsAndNonceChangeV2(t *testing.T) { + t.Parallel() + + var ( + systemStarts int + systemEnds int + nonceV2Calls int + ) + sub := &etracers.Tracer{ + Hooks: &tracing.Hooks{ + OnSystemCallStartV2: func(*tracing.VMContext) { systemStarts++ }, + OnSystemCallEnd: func() { systemEnds++ }, + OnNonceChangeV2: func(common.Address, uint64, uint64, tracing.NonceChangeReason) { + nonceV2Calls++ + }, + }, + GetResult: func() (json.RawMessage, error) { return json.RawMessage(`{}`), nil }, + Stop: func(error) {}, + } + mux := &muxTracer{tracers: []*etracers.Tracer{sub}} + + mux.OnSystemCallStartV2(&tracing.VMContext{}) + mux.OnSystemCallEnd() + mux.OnNonceChangeV2(common.Address{}, 1, 2, tracing.NonceChangeUnspecified) + + if systemStarts != 1 || systemEnds != 1 || nonceV2Calls != 1 { + t.Fatalf("unexpected forwarded counts: starts=%d ends=%d nonceV2=%d", systemStarts, systemEnds, nonceV2Calls) + } +}