Montgomery multiplication optimized

This commit is contained in:
Jordi Baylina 2019-06-25 15:51:30 +02:00
parent 4e6f320667
commit 44f11945f1
No known key found for this signature in database
GPG Key ID: 7480C80C1BE43112
7 changed files with 274 additions and 9 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -32,7 +32,7 @@
"mocha": "^6.1.4", "mocha": "^6.1.4",
"package": "^1.0.1", "package": "^1.0.1",
"snarkjs": "^0.1.12", "snarkjs": "^0.1.12",
"wasmbuilder": "0.0.2" "wasmbuilder": "0.0.3"
}, },
"dependencies": { "dependencies": {
"big-integer": "^1.6.42" "big-integer": "^1.6.42"

@ -187,17 +187,224 @@ module.exports = function buildF1m(module, _q, _prefix, _intPrefix) {
} }
function buildMul() {
const pAux2 = module.alloc(n8*2); function buildMul() {
const f = module.addFunction(prefix+"_mul"); const f = module.addFunction(prefix+"_mul");
f.addParam("x", "i32"); f.addParam("x", "i32");
f.addParam("y", "i32"); f.addParam("y", "i32");
f.addParam("r", "i32"); f.addParam("r", "i32");
f.addLocal("c0", "i64");
f.addLocal("c1", "i64");
f.addLocal("np32", "i64");
for (let i=0;i<n32; i++) {
f.addLocal("x"+i, "i64");
f.addLocal("y"+i, "i64");
f.addLocal("m"+i, "i64");
f.addLocal("q"+i, "i64");
}
const c = f.getCodeBuilder(); const c = f.getCodeBuilder();
f.addCode(c.call(intPrefix + "_mul", c.getLocal("x"), c.getLocal("y"), c.i32_const(pAux2) ));
const np32 = bigInt("100000000",16).minus( q.modInv(bigInt("100000000",16))).toJSNumber();
f.addCode(c.setLocal("np32", c.i64_const(np32)));
const loadX = [];
const loadY = [];
const loadQ = [];
function mulij(i, j) {
let X,Y;
if (!loadX[i]) {
X = c.teeLocal("x"+i, c.i64_load32_u( c.getLocal("x"), i*4));
loadX[i] = true;
} else {
X = c.getLocal("x"+i);
}
if (!loadY[j]) {
Y = c.teeLocal("y"+j, c.i64_load32_u( c.getLocal("y"), j*4));
loadY[j] = true;
} else {
Y = c.getLocal("y"+j);
}
return c.i64_mul( X, Y );
}
function mulqm(i, j) {
let Q,M;
if (!loadQ[i]) {
Q = c.teeLocal("q"+i, c.i64_load32_u(c.i32_const(0), pq+i*4 ));
loadQ[i] = true;
} else {
Q = c.getLocal("q"+i);
}
M = c.getLocal("m"+j);
return c.i64_mul( Q, M );
}
let c0 = "c0";
let c1 = "c1";
for (let k=0; k<n32*2-1; k++) {
for (let i=Math.max(0, k-n32+1); (i<=k)&&(i<n32); i++) {
const j= k-i;
f.addCode(
c.setLocal(c0,
c.i64_add(
c.i64_and(
c.getLocal(c0),
c.i64_const(0xFFFFFFFF)
),
mulij(i,j)
)
)
);
f.addCode(
c.setLocal(c1,
c.i64_add(
c.getLocal(c1),
c.i64_shr_u(
c.getLocal(c0),
c.i64_const(32)
)
)
)
);
}
for (let i=Math.max(1, k-n32+1); (i<=k)&&(i<n32); i++) {
const j= k-i;
f.addCode(
c.setLocal(c0,
c.i64_add(
c.i64_and(
c.getLocal(c0),
c.i64_const(0xFFFFFFFF)
),
mulqm(i,j)
)
)
);
f.addCode(
c.setLocal(c1,
c.i64_add(
c.getLocal(c1),
c.i64_shr_u(
c.getLocal(c0),
c.i64_const(32)
)
)
)
);
}
if (k<n32) {
f.addCode(
c.setLocal(
"m"+k,
c.i64_and(
c.i64_mul(
c.i64_and(
c.getLocal(c0),
c.i64_const(0xFFFFFFFF)
),
c.getLocal("np32")
),
c.i64_const("0xFFFFFFFF")
)
)
);
f.addCode(
c.setLocal(c0,
c.i64_add(
c.i64_and(
c.getLocal(c0),
c.i64_const(0xFFFFFFFF)
),
mulqm(0,k)
)
)
);
f.addCode(
c.setLocal(c1,
c.i64_add(
c.getLocal(c1),
c.i64_shr_u(
c.getLocal(c0),
c.i64_const(32)
)
)
)
);
}
if (k>=n32) {
f.addCode(
c.i64_store32(
c.getLocal("r"),
(k-n32)*4,
c.getLocal(c0)
)
);
}
[c0, c1] = [c1, c0];
f.addCode(
c.setLocal(c1,
c.i64_shr_u(
c.getLocal(c0),
c.i64_const(32)
)
)
);
}
f.addCode(
c.i64_store32(
c.getLocal("r"),
n32*4-4,
c.getLocal(c0)
)
);
f.addCode(
c.if(
c.i32_wrap_i64(c.getLocal(c1)),
c.drop(c.call(intPrefix+"_sub", c.getLocal("r"), c.i32_const(pq), c.getLocal("r"))),
c.if(
c.call(intPrefix+"_gte", c.getLocal("r"), c.i32_const(pq) ),
c.drop(c.call(intPrefix+"_sub", c.getLocal("r"), c.i32_const(pq), c.getLocal("r"))),
)
)
);
}
function buildMulOld() {
const pAux2 = module.alloc(n8*2);
const f = module.addFunction(prefix+"_mulOld");
f.addParam("x", "i32");
f.addParam("y", "i32");
f.addParam("r", "i32");
const c = f.getCodeBuilder();
f.addCode(c.call(intPrefix + "_mulOld", c.getLocal("x"), c.getLocal("y"), c.i32_const(pAux2) ));
f.addCode(c.call(prefix + "_mReduct", c.i32_const(pAux2), c.getLocal("r"))); f.addCode(c.call(prefix + "_mReduct", c.i32_const(pAux2), c.getLocal("r")));
} }
@ -242,6 +449,7 @@ module.exports = function buildF1m(module, _q, _prefix, _intPrefix) {
buildNeg(); buildNeg();
buildMReduct(); buildMReduct();
buildMul(); buildMul();
buildMulOld();
buildToMontgomery(); buildToMontgomery();
buildFromMontgomery(); buildFromMontgomery();
buildInverse(); buildInverse();
@ -250,6 +458,7 @@ module.exports = function buildF1m(module, _q, _prefix, _intPrefix) {
module.exportFunction(prefix + "_neg"); module.exportFunction(prefix + "_neg");
module.exportFunction(prefix + "_mReduct"); module.exportFunction(prefix + "_mReduct");
module.exportFunction(prefix + "_mul"); module.exportFunction(prefix + "_mul");
module.exportFunction(prefix + "_mulOld");
module.exportFunction(prefix + "_fromMontgomery"); module.exportFunction(prefix + "_fromMontgomery");
module.exportFunction(prefix + "_toMontgomery"); module.exportFunction(prefix + "_toMontgomery");
module.exportFunction(prefix + "_inverse"); module.exportFunction(prefix + "_inverse");

@ -2,6 +2,9 @@ const assert = require("assert");
const bigInt = require("big-integer"); const bigInt = require("big-integer");
const buildF1 = require("../index.js").buildF1; const buildF1 = require("../index.js").buildF1;
const buildF1m = require("../src/build_f1m");
const buildProtoboard = require("../src/protoboard.js");
const buildTest = require("../src/build_test.js");
describe("Basic tests for Zq", () => { describe("Basic tests for Zq", () => {
it("It should do a basic addition", async () => { it("It should do a basic addition", async () => {
@ -482,4 +485,57 @@ describe("Basic tests for Zq", () => {
assert(a.equals(v[i])); assert(a.equals(v[i]));
} }
}); });
it("It should profile int", async () => {
let start,end,time;
const q = bigInt("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const A=q.minus(1);
const B=q.minus(1).shiftRight(1);
const pbF1m = await buildProtoboard((module) => {
buildF1m(module, q);
buildTest(module, "f1m_mul");
buildTest(module, "f1m_mulOld");
}, 32);
const pA = pbF1m.alloc();
const pB = pbF1m.alloc();
const pC = pbF1m.alloc();
pbF1m.set(pA, A);
pbF1m.f1m_toMontgomery(pA, pA);
pbF1m.set(pB, B);
pbF1m.f1m_toMontgomery(pB, pB);
start = new Date().getTime();
pbF1m.test_f1m_mul(pA, pB, pC, 50000000);
end = new Date().getTime();
time = end - start;
pbF1m.f1m_fromMontgomery(pC, pC);
const c1 = pbF1m.get(pC, 1, 32);
assert(c1.equals(A.times(B).mod(q)));
console.log("Mul Time (ms): " + time);
start = new Date().getTime();
pbF1m.test_f1m_mulOld(pA, pB, pC, 50000000);
end = new Date().getTime();
time = end - start;
pbF1m.f1m_fromMontgomery(pC, pC);
const c2 = pbF1m .get(pC, 1, 32);
assert(c2.equals(A.times(B).mod(q)));
console.log("Mul Old Time (ms): " + time);
}).timeout(10000000);
}); });

@ -37,8 +37,8 @@ function buildWasm() {
buildCurve(moduleBuilder, "g2", "f2m"); buildCurve(moduleBuilder, "g2", "f2m");
buildMultiexp(moduleBuilder, "g2", "g2", "f2m", "fr"); buildMultiexp(moduleBuilder, "g2", "g2", "f2m", "fr");
buildTest(moduleBuilder, "int_mul"); buildTest(moduleBuilder, "f1m_mul");
buildTest(moduleBuilder, "int_mulOld"); buildTest(moduleBuilder, "f1m_mulOld");
const code = moduleBuilder.build(); const code = moduleBuilder.build();