diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index d3d45b8f09f1..4391abec7e33 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -130,7 +130,12 @@ func (dl *downloadTester) HasFastBlock(hash common.Hash, number uint64) bool { func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header { dl.lock.RLock() defer dl.lock.RUnlock() + return dl.getHeaderByHash(hash) +} +// getHeaderByHash returns the header if found either within ancients or own blocks) +// This method assumes that the caller holds at least the read-lock (dl.lock) +func (dl *downloadTester) getHeaderByHash(hash common.Hash) *types.Header { return dl.ownHeaders[hash] } @@ -197,7 +202,13 @@ func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { func (dl *downloadTester) GetTd(hash common.Hash, number uint64) *big.Int { dl.lock.RLock() defer dl.lock.RUnlock() + return dl.getTd(hash) +} +// getTd retrieves the block's total difficulty if found either within +// ancients or own blocks). +// This method assumes that the caller holds at least the read-lock (dl.lock) +func (dl *downloadTester) getTd(hash common.Hash) *big.Int { return dl.ownChainTd[hash] } @@ -205,27 +216,34 @@ func (dl *downloadTester) GetTd(hash common.Hash, number uint64) *big.Int { func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq int) (i int, err error) { dl.lock.Lock() defer dl.lock.Unlock() - // Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anything in case of errors - if _, ok := dl.ownHeaders[headers[0].ParentHash]; !ok { - return 0, errors.New("unknown parent") + if dl.getHeaderByHash(headers[0].ParentHash) == nil { + return 0, fmt.Errorf("InsertHeaderChain: unknown parent at first position, parent of number %d", headers[0].Number) } + var hashes []common.Hash for i := 1; i < len(headers); i++ { + hash := headers[i-1].Hash() if headers[i].ParentHash != headers[i-1].Hash() { - return i, errors.New("unknown parent") + return i, fmt.Errorf("non-contiguous import at position %d", i) } + hashes = append(hashes, hash) } + hashes = append(hashes, headers[len(headers)-1].Hash()) // Do a full insert if pre-checks passed for i, header := range headers { - if _, ok := dl.ownHeaders[header.Hash()]; ok { + hash := hashes[i] + if dl.getHeaderByHash(hash) != nil { continue } - if _, ok := dl.ownHeaders[header.ParentHash]; !ok { - return i, errors.New("unknown parent") + if dl.getHeaderByHash(header.ParentHash) == nil { + // This _should_ be impossible, due to precheck and induction + return i, fmt.Errorf("InsertHeaderChain: unknown parent at position %d", i) } - dl.ownHashes = append(dl.ownHashes, header.Hash()) - dl.ownHeaders[header.Hash()] = header - dl.ownChainTd[header.Hash()] = new(big.Int).Add(dl.ownChainTd[header.ParentHash], header.Difficulty) + dl.ownHashes = append(dl.ownHashes, hash) + dl.ownHeaders[hash] = header + + td := dl.getTd(header.ParentHash) + dl.ownChainTd[hash] = new(big.Int).Add(td, header.Difficulty) } return len(headers), nil } @@ -237,9 +255,9 @@ func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) { for i, block := range blocks { if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok { - return i, errors.New("unknown parent") + return i, fmt.Errorf("InsertChain: unknown parent at position %d / %d", i, len(blocks)) } else if _, err := dl.stateDb.Get(parent.Root().Bytes()); err != nil { - return i, fmt.Errorf("unknown parent state %x: %v", parent.Root(), err) + return i, fmt.Errorf("InsertChain: unknown parent state %x: %v", parent.Root(), err) } if _, ok := dl.ownHeaders[block.Hash()]; !ok { dl.ownHashes = append(dl.ownHashes, block.Hash()) @@ -262,7 +280,7 @@ func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []typ return i, errors.New("unknown owner") } if _, ok := dl.ownBlocks[blocks[i].ParentHash()]; !ok { - return i, errors.New("unknown parent") + return i, errors.New("InsertReceiptChain: unknown parent") } dl.ownBlocks[blocks[i].Hash()] = blocks[i] dl.ownReceipts[blocks[i].Hash()] = receipts[i]