diff --git a/src/curves/bls381/ec.rs b/src/curves/bls381/ec.rs index dcbaee0..06ade0b 100644 --- a/src/curves/bls381/ec.rs +++ b/src/curves/bls381/ec.rs @@ -44,6 +44,9 @@ macro_rules! curve_impl { } impl Group<$engine> for $name { + fn group_zero(e: &$engine) -> $name { + $name::zero(e) + } fn group_mul_assign(&mut self, e: &$engine, scalar: &$scalarfield) { self.mul_assign(e, scalar); } diff --git a/src/curves/bls381/mod.rs b/src/curves/bls381/mod.rs index 8dd956a..d100da7 100644 --- a/src/curves/bls381/mod.rs +++ b/src/curves/bls381/mod.rs @@ -97,6 +97,9 @@ fp_impl!( ); impl Group for Fr { + fn group_zero(_: &Bls381) -> Fr { + Fr::zero() + } fn group_mul_assign(&mut self, e: &Bls381, scalar: &Fr) { self.mul_assign(e, scalar); } diff --git a/src/curves/mod.rs b/src/curves/mod.rs index c1414e3..60d2a94 100644 --- a/src/curves/mod.rs +++ b/src/curves/mod.rs @@ -9,7 +9,7 @@ use super::{Cow, Convert}; pub mod bls381; -pub trait Engine: Sized + Clone +pub trait Engine: Sized + Clone + Send + Sync { type Fq: PrimeField; type Fr: SnarkField; @@ -46,8 +46,9 @@ pub trait Engine: Sized + Clone fn batch_baseexp, S: AsRef<[Self::Fr]>>(&self, table: &WindowTable>, scalars: S) -> Vec; } -pub trait Group: Copy +pub trait Group: Copy + Send + Sync + Sized { + fn group_zero(&E) -> Self; fn group_mul_assign(&mut self, &E, scalar: &E::Fr); fn group_add_assign(&mut self, &E, other: &Self); fn group_sub_assign(&mut self, &E, other: &Self); diff --git a/src/groth16/domain.rs b/src/groth16/domain.rs index 490e9c5..6db3acc 100644 --- a/src/groth16/domain.rs +++ b/src/groth16/domain.rs @@ -1,4 +1,6 @@ use curves::{Engine, Field, SnarkField, PrimeField, Group}; +use crossbeam; +use num_cpus; pub struct EvaluationDomain { pub m: u64, @@ -49,19 +51,36 @@ impl EvaluationDomain { pub fn ifft>(&self, e: &E, v: &mut [T]) { assert!(v.len() == self.m as usize); - self._fft(e, v, &self.omegainv); - for v in v { - v.group_mul_assign(e, &self.minv); - } + parallel_fft(e, v, &self.omegainv, self.exp); + + let chunk = (v.len() / num_cpus::get()) + 1; + + crossbeam::scope(|scope| { + for v in v.chunks_mut(chunk) { + scope.spawn(move || { + for v in v { + v.group_mul_assign(e, &self.minv); + } + }); + } + }); } fn mul_coset(&self, e: &E, v: &mut [E::Fr], g: &E::Fr) { - let mut u = *g; - for v in v.iter_mut().skip(1) { - v.mul_assign(e, &u); - u.mul_assign(e, g); - } + let chunk = (v.len() / num_cpus::get()) + 1; + + crossbeam::scope(|scope| { + for (i, v) in v.chunks_mut(chunk).enumerate() { + scope.spawn(move || { + let mut u = g.pow(e, &[(i * chunk) as u64]); + for v in v.iter_mut() { + v.mul_assign(e, &u); + u.mul_assign(e, g); + } + }); + } + }); } pub fn coset_fft(&self, e: &E, v: &mut [E::Fr]) @@ -79,59 +98,119 @@ impl EvaluationDomain { pub fn divide_by_z_on_coset(&self, e: &E, v: &mut [E::Fr]) { let i = self.z(e, &E::Fr::multiplicative_generator(e)).inverse(e).unwrap(); - for v in v { - v.mul_assign(e, &i); - } + + let chunk = (v.len() / num_cpus::get()) + 1; + + crossbeam::scope(|scope| { + for v in v.chunks_mut(chunk) { + scope.spawn(move || { + for v in v { + v.mul_assign(e, &i); + } + }); + } + }); } pub fn fft>(&self, e: &E, a: &mut [T]) { - self._fft(e, a, &self.omega); + parallel_fft(e, a, &self.omega, self.exp); + } +} + +fn parallel_fft>(e: &E, a: &mut [T], omega: &E::Fr, log_n: u64) +{ + let log_cpus = get_log_cpus(); + let num_cpus = 1 << log_cpus; + + if log_n < log_cpus { + serial_fft(e, a, omega, log_n) + } else { + // Shuffle + let log_new_n = log_n - log_cpus; + let mut tmp = vec![vec![T::group_zero(e); 1 << log_new_n]; num_cpus]; + let omega_num_cpus = omega.pow(e, &[num_cpus as u64]); + + crossbeam::scope(|scope| { + let a = &*a; + + for (j, tmp) in tmp.iter_mut().enumerate() { + scope.spawn(move || { + let omega_j = omega.pow(e, &[j as u64]); + let omega_step = omega.pow(e, &[(j as u64) << log_new_n]); + + let mut elt = E::Fr::one(e); + for i in 0..(1 << log_new_n) { + for s in 0..num_cpus { + let idx = (i + (s << log_new_n)) % (1 << log_n); + let mut t = a[idx]; + t.group_mul_assign(e, &elt); + tmp[i].group_add_assign(e, &t); + elt.mul_assign(e, &omega_step); + } + elt.mul_assign(e, &omega_j); + } + + serial_fft(e, tmp, &omega_num_cpus, log_new_n); + }); + } + }); + + // TODO: parallelize + // Unshuffle + for i in 0..num_cpus { + for j in 0..(1 << log_new_n) { + a[(j << log_cpus) + i] = tmp[i][j]; + } + } + } +} + +fn serial_fft>(e: &E, a: &mut [T], omega: &E::Fr, log_n: u64) +{ + fn bitreverse(mut n: usize, l: u64) -> usize { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r } - fn _fft>(&self, e: &E, a: &mut [T], omega: &E::Fr) - { - fn bitreverse(mut n: usize, l: u64) -> usize { - let mut r = 0; - for _ in 0..l { - r = (r << 1) | (n & 1); - n >>= 1; - } - r + let n = a.len(); + assert_eq!(n, 1 << log_n); + + for k in 0..n { + let rk = bitreverse(k, log_n); + if k < rk { + let tmp1 = a[rk]; + let tmp2 = a[k]; + a[rk] = tmp2; + a[k] = tmp1; } + } - for k in 0..(self.m as usize) { - let rk = bitreverse(k, self.exp); - if k < rk { - let tmp1 = a[rk]; - let tmp2 = a[k]; - a[rk] = tmp2; - a[k] = tmp1; - } - } + let mut m = 1; + for _ in 0..log_n { + let w_m = omega.pow(e, &[(n / (2*m)) as u64]); - let mut m = 1; - for _ in 0..self.exp { - let w_m = omega.pow(e, &[(self.m / (2*m)) as u64]); - - let mut k = 0; - while k < self.m { - let mut w = E::Fr::one(e); - for j in 0..m { - let mut t = a[(k+j+m) as usize]; - t.group_mul_assign(e, &w); - let mut tmp = a[(k+j) as usize]; - tmp.group_sub_assign(e, &t); - a[(k+j+m) as usize] = tmp; - a[(k+j) as usize].group_add_assign(e, &t); - w.mul_assign(e, &w_m); - } - - k += 2*m; + let mut k = 0; + while k < n { + let mut w = E::Fr::one(e); + for j in 0..m { + let mut t = a[(k+j+m) as usize]; + t.group_mul_assign(e, &w); + let mut tmp = a[(k+j) as usize]; + tmp.group_sub_assign(e, &t); + a[(k+j+m) as usize] = tmp; + a[(k+j) as usize].group_add_assign(e, &t); + w.mul_assign(e, &w_m); } - m *= 2; + k += 2*m; } + + m *= 2; } } @@ -190,3 +269,32 @@ fn polynomial_arith() { test_mul(e, rng); } + +fn get_log_cpus() -> u64 { + let num = num_cpus::get(); + log2_floor(num) +} + +fn log2_floor(num: usize) -> u64 { + assert!(num > 0); + + let mut pow = 0; + + while (1 << (pow+1)) <= num { + pow += 1; + } + + pow +} + +#[test] +fn test_log2_floor() { + assert_eq!(log2_floor(1), 0); + assert_eq!(log2_floor(2), 1); + assert_eq!(log2_floor(3), 1); + assert_eq!(log2_floor(4), 2); + assert_eq!(log2_floor(5), 2); + assert_eq!(log2_floor(6), 2); + assert_eq!(log2_floor(7), 2); + assert_eq!(log2_floor(8), 3); +}