mod, utils, weierstrass, secp: improve hex errors. secp: improve verify() logic and schnorr

This commit is contained in:
Paul Miller 2023-02-14 16:34:31 +00:00
parent 4d311d7294
commit e1910e85ea
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
4 changed files with 123 additions and 111 deletions

@ -407,7 +407,7 @@ export function hashToPrivateScalar(
groupOrder: bigint, groupOrder: bigint,
isLE = false isLE = false
): bigint { ): bigint {
hash = ensureBytes(hash); hash = ensureBytes('privateHash', hash);
const hashLen = hash.length; const hashLen = hash.length;
const minLen = nLength(groupOrder).nByteLength + 8; const minLen = nLength(groupOrder).nByteLength + 8;
if (minLen < 24 || hashLen < minLen || hashLen > 1024) if (minLen < 24 || hashLen < minLen || hashLen > 1024)

@ -33,14 +33,14 @@ export function numberToHexUnpadded(num: number | bigint): string {
} }
export function hexToNumber(hex: string): bigint { export function hexToNumber(hex: string): bigint {
if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex); if (typeof hex !== 'string') throw new Error('hex string expected, got ' + typeof hex);
// Big Endian // Big Endian
return BigInt(hex === '' ? '0' : `0x${hex}`); return BigInt(hex === '' ? '0' : `0x${hex}`);
} }
// Caching slows it down 2-3x // Caching slows it down 2-3x
export function hexToBytes(hex: string): Uint8Array { export function hexToBytes(hex: string): Uint8Array {
if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex); if (typeof hex !== 'string') throw new Error('hex string expected, got ' + typeof hex);
if (hex.length % 2) throw new Error('hex string is invalid: unpadded ' + hex.length); if (hex.length % 2) throw new Error('hex string is invalid: unpadded ' + hex.length);
const array = new Uint8Array(hex.length / 2); const array = new Uint8Array(hex.length / 2);
for (let i = 0; i < array.length; i++) { for (let i = 0; i < array.length; i++) {
@ -68,13 +68,25 @@ export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, le
// Returns variable number bytes (minimal bigint encoding?) // Returns variable number bytes (minimal bigint encoding?)
export const numberToVarBytesBE = (n: bigint) => hexToBytes(numberToHexUnpadded(n)); export const numberToVarBytesBE = (n: bigint) => hexToBytes(numberToHexUnpadded(n));
export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array { export function ensureBytes(title: string, hex: Hex, expectedLength?: number): Uint8Array {
let res: Uint8Array;
if (typeof hex === 'string') {
try {
res = hexToBytes(hex);
} catch (e) {
throw new Error(`${title} must be valid hex string, got "${hex}". Cause: ${e}`);
}
} else if (u8a(hex)) {
// Uint8Array.from() instead of hash.slice() because node.js Buffer // Uint8Array.from() instead of hash.slice() because node.js Buffer
// is instance of Uint8Array, and its slice() creates **mutable** copy // is instance of Uint8Array, and its slice() creates **mutable** copy
const bytes = u8a(hex) ? Uint8Array.from(hex) : hexToBytes(hex); res = Uint8Array.from(hex);
if (typeof expectedLength === 'number' && bytes.length !== expectedLength) } else {
throw new Error(`Expected ${expectedLength} bytes`); throw new Error(`${title} must be hex string or Uint8Array`);
return bytes; }
const len = res.length;
if (typeof expectedLength === 'number' && len !== expectedLength)
throw new Error(`${title} expected ${expectedLength} bytes, got ${len}`);
return res;
} }
// Copies several Uint8Arrays into one. // Copies several Uint8Arrays into one.
@ -96,6 +108,16 @@ export function equalBytes(b1: Uint8Array, b2: Uint8Array) {
return true; return true;
} }
// Global symbols in both browsers and Node.js since v11
// See https://github.com/microsoft/TypeScript/issues/31535
declare const TextEncoder: any;
export function utf8ToBytes(str: string): Uint8Array {
if (typeof str !== 'string') {
throw new Error(`utf8ToBytes expected string, got ${typeof str}`);
}
return new TextEncoder().encode(str);
}
// Bit operations // Bit operations
// Amount of bits inside bigint (Same as n.toString(2).length) // Amount of bits inside bigint (Same as n.toString(2).length)

@ -59,9 +59,6 @@ export interface ProjPointType<T> extends Group<ProjPointType<T>> {
readonly py: T; readonly py: T;
readonly pz: T; readonly pz: T;
multiply(scalar: bigint): ProjPointType<T>; multiply(scalar: bigint): ProjPointType<T>;
multiplyUnsafe(scalar: bigint): ProjPointType<T>;
multiplyAndAddUnsafe(Q: ProjPointType<T>, a: bigint, b: bigint): ProjPointType<T> | undefined;
_setWindowSize(windowSize: number): void;
toAffine(iz?: T): AffinePoint<T>; toAffine(iz?: T): AffinePoint<T>;
isTorsionFree(): boolean; isTorsionFree(): boolean;
clearCofactor(): ProjPointType<T>; clearCofactor(): ProjPointType<T>;
@ -69,6 +66,10 @@ export interface ProjPointType<T> extends Group<ProjPointType<T>> {
hasEvenY(): boolean; hasEvenY(): boolean;
toRawBytes(isCompressed?: boolean): Uint8Array; toRawBytes(isCompressed?: boolean): Uint8Array;
toHex(isCompressed?: boolean): string; toHex(isCompressed?: boolean): string;
multiplyUnsafe(scalar: bigint): ProjPointType<T>;
multiplyAndAddUnsafe(Q: ProjPointType<T>, a: bigint, b: bigint): ProjPointType<T> | undefined;
_setWindowSize(windowSize: number): void;
} }
// Static methods for 3d XYZ points // Static methods for 3d XYZ points
export interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> { export interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> {
@ -213,7 +214,10 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
} }
let num: bigint; let num: bigint;
try { try {
num = typeof key === 'bigint' ? key : ut.bytesToNumberBE(ensureBytes(key, nByteLength)); num =
typeof key === 'bigint'
? key
: ut.bytesToNumberBE(ensureBytes('private key', key, nByteLength));
} catch (error) { } catch (error) {
throw new Error(`private key must be ${nByteLength} bytes, hex or bigint, not ${typeof key}`); throw new Error(`private key must be ${nByteLength} bytes, hex or bigint, not ${typeof key}`);
} }
@ -276,7 +280,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
* @param hex short/long ECDSA hex * @param hex short/long ECDSA hex
*/ */
static fromHex(hex: Hex): Point { static fromHex(hex: Hex): Point {
const P = Point.fromAffine(CURVE.fromBytes(ensureBytes(hex))); const P = Point.fromAffine(CURVE.fromBytes(ensureBytes('pointHex', hex)));
P.assertValidity(); P.assertValidity();
return P; return P;
} }
@ -782,24 +786,22 @@ export function weierstrass(curveDef: CurveType): CurveFn {
// pair (bytes of r, bytes of s) // pair (bytes of r, bytes of s)
static fromCompact(hex: Hex) { static fromCompact(hex: Hex) {
const gl = CURVE.nByteLength; const l = CURVE.nByteLength;
hex = ensureBytes(hex, gl * 2); hex = ensureBytes('compactSignature', hex, l * 2);
return new Signature(slcNum(hex, 0, gl), slcNum(hex, gl, 2 * gl)); return new Signature(slcNum(hex, 0, l), slcNum(hex, l, 2 * l));
} }
// DER encoded ECDSA signature // DER encoded ECDSA signature
// https://bitcoin.stackexchange.com/questions/57644/what-are-the-parts-of-a-bitcoin-transaction-input-script // https://bitcoin.stackexchange.com/questions/57644/what-are-the-parts-of-a-bitcoin-transaction-input-script
static fromDER(hex: Hex) { static fromDER(hex: Hex) {
if (typeof hex !== 'string' && !(hex instanceof Uint8Array)) const { r, s } = DER.toSig(ensureBytes('DER', hex));
throw new Error(`Signature.fromDER: Expected string or Uint8Array`);
const { r, s } = DER.toSig(ensureBytes(hex));
return new Signature(r, s); return new Signature(r, s);
} }
assertValidity(): void { assertValidity(): void {
// can use assertGE here // can use assertGE here
if (!isWithinCurveOrder(this.r)) throw new Error('r must be 0 < r < n'); if (!isWithinCurveOrder(this.r)) throw new Error('r must be 0 < r < CURVE.n');
if (!isWithinCurveOrder(this.s)) throw new Error('s must be 0 < s < n'); if (!isWithinCurveOrder(this.s)) throw new Error('s must be 0 < s < CURVE.n');
} }
addRecoveryBit(recovery: number) { addRecoveryBit(recovery: number) {
@ -807,11 +809,10 @@ export function weierstrass(curveDef: CurveType): CurveFn {
} }
recoverPublicKey(msgHash: Hex): typeof Point.BASE { recoverPublicKey(msgHash: Hex): typeof Point.BASE {
const { n: N } = CURVE; // ECDSA public key recovery secg.org/sec1-v2.pdf 4.1.6
const { r, s, recovery: rec } = this; const { r, s, recovery: rec } = this;
const h = bits2int_modN(ensureBytes(msgHash)); // Truncate hash const h = bits2int_modN(ensureBytes('msgHash', msgHash)); // Truncate hash
if (rec == null || ![0, 1, 2, 3].includes(rec)) throw new Error('recovery id invalid'); if (rec == null || ![0, 1, 2, 3].includes(rec)) throw new Error('recovery id invalid');
const radj = rec === 2 || rec === 3 ? r + N : r; const radj = rec === 2 || rec === 3 ? r + CURVE.n : r;
if (radj >= Fp.ORDER) throw new Error('recovery id 2 or 3 invalid'); if (radj >= Fp.ORDER) throw new Error('recovery id 2 or 3 invalid');
const prefix = (rec & 1) === 0 ? '02' : '03'; const prefix = (rec & 1) === 0 ? '02' : '03';
const R = Point.fromHex(prefix + numToNByteStr(radj)); const R = Point.fromHex(prefix + numToNByteStr(radj));
@ -936,8 +937,8 @@ export function weierstrass(curveDef: CurveType): CurveFn {
function (bytes: Uint8Array): bigint { function (bytes: Uint8Array): bigint {
// For curves with nBitLength % 8 !== 0: bits2octets(bits2octets(m)) !== bits2octets(m) // For curves with nBitLength % 8 !== 0: bits2octets(bits2octets(m)) !== bits2octets(m)
// for some cases, since bytes.length * 8 is not actual bitLength. // for some cases, since bytes.length * 8 is not actual bitLength.
const delta = bytes.length * 8 - CURVE.nBitLength; // truncate to nBitLength leftmost bits
const num = ut.bytesToNumberBE(bytes); // check for == u8 done here const num = ut.bytesToNumberBE(bytes); // check for == u8 done here
const delta = bytes.length * 8 - CURVE.nBitLength; // truncate to nBitLength leftmost bits
return delta > 0 ? num >> BigInt(delta) : num; return delta > 0 ? num >> BigInt(delta) : num;
}; };
const bits2int_modN = const bits2int_modN =
@ -962,26 +963,25 @@ export function weierstrass(curveDef: CurveType): CurveFn {
// NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521. // NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521.
// Also it can be bigger for P224 + SHA256 // Also it can be bigger for P224 + SHA256
function prepSig(msgHash: Hex, privateKey: PrivKey, opts = defaultSigOpts) { function prepSig(msgHash: Hex, privateKey: PrivKey, opts = defaultSigOpts) {
const { hash, randomBytes } = CURVE;
if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`);
if (['recovered', 'canonical'].some((k) => k in opts)) if (['recovered', 'canonical'].some((k) => k in opts))
// Ban legacy options
throw new Error('sign() legacy options not supported'); throw new Error('sign() legacy options not supported');
const { hash, randomBytes } = CURVE;
let { lowS, prehash, extraEntropy: ent } = opts; // generates low-s sigs by default let { lowS, prehash, extraEntropy: ent } = opts; // generates low-s sigs by default
if (prehash) msgHash = hash(ensureBytes(msgHash));
if (lowS == null) lowS = true; // RFC6979 3.2: we skip step A, because we already provide hash if (lowS == null) lowS = true; // RFC6979 3.2: we skip step A, because we already provide hash
msgHash = ensureBytes('msgHash', msgHash);
if (prehash) msgHash = ensureBytes('prehashed msgHash', hash(msgHash));
// We can't later call bits2octets, since nested bits2int is broken for curves // We can't later call bits2octets, since nested bits2int is broken for curves
// with nBitLength % 8 !== 0. Because of that, we unwrap it here as int2octets call. // with nBitLength % 8 !== 0. Because of that, we unwrap it here as int2octets call.
// const bits2octets = (bits) => int2octets(bits2int_modN(bits)) // const bits2octets = (bits) => int2octets(bits2int_modN(bits))
const h1int = bits2int_modN(ensureBytes(msgHash)); const h1int = bits2int_modN(msgHash);
const d = normalizePrivateKey(privateKey); // validate private key, convert to bigint const d = normalizePrivateKey(privateKey); // validate private key, convert to bigint
const seedArgs = [int2octets(d), int2octets(h1int)]; const seedArgs = [int2octets(d), int2octets(h1int)];
// extraEntropy. RFC6979 3.6: additional k' (optional). // extraEntropy. RFC6979 3.6: additional k' (optional).
if (ent != null) { if (ent != null) {
// K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k') // K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k')
// Either pass as-is, or generate random bytes. Then validate for being ui8a of size BYTES const e = ent === true ? randomBytes(Fp.BYTES) : ent; // generate random bytes OR pass as-is
seedArgs.push(ensureBytes(ent === true ? randomBytes(Fp.BYTES) : ent, Fp.BYTES)); seedArgs.push(ensureBytes('extraEntropy', e, Fp.BYTES)); // check for being of size BYTES
} }
const seed = ut.concatBytes(...seedArgs); // Step D of RFC6979 3.2 const seed = ut.concatBytes(...seedArgs); // Step D of RFC6979 3.2
const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash! const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash!
@ -1057,30 +1057,38 @@ export function weierstrass(curveDef: CurveType): CurveFn {
publicKey: Hex, publicKey: Hex,
opts = defaultVerOpts opts = defaultVerOpts
): boolean { ): boolean {
let P: ProjPointType<bigint>; const sg = signature;
msgHash = ensureBytes('msgHash', msgHash);
publicKey = ensureBytes('publicKey', publicKey);
if ('strict' in opts) throw new Error('options.strict was renamed to lowS');
const { lowS, prehash } = opts;
let _sig: Signature | undefined = undefined; let _sig: Signature | undefined = undefined;
if (publicKey instanceof Point) throw new Error('publicKey must be hex'); let P: ProjPointType<bigint>;
try { try {
if (signature && typeof signature === 'object' && !(signature instanceof Uint8Array)) { if (typeof sg === 'string' || sg instanceof Uint8Array) {
const { r, s } = signature;
_sig = new Signature(r, s); // assertValidity() is executed on creation
} else {
// Signature can be represented in 2 ways: compact (2*nByteLength) & DER (variable-length). // Signature can be represented in 2 ways: compact (2*nByteLength) & DER (variable-length).
// Since DER can also be 2*nByteLength bytes, we check for it first. // Since DER can also be 2*nByteLength bytes, we check for it first.
try { try {
_sig = Signature.fromDER(signature as Hex); _sig = Signature.fromDER(sg);
} catch (derError) { } catch (derError) {
if (!(derError instanceof DER.Err)) throw derError; if (!(derError instanceof DER.Err)) throw derError;
_sig = Signature.fromCompact(signature as Hex); _sig = Signature.fromCompact(sg);
} }
} else if (typeof sg === 'object' && typeof sg.r === 'bigint' && typeof sg.s === 'bigint') {
const { r, s } = sg;
_sig = new Signature(r, s);
} else {
throw new Error('PARSE');
} }
msgHash = ensureBytes(msgHash);
P = Point.fromHex(publicKey); P = Point.fromHex(publicKey);
} catch (error) { } catch (error) {
if ((error as Error).message === 'PARSE')
throw new Error(`signature must be Signature instance, Uint8Array or hex string`);
return false; return false;
} }
if (opts.lowS && _sig.hasHighS()) return false; if (lowS && _sig.hasHighS()) return false;
if (opts.prehash) msgHash = CURVE.hash(msgHash); if (prehash) msgHash = CURVE.hash(msgHash);
const { r, s } = _sig; const { r, s } = _sig;
const h = bits2int_modN(msgHash); // Cannot use fields methods, since it is group element const h = bits2int_modN(msgHash); // Cannot use fields methods, since it is group element
const is = invN(s); // s^-1 const is = invN(s); // s^-1

@ -1,26 +1,12 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
import { sha256 } from '@noble/hashes/sha256'; import { sha256 } from '@noble/hashes/sha256';
import { Fp as Field, mod, pow2 } from './abstract/modular.js';
import { createCurve } from './_shortw_utils.js';
import { ProjPointType as PointType, mapToCurveSimpleSWU } from './abstract/weierstrass.js';
import {
ensureBytes,
concatBytes,
Hex,
bytesToNumberBE as bytesToInt,
PrivKey,
numberToBytesBE,
} from './abstract/utils.js';
import { randomBytes } from '@noble/hashes/utils'; import { randomBytes } from '@noble/hashes/utils';
import { Fp as Field, mod, pow2 } from './abstract/modular.js';
import { ProjPointType as PointType, mapToCurveSimpleSWU } from './abstract/weierstrass.js';
import type { Hex, PrivKey } from './abstract/utils.js';
import { bytesToNumberBE, concatBytes, ensureBytes, numberToBytesBE } from './abstract/utils.js';
import * as htf from './abstract/hash-to-curve.js'; import * as htf from './abstract/hash-to-curve.js';
import { createCurve } from './_shortw_utils.js';
/**
* secp256k1 belongs to Koblitz curves: it has efficiently computable endomorphism.
* Endomorphism uses 2x less RAM, speeds up precomputation by 2x and ECDH / key recovery by 20%.
* Should always be used for Projective's double-and-add multiplication.
* For affines cached multiplication, it trades off 1/2 init time & 1/3 ram for 20% perf hit.
* https://gist.github.com/paulmillr/eb670806793e84df628a7c434a873066
*/
const secp256k1P = BigInt('0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f'); const secp256k1P = BigInt('0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f');
const secp256k1N = BigInt('0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141'); const secp256k1N = BigInt('0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141');
@ -61,23 +47,22 @@ type Fp = bigint;
export const secp256k1 = createCurve( export const secp256k1 = createCurve(
{ {
// Params: a, b a: BigInt(0), // equation params: a, b
// Seem to be rigid https://bitcointalk.org/index.php?topic=289795.msg3183975#msg3183975 b: BigInt(7), // Seem to be rigid: bitcointalk.org/index.php?topic=289795.msg3183975#msg3183975
a: BigInt(0), Fp, // Field's prime: 2n**256n - 2n**32n - 2n**9n - 2n**8n - 2n**7n - 2n**6n - 2n**4n - 1n
b: BigInt(7), n: secp256k1N, // Curve order, total count of valid points in the field
// Field over which we'll do calculations;
// 2n**256n - 2n**32n - 2n**9n - 2n**8n - 2n**7n - 2n**6n - 2n**4n - 1n
Fp,
// Curve order, total count of valid points in the field
n: secp256k1N,
// Base point (x, y) aka generator point // Base point (x, y) aka generator point
Gx: BigInt('55066263022277343669578718895168534326250603453777594175500187360389116729240'), Gx: BigInt('55066263022277343669578718895168534326250603453777594175500187360389116729240'),
Gy: BigInt('32670510020758816978083085130507043184471273380659243275938904335757337482424'), Gy: BigInt('32670510020758816978083085130507043184471273380659243275938904335757337482424'),
h: BigInt(1), h: BigInt(1), // Cofactor
// Alllow only low-S signatures by default in sign() and verify() lowS: true, // Allow only low-S signatures by default in sign() and verify()
lowS: true, /**
* secp256k1 belongs to Koblitz curves: it has efficiently computable endomorphism.
* Endomorphism uses 2x less RAM, speeds up precomputation by 2x and ECDH / key recovery by 20%.
* For precomputed wNAF it trades off 1/2 init time & 1/3 ram for 20% perf hit.
* Explanation: https://gist.github.com/paulmillr/eb670806793e84df628a7c434a873066
*/
endo: { endo: {
// Params taken from https://gist.github.com/paulmillr/eb670806793e84df628a7c434a873066
beta: BigInt('0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee'), beta: BigInt('0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee'),
splitScalar: (k: bigint) => { splitScalar: (k: bigint) => {
const n = secp256k1N; const n = secp256k1N;
@ -105,19 +90,11 @@ export const secp256k1 = createCurve(
sha256 sha256
); );
// Schnorr signatures are superior to ECDSA from above. // Schnorr signatures are superior to ECDSA from above. Below is Schnorr-specific BIP0340 code.
// Below is Schnorr-specific code as per BIP0340.
// https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki // https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki
const _0n = BigInt(0); const _0n = BigInt(0);
const fe = (x: bigint) => typeof x === 'bigint' && _0n < x && x < secp256k1P; const fe = (x: bigint) => typeof x === 'bigint' && _0n < x && x < secp256k1P;
const ge = (x: bigint) => typeof x === 'bigint' && _0n < x && x < secp256k1N; const ge = (x: bigint) => typeof x === 'bigint' && _0n < x && x < secp256k1N;
const TAGS = {
challenge: 'BIP0340/challenge',
aux: 'BIP0340/aux',
nonce: 'BIP0340/nonce',
} as const;
/** An object mapping tags to their tagged hash prefix of [SHA256(tag) | SHA256(tag)] */ /** An object mapping tags to their tagged hash prefix of [SHA256(tag) | SHA256(tag)] */
const TAGGED_HASH_PREFIXES: { [tag: string]: Uint8Array } = {}; const TAGGED_HASH_PREFIXES: { [tag: string]: Uint8Array } = {};
function taggedHash(tag: string, ...messages: Uint8Array[]): Uint8Array { function taggedHash(tag: string, ...messages: Uint8Array[]): Uint8Array {
@ -132,49 +109,53 @@ function taggedHash(tag: string, ...messages: Uint8Array[]): Uint8Array {
const pointToBytes = (point: PointType<bigint>) => point.toRawBytes(true).slice(1); const pointToBytes = (point: PointType<bigint>) => point.toRawBytes(true).slice(1);
const numTo32b = (n: bigint) => numberToBytesBE(n, 32); const numTo32b = (n: bigint) => numberToBytesBE(n, 32);
const modP = (x: bigint) => mod(x, secp256k1P);
const modN = (x: bigint) => mod(x, secp256k1N); const modN = (x: bigint) => mod(x, secp256k1N);
const Point = secp256k1.ProjectivePoint; const Point = secp256k1.ProjectivePoint;
const GmulAdd = (Q: PointType<bigint>, a: bigint, b: bigint) => const GmulAdd = (Q: PointType<bigint>, a: bigint, b: bigint) =>
Point.BASE.multiplyAndAddUnsafe(Q, a, b); Point.BASE.multiplyAndAddUnsafe(Q, a, b);
const hex32ToInt = (key: Hex) => bytesToInt(ensureBytes(key, 32));
function schnorrGetExtPubKey(priv: PrivKey) { function schnorrGetExtPubKey(priv: PrivKey) {
let d = typeof priv === 'bigint' ? priv : hex32ToInt(priv); const d = secp256k1.utils.normPrivateKeyToScalar(priv);
const point = Point.fromPrivateKey(d); // P = d'⋅G; 0 < d' < n check is done inside const point = Point.fromPrivateKey(d); // P = d'⋅G; 0 < d' < n check is done inside
const scalar = point.hasEvenY() ? d : modN(-d); // d = d' if has_even_y(P), otherwise d = n-d' const scalar = point.hasEvenY() ? d : modN(-d); // d = d' if has_even_y(P), otherwise d = n-d'
return { point, scalar, bytes: pointToBytes(point) }; return { point, scalar, bytes: pointToBytes(point) };
} }
function lift_x(x: bigint): PointType<bigint> { function lift_x(x: bigint): PointType<bigint> {
if (!fe(x)) throw new Error('bad x: need 0 < x < p'); // Fail if x ≥ p. if (!fe(x)) throw new Error('bad x: need 0 < x < p'); // Fail if x ≥ p.
const c = mod(x * x * x + BigInt(7), secp256k1P); // Let c = x³ + 7 mod p. const xx = modP(x * x);
const c = modP(xx * x + BigInt(7)); // Let c = x³ + 7 mod p.
let y = sqrtMod(c); // Let y = c^(p+1)/4 mod p. let y = sqrtMod(c); // Let y = c^(p+1)/4 mod p.
if (y % 2n !== 0n) y = mod(-y, secp256k1P); // Return the unique point P such that x(P) = x and if (y % 2n !== 0n) y = modP(-y); // Return the unique point P such that x(P) = x and
const p = new Point(x, y, _1n); // y(P) = y if y mod 2 = 0 or y(P) = p-y otherwise. const p = new Point(x, y, _1n); // y(P) = y if y mod 2 = 0 or y(P) = p-y otherwise.
p.assertValidity(); p.assertValidity();
return p; return p;
} }
function challenge(...args: Uint8Array[]): bigint { function challenge(...args: Uint8Array[]): bigint {
return modN(bytesToInt(taggedHash(TAGS.challenge, ...args))); return modN(bytesToNumberBE(taggedHash('BIP0340/challenge', ...args)));
} }
// Schnorr's pubkey is just `x` of Point (BIP340) /**
* Schnorr public key is just `x` coordinate of Point as per BIP340.
*/
function schnorrGetPublicKey(privateKey: Hex): Uint8Array { function schnorrGetPublicKey(privateKey: Hex): Uint8Array {
return schnorrGetExtPubKey(privateKey).bytes; // d'=int(sk). Fail if d'=0 or d'≥n. Ret bytes(d'⋅G) return schnorrGetExtPubKey(privateKey).bytes; // d'=int(sk). Fail if d'=0 or d'≥n. Ret bytes(d'⋅G)
} }
// Creates Schnorr signature as per BIP340. Verifies itself before returning anything. /**
// auxRand is optional and is not the sole source of k generation: bad CSPRNG won't be dangerous * Creates Schnorr signature as per BIP340. Verifies itself before returning anything.
* auxRand is optional and is not the sole source of k generation: bad CSPRNG won't be dangerous.
*/
function schnorrSign( function schnorrSign(
message: Hex, message: Hex,
privateKey: PrivKey, privateKey: PrivKey,
auxRand: Hex = randomBytes(32) auxRand: Hex = randomBytes(32)
): Uint8Array { ): Uint8Array {
if (message == null) throw new Error(`sign: Expected valid message, not "${message}"`); const m = ensureBytes('message', message);
const m = ensureBytes(message); // checks for isWithinCurveOrder const { bytes: px, scalar: d } = schnorrGetExtPubKey(privateKey); // checks for isWithinCurveOrder
const { bytes: px, scalar: d } = schnorrGetExtPubKey(privateKey); const a = ensureBytes('auxRand', auxRand, 32); // Auxiliary random data a: a 32-byte array
const a = ensureBytes(auxRand, 32); // Auxiliary random data a: a 32-byte array const t = numTo32b(d ^ bytesToNumberBE(taggedHash('BIP0340/aux', a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a)
const t = numTo32b(d ^ bytesToInt(taggedHash(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a) const rand = taggedHash('BIP0340/nonce', t, px, m); // Let rand = hash/nonce(t || bytes(P) || m)
const rand = taggedHash(TAGS.nonce, t, px, m); // Let rand = hash/nonce(t || bytes(P) || m) const k_ = modN(bytesToNumberBE(rand)); // Let k' = int(rand) mod n
const k_ = modN(bytesToInt(rand)); // Let k' = int(rand) mod n
if (k_ === _0n) throw new Error('sign failed: k is zero'); // Fail if k' = 0. if (k_ === _0n) throw new Error('sign failed: k is zero'); // Fail if k' = 0.
const { point: R, bytes: rx, scalar: k } = schnorrGetExtPubKey(k_); // Let R = k'⋅G. const { point: R, bytes: rx, scalar: k } = schnorrGetExtPubKey(k_); // Let R = k'⋅G.
const e = challenge(rx, px, m); // Let e = int(hash/challenge(bytes(R) || bytes(P) || m)) mod n. const e = challenge(rx, px, m); // Let e = int(hash/challenge(bytes(R) || bytes(P) || m)) mod n.
@ -187,18 +168,19 @@ function schnorrSign(
} }
/** /**
* Verifies Schnorr signature synchronously. * Verifies Schnorr signature.
*/ */
function schnorrVerify(signature: Hex, message: Hex, publicKey: Hex): boolean { function schnorrVerify(signature: Hex, message: Hex, publicKey: Hex): boolean {
const sig = ensureBytes('signature', signature, 64);
const m = ensureBytes('message', message);
const pub = ensureBytes('publicKey', publicKey, 32);
try { try {
const P = lift_x(hex32ToInt(publicKey)); // P = lift_x(int(pk)); fail if that fails const P = lift_x(bytesToNumberBE(pub)); // P = lift_x(int(pk)); fail if that fails
const sig = ensureBytes(signature, 64); const r = bytesToNumberBE(sig.subarray(0, 32)); // Let r = int(sig[0:32]); fail if r ≥ p.
const r = bytesToInt(sig.subarray(0, 32)); // Let r = int(sig[0:32]); fail if r ≥ p.
if (!fe(r)) return false; if (!fe(r)) return false;
const s = bytesToInt(sig.subarray(32, 64)); // Let s = int(sig[32:64]); fail if s ≥ n. const s = bytesToNumberBE(sig.subarray(32, 64)); // Let s = int(sig[32:64]); fail if s ≥ n.
if (!ge(s)) return false; if (!ge(s)) return false;
const m = ensureBytes(message); const e = challenge(numTo32b(r), pointToBytes(P), m); // int(challenge(bytes(r)||bytes(P)||m))%n
const e = challenge(numTo32b(r), pointToBytes(P), m); // int(challenge(bytes(r)||bytes(P)||m)) mod n
const R = GmulAdd(P, s, modN(-e)); // R = s⋅G - e⋅P const R = GmulAdd(P, s, modN(-e)); // R = s⋅G - e⋅P
if (!R || !R.hasEvenY() || R.toAffine().x !== r) return false; // -eP == (n-e)P if (!R || !R.hasEvenY() || R.toAffine().x !== r) return false; // -eP == (n-e)P
return true; // Fail if is_infinite(R) / not has_even_y(R) / x(R) ≠ r. return true; // Fail if is_infinite(R) / not has_even_y(R) / x(R) ≠ r.
@ -212,11 +194,12 @@ export const schnorr = {
sign: schnorrSign, sign: schnorrSign,
verify: schnorrVerify, verify: schnorrVerify,
utils: { utils: {
randomPrivateKey: secp256k1.utils.randomPrivateKey,
getExtendedPublicKey: schnorrGetExtPubKey, getExtendedPublicKey: schnorrGetExtPubKey,
lift_x, lift_x,
pointToBytes, pointToBytes,
numberToBytesBE, numberToBytesBE,
bytesToNumberBE: bytesToInt, bytesToNumberBE,
taggedHash, taggedHash,
mod, mod,
}, },
@ -259,7 +242,7 @@ const mapSWU = mapToCurveSimpleSWU(Fp, {
B: BigInt('1771'), B: BigInt('1771'),
Z: Fp.create(BigInt('-11')), Z: Fp.create(BigInt('-11')),
}); });
const { hashToCurve, encodeToCurve } = htf.hashToCurve( export const { hashToCurve, encodeToCurve } = htf.hashToCurve(
secp256k1.ProjectivePoint, secp256k1.ProjectivePoint,
(scalars: bigint[]) => { (scalars: bigint[]) => {
const { x, y } = mapSWU(Fp.create(scalars[0])); const { x, y } = mapSWU(Fp.create(scalars[0]));
@ -275,4 +258,3 @@ const { hashToCurve, encodeToCurve } = htf.hashToCurve(
hash: sha256, hash: sha256,
} }
); );
export { hashToCurve, encodeToCurve };