Added primitive circuit abstraction, tests for sha3.

This commit is contained in:
Sean Bowe 2016-01-28 20:37:54 -07:00
parent 67003a471b
commit e24fcfdc5c
7 changed files with 309 additions and 53 deletions

@ -5,6 +5,7 @@ use std::cell::RefCell;
use super::variable::*; use super::variable::*;
use self::Bit::*; use self::Bit::*;
use self::Op::*; use self::Op::*;
use super::circuit::*;
macro_rules! mirror { macro_rules! mirror {
($a:pat, $b:pat) => (($a, $b) | ($b, $a)) ($a:pat, $b:pat) => (($a, $b) | ($b, $a))
@ -206,6 +207,59 @@ pub enum Bit {
Bin(BinaryOp, bool) Bin(BinaryOp, bool)
} }
struct BitEquality {
a: Bit,
b: Var
}
impl Constrainable for BitEquality {
type Result = Var;
fn synthesize(&self, enforce: &Bit) -> Var {
// TODO: currently only support unconditional enforcement
match enforce {
&Bit::Constant(true) => {},
_ => unimplemented!()
}
match self.a {
Bin(ref binop, inverted) => {
// TODO: figure this out later
assert!(binop.resolved.borrow().is_none());
let mut op = binop.op;
if inverted {
op = op.not();
}
gadget(&[&binop.a, &binop.b, &self.b], 0, move |vals| {
let a = vals.get_input(0);
let b = vals.get_input(1);
unsafe { vals.set_input(2, op.val(a, b)) };
}, |i, o, cs| {
cs.push(binaryop_constraint(i[0], i[1], i[2], op));
vec![i[2]]
}).remove(0)
},
_ => unimplemented!()
}
}
}
impl Equals<Var> for Bit {
type Result = BitEquality;
fn must_equal(&self, other: &Var) -> BitEquality {
BitEquality {
a: self.clone(),
b: other.clone()
}
}
}
fn binaryop_constraint(a: &Var, b: &Var, c: &Var, op: Op) -> Constraint { fn binaryop_constraint(a: &Var, b: &Var, c: &Var, op: Op) -> Constraint {
match op { match op {
// a * b = c // a * b = c
@ -286,6 +340,24 @@ fn resolve(a: &Var, b: &Var, op: Op) -> Var {
}).remove(0) }).remove(0)
} }
impl ConstraintWalker for Bit {
fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap)
{
match *self {
Constant(_) => {},
Not(ref v) => {
v.walk(counter, constraints, witness_map);
},
Is(ref v) => {
v.walk(counter, constraints, witness_map);
},
Bin(ref bin, _) => {
bin.walk(counter, constraints, witness_map);
}
}
}
}
impl Bit { impl Bit {
pub fn val(&self, map: &[FieldT]) -> bool { pub fn val(&self, map: &[FieldT]) -> bool {
match *self { match *self {
@ -304,21 +376,6 @@ impl Bit {
} }
} }
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
match *self {
Constant(_) => {},
Not(ref v) => {
v.walk(counter, constraints, witness_map);
},
Is(ref v) => {
v.walk(counter, constraints, witness_map);
},
Bin(ref bin, _) => {
bin.walk(counter, constraints, witness_map);
}
}
}
pub fn new(v: &Var) -> Bit { pub fn new(v: &Var) -> Bit {
Is(gadget(&[v], 0, |_| {}, |i, o, cs| { Is(gadget(&[v], 0, |_| {}, |i, o, cs| {
// boolean constraint: // boolean constraint:

148
src/circuit.rs Normal file

@ -0,0 +1,148 @@
use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem};
use super::variable::{Var,Constraint,WitnessMap,witness_field_elements};
use super::bit::Bit;
pub trait ConstraintWalker: 'static {
fn walk(&self,
counter: &mut usize,
constraints: &mut Vec<Constraint>,
witness_map: &mut WitnessMap);
}
impl<C: ConstraintWalker> ConstraintWalker for Vec<C> {
fn walk(&self,
counter: &mut usize,
constraints: &mut Vec<Constraint>,
witness_map: &mut WitnessMap)
{
for i in self {
i.walk(counter, constraints, witness_map);
}
}
}
pub trait Constrainable {
type Result: ConstraintWalker;
fn synthesize(&self, enforce: &Bit) -> Self::Result;
}
impl<C: Constrainable> Constrainable for Vec<C> {
type Result = Vec<C::Result>;
fn synthesize(&self, enforce: &Bit) -> Vec<C::Result> {
self.iter().map(|a| a.synthesize(enforce)).collect()
}
}
pub trait Equals<Rhs: ?Sized> {
type Result: Constrainable;
fn must_equal(&self, other: &Rhs) -> Self::Result;
}
impl<Lhs, Rhs> Equals<[Rhs]> for [Lhs] where Lhs: Equals<Rhs> {
type Result = Vec<Lhs::Result>;
fn must_equal(&self, other: &[Rhs]) -> Vec<Lhs::Result> {
assert_eq!(self.len(), other.len());
self.iter().zip(other.iter()).map(|(a, b)| a.must_equal(b)).collect()
}
}
pub struct Circuit {
public_inputs: usize,
private_inputs: usize,
aux_inputs: usize,
keypair: Keypair,
witness_map: WitnessMap
}
impl Circuit {
pub fn verify(&self, proof: &Proof, public: &[FieldT]) -> bool
{
proof.verify(&self.keypair, public)
}
pub fn prove(&self, public: &[FieldT], private: &[FieldT]) -> Result<Proof, ()>
{
assert_eq!(public.len(), self.public_inputs);
assert_eq!(private.len(), self.private_inputs);
let mut vars = Vec::new();
vars.push(FieldT::one());
vars.extend_from_slice(public);
vars.extend_from_slice(private);
for i in 0..self.aux_inputs {
vars.push(FieldT::zero());
}
witness_field_elements(&mut vars, &self.witness_map);
let primary = &vars[1..public.len()+1];
let aux = &vars[1+public.len()..];
if !self.keypair.is_satisfied(primary, aux) {
return Err(())
}
Ok(Proof::new(&self.keypair, primary, aux))
}
}
pub struct CircuitBuilder {
public_inputs: usize,
private_inputs: usize,
constraints: Vec<Box<ConstraintWalker>>
}
impl CircuitBuilder {
pub fn new(num_public: usize, num_private: usize) -> (Vec<Var>, Vec<Var>, CircuitBuilder) {
(
(0..num_public).map(|x| Var::new(1+x)).collect(),
(0..num_private).map(|x| Var::new(1+num_public+x)).collect(),
CircuitBuilder {
public_inputs: num_public,
private_inputs: num_private,
constraints: Vec::new()
},
)
}
pub fn constrain<C: Constrainable>(&mut self, constraint: C) {
self.constraints.push(Box::new(constraint.synthesize(&Bit::constant(true))));
}
pub fn finalize(self) -> Circuit {
let mut counter = 1 + self.public_inputs + self.private_inputs;
let mut constraints = vec![];
let mut witness_map = WitnessMap::new();
for c in self.constraints.into_iter() {
c.walk(&mut counter, &mut constraints, &mut witness_map);
}
let mut cs = ConstraintSystem::new(self.public_inputs, (counter - 1) - self.public_inputs);
for Constraint(a, b, c) in constraints {
let a: Vec<_> = a.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect();
let b: Vec<_> = b.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect();
let c: Vec<_> = c.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect();
cs.add_constraint(&a, &b, &c);
}
let kp = Keypair::new(&cs);
Circuit {
public_inputs: self.public_inputs,
private_inputs: self.private_inputs,
aux_inputs: ((counter - 1) - self.public_inputs) - self.private_inputs,
keypair: kp,
witness_map: witness_map
}
}
}

@ -191,13 +191,13 @@ fn keccakf(st: &mut [Byte], rounds: usize)
} }
} }
pub fn sha3_256(message: &[Byte]) -> Vec<Byte> { pub fn sha3_256(message: &[Byte]) -> Vec<Bit> {
// As defined by FIPS202 // As defined by FIPS202
keccak(1088, 512, message, 0x06, 32, 24) keccak(1088, 512, message, 0x06, 32, 24)
} }
fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8, mut mdlen: usize, num_rounds: usize) fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8, mut mdlen: usize, num_rounds: usize)
-> Vec<Byte> -> Vec<Bit>
{ {
use std::cmp::min; use std::cmp::min;
@ -249,11 +249,15 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8
} }
} }
output output.into_iter().flat_map(|byte| byte.bits.into_iter()).collect()
} }
#[test] #[test]
fn test_sha3_256() { fn test_sha3_256() {
use super::circuit::{CircuitBuilder,Equals};
use super::variable::Var;
use tinysnark::{self,FieldT};
let test_vector: Vec<(Vec<u8>, [u8; 32])> = vec![ let test_vector: Vec<(Vec<u8>, [u8; 32])> = vec![
(vec![0xff], (vec![0xff],
[0x44,0x4b,0x89,0xec,0xce,0x39,0x5a,0xec,0x5d,0xc9,0x8f,0x19,0xde,0xfd,0x3a,0x23,0xbc,0xa0,0x82,0x2f,0xc7,0x22,0x26,0xf5,0x8c,0xa4,0x6a,0x17,0xee,0xec,0xa4,0x42] [0x44,0x4b,0x89,0xec,0xce,0x39,0x5a,0xec,0x5d,0xc9,0x8f,0x19,0xde,0xfd,0x3a,0x23,0xbc,0xa0,0x82,0x2f,0xc7,0x22,0x26,0xf5,0x8c,0xa4,0x6a,0x17,0xee,0xec,0xa4,0x42]
@ -289,7 +293,11 @@ fn test_sha3_256() {
for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() { for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() {
let message: Vec<Byte> = message.iter().map(|a| Byte::new(*a)).collect(); let message: Vec<Byte> = message.iter().map(|a| Byte::new(*a)).collect();
let result: Vec<u8> = sha3_256(&message).into_iter().map(|a| a.unwrap_constant()).collect(); let result: Vec<u8> = sha3_256(&message)
.chunks(8)
.map(|a| Byte::from(a))
.map(|a| a.unwrap_constant())
.collect();
if &*result != expected { if &*result != expected {
print!("Got: "); print!("Got: ");
@ -306,6 +314,44 @@ fn test_sha3_256() {
println!("--- HASH {} SUCCESS ---", i+1); println!("--- HASH {} SUCCESS ---", i+1);
} }
} }
tinysnark::init();
for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() {
fn into_bytes(a: &[Var]) -> Vec<Byte> {
let a: Vec<_> = a.into_iter().map(|a| Bit::new(a)).collect();
a.chunks(8).map(|a| Byte::from(a)).collect()
}
fn into_fieldt(a: &[u8], vars: &mut [FieldT]) {
let mut counter = 0;
for byte in a {
for bit in (0..8).map(|i| byte & (1 << i) != 0).rev() {
if bit { vars[counter] = FieldT::one() } else { vars[counter] = FieldT::zero() }
counter += 1;
}
}
}
let (public, private, mut circuit) = CircuitBuilder::new(expected.len() * 8, message.len() * 8);
let private = into_bytes(&private);
circuit.constrain(sha3_256(&private).must_equal(&public));
let circuit = circuit.finalize();
let mut input: Vec<FieldT> = (0..message.len() * 8).map(|_| FieldT::zero()).collect();
let mut output: Vec<FieldT> = (0..expected.len() * 8).map(|_| FieldT::zero()).collect();
into_fieldt(message, &mut input);
into_fieldt(expected, &mut output);
let proof = circuit.prove(&output, &input).unwrap();
assert!(circuit.verify(&proof, &output));
}
} }
#[derive(Clone)] #[derive(Clone)]

@ -5,42 +5,15 @@ extern crate rand;
use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem}; use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem};
use variable::*; use variable::*;
use circuit::*;
use keccak::*; use keccak::*;
use bit::*; use bit::*;
mod variable; mod variable;
mod keccak; mod keccak;
mod bit; mod bit;
mod circuit;
fn main() { fn main() {
tinysnark::init();
let inbytes = 64;
//for inbits in 0..1024 {
let inbits = inbytes * 8;
let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect();
let input: Vec<Byte> = input.chunks(8).map(|c| Byte::from(c)).collect();
let output = sha3_256(&input);
let mut counter = 1 + (8*input.len());
let mut constraints = vec![];
let mut witness_map = WitnessMap::new();
for o in output.iter().flat_map(|e| e.bits().into_iter()) {
o.walk(&mut counter, &mut constraints, &mut witness_map);
}
let mut vars: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect();
vars[0] = FieldT::one();
witness_field_elements(&mut vars, &witness_map);
for b in output.iter().flat_map(|e| e.bits()) {
print!("{}", if b.val(&vars) { 1 } else { 0 });
}
println!("");
println!("{}: {} constraints", inbits, constraints.len());
//}
} }

@ -2,6 +2,7 @@ use tinysnark::FieldT;
use std::cell::Cell; use std::cell::Cell;
use std::rc::Rc; use std::rc::Rc;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use super::circuit::ConstraintWalker;
pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&mut VariableView) + 'static>)>>; pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&mut VariableView) + 'static>)>>;
@ -21,6 +22,14 @@ impl<'a> VariableView<'a> {
pub fn get_input(&self, index: usize) -> FieldT { pub fn get_input(&self, index: usize) -> FieldT {
self.vars[self.inputs[index]] self.vars[self.inputs[index]]
} }
/// Sets the value of an input variable. This is unsafe
/// because theoretically this should not be necessary,
/// and could cause soundness problems, but I've temporarily
/// done this to make testing easier.
pub fn set_input(&mut self, index: usize, to: FieldT) {
self.vars[self.inputs[index]] = to;
}
} }
use std::collections::Bound::Unbounded; use std::collections::Bound::Unbounded;
@ -102,9 +111,8 @@ impl Var {
} }
} }
// make this not public or unsafe too pub fn index(&self) -> usize {
pub fn index(&self) -> Rc<Cell<usize>> { self.index.get()
self.index.clone()
} }
pub fn val(&self, map: &[FieldT]) -> FieldT { pub fn val(&self, map: &[FieldT]) -> FieldT {
@ -119,8 +127,10 @@ impl Var {
Some(ref g) => g.group Some(ref g) => g.group
} }
} }
}
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) { impl ConstraintWalker for Var {
fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
match self.gadget { match self.gadget {
None => {}, None => {},
Some(ref g) => g.walk(counter, constraints, witness_map) Some(ref g) => g.walk(counter, constraints, witness_map)

@ -88,6 +88,15 @@ impl Keypair {
aux_size: constraint_system.aux_size aux_size: constraint_system.aux_size
} }
} }
pub fn is_satisfied(&self, primary: &[FieldT], aux: &[FieldT]) -> bool {
assert_eq!(primary.len(), self.primary_size);
assert_eq!(aux.len(), self.aux_size);
unsafe {
tinysnark_keypair_satisfies_test(self.kp, primary.get_unchecked(0), aux.get_unchecked(0))
}
}
} }
impl Drop for Keypair { impl Drop for Keypair {
@ -99,6 +108,7 @@ impl Drop for Keypair {
extern "C" { extern "C" {
fn tinysnark_gen_keypair(cs: *mut R1ConstraintSystem) -> *mut R1CSKeypair; fn tinysnark_gen_keypair(cs: *mut R1ConstraintSystem) -> *mut R1CSKeypair;
fn tinysnark_drop_keypair(cs: *mut R1CSKeypair); fn tinysnark_drop_keypair(cs: *mut R1CSKeypair);
fn tinysnark_keypair_satisfies_test(kp: *mut R1CSKeypair, primary: *const FieldT, aux: *const FieldT) -> bool;
} }
#[repr(C)] #[repr(C)]

@ -78,6 +78,18 @@ extern "C" void tinysnark_drop_r1cs(void * ics) {
delete cs; delete cs;
} }
extern "C" bool tinysnark_keypair_satisfies_test(void * kp, FieldT* primary, FieldT* aux)
{
r1cs_ppzksnark_keypair<default_r1cs_ppzksnark_pp>* keypair = static_cast<r1cs_ppzksnark_keypair<default_r1cs_ppzksnark_pp>*>(kp);
r1cs_constraint_system<FieldT>* cs = &keypair->pk.constraint_system;
r1cs_primary_input<FieldT> primary_input(primary, primary+(cs->primary_input_size));
r1cs_auxiliary_input<FieldT> aux_input(aux, aux+(cs->auxiliary_input_size));
return cs->is_valid() && cs->is_satisfied(primary_input, aux_input);
}
extern "C" bool tinysnark_satisfy_test(void * ics, FieldT* primary, FieldT* aux) { extern "C" bool tinysnark_satisfy_test(void * ics, FieldT* primary, FieldT* aux) {
r1cs_constraint_system<FieldT>* cs = static_cast<r1cs_constraint_system<FieldT>*>(ics); r1cs_constraint_system<FieldT>* cs = static_cast<r1cs_constraint_system<FieldT>*>(ics);