core: improve shutdown synchronization in BlockChain (#22853)

This change removes misuses of sync.WaitGroup in BlockChain. Before this change,
block insertion modified the WaitGroup counter in order to ensure that Stop would wait
for pending operations to complete. This was racy and could even lead to crashes
if Stop was called at an unfortunate time. The issue is resolved by adding a specialized
'closable' mutex, which prevents chain modifications after stopping while also
synchronizing writers with each other.

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Martin Holst Swende 2021-10-07 15:47:50 +02:00 committed by GitHub
parent 4e599ee469
commit edb1937cf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 164 additions and 61 deletions

@ -39,6 +39,7 @@ import (
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/syncx"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
@ -80,6 +81,7 @@ var (
blockPrefetchInterruptMeter = metrics.NewRegisteredMeter("chain/prefetch/interrupts", nil) blockPrefetchInterruptMeter = metrics.NewRegisteredMeter("chain/prefetch/interrupts", nil)
errInsertionInterrupted = errors.New("insertion is interrupted") errInsertionInterrupted = errors.New("insertion is interrupted")
errChainStopped = errors.New("blockchain is stopped")
) )
const ( const (
@ -183,7 +185,9 @@ type BlockChain struct {
scope event.SubscriptionScope scope event.SubscriptionScope
genesisBlock *types.Block genesisBlock *types.Block
chainmu sync.RWMutex // blockchain insertion lock // This mutex synchronizes chain write operations.
// Readers don't need to take it, they can just read the database.
chainmu *syncx.ClosableMutex
currentBlock atomic.Value // Current head of the block chain currentBlock atomic.Value // Current head of the block chain
currentFastBlock atomic.Value // Current head of the fast-sync chain (may be above the block chain!) currentFastBlock atomic.Value // Current head of the fast-sync chain (may be above the block chain!)
@ -196,8 +200,8 @@ type BlockChain struct {
txLookupCache *lru.Cache // Cache for the most recent transaction lookup data. txLookupCache *lru.Cache // Cache for the most recent transaction lookup data.
futureBlocks *lru.Cache // future blocks are blocks added for later processing futureBlocks *lru.Cache // future blocks are blocks added for later processing
quit chan struct{} // blockchain quit channel wg sync.WaitGroup //
wg sync.WaitGroup // chain processing wait group for shutting down quit chan struct{} // shutdown signal, closed in Stop.
running int32 // 0 if chain is running, 1 when stopped running int32 // 0 if chain is running, 1 when stopped
procInterrupt int32 // interrupt signaler for block processing procInterrupt int32 // interrupt signaler for block processing
@ -235,6 +239,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
Preimages: cacheConfig.Preimages, Preimages: cacheConfig.Preimages,
}), }),
quit: make(chan struct{}), quit: make(chan struct{}),
chainmu: syncx.NewClosableMutex(),
shouldPreserve: shouldPreserve, shouldPreserve: shouldPreserve,
bodyCache: bodyCache, bodyCache: bodyCache,
bodyRLPCache: bodyRLPCache, bodyRLPCache: bodyRLPCache,
@ -278,6 +283,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
if err := bc.loadLastState(); err != nil { if err := bc.loadLastState(); err != nil {
return nil, err return nil, err
} }
// Make sure the state associated with the block is available // Make sure the state associated with the block is available
head := bc.CurrentBlock() head := bc.CurrentBlock()
if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil { if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil {
@ -306,6 +312,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
} }
} }
} }
// Ensure that a previous crash in SetHead doesn't leave extra ancients // Ensure that a previous crash in SetHead doesn't leave extra ancients
if frozen, err := bc.db.Ancients(); err == nil && frozen > 0 { if frozen, err := bc.db.Ancients(); err == nil && frozen > 0 {
var ( var (
@ -357,6 +364,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
} }
} }
} }
// Load any existing snapshot, regenerating it if loading failed // Load any existing snapshot, regenerating it if loading failed
if bc.cacheConfig.SnapshotLimit > 0 { if bc.cacheConfig.SnapshotLimit > 0 {
// If the chain was rewound past the snapshot persistent layer (causing // If the chain was rewound past the snapshot persistent layer (causing
@ -372,14 +380,19 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
} }
bc.snaps, _ = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, head.Root(), !bc.cacheConfig.SnapshotWait, true, recover) bc.snaps, _ = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, head.Root(), !bc.cacheConfig.SnapshotWait, true, recover)
} }
// Take ownership of this particular state
go bc.update() // Start future block processor.
bc.wg.Add(1)
go bc.futureBlocksLoop()
// Start tx indexer/unindexer.
if txLookupLimit != nil { if txLookupLimit != nil {
bc.txLookupLimit = *txLookupLimit bc.txLookupLimit = *txLookupLimit
bc.wg.Add(1) bc.wg.Add(1)
go bc.maintainTxIndex(txIndexBlock) go bc.maintainTxIndex(txIndexBlock)
} }
// If periodic cache journal is required, spin it up. // If periodic cache journal is required, spin it up.
if bc.cacheConfig.TrieCleanRejournal > 0 { if bc.cacheConfig.TrieCleanRejournal > 0 {
if bc.cacheConfig.TrieCleanRejournal < time.Minute { if bc.cacheConfig.TrieCleanRejournal < time.Minute {
@ -488,7 +501,9 @@ func (bc *BlockChain) SetHead(head uint64) error {
// //
// The method returns the block number where the requested root cap was found. // The method returns the block number where the requested root cap was found.
func (bc *BlockChain) SetHeadBeyondRoot(head uint64, root common.Hash) (uint64, error) { func (bc *BlockChain) SetHeadBeyondRoot(head uint64, root common.Hash) (uint64, error) {
bc.chainmu.Lock() if !bc.chainmu.TryLock() {
return 0, errChainStopped
}
defer bc.chainmu.Unlock() defer bc.chainmu.Unlock()
// Track the block number of the requested root hash // Track the block number of the requested root hash
@ -633,8 +648,11 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB()); err != nil { if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB()); err != nil {
return err return err
} }
// If all checks out, manually set the head block
bc.chainmu.Lock() // If all checks out, manually set the head block.
if !bc.chainmu.TryLock() {
return errChainStopped
}
bc.currentBlock.Store(block) bc.currentBlock.Store(block)
headBlockGauge.Update(int64(block.NumberU64())) headBlockGauge.Update(int64(block.NumberU64()))
bc.chainmu.Unlock() bc.chainmu.Unlock()
@ -707,7 +725,9 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
if err := bc.SetHead(0); err != nil { if err := bc.SetHead(0); err != nil {
return err return err
} }
bc.chainmu.Lock() if !bc.chainmu.TryLock() {
return errChainStopped
}
defer bc.chainmu.Unlock() defer bc.chainmu.Unlock()
// Prepare the genesis block and reinitialise the chain // Prepare the genesis block and reinitialise the chain
@ -737,8 +757,10 @@ func (bc *BlockChain) Export(w io.Writer) error {
// ExportN writes a subset of the active chain to the given writer. // ExportN writes a subset of the active chain to the given writer.
func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error { func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
bc.chainmu.RLock() if !bc.chainmu.TryLock() {
defer bc.chainmu.RUnlock() return errChainStopped
}
defer bc.chainmu.Unlock()
if first > last { if first > last {
return fmt.Errorf("export failed: first (%d) is greater than last (%d)", first, last) return fmt.Errorf("export failed: first (%d) is greater than last (%d)", first, last)
@ -991,10 +1013,21 @@ func (bc *BlockChain) Stop() {
if !atomic.CompareAndSwapInt32(&bc.running, 0, 1) { if !atomic.CompareAndSwapInt32(&bc.running, 0, 1) {
return return
} }
// Unsubscribe all subscriptions registered from blockchain
// Unsubscribe all subscriptions registered from blockchain.
bc.scope.Close() bc.scope.Close()
// Signal shutdown to all goroutines.
close(bc.quit) close(bc.quit)
bc.StopInsert() bc.StopInsert()
// Now wait for all chain modifications to end and persistent goroutines to exit.
//
// Note: Close waits for the mutex to become available, i.e. any running chain
// modification will have exited when Close returns. Since we also called StopInsert,
// the mutex should become available quickly. It cannot be taken again after Close has
// returned.
bc.chainmu.Close()
bc.wg.Wait() bc.wg.Wait()
// Ensure that the entirety of the state snapshot is journalled to disk. // Ensure that the entirety of the state snapshot is journalled to disk.
@ -1005,6 +1038,7 @@ func (bc *BlockChain) Stop() {
log.Error("Failed to journal state snapshot", "err", err) log.Error("Failed to journal state snapshot", "err", err)
} }
} }
// Ensure the state of a recent block is also stored to disk before exiting. // Ensure the state of a recent block is also stored to disk before exiting.
// We're writing three different states to catch different restart scenarios: // We're writing three different states to catch different restart scenarios:
// - HEAD: So we don't need to reprocess any blocks in the general case // - HEAD: So we don't need to reprocess any blocks in the general case
@ -1128,7 +1162,9 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
// updateHead updates the head fast sync block if the inserted blocks are better // updateHead updates the head fast sync block if the inserted blocks are better
// and returns an indicator whether the inserted blocks are canonical. // and returns an indicator whether the inserted blocks are canonical.
updateHead := func(head *types.Block) bool { updateHead := func(head *types.Block) bool {
bc.chainmu.Lock() if !bc.chainmu.TryLock() {
return false
}
defer bc.chainmu.Unlock() defer bc.chainmu.Unlock()
// Rewind may have occurred, skip in that case. // Rewind may have occurred, skip in that case.
@ -1372,8 +1408,9 @@ var lastWrite uint64
// but does not write any state. This is used to construct competing side forks // but does not write any state. This is used to construct competing side forks
// up to the point where they exceed the canonical total difficulty. // up to the point where they exceed the canonical total difficulty.
func (bc *BlockChain) writeBlockWithoutState(block *types.Block, td *big.Int) (err error) { func (bc *BlockChain) writeBlockWithoutState(block *types.Block, td *big.Int) (err error) {
bc.wg.Add(1) if bc.insertStopped() {
defer bc.wg.Done() return errInsertionInterrupted
}
batch := bc.db.NewBatch() batch := bc.db.NewBatch()
rawdb.WriteTd(batch, block.Hash(), block.NumberU64(), td) rawdb.WriteTd(batch, block.Hash(), block.NumberU64(), td)
@ -1387,9 +1424,6 @@ func (bc *BlockChain) writeBlockWithoutState(block *types.Block, td *big.Int) (e
// writeKnownBlock updates the head block flag with a known block // writeKnownBlock updates the head block flag with a known block
// and introduces chain reorg if necessary. // and introduces chain reorg if necessary.
func (bc *BlockChain) writeKnownBlock(block *types.Block) error { func (bc *BlockChain) writeKnownBlock(block *types.Block) error {
bc.wg.Add(1)
defer bc.wg.Done()
current := bc.CurrentBlock() current := bc.CurrentBlock()
if block.ParentHash() != current.Hash() { if block.ParentHash() != current.Hash() {
if err := bc.reorg(current, block); err != nil { if err := bc.reorg(current, block); err != nil {
@ -1402,17 +1436,19 @@ func (bc *BlockChain) writeKnownBlock(block *types.Block) error {
// WriteBlockWithState writes the block and all associated state to the database. // WriteBlockWithState writes the block and all associated state to the database.
func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, logs []*types.Log, state *state.StateDB, emitHeadEvent bool) (status WriteStatus, err error) { func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, logs []*types.Log, state *state.StateDB, emitHeadEvent bool) (status WriteStatus, err error) {
bc.chainmu.Lock() if !bc.chainmu.TryLock() {
return NonStatTy, errInsertionInterrupted
}
defer bc.chainmu.Unlock() defer bc.chainmu.Unlock()
return bc.writeBlockWithState(block, receipts, logs, state, emitHeadEvent) return bc.writeBlockWithState(block, receipts, logs, state, emitHeadEvent)
} }
// writeBlockWithState writes the block and all associated state to the database, // writeBlockWithState writes the block and all associated state to the database,
// but is expects the chain mutex to be held. // but is expects the chain mutex to be held.
func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.Receipt, logs []*types.Log, state *state.StateDB, emitHeadEvent bool) (status WriteStatus, err error) { func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.Receipt, logs []*types.Log, state *state.StateDB, emitHeadEvent bool) (status WriteStatus, err error) {
bc.wg.Add(1) if bc.insertStopped() {
defer bc.wg.Done() return NonStatTy, errInsertionInterrupted
}
// Calculate the total difficulty of the block // Calculate the total difficulty of the block
ptd := bc.GetTd(block.ParentHash(), block.NumberU64()-1) ptd := bc.GetTd(block.ParentHash(), block.NumberU64()-1)
@ -1576,31 +1612,28 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
bc.blockProcFeed.Send(true) bc.blockProcFeed.Send(true)
defer bc.blockProcFeed.Send(false) defer bc.blockProcFeed.Send(false)
// Remove already known canon-blocks // Do a sanity check that the provided chain is actually ordered and linked.
var (
block, prev *types.Block
)
// Do a sanity check that the provided chain is actually ordered and linked
for i := 1; i < len(chain); i++ { for i := 1; i < len(chain); i++ {
block = chain[i] block, prev := chain[i], chain[i-1]
prev = chain[i-1]
if block.NumberU64() != prev.NumberU64()+1 || block.ParentHash() != prev.Hash() { if block.NumberU64() != prev.NumberU64()+1 || block.ParentHash() != prev.Hash() {
// Chain broke ancestry, log a message (programming error) and skip insertion log.Error("Non contiguous block insert",
log.Error("Non contiguous block insert", "number", block.Number(), "hash", block.Hash(), "number", block.Number(),
"parent", block.ParentHash(), "prevnumber", prev.Number(), "prevhash", prev.Hash()) "hash", block.Hash(),
"parent", block.ParentHash(),
"prevnumber", prev.Number(),
"prevhash", prev.Hash(),
)
return 0, fmt.Errorf("non contiguous insert: item %d is #%d [%x..], item %d is #%d [%x..] (parent [%x..])", i-1, prev.NumberU64(), return 0, fmt.Errorf("non contiguous insert: item %d is #%d [%x..], item %d is #%d [%x..] (parent [%x..])", i-1, prev.NumberU64(),
prev.Hash().Bytes()[:4], i, block.NumberU64(), block.Hash().Bytes()[:4], block.ParentHash().Bytes()[:4]) prev.Hash().Bytes()[:4], i, block.NumberU64(), block.Hash().Bytes()[:4], block.ParentHash().Bytes()[:4])
} }
} }
// Pre-checks passed, start the full block imports
bc.wg.Add(1)
bc.chainmu.Lock()
n, err := bc.insertChain(chain, true)
bc.chainmu.Unlock()
bc.wg.Done()
return n, err // Pre-check passed, start the full block imports.
if !bc.chainmu.TryLock() {
return 0, errChainStopped
}
defer bc.chainmu.Unlock()
return bc.insertChain(chain, true)
} }
// InsertChainWithoutSealVerification works exactly the same // InsertChainWithoutSealVerification works exactly the same
@ -1609,14 +1642,11 @@ func (bc *BlockChain) InsertChainWithoutSealVerification(block *types.Block) (in
bc.blockProcFeed.Send(true) bc.blockProcFeed.Send(true)
defer bc.blockProcFeed.Send(false) defer bc.blockProcFeed.Send(false)
// Pre-checks passed, start the full block imports if !bc.chainmu.TryLock() {
bc.wg.Add(1) return 0, errChainStopped
bc.chainmu.Lock() }
n, err := bc.insertChain(types.Blocks([]*types.Block{block}), false) defer bc.chainmu.Unlock()
bc.chainmu.Unlock() return bc.insertChain(types.Blocks([]*types.Block{block}), false)
bc.wg.Done()
return n, err
} }
// insertChain is the internal implementation of InsertChain, which assumes that // insertChain is the internal implementation of InsertChain, which assumes that
@ -1628,10 +1658,11 @@ func (bc *BlockChain) InsertChainWithoutSealVerification(block *types.Block) (in
// is imported, but then new canon-head is added before the actual sidechain // is imported, but then new canon-head is added before the actual sidechain
// completes, then the historic state could be pruned again // completes, then the historic state could be pruned again
func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, error) { func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, error) {
// If the chain is terminating, don't even bother starting up // If the chain is terminating, don't even bother starting up.
if atomic.LoadInt32(&bc.procInterrupt) == 1 { if bc.insertStopped() {
return 0, nil return 0, nil
} }
// Start a parallel signature recovery (signer will fluke on fork transition, minimal perf loss) // Start a parallel signature recovery (signer will fluke on fork transition, minimal perf loss)
senderCacher.recoverFromBlocks(types.MakeSigner(bc.chainConfig, chain[0].Number()), chain) senderCacher.recoverFromBlocks(types.MakeSigner(bc.chainConfig, chain[0].Number()), chain)
@ -1666,8 +1697,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
// First block (and state) is known // First block (and state) is known
// 1. We did a roll-back, and should now do a re-import // 1. We did a roll-back, and should now do a re-import
// 2. The block is stored as a sidechain, and is lying about it's stateroot, and passes a stateroot // 2. The block is stored as a sidechain, and is lying about it's stateroot, and passes a stateroot
// from the canonical chain, which has not been verified. // from the canonical chain, which has not been verified.
// Skip all known blocks that are behind us // Skip all known blocks that are behind us.
var ( var (
current = bc.CurrentBlock() current = bc.CurrentBlock()
localTd = bc.GetTd(current.Hash(), current.NumberU64()) localTd = bc.GetTd(current.Hash(), current.NumberU64())
@ -1791,9 +1822,9 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
lastCanon = block lastCanon = block
continue continue
} }
// Retrieve the parent block and it's state to execute on top // Retrieve the parent block and it's state to execute on top
start := time.Now() start := time.Now()
parent := it.previous() parent := it.previous()
if parent == nil { if parent == nil {
parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1) parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1)
@ -1802,6 +1833,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
if err != nil { if err != nil {
return it.index, err return it.index, err
} }
// Enable prefetching to pull in trie node paths while processing transactions // Enable prefetching to pull in trie node paths while processing transactions
statedb.StartPrefetcher("chain") statedb.StartPrefetcher("chain")
activeState = statedb activeState = statedb
@ -1823,6 +1855,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
}(time.Now(), followup, throwaway, &followupInterrupt) }(time.Now(), followup, throwaway, &followupInterrupt)
} }
} }
// Process block using the parent state as reference point // Process block using the parent state as reference point
substart := time.Now() substart := time.Now()
receipts, logs, usedGas, err := bc.processor.Process(block, statedb, bc.vmConfig) receipts, logs, usedGas, err := bc.processor.Process(block, statedb, bc.vmConfig)
@ -1831,6 +1864,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
atomic.StoreUint32(&followupInterrupt, 1) atomic.StoreUint32(&followupInterrupt, 1)
return it.index, err return it.index, err
} }
// Update the metrics touched during block processing // Update the metrics touched during block processing
accountReadTimer.Update(statedb.AccountReads) // Account reads are complete, we can mark them accountReadTimer.Update(statedb.AccountReads) // Account reads are complete, we can mark them
storageReadTimer.Update(statedb.StorageReads) // Storage reads are complete, we can mark them storageReadTimer.Update(statedb.StorageReads) // Storage reads are complete, we can mark them
@ -1906,6 +1940,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
dirty, _ := bc.stateCache.TrieDB().Size() dirty, _ := bc.stateCache.TrieDB().Size()
stats.report(chain, it.index, dirty) stats.report(chain, it.index, dirty)
} }
// Any blocks remaining here? The only ones we care about are the future ones // Any blocks remaining here? The only ones we care about are the future ones
if block != nil && errors.Is(err, consensus.ErrFutureBlock) { if block != nil && errors.Is(err, consensus.ErrFutureBlock) {
if err := bc.addFutureBlock(block); err != nil { if err := bc.addFutureBlock(block); err != nil {
@ -2215,7 +2250,10 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
return nil return nil
} }
func (bc *BlockChain) update() { // futureBlocksLoop processes the 'future block' queue.
func (bc *BlockChain) futureBlocksLoop() {
defer bc.wg.Done()
futureTimer := time.NewTicker(5 * time.Second) futureTimer := time.NewTicker(5 * time.Second)
defer futureTimer.Stop() defer futureTimer.Stop()
for { for {
@ -2252,6 +2290,7 @@ func (bc *BlockChain) maintainTxIndex(ancients uint64) {
} }
rawdb.IndexTransactions(bc.db, from, ancients, bc.quit) rawdb.IndexTransactions(bc.db, from, ancients, bc.quit)
} }
// indexBlocks reindexes or unindexes transactions depending on user configuration // indexBlocks reindexes or unindexes transactions depending on user configuration
indexBlocks := func(tail *uint64, head uint64, done chan struct{}) { indexBlocks := func(tail *uint64, head uint64, done chan struct{}) {
defer func() { done <- struct{}{} }() defer func() { done <- struct{}{} }()
@ -2284,6 +2323,7 @@ func (bc *BlockChain) maintainTxIndex(ancients uint64) {
rawdb.UnindexTransactions(bc.db, *tail, head-bc.txLookupLimit+1, bc.quit) rawdb.UnindexTransactions(bc.db, *tail, head-bc.txLookupLimit+1, bc.quit)
} }
} }
// Any reindexing done, start listening to chain events and moving the index window // Any reindexing done, start listening to chain events and moving the index window
var ( var (
done chan struct{} // Non-nil if background unindexing or reindexing routine is active. done chan struct{} // Non-nil if background unindexing or reindexing routine is active.
@ -2351,12 +2391,10 @@ func (bc *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) (i
return i, err return i, err
} }
// Make sure only one thread manipulates the chain at once if !bc.chainmu.TryLock() {
bc.chainmu.Lock() return 0, errChainStopped
}
defer bc.chainmu.Unlock() defer bc.chainmu.Unlock()
bc.wg.Add(1)
defer bc.wg.Done()
_, err := bc.hc.InsertHeaderChain(chain, start) _, err := bc.hc.InsertHeaderChain(chain, start)
return 0, err return 0, err
} }

@ -163,7 +163,8 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
blockchain.reportBlock(block, receipts, err) blockchain.reportBlock(block, receipts, err)
return err return err
} }
blockchain.chainmu.Lock()
blockchain.chainmu.MustLock()
rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
rawdb.WriteBlock(blockchain.db, block) rawdb.WriteBlock(blockchain.db, block)
statedb.Commit(false) statedb.Commit(false)
@ -181,7 +182,7 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error
return err return err
} }
// Manually insert the header into the database, but don't reorganise (allows subsequent testing) // Manually insert the header into the database, but don't reorganise (allows subsequent testing)
blockchain.chainmu.Lock() blockchain.chainmu.MustLock()
rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
rawdb.WriteHeader(blockchain.db, header) rawdb.WriteHeader(blockchain.db, header)
blockchain.chainmu.Unlock() blockchain.chainmu.Unlock()

64
internal/syncx/mutex.go Normal file

@ -0,0 +1,64 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package syncx contains exotic synchronization primitives.
package syncx
// ClosableMutex is a mutex that can also be closed.
// Once closed, it can never be taken again.
type ClosableMutex struct {
ch chan struct{}
}
func NewClosableMutex() *ClosableMutex {
ch := make(chan struct{}, 1)
ch <- struct{}{}
return &ClosableMutex{ch}
}
// TryLock attempts to lock cm.
// If the mutex is closed, TryLock returns false.
func (cm *ClosableMutex) TryLock() bool {
_, ok := <-cm.ch
return ok
}
// MustLock locks cm.
// If the mutex is closed, MustLock panics.
func (cm *ClosableMutex) MustLock() {
_, ok := <-cm.ch
if !ok {
panic("mutex closed")
}
}
// Unlock unlocks cm.
func (cm *ClosableMutex) Unlock() {
select {
case cm.ch <- struct{}{}:
default:
panic("Unlock of already-unlocked ClosableMutex")
}
}
// Close locks the mutex, then closes it.
func (cm *ClosableMutex) Close() {
_, ok := <-cm.ch
if !ok {
panic("Close of already-closed ClosableMutex")
}
close(cm.ch)
}