diff --git a/ptrie/fullnode.go b/ptrie/fullnode.go index 2b1a627891..eaa4611b66 100644 --- a/ptrie/fullnode.go +++ b/ptrie/fullnode.go @@ -14,6 +14,9 @@ func (self *FullNode) Value() Node { self.nodes[16] = self.trie.trans(self.nodes[16]) return self.nodes[16] } +func (self *FullNode) Branches() []Node { + return self.nodes[:16] +} func (self *FullNode) Copy() Node { return self } @@ -49,7 +52,7 @@ func (self *FullNode) set(k byte, value Node) { self.nodes[int(k)] = value } -func (self *FullNode) get(i byte) Node { +func (self *FullNode) branch(i byte) Node { if self.nodes[int(i)] != nil { self.nodes[int(i)] = self.trie.trans(self.nodes[int(i)]) diff --git a/ptrie/iterator.go b/ptrie/iterator.go new file mode 100644 index 0000000000..c6d4f64a01 --- /dev/null +++ b/ptrie/iterator.go @@ -0,0 +1,114 @@ +package ptrie + +import ( + "bytes" + + "github.com/ethereum/go-ethereum/trie" +) + +type Iterator struct { + trie *Trie + + Key []byte + Value []byte +} + +func NewIterator(trie *Trie) *Iterator { + return &Iterator{trie: trie, Key: []byte{0}} +} + +func (self *Iterator) Next() bool { + self.trie.mu.Lock() + defer self.trie.mu.Unlock() + + key := trie.RemTerm(trie.CompactHexDecode(string(self.Key))) + k := self.next(self.trie.root, key) + + self.Key = []byte(trie.DecodeCompact(k)) + + return len(k) > 0 + +} + +func (self *Iterator) next(node Node, key []byte) []byte { + if node == nil { + return nil + } + + switch node := node.(type) { + case *FullNode: + if len(key) > 0 { + k := self.next(node.branch(key[0]), key[1:]) + if k != nil { + return append([]byte{key[0]}, k...) + } + } + + var r byte + if len(key) > 0 { + r = key[0] + 1 + } + + for i := r; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{i}, k...) + } + } + + case *ShortNode: + k := trie.RemTerm(node.Key()) + if vnode, ok := node.Value().(*ValueNode); ok { + if bytes.Compare([]byte(k), key) > 0 { + self.Value = vnode.Val() + return k + } + } else { + cnode := node.Value() + skey := key[len(k):] + + var ret []byte + if trie.BeginsWith(key, k) { + ret = self.next(cnode, skey) + } else if bytes.Compare(k, key[:len(k)]) > 0 { + ret = self.key(node) + } + + if ret != nil { + return append(k, ret...) + } + } + } + + return nil +} + +func (self *Iterator) key(node Node) []byte { + switch node := node.(type) { + case *ShortNode: + // Leaf node + if vnode, ok := node.Value().(*ValueNode); ok { + k := trie.RemTerm(node.Key()) + self.Value = vnode.Val() + + return k + } else { + return self.key(node.Value()) + } + case *FullNode: + if node.Value() != nil { + self.Value = node.Value().(*ValueNode).Val() + + return []byte{16} + } + + for i := 0; i < 16; i++ { + k := self.key(node.branch(byte(i))) + if k != nil { + return append([]byte{byte(i)}, k...) + } + } + } + + return nil +} diff --git a/ptrie/iterator_test.go b/ptrie/iterator_test.go new file mode 100644 index 0000000000..8921bb6708 --- /dev/null +++ b/ptrie/iterator_test.go @@ -0,0 +1,28 @@ +package ptrie + +import "testing" + +func TestIterator(t *testing.T) { + trie := NewEmpty() + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + } + v := make(map[string]bool) + for _, val := range vals { + v[val.k] = false + trie.UpdateString(val.k, val.v) + } + + it := trie.Iterator() + for it.Next() { + v[string(it.Key)] = true + } + + for k, found := range v { + if !found { + t.Error("iterator didn't find", k) + } + } +} diff --git a/ptrie/trie.go b/ptrie/trie.go index 207aad91e2..bb2b3845ad 100644 --- a/ptrie/trie.go +++ b/ptrie/trie.go @@ -45,6 +45,10 @@ func New(root []byte, backend Backend) *Trie { return trie } +func (self *Trie) Iterator() *Iterator { + return NewIterator(self) +} + // Legacy support func (self *Trie) Root() []byte { return self.Hash() } func (self *Trie) Hash() []byte { @@ -144,7 +148,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { case *FullNode: cpy := node.Copy().(*FullNode) - cpy.set(key[0], self.insert(node.get(key[0]), key[1:], value)) + cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) return cpy @@ -173,7 +177,7 @@ func (self *Trie) get(node Node, key []byte) Node { return nil case *FullNode: - return self.get(node.get(key[0]), key[1:]) + return self.get(node.branch(key[0]), key[1:]) default: panic("Invalid node") } @@ -209,11 +213,11 @@ func (self *Trie) delete(node Node, key []byte) Node { case *FullNode: n := node.Copy().(*FullNode) - n.set(key[0], self.delete(n.get(key[0]), key[1:])) + n.set(key[0], self.delete(n.branch(key[0]), key[1:])) pos := -1 for i := 0; i < 17; i++ { - if n.get(byte(i)) != nil { + if n.branch(byte(i)) != nil { if pos == -1 { pos = i } else { @@ -224,16 +228,16 @@ func (self *Trie) delete(node Node, key []byte) Node { var nnode Node if pos == 16 { - nnode = NewShortNode(self, []byte{16}, n.get(byte(pos))) + nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) } else if pos >= 0 { - cnode := n.get(byte(pos)) + cnode := n.branch(byte(pos)) switch cnode := cnode.(type) { case *ShortNode: // Stitch keys k := append([]byte{byte(pos)}, cnode.Key()...) nnode = NewShortNode(self, k, cnode.Value()) case *FullNode: - nnode = NewShortNode(self, []byte{byte(pos)}, n.get(byte(pos))) + nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) } } else { nnode = n diff --git a/ptrie/trie_test.go b/ptrie/trie_test.go index 6cdd2bde49..6af6e1b406 100644 --- a/ptrie/trie_test.go +++ b/ptrie/trie_test.go @@ -139,6 +139,8 @@ func BenchmarkUpdate(b *testing.B) { // Not actual test func TestOutput(t *testing.T) { + t.Skip() + base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" trie := NewEmpty() for i := 0; i < 50; i++ { diff --git a/state/state_object.go b/state/state_object.go index 729e32ae40..f02d1b5abc 100644 --- a/state/state_object.go +++ b/state/state_object.go @@ -148,9 +148,7 @@ func (self *StateObject) EachStorage(cb trie.EachCallback) { func (self *StateObject) Sync() { for key, value := range self.storage { - if value.Len() == 0 { // value.BigInt().Cmp(ethutil.Big0) == 0 { - //data := self.getStorage([]byte(key)) - //fmt.Printf("deleting %x %x 0x%x\n", self.Address(), []byte(key), data) + if value.Len() == 0 { self.State.Trie.Delete(string(key)) continue }