From 605bfa93813b59c870a6949fb3c1df9d8baa6e8d Mon Sep 17 00:00:00 2001 From: poma Date: Sat, 25 Jan 2020 15:29:25 +0800 Subject: [PATCH] refactor stuff into lambdas and iterators --- phase2/src/bin/export_keys.rs | 94 +++++++---------------------- phase2/src/bin/generate_verifier.rs | 4 +- phase2/src/bin/prove.rs | 7 +-- phase2/src/circom_circuit.rs | 47 +++++++-------- phase2/src/lib.rs | 2 + 5 files changed, 49 insertions(+), 105 deletions(-) diff --git a/phase2/src/bin/export_keys.rs b/phase2/src/bin/export_keys.rs index bc41091..90cafa2 100644 --- a/phase2/src/bin/export_keys.rs +++ b/phase2/src/bin/export_keys.rs @@ -6,9 +6,12 @@ extern crate serde; extern crate serde_json; extern crate num_bigint; extern crate num_traits; +extern crate itertools; use std::fs; use std::fs::OpenOptions; +use std::iter::repeat; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use phase2::parameters::MPCParameters; use phase2::utils::{ @@ -74,84 +77,31 @@ fn main() { let params = MPCParameters::read(reader, disallow_points_at_infinity, true).expect("unable to read params"); let params = params.get_params(); - let mut proving_key = ProvingKeyJson { - a: vec![], - b1: vec![], - b2: vec![], - c: vec![], - vk_alfa_1: vec![], - vk_beta_1: vec![], - vk_delta_1: vec![], - vk_beta_2: vec![], - vk_delta_2: vec![], - h: vec![], + let proving_key = ProvingKeyJson { + a: params.a.iter().map(|e| p1_to_vec(e)).collect_vec(), + b1: params.b_g1.iter().map(|e| p1_to_vec(e)).collect_vec(), + b2: params.b_g2.iter().map(|e| p2_to_vec(e)).collect_vec(), + c: repeat(None).take(params.vk.ic.len()).chain(params.l.iter().map(|e| Some(p1_to_vec(e)))).collect_vec(), + vk_alfa_1: p1_to_vec(¶ms.vk.alpha_g1), + vk_beta_1: p1_to_vec(¶ms.vk.beta_g1), + vk_delta_1: p1_to_vec(¶ms.vk.delta_g1), + vk_beta_2: p2_to_vec(¶ms.vk.beta_g2), + vk_delta_2: p2_to_vec(¶ms.vk.delta_g2), + h: params.h.iter().map(|e| p1_to_vec(e)).collect_vec(), }; - let a = params.a.clone(); - for e in a.iter() { - proving_key.a.push(p1_to_vec(e)); - } - let b1 = params.b_g1.clone(); - for e in b1.iter() { - proving_key.b1.push(p1_to_vec(e)); - } - let b2 = params.b_g2.clone(); - for e in b2.iter() { - proving_key.b2.push(p2_to_vec(e)); - } - let c = params.l.clone(); - for _ in 0..params.vk.ic.len() { - proving_key.c.push(None); - } - for e in c.iter() { - proving_key.c.push(Some(p1_to_vec(e))); - } - - let vk_alfa_1 = params.vk.alpha_g1.clone(); - proving_key.vk_alfa_1 = p1_to_vec(&vk_alfa_1); - - let vk_beta_1 = params.vk.beta_g1.clone(); - proving_key.vk_beta_1 = p1_to_vec(&vk_beta_1); - - let vk_delta_1 = params.vk.delta_g1.clone(); - proving_key.vk_delta_1 = p1_to_vec(&vk_delta_1); - - let vk_beta_2 = params.vk.beta_g2.clone(); - proving_key.vk_beta_2 = p2_to_vec(&vk_beta_2); - - let vk_delta_2 = params.vk.delta_g2.clone(); - proving_key.vk_delta_2 = p2_to_vec(&vk_delta_2); - - let h = params.h.clone(); - for e in h.iter() { - proving_key.h.push(p1_to_vec(e)); - } - - let mut verification_key = VerifyingKeyJson { - ic: vec![], - vk_alfa_1: vec![], - vk_beta_2: vec![], - vk_gamma_2: vec![], - vk_delta_2: vec![], - vk_alfabeta_12: vec![], + let verification_key = VerifyingKeyJson { + ic: params.vk.ic.iter().map(|e| p1_to_vec(e)).collect_vec(), + vk_alfa_1: p1_to_vec(¶ms.vk.alpha_g1), + vk_beta_2: p2_to_vec(¶ms.vk.beta_g2), + vk_gamma_2: p2_to_vec(¶ms.vk.gamma_g2), + vk_delta_2: p2_to_vec(¶ms.vk.delta_g2), + vk_alfabeta_12: pairing_to_vec(&Bn256::pairing(params.vk.alpha_g1, params.vk.beta_g2)), }; - let ic = params.vk.ic.clone(); - for e in ic.iter() { - verification_key.ic.push(p1_to_vec(e)); - } - - verification_key.vk_alfa_1 = p1_to_vec(&vk_alfa_1); - verification_key.vk_beta_2 = p2_to_vec(&vk_beta_2); - let vk_gamma_2 = params.vk.gamma_g2.clone(); - verification_key.vk_gamma_2 = p2_to_vec(&vk_gamma_2); - verification_key.vk_delta_2 = p2_to_vec(&vk_delta_2); - verification_key.vk_alfabeta_12 = pairing_to_vec(&Bn256::pairing(vk_alfa_1, vk_beta_2)); - let pk_json = serde_json::to_string(&proving_key).unwrap(); - fs::write(pk_filename, pk_json.as_bytes()).unwrap(); - let vk_json = serde_json::to_string(&verification_key).unwrap(); + fs::write(pk_filename, pk_json.as_bytes()).unwrap(); fs::write(vk_filename, vk_json.as_bytes()).unwrap(); println!("Created {} and {}.", pk_filename, vk_filename); diff --git a/phase2/src/bin/generate_verifier.rs b/phase2/src/bin/generate_verifier.rs index d114f20..bc2bb54 100644 --- a/phase2/src/bin/generate_verifier.rs +++ b/phase2/src/bin/generate_verifier.rs @@ -12,6 +12,7 @@ use std::fs; use std::fs::OpenOptions; use num_bigint::BigUint; use num_traits::Num; +use phase2::utils::repr_to_big; use phase2::parameters::MPCParameters; use bellman_ce::pairing::{ Engine, @@ -43,9 +44,6 @@ fn main() { let params = MPCParameters::read(reader, should_filter_points_at_infinity, true).expect("unable to read params"); let vk = ¶ms.get_params().vk; - let repr_to_big = |r| { - BigUint::from_str_radix(&format!("{}", r)[2..], 16).unwrap().to_str_radix(10) - }; let p1_to_str = |p: &::G1Affine| { let x = repr_to_big(p.get_x().into_repr()); let y = repr_to_big(p.get_y().into_repr()); diff --git a/phase2/src/bin/prove.rs b/phase2/src/bin/prove.rs index 90a4755..2c32299 100644 --- a/phase2/src/bin/prove.rs +++ b/phase2/src/bin/prove.rs @@ -4,10 +4,12 @@ extern crate exitcode; extern crate serde; extern crate num_bigint; extern crate num_traits; +extern crate itertools; use std::fs; use std::fs::OpenOptions; use serde::{Deserialize, Serialize}; +use itertools::Itertools; use phase2::parameters::MPCParameters; use phase2::circom_circuit::CircomCircuit; use phase2::utils::{ @@ -78,10 +80,7 @@ fn main() { let proof_json = serde_json::to_string(&proof).unwrap(); fs::write(proof_filename, proof_json.as_bytes()).unwrap(); - let mut public_inputs = vec![]; - for x in input[1..].iter() { - public_inputs.push(repr_to_big(x.into_repr())); - } + let public_inputs = input[1..].iter().map(|x| repr_to_big(x.into_repr())).collect_vec(); let public_json = serde_json::to_string(&public_inputs).unwrap(); fs::write(public_filename, public_json.as_bytes()).unwrap(); diff --git a/phase2/src/circom_circuit.rs b/phase2/src/circom_circuit.rs index d981e21..00e4c4d 100644 --- a/phase2/src/circom_circuit.rs +++ b/phase2/src/circom_circuit.rs @@ -6,6 +6,7 @@ use std::str; use std::fs; use std::fs::OpenOptions; use std::collections::BTreeMap; +use itertools::Itertools; use std::io::{ Read, Write, @@ -83,18 +84,13 @@ impl<'a, E: Engine> CircomCircuit { let num_inputs = circuit_json.num_inputs + circuit_json.num_outputs + 1; let num_aux = circuit_json.num_variables - num_inputs; - fn convert_constraint(lc: &BTreeMap) -> Vec<(usize, EE::Fr)> { - let mut coeffs = vec![]; - for (var_index_str, coefficient_str) in lc { - coeffs.push((var_index_str.parse().unwrap(), EE::Fr::from_str(coefficient_str).unwrap())); - } - return coeffs; - } + let convert_constraint = |lc: &BTreeMap| { + lc.iter().map(|(index, coeff)| (index.parse().unwrap(), E::Fr::from_str(coeff).unwrap())).collect_vec() + }; - let mut constraints = vec![]; - for constraint in circuit_json.constraints.iter() { - constraints.push((convert_constraint::(&constraint[0]), convert_constraint::(&constraint[1]), convert_constraint::(&constraint[2]))); - } + let constraints = circuit_json.constraints.iter().map( + |c| (convert_constraint(&c[0]), convert_constraint(&c[1]), convert_constraint(&c[2])) + ).collect_vec(); return CircomCircuit { num_inputs: num_inputs, @@ -130,23 +126,22 @@ impl<'a, E: Engine> Circuit for CircomCircuit { })?; } - fn make_lc(lc_data: Vec<(usize, E::Fr)>, num_inputs: usize) -> LinearCombination { - let mut lc = LinearCombination::::zero(); - for (index, coeff) in lc_data { - let var_index = if index < num_inputs { - Index::Input(index) - } else { - Index::Aux(index - num_inputs) - }; - lc = lc + (coeff, Variable::new_unchecked(var_index)) - } - return lc; - } + let make_index = |index| + if index < self.num_inputs { + Index::Input(index) + } else { + Index::Aux(index - self.num_inputs) + }; + let make_lc = |lc_data: Vec<(usize, E::Fr)>| + lc_data.iter().fold( + LinearCombination::::zero(), + |lc: LinearCombination, (index, coeff)| lc + (*coeff, Variable::new_unchecked(make_index(*index))) + ); for (i, constraint) in self.constraints.iter().enumerate() { cs.enforce(|| format!("constraint {}", i), - |_| make_lc(constraint.0.clone(), self.num_inputs), - |_| make_lc(constraint.1.clone(), self.num_inputs), - |_| make_lc(constraint.2.clone(), self.num_inputs), + |_| make_lc(constraint.0.clone()), + |_| make_lc(constraint.1.clone()), + |_| make_lc(constraint.2.clone()), ); } Ok(()) diff --git a/phase2/src/lib.rs b/phase2/src/lib.rs index 83d23e5..a6cdd47 100644 --- a/phase2/src/lib.rs +++ b/phase2/src/lib.rs @@ -11,6 +11,8 @@ extern crate crossbeam; extern crate num_bigint; extern crate num_traits; extern crate cfg_if; +extern crate itertools; + use cfg_if::cfg_if; pub mod keypair;