trie: improve node rlp decoding performance (#25357)

This avoids copying the input []byte while decoding trie nodes. In most
cases, particularly when the input slice is provided by the underlying
database, this optimization is safe to use.

For cases where the origin of the input slice is unclear, the copying version
is retained. The new code performs better even when the input must be
copied, because it is now only copied once in decodeNode.
This commit is contained in:
rjl493456442 2022-08-19 06:39:47 +08:00 committed by GitHub
parent cce7f08438
commit a1b8892384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 158 additions and 7 deletions

@ -163,7 +163,10 @@ func (n *cachedNode) rlp() []byte {
// or by regenerating it from the rlp encoded blob. // or by regenerating it from the rlp encoded blob.
func (n *cachedNode) obj(hash common.Hash) node { func (n *cachedNode) obj(hash common.Hash) node {
if node, ok := n.node.(rawNode); ok { if node, ok := n.node.(rawNode); ok {
return mustDecodeNode(hash[:], node) // The raw-blob format nodes are loaded from either from
// clean cache or the database, they are all in their own
// copy and safe to use unsafe decoder.
return mustDecodeNodeUnsafe(hash[:], node)
} }
return expandNode(hash[:], n.node) return expandNode(hash[:], n.node)
} }
@ -346,7 +349,10 @@ func (db *Database) node(hash common.Hash) node {
if enc := db.cleans.Get(nil, hash[:]); enc != nil { if enc := db.cleans.Get(nil, hash[:]); enc != nil {
memcacheCleanHitMeter.Mark(1) memcacheCleanHitMeter.Mark(1)
memcacheCleanReadMeter.Mark(int64(len(enc))) memcacheCleanReadMeter.Mark(int64(len(enc)))
return mustDecodeNode(hash[:], enc)
// The returned value from cache is in its own copy,
// safe to use mustDecodeNodeUnsafe for decoding.
return mustDecodeNodeUnsafe(hash[:], enc)
} }
} }
// Retrieve the node from the dirty cache if available // Retrieve the node from the dirty cache if available
@ -371,7 +377,9 @@ func (db *Database) node(hash common.Hash) node {
memcacheCleanMissMeter.Mark(1) memcacheCleanMissMeter.Mark(1)
memcacheCleanWriteMeter.Mark(int64(len(enc))) memcacheCleanWriteMeter.Mark(int64(len(enc)))
} }
return mustDecodeNode(hash[:], enc) // The returned value from database is in its own copy,
// safe to use mustDecodeNodeUnsafe for decoding.
return mustDecodeNodeUnsafe(hash[:], enc)
} }
// Node retrieves an encoded cached trie node from memory. If it cannot be found // Node retrieves an encoded cached trie node from memory. If it cannot be found

@ -99,6 +99,7 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n)) return fmt.Sprintf("%x ", []byte(n))
} }
// mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered.
func mustDecodeNode(hash, buf []byte) node { func mustDecodeNode(hash, buf []byte) node {
n, err := decodeNode(hash, buf) n, err := decodeNode(hash, buf)
if err != nil { if err != nil {
@ -107,8 +108,29 @@ func mustDecodeNode(hash, buf []byte) node {
return n return n
} }
// decodeNode parses the RLP encoding of a trie node. // mustDecodeNodeUnsafe is a wrapper of decodeNodeUnsafe and panic if any error is
// encountered.
func mustDecodeNodeUnsafe(hash, buf []byte) node {
n, err := decodeNodeUnsafe(hash, buf)
if err != nil {
panic(fmt.Sprintf("node %x: %v", hash, err))
}
return n
}
// decodeNode parses the RLP encoding of a trie node. It will deep-copy the passed
// byte slice for decoding, so it's safe to modify the byte slice afterwards. The-
// decode performance of this function is not optimal, but it is suitable for most
// scenarios with low performance requirements and hard to determine whether the
// byte slice be modified or not.
func decodeNode(hash, buf []byte) (node, error) { func decodeNode(hash, buf []byte) (node, error) {
return decodeNodeUnsafe(hash, common.CopyBytes(buf))
}
// decodeNodeUnsafe parses the RLP encoding of a trie node. The passed byte slice
// will be directly referenced by node without bytes deep copy, so the input MUST
// not be changed after.
func decodeNodeUnsafe(hash, buf []byte) (node, error) {
if len(buf) == 0 { if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
@ -141,7 +163,7 @@ func decodeShort(hash, elems []byte) (node, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid value node: %v", err) return nil, fmt.Errorf("invalid value node: %v", err)
} }
return &shortNode{key, append(valueNode{}, val...), flag}, nil return &shortNode{key, valueNode(val), flag}, nil
} }
r, _, err := decodeRef(rest) r, _, err := decodeRef(rest)
if err != nil { if err != nil {
@ -164,7 +186,7 @@ func decodeFull(hash, elems []byte) (*fullNode, error) {
return n, err return n, err
} }
if len(val) > 0 { if len(val) > 0 {
n.Children[16] = append(valueNode{}, val...) n.Children[16] = valueNode(val)
} }
return n, nil return n, nil
} }
@ -190,7 +212,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
// empty node // empty node
return nil, rest, nil return nil, rest, nil
case kind == rlp.String && len(val) == 32: case kind == rlp.String && len(val) == 32:
return append(hashNode{}, val...), rest, nil return hashNode(val), rest, nil
default: default:
return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val)) return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
} }

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -92,3 +93,123 @@ func TestDecodeFullNode(t *testing.T) {
t.Fatalf("decode full node err: %v", err) t.Fatalf("decode full node err: %v", err)
} }
} }
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkEncodeShortNode
// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
func BenchmarkEncodeShortNode(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
nodeToBytes(node)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkEncodeFullNode
// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
func BenchmarkEncodeFullNode(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
nodeToBytes(node)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeShortNode
// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
func BenchmarkDecodeShortNode(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNode(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeShortNodeUnsafe
// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNodeUnsafe(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeFullNode
// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
func BenchmarkDecodeFullNode(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNode(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeFullNodeUnsafe
// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNodeUnsafe(hash, blob)
}
}