Improve modular math

This commit is contained in:
Paul Miller 2022-12-31 06:49:42 +00:00
parent cc2c84f040
commit 12da04a2bb
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
2 changed files with 134 additions and 141 deletions

@ -1,10 +1,11 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// TODO: remove circular imports
import * as utils from './utils.js';
// Utilities for modular arithmetics and finite fields
// prettier-ignore
const _0n = BigInt(0), _1n = BigInt(1), _2n = BigInt(2), _3n = BigInt(3);
// prettier-ignore
const _4n = BigInt(4), _5n = BigInt(5), _7n = BigInt(7), _8n = BigInt(8);
const _4n = BigInt(4), _5n = BigInt(5), _8n = BigInt(8);
// prettier-ignore
const _9n = BigInt(9), _16n = BigInt(16);
@ -66,26 +67,68 @@ export function invert(number: bigint, modulo: bigint): bigint {
return mod(x, modulo);
}
/**
* Calculates Legendre symbol (a | p), which denotes the value of a^((p-1)/2) (mod p).
* * (a | p) 1 if a is a square (mod p)
* * (a | p) -1 if a is not a square (mod p)
* * (a | p) 0 if a 0 (mod p)
*/
export function legendre(num: bigint, fieldPrime: bigint): bigint {
return pow(num, (fieldPrime - _1n) / _2n, fieldPrime);
// Tonelli-Shanks algorithm
// https://eprint.iacr.org/2012/685.pdf (page 12)
export function tonelliShanks(P: bigint) {
// Legendre constant: used to calculate Legendre symbol (a | p),
// which denotes the value of a^((p-1)/2) (mod p).
// (a | p) ≡ 1 if a is a square (mod p)
// (a | p) ≡ -1 if a is not a square (mod p)
// (a | p) ≡ 0 if a ≡ 0 (mod p)
const legendreC = (P - _1n) / _2n;
let Q: bigint, S: number, Z: bigint;
// Step 1: By factoring out powers of 2 from p - 1,
// find q and s such that p - 1 = q2s with q odd
for (Q = P - _1n, S = 0; Q % _2n === _0n; Q /= _2n, S++);
// Step 2: Select a non-square z such that (z | p) ≡ -1 and set c ≡ zq
for (Z = _2n; Z < P && pow(Z, legendreC, P) !== P - _1n; Z++);
// Fast-path
if (S === 1) {
const p1div4 = (P + _1n) / _4n;
return function tonelliFast<T>(Fp: Field<T>, n: T) {
const root = Fp.pow(n, p1div4);
if (!Fp.equals(Fp.square(root), n)) throw new Error('Cannot find square root');
return root;
};
}
// Slow-path
const Q1div2 = (Q + _1n) / _2n;
return function tonelliSlow<T>(Fp: Field<T>, n: T): T {
// Step 0: Check that n is indeed a square: (n | p) must be ≡ 1
if (Fp.pow(n, legendreC) !== Fp.ONE) throw new Error('Cannot find square root');
let s = S;
let c = pow(Z, Q, P);
let r = Fp.pow(n, Q1div2);
let t = Fp.pow(n, Q);
let t2 = Fp.ZERO;
while (!Fp.equals(Fp.sub(t, Fp.ONE), Fp.ZERO)) {
t2 = Fp.square(t);
let i;
for (i = 1; i < s; i++) {
// stop if t2-1 == 0
if (Fp.equals(Fp.sub(t2, Fp.ONE), Fp.ZERO)) break;
// t2 *= t2
t2 = Fp.square(t2);
}
let b = pow(c, BigInt(1 << (s - i - 1)), P);
r = Fp.mul(r, b);
c = mod(b * b, P);
t = Fp.mul(t, c);
s = i;
}
return r;
};
}
/**
* Calculates square root of a number in a finite field.
* a mod P
*/
// TODO: rewrite as generic Fp function && remove bls versions
export function sqrt(number: bigint, modulo: bigint): bigint {
// prettier-ignore
const n = number;
const P = modulo;
const p1div4 = (P + _1n) / _4n;
export function FpSqrt(P: bigint) {
// NOTE: different algorithms can give different roots, it is up to user to decide which one they want.
// For example there is FpSqrtOdd/FpSqrtEven to choice root based on oddness (used for hash-to-curve).
// P ≡ 3 (mod 4)
// √n = n^((P+1)/4)
@ -94,48 +137,54 @@ export function sqrt(number: bigint, modulo: bigint): bigint {
// const ORDER =
// 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaabn;
// const NUM = 72057594037927816n;
// TODO: fix sqrtMod in secp256k1
const root = pow(n, p1div4, P);
if (mod(root * root, modulo) !== number) throw new Error('Cannot find square root');
const p1div4 = (P + _1n) / _4n;
return function sqrt3mod4<T>(Fp: Field<T>, n: T) {
const root = Fp.pow(n, p1div4);
// Throw if root**2 != n
if (!Fp.equals(Fp.square(root), n)) throw new Error('Cannot find square root');
return root;
};
}
// P ≡ 5 (mod 8)
// Atkin algorithm for q ≡ 5 (mod 8), https://eprint.iacr.org/2012/685.pdf (page 10)
if (P % _8n === _5n) {
const n2 = mod(n * _2n, P);
const v = pow(n2, (P - _5n) / _8n, P);
const nv = mod(n * v, P);
const i = mod(_2n * nv * v, P);
const r = mod(nv * (i - _1n), P);
return r;
const c1 = (P - _5n) / _8n;
return function sqrt5mod8<T>(Fp: Field<T>, n: T) {
const n2 = Fp.mul(n, _2n);
const v = Fp.pow(n2, c1);
const nv = Fp.mul(n, v);
const i = Fp.mul(Fp.mul(nv, _2n), v);
const root = Fp.mul(nv, Fp.sub(i, Fp.ONE));
if (!Fp.equals(Fp.square(root), n)) throw new Error('Cannot find square root');
return root;
};
}
// P ≡ 9 (mod 16)
if (P % _16n === _9n) {
// NOTE: tonelli is too slow for bls-Fp2 calculations even on start
// Means we cannot use sqrt for constants at all!
//
// const c1 = Fp.sqrt(Fp.negate(Fp.ONE)); // 1. c1 = sqrt(-1) in F, i.e., (c1^2) == -1 in F
// const c2 = Fp.sqrt(c1); // 2. c2 = sqrt(c1) in F, i.e., (c2^2) == c1 in F
// const c3 = Fp.sqrt(Fp.negate(c1)); // 3. c3 = sqrt(-c1) in F, i.e., (c3^2) == -c1 in F
// const c4 = (P + _7n) / _16n; // 4. c4 = (q + 7) / 16 # Integer arithmetic
// sqrt = (x) => {
// let tv1 = Fp.pow(x, c4); // 1. tv1 = x^c4
// let tv2 = Fp.mul(c1, tv1); // 2. tv2 = c1 * tv1
// const tv3 = Fp.mul(c2, tv1); // 3. tv3 = c2 * tv1
// let tv4 = Fp.mul(c3, tv1); // 4. tv4 = c3 * tv1
// const e1 = Fp.equals(Fp.square(tv2), x); // 5. e1 = (tv2^2) == x
// const e2 = Fp.equals(Fp.square(tv3), x); // 6. e2 = (tv3^2) == x
// tv1 = Fp.cmov(tv1, tv2, e1); // 7. tv1 = CMOV(tv1, tv2, e1) # Select tv2 if (tv2^2) == x
// tv2 = Fp.cmov(tv4, tv3, e2); // 8. tv2 = CMOV(tv4, tv3, e2) # Select tv3 if (tv3^2) == x
// const e3 = Fp.equals(Fp.square(tv2), x); // 9. e3 = (tv2^2) == x
// return Fp.cmov(tv1, tv2, e3); // 10. z = CMOV(tv1, tv2, e3) # Select the sqrt from tv1 and tv2
// }
}
// Other cases: Tonelli-Shanks algorithm
if (legendre(n, P) !== _1n) throw new Error('Cannot find square root');
let q: bigint, s: number, z: bigint;
for (q = P - _1n, s = 0; q % _2n === _0n; q /= _2n, s++);
if (s === 1) return pow(n, p1div4, P);
for (z = _2n; z < P && legendre(z, P) !== P - _1n; z++);
let c = pow(z, q, P);
let r = pow(n, (q + _1n) / _2n, P);
let t = pow(n, q, P);
let t2 = _0n;
while (mod(t - _1n, P) !== _0n) {
t2 = mod(t * t, P);
let i;
for (i = 1; i < s; i++) {
if (mod(t2 - _1n, P) === _0n) break;
t2 = mod(t2 * t2, P);
}
let b = pow(c, BigInt(1 << (s - i - 1)), P);
r = mod(r * b, P);
c = mod(b * b, P);
t = mod(t * c, P);
s = i;
}
return r;
return tonelliShanks(P);
}
// Little-endian check for first LE bit (last BE bit);
@ -176,6 +225,7 @@ export interface Field<T> {
// Optional
// Should be same as sgn0 function in https://datatracker.ietf.org/doc/draft-irtf-cfrg-hash-to-curve/
// NOTE: sgn0 is 'negative in LE', which is same as odd. And negative in LE is kinda strange definition anyway.
isOdd?(num: T): boolean; // Odd instead of even since we have it for Fp2
legendre?(num: T): T;
pow(lhs: T, power: bigint): T;
@ -246,21 +296,31 @@ export function FpDiv<T>(f: Field<T>, lhs: T, rhs: T | bigint): T {
return f.mul(lhs, typeof rhs === 'bigint' ? invert(rhs, f.ORDER) : f.invert(rhs));
}
// This function returns True whenever the value x is a square in the field F.
export function FpIsSquare<T>(f: Field<T>) {
const legendreConst = (f.ORDER - _1n) / _2n; // Integer arithmetic
return (x: T): boolean => {
const p = f.pow(x, legendreConst);
return f.equals(p, f.ZERO) || f.equals(p, f.ONE);
};
}
// NOTE: very fragile, always bench. Major performance points:
// - NonNormalized ops
// - Object.freeze
// - same shape of object (don't add/remove keys)
type FpField = Field<bigint> & Required<Pick<Field<bigint>, 'isOdd'>>;
export function Fp(
ORDER: bigint,
bitLen?: number,
isLE = false,
redef: Partial<Field<bigint>> = {}
): Readonly<Field<bigint>> {
): Readonly<FpField> {
if (ORDER <= _0n) throw new Error(`Expected Fp ORDER > 0, got ${ORDER}`);
const { nBitLength: BITS, nByteLength: BYTES } = utils.nLength(ORDER, bitLen);
if (BYTES > 2048) throw new Error('Field lengths over 2048 bytes are not supported');
const sqrtP = (num: bigint) => sqrt(num, ORDER);
const f: Field<bigint> = Object.freeze({
const sqrtP = FpSqrt(ORDER);
const f: Readonly<FpField> = Object.freeze({
ORDER,
BITS,
BYTES,
@ -292,7 +352,7 @@ export function Fp(
mulN: (lhs, rhs) => lhs * rhs,
invert: (num) => invert(num, ORDER),
sqrt: redef.sqrt || sqrtP,
sqrt: redef.sqrt || ((n) => sqrtP(f, n)),
invertBatch: (lst) => FpInvertBatch(f, lst),
// TODO: do we really need constant cmov?
// We don't have const-time bigints anyway, so probably will be not very useful
@ -305,87 +365,18 @@ export function Fp(
throw new Error(`Fp.fromBytes: expected ${BYTES}, got ${bytes.length}`);
return isLE ? utils.bytesToNumberLE(bytes) : utils.bytesToNumberBE(bytes);
},
} as Field<bigint>);
} as FpField);
return Object.freeze(f);
}
// TODO: re-use in bls/generic sqrt for field/etc?
// Something like sqrtUnsafe which always returns value, but sqrt throws exception if non-square
// From draft-irtf-cfrg-hash-to-curve-16
export function FpSqrt<T>(Fp: Field<T>) {
// NOTE: it requires another sqrt for constant precomputes, but no need for roots of unity,
// probably we can simply bls code using it
const q = Fp.ORDER;
const squareConst = (q - _1n) / _2n;
// is_square(x) := { True, if x^((q - 1) / 2) is 0 or 1 in F;
// { False, otherwise.
let isSquare: (x: T) => boolean = (x) => {
const p = Fp.pow(x, squareConst);
return Fp.equals(p, Fp.ZERO) || Fp.equals(p, Fp.ONE);
};
// Constant-time Tonelli-Shanks algorithm
let l = _0n;
for (let o = q - _1n; o % _2n === _0n; o /= _2n) l += _1n;
const c1 = l; // 1. c1, the largest integer such that 2^c1 divides q - 1.
const c2 = (q - _1n) / _2n ** c1; // 2. c2 = (q - 1) / (2^c1) # Integer arithmetic
const c3 = (c2 - _1n) / _2n; // 3. c3 = (c2 - 1) / 2 # Integer arithmetic
// 4. c4, a non-square value in F
// 5. c5 = c4^c2 in F
let c4 = Fp.ONE;
while (isSquare(c4)) c4 = Fp.add(c4, Fp.ONE);
const c5 = Fp.pow(c4, c2);
let sqrt: (x: T) => T = (x) => {
let z = Fp.pow(x, c3); // 1. z = x^c3
let t = Fp.square(z); // 2. t = z * z
t = Fp.mul(t, x); // 3. t = t * x
z = Fp.mul(z, x); // 4. z = z * x
let b = t; // 5. b = t
let c = c5; // 6. c = c5
// 7. for i in (c1, c1 - 1, ..., 2):
for (let i = c1; i > 1; i--) {
// 8. for j in (1, 2, ..., i - 2):
// 9. b = b * b
for (let j = _1n; j < i - _1n; i++) b = Fp.square(b);
const e = Fp.equals(b, Fp.ONE); // 10. e = b == 1
const zt = Fp.mul(z, c); // 11. zt = z * c
z = Fp.cmov(zt, z, e); // 12. z = CMOV(zt, z, e)
c = Fp.square(c); // 13. c = c * c
let tt = Fp.mul(t, c); // 14. tt = t * c
t = Fp.cmov(tt, t, e); // 15. t = CMOV(tt, t, e)
b = t; // 16. b = t
}
return z; // 17. return z
};
if (q % _4n === _3n) {
const c1 = (q + _1n) / _4n; // 1. c1 = (q + 1) / 4 # Integer arithmetic
sqrt = (x) => Fp.pow(x, c1);
} else if (q % _8n === _5n) {
const c1 = Fp.sqrt(Fp.negate(Fp.ONE)); // 1. c1 = sqrt(-1) in F, i.e., (c1^2) == -1 in F
const c2 = (q + _3n) / _8n; // 2. c2 = (q + 3) / 8 # Integer arithmetic
sqrt = (x) => {
let tv1 = Fp.pow(x, c2); // 1. tv1 = x^c2
let tv2 = Fp.mul(tv1, c1); // 2. tv2 = tv1 * c1
let e = Fp.equals(Fp.square(tv1), x); // 3. e = (tv1^2) == x
return Fp.cmov(tv2, tv1, e); // 4. z = CMOV(tv2, tv1, e)
};
} else if (Fp.ORDER % _16n === _9n) {
const c1 = Fp.sqrt(Fp.negate(Fp.ONE)); // 1. c1 = sqrt(-1) in F, i.e., (c1^2) == -1 in F
const c2 = Fp.sqrt(c1); // 2. c2 = sqrt(c1) in F, i.e., (c2^2) == c1 in F
const c3 = Fp.sqrt(Fp.negate(c1)); // 3. c3 = sqrt(-c1) in F, i.e., (c3^2) == -c1 in F
const c4 = (Fp.ORDER + _7n) / _16n; // 4. c4 = (q + 7) / 16 # Integer arithmetic
sqrt = (x) => {
let tv1 = Fp.pow(x, c4); // 1. tv1 = x^c4
let tv2 = Fp.mul(c1, tv1); // 2. tv2 = c1 * tv1
const tv3 = Fp.mul(c2, tv1); // 3. tv3 = c2 * tv1
let tv4 = Fp.mul(c3, tv1); // 4. tv4 = c3 * tv1
const e1 = Fp.equals(Fp.square(tv2), x); // 5. e1 = (tv2^2) == x
const e2 = Fp.equals(Fp.square(tv3), x); // 6. e2 = (tv3^2) == x
tv1 = Fp.cmov(tv1, tv2, e1); // 7. tv1 = CMOV(tv1, tv2, e1) # Select tv2 if (tv2^2) == x
tv2 = Fp.cmov(tv4, tv3, e2); // 8. tv2 = CMOV(tv4, tv3, e2) # Select tv3 if (tv3^2) == x
const e3 = Fp.equals(Fp.square(tv2), x); // 9. e3 = (tv2^2) == x
return Fp.cmov(tv1, tv2, e3); // 10. z = CMOV(tv1, tv2, e3) # Select the sqrt from tv1 and tv2
};
}
return { sqrt, isSquare };
export function FpSqrtOdd<T>(Fp: Field<T>, elm: T) {
if (!Fp.isOdd) throw new Error(`Field doesn't have isOdd`);
const root = Fp.sqrt(elm);
return Fp.isOdd(root) ? root : Fp.negate(root);
}
export function FpSqrtEven<T>(Fp: Field<T>, elm: T) {
if (!Fp.isOdd) throw new Error(`Field doesn't have isOdd`);
const root = Fp.sqrt(elm);
return Fp.isOdd(root) ? Fp.negate(root) : root;
}

@ -56,7 +56,9 @@ function sqrtMod(y: bigint): bigint {
const b223 = (pow2(b220, _3n, P) * b3) % P;
const t1 = (pow2(b223, _23n, P) * b22) % P;
const t2 = (pow2(t1, _6n, P) * b2) % P;
return pow2(t2, _2n, P);
const root = pow2(t2, _2n, P);
if (!Fp.equals(Fp.square(root), y)) throw new Error('Cannot find square root');
return root;
}
const Fp = Field(secp256k1P, undefined, undefined, { sqrt: sqrtMod });
@ -152,7 +154,7 @@ export const secp256k1 = createCurve(
p: Fp.ORDER,
m: 1,
k: 128,
expand: true,
expand: 'xmd',
hash: sha256,
},
},