diff --git a/src/bls12-381.ts b/src/bls12-381.ts index 85d12f7..8aebabc 100644 --- a/src/bls12-381.ts +++ b/src/bls12-381.ts @@ -40,7 +40,6 @@ import { numberToBytesBE, bytesToNumberBE, bitLen, - bitSet, bitGet, Hex, bitMask, @@ -1019,11 +1018,31 @@ const htfDefaults = Object.freeze({ // Encoding utils // Point on G1 curve: (x, y) -const C_BIT_POS = Fp.BITS; // C_bit, compression bit for serialization flag -const I_BIT_POS = Fp.BITS + 1; // I_bit, point-at-infinity bit for serialization flag -const S_BIT_POS = Fp.BITS + 2; // S_bit, sign bit for serialization flag + // Compressed point of infinity -const COMPRESSED_ZERO = Fp.toBytes(bitSet(bitSet(_0n, I_BIT_POS, true), S_BIT_POS, true)); // set compressed & point-at-infinity bits +const COMPRESSED_ZERO = setMask(Fp.toBytes(_0n), { infinity: true, compressed: true }); // set compressed & point-at-infinity bits + +function parseMask(bytes: Uint8Array) { + // Copy, so we can remove mask data. It will be removed also later, when Fp.create will call modulo. + bytes = bytes.slice(); + const mask = bytes[0] & 0b1110_0000; + const compressed = !!((mask >> 7) & 1); // compression bit (0b1000_0000) + const infinity = !!((mask >> 6) & 1); // point at infinity bit (0b0100_0000) + const sort = !!((mask >> 5) & 1); // sort bit (0b0010_0000) + bytes[0] &= 0b0001_1111; // clear mask (zero first 3 bits) + return { compressed, infinity, sort, value: bytes }; +} + +function setMask( + bytes: Uint8Array, + mask: { compressed?: boolean; infinity?: boolean; sort?: boolean } +) { + if (bytes[0] & 0b1110_0000) throw new Error('setMask: non-empty mask'); + if (mask.compressed) bytes[0] |= 0b1000_0000; + if (mask.infinity) bytes[0] |= 0b0100_0000; + if (mask.sort) bytes[0] |= 0b0010_0000; + return bytes; +} function signatureG1ToRawBytes(point: ProjPointType) { point.assertValidity(); @@ -1031,10 +1050,8 @@ function signatureG1ToRawBytes(point: ProjPointType) { const { x, y } = point.toAffine(); if (isZero) return COMPRESSED_ZERO.slice(); const P = Fp.ORDER; - let num; - num = bitSet(x, C_BIT_POS, Boolean((y * _2n) / P)); // set aflag - num = bitSet(num, S_BIT_POS, true); - return numberToBytesBE(num, Fp.BYTES); + const sort = Boolean((y * _2n) / P); + return setMask(numberToBytesBE(x, Fp.BYTES), { compressed: true, sort }); } function signatureG2ToRawBytes(point: ProjPointType) { @@ -1047,10 +1064,12 @@ function signatureG2ToRawBytes(point: ProjPointType) { const { re: x0, im: x1 } = Fp2.reim(x); const { re: y0, im: y1 } = Fp2.reim(y); const tmp = y1 > _0n ? y1 * _2n : y0 * _2n; - const aflag1 = Boolean((tmp / Fp.ORDER) & _1n); - const z1 = bitSet(bitSet(x1, 381, aflag1), S_BIT_POS, true); + const sort = Boolean((tmp / Fp.ORDER) & _1n); const z2 = x0; - return concatB(numberToBytesBE(z1, len), numberToBytesBE(z2, len)); + return concatB( + setMask(numberToBytesBE(x1, len), { sort, compressed: true }), + numberToBytesBE(z2, len) + ); } // To verify curve parameters, see pairing-friendly-curves spec: @@ -1131,26 +1150,30 @@ export const bls12_381: CurveFn = bls({ return isogenyMapG1(x, y); }, fromBytes: (bytes: Uint8Array): AffinePoint => { - bytes = bytes.slice(); - if (bytes.length === 48) { + const { compressed, infinity, sort, value } = parseMask(bytes); + if (value.length === 48 && compressed) { // TODO: Fp.bytes const P = Fp.ORDER; - const compressedValue = bytesToNumberBE(bytes); - const bflag = bitGet(compressedValue, I_BIT_POS); + const compressedValue = bytesToNumberBE(value); // Zero - if (bflag === _1n) return { x: _0n, y: _0n }; const x = Fp.create(compressedValue & Fp.MASK); + if (infinity) { + if (x !== _0n) throw new Error('G1: non-empty compressed point at infinity'); + return { x: _0n, y: _0n }; + } const right = Fp.add(Fp.pow(x, _3n), Fp.create(bls12_381.params.G1b)); // y² = x³ + b let y = Fp.sqrt(right); if (!y) throw new Error('Invalid compressed G1 point'); - const aflag = bitGet(compressedValue, C_BIT_POS); - if ((y * _2n) / P !== aflag) y = Fp.neg(y); + if ((y * _2n) / P !== BigInt(sort)) y = Fp.neg(y); return { x: Fp.create(x), y: Fp.create(y) }; - } else if (bytes.length === 96) { + } else if (value.length === 96 && !compressed) { // Check if the infinity flag is set - if ((bytes[0] & (1 << 6)) !== 0) return bls12_381.G1.ProjectivePoint.ZERO.toAffine(); - const x = bytesToNumberBE(bytes.subarray(0, Fp.BYTES)); - const y = bytesToNumberBE(bytes.subarray(Fp.BYTES)); + const x = bytesToNumberBE(value.subarray(0, Fp.BYTES)); + const y = bytesToNumberBE(value.subarray(Fp.BYTES)); + if (infinity) { + if (x !== _0n || y !== _0n) throw new Error('G1: non-empty point at infinity'); + return bls12_381.G1.ProjectivePoint.ZERO.toAffine(); + } return { x: Fp.create(x), y: Fp.create(y) }; } else { throw new Error('Invalid point G1, expected 48/96 bytes'); @@ -1162,10 +1185,8 @@ export const bls12_381: CurveFn = bls({ if (isCompressed) { if (isZero) return COMPRESSED_ZERO.slice(); const P = Fp.ORDER; - let num; - num = bitSet(x, C_BIT_POS, Boolean((y * _2n) / P)); // set aflag - num = bitSet(num, S_BIT_POS, true); - return numberToBytesBE(num, Fp.BYTES); + const sort = Boolean((y * _2n) / P); + return setMask(numberToBytesBE(x, Fp.BYTES), { compressed: true, sort }); } else { if (isZero) { // 2x PUBLIC_KEY_LENGTH @@ -1178,18 +1199,16 @@ export const bls12_381: CurveFn = bls({ }, ShortSignature: { fromHex(hex: Hex): ProjPointType { - const bytes = ensureBytes('signatureHex', hex, 48); - + const { infinity, sort, value } = parseMask(ensureBytes('signatureHex', hex, 48)); const P = Fp.ORDER; - const compressedValue = bytesToNumberBE(bytes); - const bflag = bitGet(compressedValue, I_BIT_POS); + const compressedValue = bytesToNumberBE(value); // Zero - if (bflag === _1n) return bls12_381.G1.ProjectivePoint.ZERO; + if (infinity) return bls12_381.G1.ProjectivePoint.ZERO; const x = Fp.create(compressedValue & Fp.MASK); const right = Fp.add(Fp.pow(x, _3n), Fp.create(bls12_381.params.G1b)); // y² = x³ + b let y = Fp.sqrt(right); if (!y) throw new Error('Invalid compressed G1 point'); - const aflag = bitGet(compressedValue, C_BIT_POS); + const aflag = BigInt(sort); if ((y * _2n) / P !== aflag) y = Fp.neg(y); const point = bls12_381.G1.ProjectivePoint.fromAffine({ x, y }); point.assertValidity(); @@ -1273,45 +1292,45 @@ export const bls12_381: CurveFn = bls({ return Q; // [x²-x-1]P + [x-1]Ψ(P) + Ψ²(2P) }, fromBytes: (bytes: Uint8Array): AffinePoint => { - bytes = bytes.slice(); - const m_byte = bytes[0] & 0xe0; - if (m_byte === 0x20 || m_byte === 0x60 || m_byte === 0xe0) { - throw new Error('Invalid encoding flag: ' + m_byte); + const { compressed, infinity, sort, value } = parseMask(bytes); + if ( + (!compressed && !infinity && sort) || // 00100000 + (!compressed && infinity && sort) || // 01100000 + (sort && infinity && compressed) // 11100000 + ) { + throw new Error('Invalid encoding flag: ' + (bytes[0] & 0b1110_0000)); } - const bitC = m_byte & 0x80; // compression bit - const bitI = m_byte & 0x40; // point at infinity bit - const bitS = m_byte & 0x20; // sign bit const L = Fp.BYTES; const slc = (b: Uint8Array, from: number, to?: number) => bytesToNumberBE(b.slice(from, to)); - if (bytes.length === 96 && bitC) { + if (value.length === 96 && compressed) { const b = bls12_381.params.G2b; const P = Fp.ORDER; - - bytes[0] = bytes[0] & 0x1f; // clear flags - if (bitI) { + if (infinity) { // check that all bytes are 0 - if (bytes.reduce((p, c) => (p !== 0 ? c + 1 : c), 0) > 0) { + if (value.reduce((p, c) => (p !== 0 ? c + 1 : c), 0) > 0) { throw new Error('Invalid compressed G2 point'); } return { x: Fp2.ZERO, y: Fp2.ZERO }; } - const x_1 = slc(bytes, 0, L); - const x_0 = slc(bytes, L, 2 * L); + const x_1 = slc(value, 0, L); + const x_0 = slc(value, L, 2 * L); const x = Fp2.create({ c0: Fp.create(x_0), c1: Fp.create(x_1) }); const right = Fp2.add(Fp2.pow(x, _3n), b); // y² = x³ + 4 * (u+1) = x³ + b let y = Fp2.sqrt(right); const Y_bit = y.c1 === _0n ? (y.c0 * _2n) / P : (y.c1 * _2n) / P ? _1n : _0n; - y = bitS > 0 && Y_bit > 0 ? y : Fp2.neg(y); + y = sort && Y_bit > 0 ? y : Fp2.neg(y); return { x, y }; - } else if (bytes.length === 192 && !bitC) { - // Check if the infinity flag is set - if ((bytes[0] & (1 << 6)) !== 0) { + } else if (value.length === 192 && !compressed) { + if (infinity) { + if (value.reduce((p, c) => (p !== 0 ? c + 1 : c), 0) > 0) { + throw new Error('Invalid uncompressed G2 point'); + } return { x: Fp2.ZERO, y: Fp2.ZERO }; } - const x1 = slc(bytes, 0, L); - const x0 = slc(bytes, L, 2 * L); - const y1 = slc(bytes, 2 * L, 3 * L); - const y0 = slc(bytes, 3 * L, 4 * L); + const x1 = slc(value, 0, L); + const x0 = slc(value, L, 2 * L); + const y1 = slc(value, 2 * L, 3 * L); + const y0 = slc(value, 3 * L, 4 * L); return { x: Fp2.fromBigTuple([x0, x1]), y: Fp2.fromBigTuple([y0, y1]) }; } else { throw new Error('Invalid point G2, expected 96/192 bytes'); @@ -1324,10 +1343,10 @@ export const bls12_381: CurveFn = bls({ if (isCompressed) { if (isZero) return concatB(COMPRESSED_ZERO, numberToBytesBE(_0n, len)); const flag = Boolean(y.c1 === _0n ? (y.c0 * _2n) / P : (y.c1 * _2n) / P); - // set compressed & sign bits (looks like different offsets than for G1/Fp?) - let x_1 = bitSet(x.c1, C_BIT_POS, flag); - x_1 = bitSet(x_1, S_BIT_POS, true); - return concatB(numberToBytesBE(x_1, len), numberToBytesBE(x.c0, len)); + return concatB( + setMask(numberToBytesBE(x.c1, len), { compressed: true, sort: flag }), + numberToBytesBE(x.c0, len) + ); } else { if (isZero) return concatB(new Uint8Array([0x40]), new Uint8Array(4 * len - 1)); // bytes[0] |= 1 << 6; const { re: x0, im: x1 } = Fp2.reim(x); @@ -1343,17 +1362,15 @@ export const bls12_381: CurveFn = bls({ Signature: { // TODO: Optimize, it's very slow because of sqrt. fromHex(hex: Hex): ProjPointType { - hex = ensureBytes('signatureHex', hex); + const { infinity, sort, value } = parseMask(ensureBytes('signatureHex', hex)); const P = Fp.ORDER; const half = hex.length / 2; if (half !== 48 && half !== 96) throw new Error('Invalid compressed signature length, must be 96 or 192'); - const z1 = bytesToNumberBE(hex.slice(0, half)); - const z2 = bytesToNumberBE(hex.slice(half)); + const z1 = bytesToNumberBE(value.slice(0, half)); + const z2 = bytesToNumberBE(value.slice(half)); // Indicates the infinity point - const bflag1 = bitGet(z1, I_BIT_POS); - if (bflag1 === _1n) return bls12_381.G2.ProjectivePoint.ZERO; - + if (infinity) return bls12_381.G2.ProjectivePoint.ZERO; const x1 = Fp.create(z1 & Fp.MASK); const x2 = Fp.create(z2); const x = Fp2.create({ c0: x2, c1: x1 }); @@ -1365,7 +1382,7 @@ export const bls12_381: CurveFn = bls({ // Choose the y whose leftmost bit of the imaginary part is equal to the a_flag1 // If y1 happens to be zero, then use the bit of y0 const { re: y0, im: y1 } = Fp2.reim(y); - const aflag1 = bitGet(z1, 381); + const aflag1 = BigInt(sort); const isGreater = y1 > _0n && (y1 * _2n) / P !== aflag1; const isZero = y1 === _0n && (y0 * _2n) / P !== aflag1; if (isGreater || isZero) y = Fp2.neg(y); diff --git a/test/bls12-381.test.js b/test/bls12-381.test.js index 6c7ff87..846dbcb 100644 --- a/test/bls12-381.test.js +++ b/test/bls12-381.test.js @@ -5,7 +5,9 @@ import { describe, should } from 'micro-should'; import { wNAF } from '../esm/abstract/curve.js'; import { bytesToHex, utf8ToBytes } from '../esm/abstract/utils.js'; import { hash_to_field } from '../esm/abstract/hash-to-curve.js'; -import { bls12_381 as bls } from '../esm/bls12-381.js'; +import { bls12_381 as bls, bls12_381 } from '../esm/bls12-381.js'; + +import * as utils from '../esm/abstract/utils.js'; import zkVectors from './bls12-381/zkcrypto/converted.json' assert { type: 'json' }; import pairingVectors from './bls12-381/go_pairing_vectors/pairing.json' assert { type: 'json' }; @@ -1415,6 +1417,37 @@ describe('bls12-381 deterministic', () => { } } }); + should(`zkcrypt/G1 & G2 encoding edge cases`, () => { + const Fp = bls12_381.fields.Fp; + const S_BIT_POS = Fp.BITS; // C_bit, compression bit for serialization flag + const I_BIT_POS = Fp.BITS + 1; // I_bit, point-at-infinity bit for serialization flag + const C_BIT_POS = Fp.BITS + 2; // S_bit, sort bit for serialization flag + const VECTORS = [ + { pos: C_BIT_POS, shift: 7 }, // compression_flag_set = Choice::from((bytes[0] >> 7) & 1); + { pos: I_BIT_POS, shift: 6 }, // infinity_flag_set = Choice::from((bytes[0] >> 6) & 1) + { pos: S_BIT_POS, shift: 5 }, // sort_flag_set = Choice::from((bytes[0] >> 5) & 1) + ]; + for (const { pos, shift } of VECTORS) { + const d = utils.numberToBytesBE(utils.bitSet(0n, pos, Boolean(true)), Fp.BYTES); + deepStrictEqual((d[0] >> shift) & 1, 1, `${pos}`); + } + const baseC = G1Point.BASE.toRawBytes(); + deepStrictEqual(baseC.length, 48); + const baseU = G1Point.BASE.toRawBytes(false); + deepStrictEqual(baseU.length, 96); + const compressedBit = baseU.slice(); + compressedBit[0] |= 0b1000_0000; // add compression bit + throws(() => G1Point.fromHex(compressedBit), 'compressed bit'); // uncompressed point with compressed length + const uncompressedBit = baseC.slice(); + uncompressedBit[0] &= 0b0111_1111; // remove compression bit + throws(() => G1Point.fromHex(uncompressedBit), 'uncompressed bit'); + const infinityUncompressed = baseU.slice(); + infinityUncompressed[0] |= 0b0100_0000; + throws(() => G1Point.fromHex(compressedBit), 'infinity uncompressed'); + const infinityCompressed = baseC.slice(); + infinityCompressed[0] |= 0b0100_0000; + throws(() => G1Point.fromHex(compressedBit), 'infinity compressed'); + }); }); // ESM is broken.