diff --git a/crypto/key_store_passphrase.go b/crypto/key_store_passphrase.go index b7ae9e1de9..1e8d5509bd 100644 --- a/crypto/key_store_passphrase.go +++ b/crypto/key_store_passphrase.go @@ -75,15 +75,7 @@ func (ks keyStorePassphrase) GenerateNewKey(rand io.Reader, auth string) (key *K } func (ks keyStorePassphrase) GetKey(keyAddr common.Address, auth string) (key *Key, err error) { - keyBytes, keyId, err := decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) - if err == nil { - key = &Key{ - Id: uuid.UUID(keyId), - Address: keyAddr, - PrivateKey: ToECDSA(keyBytes), - } - } - return + return decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) } func (ks keyStorePassphrase) Cleanup(keyAddr common.Address) (err error) { @@ -145,39 +137,58 @@ func (ks keyStorePassphrase) StoreKey(key *Key, auth string) (err error) { return writeKeyFile(key.Address, ks.keysDirPath, keyJSON) } -func (ks keyStorePassphrase) DeleteKey(keyAddr common.Address, auth string) (err error) { +func (ks keyStorePassphrase) DeleteKey(keyAddr common.Address, auth string) error { // only delete if correct passphrase is given - _, _, err = decryptKeyFromFile(ks.keysDirPath, keyAddr, auth) - if err != nil { + if _, err := decryptKeyFromFile(ks.keysDirPath, keyAddr, auth); err != nil { return err } - return deleteKey(ks.keysDirPath, keyAddr) } -func decryptKeyFromFile(keysDirPath string, keyAddr common.Address, auth string) (keyBytes []byte, keyId []byte, err error) { +// DecryptKey decrypts a key from a json blob, returning the private key itself. +func DecryptKey(keyjson []byte, auth string) (*Key, error) { + // Parse the json into a simple map to fetch the key version m := make(map[string]interface{}) - err = getKey(keysDirPath, keyAddr, &m) - if err != nil { - return + if err := json.Unmarshal(keyjson, &m); err != nil { + return nil, err } - + // Depending on the version try to parse one way or another + var ( + keyBytes, keyId []byte + err error + ) v := reflect.ValueOf(m["version"]) if v.Kind() == reflect.String && v.String() == "1" { k := new(encryptedKeyJSONV1) - err = getKey(keysDirPath, keyAddr, &k) - if err != nil { - return + if err := json.Unmarshal(keyjson, k); err != nil { + return nil, err } - return decryptKeyV1(k, auth) + keyBytes, keyId, err = decryptKeyV1(k, auth) } else { k := new(encryptedKeyJSONV3) - err = getKey(keysDirPath, keyAddr, &k) - if err != nil { - return + if err := json.Unmarshal(keyjson, k); err != nil { + return nil, err } - return decryptKeyV3(k, auth) + keyBytes, keyId, err = decryptKeyV3(k, auth) } + // Handle any decryption errors and return the key + if err != nil { + return nil, err + } + key := ToECDSA(keyBytes) + return &Key{ + Id: uuid.UUID(keyId), + Address: PubkeyToAddress(key.PublicKey), + PrivateKey: key, + }, nil +} + +func decryptKeyFromFile(keysDirPath string, keyAddr common.Address, auth string) (key *Key, err error) { + keyjson, err := getKeyFile(keysDirPath, keyAddr) + if err != nil { + return nil, err + } + return DecryptKey(keyjson, auth) } func decryptKeyV3(keyProtected *encryptedKeyJSONV3, auth string) (keyBytes []byte, keyId []byte, err error) { diff --git a/crypto/key_store_plain.go b/crypto/key_store_plain.go index c1c23f8b8b..4ce789a30e 100644 --- a/crypto/key_store_plain.go +++ b/crypto/key_store_plain.go @@ -62,18 +62,16 @@ func GenerateNewKeyDefault(ks KeyStore, rand io.Reader, auth string) (key *Key, return key, err } -func (ks keyStorePlain) GetKey(keyAddr common.Address, auth string) (key *Key, err error) { - key = new(Key) - err = getKey(ks.keysDirPath, keyAddr, key) - return -} - -func getKey(keysDirPath string, keyAddr common.Address, content interface{}) (err error) { - fileContent, err := getKeyFile(keysDirPath, keyAddr) +func (ks keyStorePlain) GetKey(keyAddr common.Address, auth string) (*Key, error) { + keyjson, err := getKeyFile(ks.keysDirPath, keyAddr) if err != nil { - return + return nil, err } - return json.Unmarshal(fileContent, content) + key := new(Key) + if err := json.Unmarshal(keyjson, key); err != nil { + return nil, err + } + return key, nil } func (ks keyStorePlain) GetKeyAddresses() (addresses []common.Address, err error) {