diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 1b9f00c695..e18db00ed1 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -144,7 +144,12 @@ func (dl *downloadTester) HasFastBlock(hash common.Hash, number uint64) bool { func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header { dl.lock.RLock() defer dl.lock.RUnlock() + return dl.getHeaderByHash(hash) +} +// getHeaderByHash returns the header if found either within ancients or own blocks) +// This method assumes that the caller holds at least the read-lock (dl.lock) +func (dl *downloadTester) getHeaderByHash(hash common.Hash) *types.Header { header := dl.ancientHeaders[hash] if header != nil { return header @@ -232,6 +237,13 @@ func (dl *downloadTester) GetTd(hash common.Hash, number uint64) *big.Int { dl.lock.RLock() defer dl.lock.RUnlock() + return dl.getTd(hash) +} + +// getTd retrieves the block's total difficulty if found either within +// ancients or own blocks). +// This method assumes that the caller holds at least the read-lock (dl.lock) +func (dl *downloadTester) getTd(hash common.Hash) *big.Int { if td := dl.ancientChainTd[hash]; td != nil { return td } @@ -243,8 +255,8 @@ func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq i dl.lock.Lock() defer dl.lock.Unlock() // Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anything in case of errors - if _, ok := dl.ownHeaders[headers[0].ParentHash]; !ok { - return 0, errors.New("InsertHeaderChain: unknown parent at first position") + if dl.getHeaderByHash(headers[0].ParentHash) == nil { + return 0, fmt.Errorf("InsertHeaderChain: unknown parent at first position, parent of number %d", headers[0].Number) } var hashes []common.Hash for i := 1; i < len(headers); i++ { @@ -258,16 +270,18 @@ func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq i // Do a full insert if pre-checks passed for i, header := range headers { hash := hashes[i] - if _, ok := dl.ownHeaders[hash]; ok { + if dl.getHeaderByHash(hash) != nil { continue } - if _, ok := dl.ownHeaders[header.ParentHash]; !ok { + if dl.getHeaderByHash(header.ParentHash) == nil { // This _should_ be impossible, due to precheck and induction return i, fmt.Errorf("InsertHeaderChain: unknown parent at position %d", i) } dl.ownHashes = append(dl.ownHashes, hash) dl.ownHeaders[hash] = header - dl.ownChainTd[hash] = new(big.Int).Add(dl.ownChainTd[header.ParentHash], header.Difficulty) + + td := dl.getTd(header.ParentHash) + dl.ownChainTd[hash] = new(big.Int).Add(td, header.Difficulty) } return len(headers), nil } @@ -276,7 +290,6 @@ func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq i func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) { dl.lock.Lock() defer dl.lock.Unlock() - for i, block := range blocks { if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok { return i, fmt.Errorf("InsertChain: unknown parent at position %d / %d", i, len(blocks)) @@ -290,7 +303,9 @@ func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) { dl.ownBlocks[block.Hash()] = block dl.ownReceipts[block.Hash()] = make(types.Receipts, 0) dl.stateDb.Put(block.Root().Bytes(), []byte{0x00}) - dl.ownChainTd[block.Hash()] = new(big.Int).Add(dl.ownChainTd[block.ParentHash()], block.Difficulty()) + + td := dl.getTd(block.ParentHash()) + dl.ownChainTd[block.Hash()] = new(big.Int).Add(td, block.Difficulty()) } return len(blocks), nil } @@ -316,7 +331,6 @@ func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []typ // Migrate from active db to ancient db dl.ancientHeaders[blocks[i].Hash()] = blocks[i].Header() dl.ancientChainTd[blocks[i].Hash()] = new(big.Int).Add(dl.ancientChainTd[blocks[i].ParentHash()], blocks[i].Difficulty()) - delete(dl.ownHeaders, blocks[i].Hash()) delete(dl.ownChainTd, blocks[i].Hash()) } else { diff --git a/miner/worker_test.go b/miner/worker_test.go index 65eccbc4c0..a5c558ba5f 100644 --- a/miner/worker_test.go +++ b/miner/worker_test.go @@ -285,9 +285,7 @@ func testEmptyWork(t *testing.T, chainConfig *params.ChainConfig, engine consens } w.skipSealHook = func(task *task) bool { return true } w.fullTaskHook = func() { - // Arch64 unit tests are running in a VM on travis, they must - // be given more time to execute. - time.Sleep(time.Second) + time.Sleep(100 * time.Millisecond) } w.start() // Start mining! for i := 0; i < 2; i += 1 {