accounts/keystore: fix double import race (#20915)
* accounts/keystore: fix race in Import/ImportECDSA * accounts/keystore: added import/export tests * cmd/geth: improved TestAccountImport test * accounts/keystore: added import/export tests * accounts/keystore: fixed naming * accounts/keystore: fixed typo * accounts/keystore: use mutex instead of rwmutex * accounts: use errors instead of fmt
This commit is contained in:
parent
2ec7232191
commit
38aab0aa83
@ -24,7 +24,6 @@ import (
|
||||
"crypto/ecdsa"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -67,7 +66,8 @@ type KeyStore struct {
|
||||
updateScope event.SubscriptionScope // Subscription scope tracking current live listeners
|
||||
updating bool // Whether the event notification loop is running
|
||||
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
importMu sync.Mutex // Import Mutex locks the import to prevent two insertions from racing
|
||||
}
|
||||
|
||||
type unlocked struct {
|
||||
@ -443,14 +443,21 @@ func (ks *KeyStore) Import(keyJSON []byte, passphrase, newPassphrase string) (ac
|
||||
if err != nil {
|
||||
return accounts.Account{}, err
|
||||
}
|
||||
ks.importMu.Lock()
|
||||
defer ks.importMu.Unlock()
|
||||
if ks.cache.hasAddress(key.Address) {
|
||||
return accounts.Account{}, errors.New("account already exists")
|
||||
}
|
||||
return ks.importKey(key, newPassphrase)
|
||||
}
|
||||
|
||||
// ImportECDSA stores the given key into the key directory, encrypting it with the passphrase.
|
||||
func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (accounts.Account, error) {
|
||||
key := newKeyFromECDSA(priv)
|
||||
ks.importMu.Lock()
|
||||
defer ks.importMu.Unlock()
|
||||
if ks.cache.hasAddress(key.Address) {
|
||||
return accounts.Account{}, fmt.Errorf("account already exists")
|
||||
return accounts.Account{}, errors.New("account already exists")
|
||||
}
|
||||
return ks.importKey(key, passphrase)
|
||||
}
|
||||
|
@ -23,11 +23,14 @@ import (
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/accounts"
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/event"
|
||||
)
|
||||
|
||||
@ -338,6 +341,88 @@ func TestWalletNotifications(t *testing.T) {
|
||||
checkEvents(t, wantEvents, events)
|
||||
}
|
||||
|
||||
// TestImportExport tests the import functionality of a keystore.
|
||||
func TestImportECDSA(t *testing.T) {
|
||||
dir, ks := tmpKeyStore(t, true)
|
||||
defer os.RemoveAll(dir)
|
||||
key, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %v", key)
|
||||
}
|
||||
if _, err = ks.ImportECDSA(key, "old"); err != nil {
|
||||
t.Errorf("importing failed: %v", err)
|
||||
}
|
||||
if _, err = ks.ImportECDSA(key, "old"); err == nil {
|
||||
t.Errorf("importing same key twice succeeded")
|
||||
}
|
||||
if _, err = ks.ImportECDSA(key, "new"); err == nil {
|
||||
t.Errorf("importing same key twice succeeded")
|
||||
}
|
||||
}
|
||||
|
||||
// TestImportECDSA tests the import and export functionality of a keystore.
|
||||
func TestImportExport(t *testing.T) {
|
||||
dir, ks := tmpKeyStore(t, true)
|
||||
defer os.RemoveAll(dir)
|
||||
acc, err := ks.NewAccount("old")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account: %v", acc)
|
||||
}
|
||||
json, err := ks.Export(acc, "old", "new")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to export account: %v", acc)
|
||||
}
|
||||
dir2, ks2 := tmpKeyStore(t, true)
|
||||
defer os.RemoveAll(dir2)
|
||||
if _, err = ks2.Import(json, "old", "old"); err == nil {
|
||||
t.Errorf("importing with invalid password succeeded")
|
||||
}
|
||||
acc2, err := ks2.Import(json, "new", "new")
|
||||
if err != nil {
|
||||
t.Errorf("importing failed: %v", err)
|
||||
}
|
||||
if acc.Address != acc2.Address {
|
||||
t.Error("imported account does not match exported account")
|
||||
}
|
||||
if _, err = ks2.Import(json, "new", "new"); err == nil {
|
||||
t.Errorf("importing a key twice succeeded")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TestImportRace tests the keystore on races.
|
||||
// This test should fail under -race if importing races.
|
||||
func TestImportRace(t *testing.T) {
|
||||
dir, ks := tmpKeyStore(t, true)
|
||||
defer os.RemoveAll(dir)
|
||||
acc, err := ks.NewAccount("old")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account: %v", acc)
|
||||
}
|
||||
json, err := ks.Export(acc, "old", "new")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to export account: %v", acc)
|
||||
}
|
||||
dir2, ks2 := tmpKeyStore(t, true)
|
||||
defer os.RemoveAll(dir2)
|
||||
var atom uint32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := ks2.Import(json, "new", "new"); err != nil {
|
||||
atomic.AddUint32(&atom, 1)
|
||||
}
|
||||
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if atom != 1 {
|
||||
t.Errorf("Import is racy")
|
||||
}
|
||||
}
|
||||
|
||||
// checkAccounts checks that all known live accounts are present in the wallet list.
|
||||
func checkAccounts(t *testing.T, live map[common.Address]accounts.Account, wallets []accounts.Wallet) {
|
||||
if len(live) != len(wallets) {
|
||||
|
@ -89,18 +89,23 @@ Path of the secret key file: .*UTC--.+--[0-9a-f]{40}
|
||||
}
|
||||
|
||||
func TestAccountImport(t *testing.T) {
|
||||
tests := []struct{ key, output string }{
|
||||
tests := []struct{ name, key, output string }{
|
||||
{
|
||||
name: "correct account",
|
||||
key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
output: "Address: {fcad0b19bb29d4674531d6f115237e16afce377c}\n",
|
||||
},
|
||||
{
|
||||
name: "invalid character",
|
||||
key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef1",
|
||||
output: "Fatal: Failed to load the private key: invalid character '1' at end of key file\n",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
importAccountWithExpect(t, test.key, test.output)
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
importAccountWithExpect(t, test.key, test.output)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user