Merge pull request #2 from Gustav-Simonsson/correct_ecies_shared_key_generation

Correct ECIES shared key length check
This commit is contained in:
Jeffrey Wilcke 2015-02-12 00:10:00 +01:00
commit 04c1a81509

@ -13,11 +13,12 @@ import (
) )
var ( var (
ErrImport = fmt.Errorf("ecies: failed to import key") ErrImport = fmt.Errorf("ecies: failed to import key")
ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve") ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve")
ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters") ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters")
ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key") ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key")
ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key is too big") ErrSharedKeyIsPointAtInfinity = fmt.Errorf("ecies: shared key is point at infinity")
ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key params are too big")
) )
// PublicKey is a representation of an elliptic curve public key. // PublicKey is a representation of an elliptic curve public key.
@ -90,16 +91,20 @@ func MaxSharedKeyLength(pub *PublicKey) int {
// ECDH key agreement method used to establish secret keys for encryption. // ECDH key agreement method used to establish secret keys for encryption.
func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []byte, err error) { func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []byte, err error) {
if prv.PublicKey.Curve != pub.Curve { if prv.PublicKey.Curve != pub.Curve {
err = ErrInvalidCurve return nil, ErrInvalidCurve
return }
if skLen+macLen > MaxSharedKeyLength(pub) {
return nil, ErrSharedKeyTooBig
} }
x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes()) x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes())
if x == nil || (x.BitLen()+7)/8 < (skLen+macLen) { if x == nil {
err = ErrSharedKeyTooBig return nil, ErrSharedKeyIsPointAtInfinity
return
} }
sk = x.Bytes()[:skLen+macLen]
return sk = make([]byte, skLen+macLen)
skBytes := x.Bytes()
copy(sk[len(sk)-len(skBytes):], skBytes)
return sk, nil
} }
var ( var (