diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 5f58a234c..bfb081191 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -25,10 +25,11 @@ import ( "math/rand" "reflect" "strings" + "sync" "testing" "testing/quick" - check "gopkg.in/check.v1" + "gopkg.in/check.v1" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -53,8 +54,13 @@ func TestUpdateLeaks(t *testing.T) { if i%3 == 0 { state.SetCode(addr, []byte{i, i, i, i, i}) } - state.IntermediateRoot(false) } + + root := state.IntermediateRoot(false) + if err := state.Database().TrieDB().Commit(root, false); err != nil { + t.Errorf("can not commit trie %v to persistent database", root.Hex()) + } + // Ensure that no data was leaked into the database it := db.NewIterator() for it.Next() { @@ -98,27 +104,45 @@ func TestIntermediateLeaks(t *testing.T) { } // Commit and cross check the databases. - if _, err := transState.Commit(false); err != nil { + transRoot, err := transState.Commit(false) + if err != nil { t.Fatalf("failed to commit transition state: %v", err) } - if _, err := finalState.Commit(false); err != nil { + if err = transState.Database().TrieDB().Commit(transRoot, false); err != nil { + t.Errorf("can not commit trie %v to persistent database", transRoot.Hex()) + } + + finalRoot, err := finalState.Commit(false) + if err != nil { t.Fatalf("failed to commit final state: %v", err) } + if err = finalState.Database().TrieDB().Commit(finalRoot, false); err != nil { + t.Errorf("can not commit trie %v to persistent database", finalRoot.Hex()) + } + it := finalDb.NewIterator() for it.Next() { - key := it.Key() - if _, err := transDb.Get(key); err != nil { - t.Errorf("entry missing from the transition database: %x -> %x", key, it.Value()) + key, fvalue := it.Key(), it.Value() + tvalue, err := transDb.Get(key) + if err != nil { + t.Errorf("entry missing from the transition database: %x -> %x", key, fvalue) + } + if !bytes.Equal(fvalue, tvalue) { + t.Errorf("the value associate key %x is mismatch,: %x in transition database ,%x in final database", key, tvalue, fvalue) } } it.Release() it = transDb.NewIterator() for it.Next() { - key := it.Key() - if _, err := finalDb.Get(key); err != nil { + key, tvalue := it.Key(), it.Value() + fvalue, err := finalDb.Get(key) + if err != nil { t.Errorf("extra entry in the transition database: %x -> %x", key, it.Value()) } + if !bytes.Equal(fvalue, tvalue) { + t.Errorf("the value associate key %x is mismatch,: %x in transition database ,%x in final database", key, tvalue, fvalue) + } } } @@ -136,32 +160,45 @@ func TestCopy(t *testing.T) { } orig.Finalise(false) - // Copy the state, modify both in-memory + // Copy the state copy := orig.Copy() + // Copy the copy state + ccopy := copy.Copy() + + // modify all in memory for i := byte(0); i < 255; i++ { origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) origObj.AddBalance(big.NewInt(2 * int64(i))) copyObj.AddBalance(big.NewInt(3 * int64(i))) + ccopyObj.AddBalance(big.NewInt(4 * int64(i))) orig.updateStateObject(origObj) copy.updateStateObject(copyObj) + ccopy.updateStateObject(copyObj) } - // Finalise the changes on both concurrently - done := make(chan struct{}) - go func() { - orig.Finalise(true) - close(done) - }() - copy.Finalise(true) - <-done - // Verify that the two states have been updated independently + // Finalise the changes on all concurrently + finalise := func(wg *sync.WaitGroup, db *StateDB) { + defer wg.Done() + db.Finalise(true) + } + + var wg sync.WaitGroup + wg.Add(3) + go finalise(&wg, orig) + go finalise(&wg, copy) + go finalise(&wg, ccopy) + wg.Wait() + + // Verify that the three states have been updated independently for i := byte(0); i < 255; i++ { origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) @@ -169,6 +206,9 @@ func TestCopy(t *testing.T) { if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) } + if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want) + } } }