edwards: remove affine Point, Signature. Stricter types

This commit is contained in:
Paul Miller 2023-01-24 04:34:56 +00:00
parent bfbcf733e6
commit a49f0d266e
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B

@ -60,19 +60,8 @@ function validateOpts(curve: CurveType) {
return Object.freeze({ ...opts } as const); return Object.freeze({ ...opts } as const);
} }
// Instance // 2d point in XY coords
export interface SignatureType { export interface AffinePoint { x: bigint; y: bigint };
readonly r: PointType;
readonly s: bigint;
assertValidity(): SignatureType;
toRawBytes(): Uint8Array;
toHex(): string;
}
// Static methods
export type SignatureConstructor = {
new (r: PointType, s: bigint): SignatureType;
fromHex(hex: Hex): SignatureType;
};
// Instance of Extended Point with coordinates in X, Y, Z, T // Instance of Extended Point with coordinates in X, Y, Z, T
export interface ExtendedPointType extends Group<ExtendedPointType> { export interface ExtendedPointType extends Group<ExtendedPointType> {
@ -80,56 +69,36 @@ export interface ExtendedPointType extends Group<ExtendedPointType> {
readonly y: bigint; readonly y: bigint;
readonly z: bigint; readonly z: bigint;
readonly t: bigint; readonly t: bigint;
multiply(scalar: bigint, affinePoint?: PointType): ExtendedPointType; multiply(scalar: bigint): ExtendedPointType;
multiplyUnsafe(scalar: bigint): ExtendedPointType; multiplyUnsafe(scalar: bigint): ExtendedPointType;
isSmallOrder(): boolean; isSmallOrder(): boolean;
isTorsionFree(): boolean; isTorsionFree(): boolean;
toAffine(invZ?: bigint): PointType; toAffine(invZ?: bigint): AffinePoint;
clearCofactor(): ExtendedPointType; clearCofactor(): ExtendedPointType;
} }
// Static methods of Extended Point with coordinates in X, Y, Z, T // Static methods of Extended Point with coordinates in X, Y, Z, T
export interface ExtendedPointConstructor extends GroupConstructor<ExtendedPointType> { export interface ExtendedPointConstructor extends GroupConstructor<ExtendedPointType> {
new (x: bigint, y: bigint, z: bigint, t: bigint): ExtendedPointType; new (x: bigint, y: bigint, z: bigint, t: bigint): ExtendedPointType;
fromAffine(p: PointType): ExtendedPointType; fromAffine(p: AffinePoint): ExtendedPointType;
toAffineBatch(points: ExtendedPointType[]): PointType[]; toAffineBatch(points: ExtendedPointType[]): AffinePoint[];
normalizeZ(points: ExtendedPointType[]): ExtendedPointType[]; fromHex(hex: Hex): ExtendedPointType;
fromPrivateKey(privateKey: PrivKey): ExtendedPointType; // TODO: remove
} }
// Instance of Affine Point with coordinates in X, Y
export interface PointType extends Group<PointType> {
readonly x: bigint;
readonly y: bigint;
_setWindowSize(windowSize: number): void;
toRawBytes(isCompressed?: boolean): Uint8Array;
toHex(isCompressed?: boolean): string;
isTorsionFree(): boolean;
clearCofactor(): PointType;
}
// Static methods of Affine Point with coordinates in X, Y
export interface PointConstructor extends GroupConstructor<PointType> {
new (x: bigint, y: bigint): PointType;
fromHex(hex: Hex): PointType;
fromPrivateKey(privateKey: PrivKey): PointType;
}
export type PubKey = Hex | PointType;
export type SigType = Hex | SignatureType;
export type CurveFn = { export type CurveFn = {
CURVE: ReturnType<typeof validateOpts>; CURVE: ReturnType<typeof validateOpts>;
getPublicKey: (privateKey: PrivKey, isCompressed?: boolean) => Uint8Array; getPublicKey: (privateKey: PrivKey, isCompressed?: boolean) => Uint8Array;
sign: (message: Hex, privateKey: Hex) => Uint8Array; sign: (message: Hex, privateKey: Hex) => Uint8Array;
verify: (sig: SigType, message: Hex, publicKey: PubKey) => boolean; verify: (sig: Hex, message: Hex, publicKey: Hex) => boolean;
Point: PointConstructor; // Point: PointConstructor;
ExtendedPoint: ExtendedPointConstructor; ExtendedPoint: ExtendedPointConstructor;
Signature: SignatureConstructor;
utils: { utils: {
randomPrivateKey: () => Uint8Array; randomPrivateKey: () => Uint8Array;
getExtendedPublicKey: (key: PrivKey) => { getExtendedPublicKey: (key: PrivKey) => {
head: Uint8Array; head: Uint8Array;
prefix: Uint8Array; prefix: Uint8Array;
scalar: bigint; scalar: bigint;
point: PointType; point: ExtendedPointType;
pointBytes: Uint8Array; pointBytes: Uint8Array;
}; };
}; };
@ -199,23 +168,21 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
static BASE = new ExtendedPoint(CURVE.Gx, CURVE.Gy); static BASE = new ExtendedPoint(CURVE.Gx, CURVE.Gy);
static ZERO = new ExtendedPoint(_0n, _1n); // 0, 1, 1, 0 static ZERO = new ExtendedPoint(_0n, _1n); // 0, 1, 1, 0
static fromAffine(p: Point): ExtendedPoint { static fromAffine(p: AffinePoint): ExtendedPoint {
if (!(p instanceof Point)) { if (!(p && typeof p === 'object' && typeof p.x === 'bigint' && typeof p.y === 'bigint'))
throw new TypeError('ExtendedPoint#fromAffine: expected Point'); throw new Error('fromAffine error');
} if (p.x === 0n && p.y === 1n) return ExtendedPoint.ZERO;
if (p.equals(Point.ZERO)) return ExtendedPoint.ZERO;
return new ExtendedPoint(p.x, p.y); return new ExtendedPoint(p.x, p.y);
} }
// Takes a bunch of Jacobian Points but executes only one // Takes a bunch of Jacobian Points but executes only one
// invert on all of them. invert is very slow operation, // invert on all of them. invert is very slow operation,
// so this improves performance massively. // so this improves performance massively.
static toAffineBatch(points: ExtendedPoint[]): Point[] { static toAffineBatch(points: ExtendedPoint[]): AffinePoint[] {
const toInv = Fp.invertBatch(points.map((p) => p.z)); const toInv = Fp.invertBatch(points.map((p) => p.z));
return points.map((p, i) => p.toAffine(toInv[i])); return points.map((p, i) => p.toAffine(toInv[i]));
} }
static normalizeZ(denorm: ExtendedPoint[]): ExtendedPoint[] {
static normalizeZ(points: ExtendedPoint[]): ExtendedPoint[] { return ExtendedPoint.toAffineBatch(denorm).map(ExtendedPoint.fromAffine);
return this.toAffineBatch(points).map(this.fromAffine);
} }
// Compare one point to another. // Compare one point to another.
@ -305,26 +272,11 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
return this.add(other.negate()); return this.add(other.negate());
} }
private wNAF(n: bigint, affinePoint?: Point): ExtendedPoint {
if (!affinePoint && this.equals(ExtendedPoint.BASE)) affinePoint = Point.BASE;
const W = (affinePoint && affinePoint._WINDOW_SIZE) || 1;
let precomputes = affinePoint && pointPrecomputes.get(affinePoint);
if (!precomputes) {
precomputes = wnaf.precomputeWindow(this, W) as ExtendedPoint[];
if (affinePoint && W !== 1) {
precomputes = ExtendedPoint.normalizeZ(precomputes);
pointPrecomputes.set(affinePoint, precomputes);
}
}
const { p, f } = wnaf.wNAF(W, precomputes, n);
return ExtendedPoint.normalizeZ([p, f])[0];
}
// Constant time multiplication. // Constant time multiplication.
// Uses wNAF method. Windowed method may be 10% faster, // Uses wNAF method. Windowed method may be 10% faster,
// but takes 2x longer to generate and consumes 2x memory. // but takes 2x longer to generate and consumes 2x memory.
multiply(scalar: bigint, affinePoint?: Point): ExtendedPoint { multiply(scalar: bigint): ExtendedPoint {
return this.wNAF(assertGE(scalar), affinePoint); return wNAF_TMP_FN(this, assertGE(scalar));
} }
// Non-constant-time multiplication. Uses double-and-add algorithm. // Non-constant-time multiplication. Uses double-and-add algorithm.
@ -332,11 +284,9 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// an exposed private key e.g. sig verification. // an exposed private key e.g. sig verification.
multiplyUnsafe(scalar: bigint): ExtendedPoint { multiplyUnsafe(scalar: bigint): ExtendedPoint {
let n = assertGE0(scalar); let n = assertGE0(scalar);
const G = ExtendedPoint.BASE; if (n === _0n) return I;
const P0 = ExtendedPoint.ZERO; if (this.equals(I) || n === _1n) return this;
if (n === _0n) return P0; if (this.equals(G)) return wNAF_TMP_FN(this, n);
if (this.equals(P0) || n === _1n) return this;
if (this.equals(G)) return this.wNAF(n);
return wnaf.unsafeLadder(this, n); return wnaf.unsafeLadder(this, n);
} }
@ -356,53 +306,22 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// Converts Extended point to default (x, y) coordinates. // Converts Extended point to default (x, y) coordinates.
// Can accept precomputed Z^-1 - for example, from invertBatch. // Can accept precomputed Z^-1 - for example, from invertBatch.
toAffine(invZ?: bigint): Point { toAffine(iz?: bigint): AffinePoint {
const { x, y, z } = this; const { x, y, z } = this;
const is0 = this.equals(ExtendedPoint.ZERO); const is0 = this.equals(ExtendedPoint.ZERO);
if (invZ == null) invZ = is0 ? _8n : (Fp.invert(z) as bigint); // 8 was chosen arbitrarily if (iz == null) iz = is0 ? _8n : (Fp.invert(z) as bigint); // 8 was chosen arbitrarily
const ax = modP(x * invZ); const ax = modP(x * iz);
const ay = modP(y * invZ); const ay = modP(y * iz);
const zz = modP(z * invZ); const zz = modP(z * iz);
if (is0) return Point.ZERO; if (is0) return { x: _0n, y: _1n };
if (zz !== _1n) throw new Error('invZ was invalid'); if (zz !== _1n) throw new Error('invZ was invalid');
return new Point(ax, ay); return { x: ax, y: ay };
} }
clearCofactor(): ExtendedPoint { clearCofactor(): ExtendedPoint {
const { h: cofactor } = CURVE; const { h: cofactor } = CURVE;
if (cofactor === _1n) return this; if (cofactor === _1n) return this;
return this.multiplyUnsafe(cofactor); return this.multiplyUnsafe(cofactor);
} }
}
const wnaf = wNAF(ExtendedPoint, CURVE.nByteLength * 8);
function assertExtPoint(other: unknown) {
if (!(other instanceof ExtendedPoint)) throw new TypeError('ExtendedPoint expected');
}
// Stores precomputed values for points.
const pointPrecomputes = new WeakMap<Point, ExtendedPoint[]>();
/**
* Default Point works in affine coordinates: (x, y)
*/
class Point implements PointType {
// Base point aka generator
// public_key = Point.BASE * private_key
static BASE: Point = new Point(CURVE.Gx, CURVE.Gy);
// Identity point aka point at infinity
// point = point + zero_point
static ZERO: Point = new Point(_0n, _1n);
// We calculate precomputes for elliptic curve point multiplication
// using windowed method. This specifies window size and
// stores precomputed values. Usually only base point would be precomputed.
_WINDOW_SIZE?: number;
constructor(readonly x: bigint, readonly y: bigint) {}
// "Private method", don't use it directly.
_setWindowSize(windowSize: number) {
this._WINDOW_SIZE = windowSize;
pointPrecomputes.delete(this);
}
// Converts hash string or Uint8Array to Point. // Converts hash string or Uint8Array to Point.
// Uses algo from RFC8032 5.1.3. // Uses algo from RFC8032 5.1.3.
@ -448,100 +367,46 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const isXOdd = (x & _1n) === _1n; const isXOdd = (x & _1n) === _1n;
const isLastByteOdd = (lastByte & 0x80) !== 0; const isLastByteOdd = (lastByte & 0x80) !== 0;
if (isLastByteOdd !== isXOdd) x = modP(-x); if (isLastByteOdd !== isXOdd) x = modP(-x);
return new Point(x, y); return new ExtendedPoint(x, y);
} }
static fromPrivateKey(privateKey: PrivKey) { static fromPrivateKey(privateKey: PrivKey) {
return getExtendedPublicKey(privateKey).point; return getExtendedPublicKey(privateKey).point;
} }
// There can always be only two x values (x, -x) for any y
// When compressing point, it's enough to only store its y coordinate
// and use the last byte to encode sign of x.
toRawBytes(): Uint8Array { toRawBytes(): Uint8Array {
const bytes = ut.numberToBytesLE(this.y, Fp.BYTES); const { x, y } = this.toAffine();
bytes[Fp.BYTES - 1] |= this.x & _1n ? 0x80 : 0;
return bytes;
}
// Same as toRawBytes, but returns string.
toHex(): string {
return ut.bytesToHex(this.toRawBytes());
}
// Determines if point is in prime-order subgroup.
// Returns `false` is the point is dirty.
isTorsionFree(): boolean {
return ExtendedPoint.fromAffine(this).isTorsionFree();
}
equals(other: Point): boolean {
if (!(other instanceof Point)) throw new TypeError('Point#equals: expected Point');
return this.x === other.x && this.y === other.y;
}
negate(): Point {
return new Point(modP(-this.x), this.y);
}
double(): Point {
return ExtendedPoint.fromAffine(this).double().toAffine();
}
add(other: Point) {
return ExtendedPoint.fromAffine(this).add(ExtendedPoint.fromAffine(other)).toAffine();
}
subtract(other: Point) {
return this.add(other.negate());
}
/**
* Constant time multiplication.
* @param scalar Big-Endian number
* @returns new point
*/
multiply(scalar: bigint): Point {
return ExtendedPoint.fromAffine(this).multiply(scalar, this).toAffine();
}
clearCofactor() {
return ExtendedPoint.fromAffine(this).clearCofactor().toAffine();
}
}
/**
* EDDSA signature.
*/
class Signature implements SignatureType {
constructor(readonly r: Point, readonly s: bigint) {
this.assertValidity();
}
static fromHex(hex: Hex) {
const len = Fp.BYTES; const len = Fp.BYTES;
const bytes = ensureBytes(hex, 2 * len); const bytes = ut.numberToBytesLE(y, len); // each y has 2 x values (x, -y)
const r = Point.fromHex(bytes.slice(0, len), false); bytes[len - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y
const s = ut.bytesToNumberLE(bytes.slice(len, 2 * len)); return bytes; // and use the last byte to encode sign of x
return new Signature(r, s);
} }
toHex(): string {
assertValidity() { return ut.bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string.
const { r, s } = this;
if (!(r instanceof Point)) throw new Error('Expected Point instance');
assertGE0(s); // 0 <= s < l
return this;
}
toRawBytes() {
return ut.concatBytes(this.r.toRawBytes(), ut.numberToBytesLE(this.s, Fp.BYTES));
}
toHex() {
return ut.bytesToHex(this.toRawBytes());
} }
} }
const { BASE: G, ZERO: I } = ExtendedPoint;
let Gpows: ExtendedPoint[] | undefined = undefined; // precomputes for base point G
const wnaf = wNAF(ExtendedPoint, CURVE.nByteLength * 8);
function wNAF_TMP_FN(P: ExtendedPoint, n: bigint): ExtendedPoint {
if (P.equals(G)) {
const W = 8;
if (!Gpows) {
const denorm = wnaf.precomputeWindow(P, W) as ExtendedPoint[];
const norm = ExtendedPoint.toAffineBatch(denorm).map(ExtendedPoint.fromAffine);
Gpows = norm;
}
const comp = Gpows;
const { p, f } = wnaf.wNAF(W, comp, n);
return ExtendedPoint.normalizeZ([p, f])[0];
}
const W = 1;
const denorm = wnaf.precomputeWindow(P, W) as ExtendedPoint[];
const norm = ExtendedPoint.toAffineBatch(denorm).map(ExtendedPoint.fromAffine);
const { p, f } = wnaf.wNAF(W, norm, n);
return ExtendedPoint.normalizeZ([p, f])[0];
}
function assertExtPoint(other: unknown) {
if (!(other instanceof ExtendedPoint)) throw new TypeError('ExtendedPoint expected');
}
// Little-endian SHA512 with modulo n // Little-endian SHA512 with modulo n
function modnLE(hash: Uint8Array): bigint { function modnLE(hash: Uint8Array): bigint {
return mod.mod(ut.bytesToNumberLE(hash), CURVE_ORDER); return mod.mod(ut.bytesToNumberLE(hash), CURVE_ORDER);
@ -563,7 +428,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// The actual private scalar // The actual private scalar
const scalar = modnLE(head); const scalar = modnLE(head);
// Point on Edwards curve aka public key // Point on Edwards curve aka public key
const point = Point.BASE.multiply(scalar); const point = G.multiply(scalar);
// Uint8Array representation // Uint8Array representation
const pointBytes = point.toRawBytes(); const pointBytes = point.toRawBytes();
return { head, prefix, scalar, point, pointBytes }; return { head, prefix, scalar, point, pointBytes };
@ -590,10 +455,11 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
if (CURVE.preHash) message = CURVE.preHash(message); if (CURVE.preHash) message = CURVE.preHash(message);
const { prefix, scalar, pointBytes } = getExtendedPublicKey(privateKey); const { prefix, scalar, pointBytes } = getExtendedPublicKey(privateKey);
const r = hashDomainToScalar(ut.concatBytes(prefix, message), context); const r = hashDomainToScalar(ut.concatBytes(prefix, message), context);
const R = Point.BASE.multiply(r); // R = rG const R = G.multiply(r); // R = rG
const k = hashDomainToScalar(ut.concatBytes(R.toRawBytes(), pointBytes, message), context); // k = hash(R+P+msg) const k = hashDomainToScalar(ut.concatBytes(R.toRawBytes(), pointBytes, message), context); // k = hash(R+P+msg)
const s = mod.mod(r + k * scalar, CURVE_ORDER); // s = r + kp const s = mod.mod(r + k * scalar, CURVE_ORDER); // s = r + kp
return new Signature(R, s).toRawBytes(); assertGE0(s); // 0 <= s < l
return ut.concatBytes(R.toRawBytes(), ut.numberToBytesLE(s, Fp.BYTES));
} }
/** /**
@ -605,40 +471,24 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
* 0 <= sig.s < l * 0 <= sig.s < l
* Not compliant with RFC8032: it's not possible to comply to both ZIP & RFC at the same time. * Not compliant with RFC8032: it's not possible to comply to both ZIP & RFC at the same time.
*/ */
function verify(sig: SigType, message: Hex, publicKey: PubKey, context?: Hex): boolean { function verify(sig: Hex, message: Hex, publicKey: Hex, context?: Hex): boolean {
const len = Fp.BYTES;
sig = ensureBytes(sig, 2 * len);
message = ensureBytes(message); message = ensureBytes(message);
if (CURVE.preHash) message = CURVE.preHash(message); if (CURVE.preHash) message = CURVE.preHash(message);
// When hex is passed, we check public key fully. const R = ExtendedPoint.fromHex(sig.slice(0, len), false); // non-strict; allows 0..MASK
// When Point instance is passed, we assume it has already been checked, for performance. const s = ut.bytesToNumberLE(sig.slice(len, 2 * len));
// If user passes Point/Sig instance, we assume it has been already verified. const A = ExtendedPoint.fromHex(publicKey, false); // Check for s bounds, hex validity
// We don't check its equations for performance. We do check for valid bounds for s though const SB = G.multiplyUnsafe(s);
// We always check for: a) s bounds. b) hex validity const k = hashDomainToScalar(ut.concatBytes(R.toRawBytes(), A.toRawBytes(), message), context);
if (publicKey instanceof Point) { const kA = A.multiplyUnsafe(k);
// ignore const RkA = R.add(kA);
} else if (publicKey instanceof Uint8Array || typeof publicKey === 'string') {
publicKey = Point.fromHex(publicKey, false);
} else {
throw new Error(`Invalid publicKey: ${publicKey}`);
}
if (sig instanceof Signature) sig.assertValidity();
else if (sig instanceof Uint8Array || typeof sig === 'string') sig = Signature.fromHex(sig);
else throw new Error(`Wrong signature: ${sig}`);
const { r, s } = sig;
const SB = ExtendedPoint.BASE.multiplyUnsafe(s);
const k = hashDomainToScalar(
ut.concatBytes(r.toRawBytes(), publicKey.toRawBytes(), message),
context
);
const kA = ExtendedPoint.fromAffine(publicKey).multiplyUnsafe(k);
const RkA = ExtendedPoint.fromAffine(r).add(kA);
// [8][S]B = [8]R + [8][k]A' // [8][S]B = [8]R + [8][k]A'
return RkA.subtract(SB).clearCofactor().equals(ExtendedPoint.ZERO); return RkA.subtract(SB).clearCofactor().equals(ExtendedPoint.ZERO);
} }
// Enable precomputes. Slows down first publicKey computation by 20ms. // Enable precomputes. Slows down first publicKey computation by 20ms.
Point.BASE._setWindowSize(8); // G._setWindowSize(8);
const utils = { const utils = {
getExtendedPublicKey, getExtendedPublicKey,
@ -659,11 +509,12 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
* but allows to speed-up subsequent getPublicKey() calls up to 20x. * but allows to speed-up subsequent getPublicKey() calls up to 20x.
* @param windowSize 2, 4, 8, 16 * @param windowSize 2, 4, 8, 16
*/ */
precompute(windowSize = 8, point = Point.BASE): Point { precompute(windowSize = 8, point = G): ExtendedPoint {
const cached = point.equals(Point.BASE) ? point : new Point(point.x, point.y); return G.multiply(2n);
cached._setWindowSize(windowSize); // const cached = point.equals(Point.BASE) ? point : new Point(point.x, point.y);
cached.multiply(_2n); // cached._setWindowSize(windowSize);
return cached; // cached.multiply(_2n);
// return cached;
}, },
}; };
@ -673,8 +524,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
sign, sign,
verify, verify,
ExtendedPoint, ExtendedPoint,
Point, // Point: ExtendedPoint,
Signature,
utils, utils,
}; };
} }