diff --git a/core/state/statedb.go b/core/state/statedb.go index 0afd7554dcee..6684f41cafee 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -18,6 +18,7 @@ package state import ( + "bytes" "fmt" "math/big" "sort" @@ -270,9 +271,12 @@ func (self *StateDB) GetCodeSize(addr common.Address) int { if stateObject.code != nil { return len(stateObject.code) } + if bytes.Equal(stateObject.CodeHash(), emptyCode[:]) { + return 0 + } size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) if err != nil { - self.setError(err) + self.setError(fmt.Errorf("GetCodeSize (%x) error: %v", addr[:], err)) } return size } @@ -419,14 +423,18 @@ func (self *StateDB) updateStateObject(stateObject *stateObject) { if err != nil { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } - self.setError(self.trie.TryUpdate(addr[:], data)) + if err = self.trie.TryUpdate(addr[:], data); err != nil { + self.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) + } } // deleteStateObject removes the given object from the state trie. func (self *StateDB) deleteStateObject(stateObject *stateObject) { stateObject.deleted = true addr := stateObject.Address() - self.setError(self.trie.TryDelete(addr[:])) + if err := self.trie.TryDelete(addr[:]); err != nil { + self.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) + } } // DeleteAddress removes the address from the state trie. @@ -678,6 +686,10 @@ func (s *StateDB) clearJournalAndRefund() { // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { + if s.dbErr != nil { + return common.Hash{}, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr) + } + defer s.clearJournalAndRefund() // Commit objects to the trie. diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 15533ec5a086..c5642d5db35b 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -615,3 +615,51 @@ func TestStateDBAccessList(t *testing.T) { t.Fatalf("expected empty, got %d", got) } } + +// TestMissingTrieNodes tests that if the statedb fails to load parts of the trie, +// the Commit operation fails with an error +// If we are missing trie nodes, we should not continue writing to the trie +func TestMissingTrieNodes(t *testing.T) { + + // Create an initial state with a few accounts + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + var root common.Hash + state, _ := New(common.Hash{}, db) + addr := toAddr([]byte("so")) + { + state.SetBalance(addr, big.NewInt(1)) + state.SetCode(addr, []byte{1, 2, 3}) + a2 := toAddr([]byte("another")) + state.SetBalance(a2, big.NewInt(100)) + state.SetCode(a2, []byte{1, 2, 4}) + root, _ = state.Commit(false) + t.Logf("root: %x", root) + // force-flush + state.Database().TrieDB().Cap(0) + } + // Create a new state on the old root + state, _ = New(root, db) + //state, _ = New(root, db, nil) + // Now we clear out the memdb + it := memDb.NewIterator(nil, nil) + for it.Next() { + k := it.Key() + // Leave the root intact + if !bytes.Equal(k, root[:]) { + t.Logf("key: %x", k) + memDb.Delete(k) + } + } + balance := state.GetBalance(addr) + // The removed elem should lead to it returning zero balance + if exp, got := uint64(0), balance.Uint64(); got != exp { + t.Errorf("expected %d, got %d", exp, got) + } + // Modify the state + state.SetBalance(addr, big.NewInt(2)) + root, err := state.Commit(false) + if err == nil { + t.Fatalf("expected error, got root :%x", root) + } +}