diff --git a/trie/hasher.go b/trie/hasher.go index 57e156ebf..b6223bf32 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) if err != nil { return hashNode{}, n, err } - // Cache the hash of the ndoe for later reuse. - if hash, ok := hashed.(hashNode); ok && !force { - switch cached := cached.(type) { - case *shortNode: - cached = cached.copy() - cached.flags.hash = hash - if db != nil { - cached.flags.dirty = false - } - return hashed, cached, nil - case *fullNode: - cached = cached.copy() - cached.flags.hash = hash - if db != nil { - cached.flags.dirty = false - } - return hashed, cached, nil + // Cache the hash of the ndoe for later reuse and remove + // the dirty flag in commit mode. It's fine to assign these values directly + // without copying the node first because hashChildren copies it. + cachedHash, _ := hashed.(hashNode) + switch cn := cached.(type) { + case *shortNode: + cn.flags.hash = cachedHash + if db != nil { + cn.flags.dirty = false + } + case *fullNode: + cn.flags.hash = cachedHash + if db != nil { + cn.flags.dirty = false } } return hashed, cached, nil diff --git a/trie/trie_test.go b/trie/trie_test.go index da0d2360b..14ac5a666 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -462,31 +462,44 @@ func runRandTest(rt randTest) bool { return false } case opCheckCacheInvariant: - return checkCacheInvariant(tr.root, tr.cachegen, 0) + return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0) } } return true } -func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool { +func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool { + var children []node + var flag nodeFlag switch n := n.(type) { case *shortNode: - if n.flags.gen > parentCachegen { - fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n)) - return false - } - return checkCacheInvariant(n.Val, n.flags.gen, depth+1) + flag = n.flags + children = []node{n.Val} case *fullNode: - if n.flags.gen > parentCachegen { - fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n)) + flag = n.flags + children = n.Children[:] + default: + return true + } + + showerror := func() { + fmt.Printf("at depth %d node %s", depth, spew.Sdump(n)) + fmt.Printf("parent: %s", spew.Sdump(parent)) + } + if flag.gen > parentCachegen { + fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen) + showerror() + return false + } + if depth > 0 && !parentDirty && flag.dirty { + fmt.Printf("cache invariant violation: child is dirty but parent isn't\n") + showerror() + return false + } + for _, child := range children { + if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) { return false } - for _, child := range n.Children { - if !checkCacheInvariant(child, n.flags.gen, depth+1) { - return false - } - } - return true } return true }