bsc/trie/trie.go

353 lines
7.0 KiB
Go
Raw Normal View History

2014-10-31 15:45:03 +02:00
package trie
2014-02-15 00:56:09 +02:00
import (
2014-07-02 18:47:18 +03:00
"bytes"
2015-01-08 12:47:04 +02:00
"container/list"
2014-02-15 00:56:09 +02:00
"fmt"
2014-05-21 02:12:28 +03:00
"sync"
2014-08-04 11:38:18 +03:00
2015-03-16 12:27:38 +02:00
"github.com/ethereum/go-ethereum/common"
2015-03-19 11:57:02 +02:00
"github.com/ethereum/go-ethereum/crypto"
2014-02-15 00:56:09 +02:00
)
2015-01-08 12:47:04 +02:00
func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
t2 := New(nil, backend)
2014-07-02 18:47:18 +03:00
2015-01-08 12:47:04 +02:00
it := t1.Iterator()
for it.Next() {
t2.Update(it.Key, it.Value)
}
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
return bytes.Equal(t2.Hash(), t1.Hash()), t2
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
type Trie struct {
mu sync.Mutex
root Node
roothash []byte
cache *Cache
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
revisions *list.List
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
func New(root []byte, backend Backend) *Trie {
trie := &Trie{}
trie.revisions = list.New()
trie.roothash = root
if backend != nil {
trie.cache = NewCache(backend)
}
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
if root != nil {
2015-03-16 12:27:38 +02:00
value := common.NewValueFromBytes(trie.cache.Get(root))
2015-01-08 12:47:04 +02:00
trie.root = trie.mknode(value)
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
return trie
2014-06-30 14:08:00 +03:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) Iterator() *Iterator {
return NewIterator(self)
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) Copy() *Trie {
cpy := make([]byte, 32)
copy(cpy, self.roothash)
trie := New(nil, nil)
trie.cache = self.cache.Copy()
if self.root != nil {
trie.root = self.root.Copy(trie)
}
return trie
2014-02-24 13:11:00 +02:00
}
2015-01-08 12:47:04 +02:00
// Legacy support
func (self *Trie) Root() []byte { return self.Hash() }
func (self *Trie) Hash() []byte {
var hash []byte
if self.root != nil {
t := self.root.Hash()
if byts, ok := t.([]byte); ok && len(byts) > 0 {
hash = byts
} else {
2015-03-16 12:27:38 +02:00
hash = crypto.Sha3(common.Encode(self.root.RlpData()))
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
} else {
2015-03-16 12:27:38 +02:00
hash = crypto.Sha3(common.Encode(""))
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
if !bytes.Equal(hash, self.roothash) {
self.revisions.PushBack(self.roothash)
self.roothash = hash
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
return hash
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) Commit() {
self.mu.Lock()
defer self.mu.Unlock()
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
// Hash first
self.Hash()
2014-04-29 13:36:27 +03:00
2015-01-08 12:47:04 +02:00
self.cache.Flush()
2014-04-29 13:36:27 +03:00
}
2015-01-08 12:47:04 +02:00
// Reset should only be called if the trie has been hashed
func (self *Trie) Reset() {
self.mu.Lock()
defer self.mu.Unlock()
2014-04-29 13:36:27 +03:00
2015-01-08 12:47:04 +02:00
self.cache.Reset()
2014-10-29 15:20:42 +02:00
2015-01-08 12:47:04 +02:00
if self.revisions.Len() > 0 {
revision := self.revisions.Remove(self.revisions.Back()).([]byte)
self.roothash = revision
2014-10-29 15:20:42 +02:00
}
2015-03-16 12:27:38 +02:00
value := common.NewValueFromBytes(self.cache.Get(self.roothash))
2015-01-08 12:47:04 +02:00
self.root = self.mknode(value)
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
func (self *Trie) Update(key, value []byte) Node {
self.mu.Lock()
defer self.mu.Unlock()
2014-05-21 02:12:28 +03:00
2015-01-08 12:47:04 +02:00
k := CompactHexDecode(string(key))
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
if len(value) != 0 {
self.root = self.insert(self.root, k, &ValueNode{self, value})
} else {
2015-01-08 12:47:04 +02:00
self.root = self.delete(self.root, k)
}
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
return self.root
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
func (self *Trie) Get(key []byte) []byte {
self.mu.Lock()
defer self.mu.Unlock()
2014-02-20 15:40:00 +02:00
2015-01-08 12:47:04 +02:00
k := CompactHexDecode(string(key))
2014-07-02 14:40:02 +03:00
2015-01-08 12:47:04 +02:00
n := self.get(self.root, k)
if n != nil {
return n.(*ValueNode).Val()
2014-07-02 14:40:02 +03:00
}
2015-01-08 12:47:04 +02:00
return nil
2014-07-02 14:40:02 +03:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
func (self *Trie) Delete(key []byte) Node {
self.mu.Lock()
defer self.mu.Unlock()
2014-07-02 14:40:02 +03:00
2015-01-08 12:47:04 +02:00
k := CompactHexDecode(string(key))
self.root = self.delete(self.root, k)
2014-07-02 14:40:02 +03:00
2015-01-08 12:47:04 +02:00
return self.root
2014-07-02 14:40:02 +03:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) insert(node Node, key []byte, value Node) Node {
if len(key) == 0 {
return value
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
if node == nil {
return NewShortNode(self, key, value)
}
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
switch node := node.(type) {
case *ShortNode:
k := node.Key()
cnode := node.Value()
if bytes.Equal(k, key) {
return NewShortNode(self, key, value)
}
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
var n Node
matchlength := MatchingNibbleLength(key, k)
if matchlength == len(k) {
n = self.insert(cnode, key[matchlength:], value)
2014-02-15 00:56:09 +02:00
} else {
2015-01-08 12:47:04 +02:00
pnode := self.insert(nil, k[matchlength+1:], cnode)
nnode := self.insert(nil, key[matchlength+1:], value)
fulln := NewFullNode(self)
fulln.set(k[matchlength], pnode)
fulln.set(key[matchlength], nnode)
n = fulln
}
if matchlength == 0 {
return n
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
return NewShortNode(self, key[:matchlength], n)
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
case *FullNode:
cpy := node.Copy(self).(*FullNode)
2015-01-08 12:47:04 +02:00
cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
return cpy
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
default:
panic(fmt.Sprintf("%T: invalid node: %v", node, node))
2014-02-15 00:56:09 +02:00
}
}
2015-01-08 12:47:04 +02:00
func (self *Trie) get(node Node, key []byte) Node {
2014-02-15 00:56:09 +02:00
if len(key) == 0 {
2015-01-08 12:47:04 +02:00
return node
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
if node == nil {
return nil
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
switch node := node.(type) {
case *ShortNode:
k := node.Key()
cnode := node.Value()
2014-02-15 00:56:09 +02:00
2015-01-08 12:47:04 +02:00
if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
return self.get(cnode, key[len(k):])
2014-02-15 00:56:09 +02:00
}
2015-01-08 12:47:04 +02:00
return nil
case *FullNode:
return self.get(node.branch(key[0]), key[1:])
default:
panic(fmt.Sprintf("%T: invalid node: %v", node, node))
2014-02-15 00:56:09 +02:00
}
}
2015-01-08 12:47:04 +02:00
func (self *Trie) delete(node Node, key []byte) Node {
if len(key) == 0 && node == nil {
return nil
2014-02-20 15:40:00 +02:00
}
2015-01-08 12:47:04 +02:00
switch node := node.(type) {
case *ShortNode:
k := node.Key()
cnode := node.Value()
if bytes.Equal(key, k) {
return nil
} else if bytes.Equal(key[:len(k)], k) {
child := self.delete(cnode, key[len(k):])
var n Node
switch child := child.(type) {
case *ShortNode:
nkey := append(k, child.Key()...)
n = NewShortNode(self, nkey, child.Value())
case *FullNode:
sn := NewShortNode(self, node.Key(), child)
sn.key = node.key
n = sn
2014-02-20 15:40:00 +02:00
}
2015-01-08 12:47:04 +02:00
return n
2014-02-20 15:40:00 +02:00
} else {
return node
}
2015-01-08 12:47:04 +02:00
case *FullNode:
n := node.Copy(self).(*FullNode)
2015-01-08 12:47:04 +02:00
n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
2014-02-20 15:40:00 +02:00
2015-01-08 12:47:04 +02:00
pos := -1
2014-02-20 15:40:00 +02:00
for i := 0; i < 17; i++ {
2015-01-08 12:47:04 +02:00
if n.branch(byte(i)) != nil {
if pos == -1 {
pos = i
2014-02-20 15:40:00 +02:00
} else {
2015-01-08 12:47:04 +02:00
pos = -2
2014-02-20 15:40:00 +02:00
}
}
}
2015-01-08 12:47:04 +02:00
var nnode Node
if pos == 16 {
nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
} else if pos >= 0 {
cnode := n.branch(byte(pos))
switch cnode := cnode.(type) {
case *ShortNode:
// Stitch keys
k := append([]byte{byte(pos)}, cnode.Key()...)
nnode = NewShortNode(self, k, cnode.Value())
case *FullNode:
nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
}
2014-02-20 15:40:00 +02:00
} else {
2015-01-08 12:47:04 +02:00
nnode = n
2014-02-20 15:40:00 +02:00
}
2015-01-08 12:47:04 +02:00
return nnode
case nil:
return nil
default:
panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
2014-02-20 15:40:00 +02:00
}
2014-02-24 13:11:00 +02:00
}
2015-01-08 12:47:04 +02:00
// casting functions and cache storing
2015-03-16 12:27:38 +02:00
func (self *Trie) mknode(value *common.Value) Node {
2015-01-08 12:47:04 +02:00
l := value.Len()
switch l {
case 0:
return nil
case 2:
// A value node may consists of 2 bytes.
if value.Get(0).Len() != 0 {
2015-03-19 15:31:14 +02:00
key := CompactDecode(string(value.Get(0).Bytes()))
if key[len(key)-1] == 16 {
return NewShortNode(self, key, &ValueNode{self, value.Get(1).Bytes()})
} else {
return NewShortNode(self, key, self.mknode(value.Get(1)))
}
2014-02-24 13:11:00 +02:00
}
2015-01-08 12:47:04 +02:00
case 17:
2015-03-19 11:57:02 +02:00
if len(value.Bytes()) != 17 {
fnode := NewFullNode(self)
2015-03-19 15:31:14 +02:00
for i := 0; i < 16; i++ {
2015-03-19 11:57:02 +02:00
fnode.set(byte(i), self.mknode(value.Get(i)))
}
return fnode
2014-02-24 13:11:00 +02:00
}
2015-01-08 12:47:04 +02:00
case 32:
return &HashNode{value.Bytes(), self}
2014-02-24 13:11:00 +02:00
}
2015-01-08 12:47:04 +02:00
return &ValueNode{self, value.Bytes()}
2014-02-24 13:11:00 +02:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) trans(node Node) Node {
switch node := node.(type) {
case *HashNode:
2015-03-16 12:27:38 +02:00
value := common.NewValueFromBytes(self.cache.Get(node.key))
2015-01-08 12:47:04 +02:00
return self.mknode(value)
default:
return node
2014-02-24 13:11:00 +02:00
}
}
2015-01-08 12:47:04 +02:00
func (self *Trie) store(node Node) interface{} {
2015-03-16 12:27:38 +02:00
data := common.Encode(node)
2015-01-08 12:47:04 +02:00
if len(data) >= 32 {
key := crypto.Sha3(data)
self.cache.Put(key, data)
2014-05-27 02:08:51 +03:00
2015-01-08 12:47:04 +02:00
return key
}
2014-05-27 02:08:51 +03:00
2015-01-08 12:47:04 +02:00
return node.RlpData()
2014-05-27 02:08:51 +03:00
}
2015-01-08 12:47:04 +02:00
func (self *Trie) PrintRoot() {
fmt.Println(self.root)
fmt.Printf("root=%x\n", self.Root())
2014-05-27 02:08:51 +03:00
}