Added primitive circuit abstraction, tests for sha3.
This commit is contained in:
parent
67003a471b
commit
e24fcfdc5c
87
src/bit.rs
87
src/bit.rs
@ -5,6 +5,7 @@ use std::cell::RefCell;
|
||||
use super::variable::*;
|
||||
use self::Bit::*;
|
||||
use self::Op::*;
|
||||
use super::circuit::*;
|
||||
|
||||
macro_rules! mirror {
|
||||
($a:pat, $b:pat) => (($a, $b) | ($b, $a))
|
||||
@ -206,6 +207,59 @@ pub enum Bit {
|
||||
Bin(BinaryOp, bool)
|
||||
}
|
||||
|
||||
struct BitEquality {
|
||||
a: Bit,
|
||||
b: Var
|
||||
}
|
||||
|
||||
impl Constrainable for BitEquality {
|
||||
type Result = Var;
|
||||
|
||||
fn synthesize(&self, enforce: &Bit) -> Var {
|
||||
// TODO: currently only support unconditional enforcement
|
||||
match enforce {
|
||||
&Bit::Constant(true) => {},
|
||||
_ => unimplemented!()
|
||||
}
|
||||
|
||||
match self.a {
|
||||
Bin(ref binop, inverted) => {
|
||||
// TODO: figure this out later
|
||||
assert!(binop.resolved.borrow().is_none());
|
||||
|
||||
let mut op = binop.op;
|
||||
|
||||
if inverted {
|
||||
op = op.not();
|
||||
}
|
||||
|
||||
gadget(&[&binop.a, &binop.b, &self.b], 0, move |vals| {
|
||||
let a = vals.get_input(0);
|
||||
let b = vals.get_input(1);
|
||||
|
||||
unsafe { vals.set_input(2, op.val(a, b)) };
|
||||
}, |i, o, cs| {
|
||||
cs.push(binaryop_constraint(i[0], i[1], i[2], op));
|
||||
|
||||
vec![i[2]]
|
||||
}).remove(0)
|
||||
},
|
||||
_ => unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Equals<Var> for Bit {
|
||||
type Result = BitEquality;
|
||||
|
||||
fn must_equal(&self, other: &Var) -> BitEquality {
|
||||
BitEquality {
|
||||
a: self.clone(),
|
||||
b: other.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn binaryop_constraint(a: &Var, b: &Var, c: &Var, op: Op) -> Constraint {
|
||||
match op {
|
||||
// a * b = c
|
||||
@ -286,6 +340,24 @@ fn resolve(a: &Var, b: &Var, op: Op) -> Var {
|
||||
}).remove(0)
|
||||
}
|
||||
|
||||
impl ConstraintWalker for Bit {
|
||||
fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap)
|
||||
{
|
||||
match *self {
|
||||
Constant(_) => {},
|
||||
Not(ref v) => {
|
||||
v.walk(counter, constraints, witness_map);
|
||||
},
|
||||
Is(ref v) => {
|
||||
v.walk(counter, constraints, witness_map);
|
||||
},
|
||||
Bin(ref bin, _) => {
|
||||
bin.walk(counter, constraints, witness_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Bit {
|
||||
pub fn val(&self, map: &[FieldT]) -> bool {
|
||||
match *self {
|
||||
@ -304,21 +376,6 @@ impl Bit {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
|
||||
match *self {
|
||||
Constant(_) => {},
|
||||
Not(ref v) => {
|
||||
v.walk(counter, constraints, witness_map);
|
||||
},
|
||||
Is(ref v) => {
|
||||
v.walk(counter, constraints, witness_map);
|
||||
},
|
||||
Bin(ref bin, _) => {
|
||||
bin.walk(counter, constraints, witness_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(v: &Var) -> Bit {
|
||||
Is(gadget(&[v], 0, |_| {}, |i, o, cs| {
|
||||
// boolean constraint:
|
||||
|
148
src/circuit.rs
Normal file
148
src/circuit.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem};
|
||||
use super::variable::{Var,Constraint,WitnessMap,witness_field_elements};
|
||||
use super::bit::Bit;
|
||||
|
||||
pub trait ConstraintWalker: 'static {
|
||||
fn walk(&self,
|
||||
counter: &mut usize,
|
||||
constraints: &mut Vec<Constraint>,
|
||||
witness_map: &mut WitnessMap);
|
||||
}
|
||||
|
||||
impl<C: ConstraintWalker> ConstraintWalker for Vec<C> {
|
||||
fn walk(&self,
|
||||
counter: &mut usize,
|
||||
constraints: &mut Vec<Constraint>,
|
||||
witness_map: &mut WitnessMap)
|
||||
{
|
||||
for i in self {
|
||||
i.walk(counter, constraints, witness_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Constrainable {
|
||||
type Result: ConstraintWalker;
|
||||
|
||||
fn synthesize(&self, enforce: &Bit) -> Self::Result;
|
||||
}
|
||||
|
||||
impl<C: Constrainable> Constrainable for Vec<C> {
|
||||
type Result = Vec<C::Result>;
|
||||
|
||||
fn synthesize(&self, enforce: &Bit) -> Vec<C::Result> {
|
||||
self.iter().map(|a| a.synthesize(enforce)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Equals<Rhs: ?Sized> {
|
||||
type Result: Constrainable;
|
||||
|
||||
fn must_equal(&self, other: &Rhs) -> Self::Result;
|
||||
}
|
||||
|
||||
impl<Lhs, Rhs> Equals<[Rhs]> for [Lhs] where Lhs: Equals<Rhs> {
|
||||
type Result = Vec<Lhs::Result>;
|
||||
|
||||
fn must_equal(&self, other: &[Rhs]) -> Vec<Lhs::Result> {
|
||||
assert_eq!(self.len(), other.len());
|
||||
|
||||
self.iter().zip(other.iter()).map(|(a, b)| a.must_equal(b)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Circuit {
|
||||
public_inputs: usize,
|
||||
private_inputs: usize,
|
||||
aux_inputs: usize,
|
||||
keypair: Keypair,
|
||||
witness_map: WitnessMap
|
||||
}
|
||||
|
||||
impl Circuit {
|
||||
pub fn verify(&self, proof: &Proof, public: &[FieldT]) -> bool
|
||||
{
|
||||
proof.verify(&self.keypair, public)
|
||||
}
|
||||
|
||||
pub fn prove(&self, public: &[FieldT], private: &[FieldT]) -> Result<Proof, ()>
|
||||
{
|
||||
assert_eq!(public.len(), self.public_inputs);
|
||||
assert_eq!(private.len(), self.private_inputs);
|
||||
|
||||
let mut vars = Vec::new();
|
||||
vars.push(FieldT::one());
|
||||
vars.extend_from_slice(public);
|
||||
vars.extend_from_slice(private);
|
||||
|
||||
for i in 0..self.aux_inputs {
|
||||
vars.push(FieldT::zero());
|
||||
}
|
||||
|
||||
witness_field_elements(&mut vars, &self.witness_map);
|
||||
|
||||
let primary = &vars[1..public.len()+1];
|
||||
let aux = &vars[1+public.len()..];
|
||||
|
||||
if !self.keypair.is_satisfied(primary, aux) {
|
||||
return Err(())
|
||||
}
|
||||
|
||||
Ok(Proof::new(&self.keypair, primary, aux))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CircuitBuilder {
|
||||
public_inputs: usize,
|
||||
private_inputs: usize,
|
||||
constraints: Vec<Box<ConstraintWalker>>
|
||||
}
|
||||
|
||||
impl CircuitBuilder {
|
||||
pub fn new(num_public: usize, num_private: usize) -> (Vec<Var>, Vec<Var>, CircuitBuilder) {
|
||||
(
|
||||
(0..num_public).map(|x| Var::new(1+x)).collect(),
|
||||
(0..num_private).map(|x| Var::new(1+num_public+x)).collect(),
|
||||
CircuitBuilder {
|
||||
public_inputs: num_public,
|
||||
private_inputs: num_private,
|
||||
constraints: Vec::new()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn constrain<C: Constrainable>(&mut self, constraint: C) {
|
||||
self.constraints.push(Box::new(constraint.synthesize(&Bit::constant(true))));
|
||||
}
|
||||
|
||||
pub fn finalize(self) -> Circuit {
|
||||
let mut counter = 1 + self.public_inputs + self.private_inputs;
|
||||
let mut constraints = vec![];
|
||||
let mut witness_map = WitnessMap::new();
|
||||
|
||||
for c in self.constraints.into_iter() {
|
||||
c.walk(&mut counter, &mut constraints, &mut witness_map);
|
||||
}
|
||||
|
||||
let mut cs = ConstraintSystem::new(self.public_inputs, (counter - 1) - self.public_inputs);
|
||||
|
||||
for Constraint(a, b, c) in constraints {
|
||||
let a: Vec<_> = a.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect();
|
||||
let b: Vec<_> = b.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect();
|
||||
let c: Vec<_> = c.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect();
|
||||
|
||||
cs.add_constraint(&a, &b, &c);
|
||||
}
|
||||
|
||||
let kp = Keypair::new(&cs);
|
||||
|
||||
Circuit {
|
||||
public_inputs: self.public_inputs,
|
||||
private_inputs: self.private_inputs,
|
||||
aux_inputs: ((counter - 1) - self.public_inputs) - self.private_inputs,
|
||||
keypair: kp,
|
||||
witness_map: witness_map
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -191,13 +191,13 @@ fn keccakf(st: &mut [Byte], rounds: usize)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sha3_256(message: &[Byte]) -> Vec<Byte> {
|
||||
pub fn sha3_256(message: &[Byte]) -> Vec<Bit> {
|
||||
// As defined by FIPS202
|
||||
keccak(1088, 512, message, 0x06, 32, 24)
|
||||
}
|
||||
|
||||
fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8, mut mdlen: usize, num_rounds: usize)
|
||||
-> Vec<Byte>
|
||||
-> Vec<Bit>
|
||||
{
|
||||
use std::cmp::min;
|
||||
|
||||
@ -249,11 +249,15 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
output.into_iter().flat_map(|byte| byte.bits.into_iter()).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sha3_256() {
|
||||
use super::circuit::{CircuitBuilder,Equals};
|
||||
use super::variable::Var;
|
||||
use tinysnark::{self,FieldT};
|
||||
|
||||
let test_vector: Vec<(Vec<u8>, [u8; 32])> = vec![
|
||||
(vec![0xff],
|
||||
[0x44,0x4b,0x89,0xec,0xce,0x39,0x5a,0xec,0x5d,0xc9,0x8f,0x19,0xde,0xfd,0x3a,0x23,0xbc,0xa0,0x82,0x2f,0xc7,0x22,0x26,0xf5,0x8c,0xa4,0x6a,0x17,0xee,0xec,0xa4,0x42]
|
||||
@ -289,7 +293,11 @@ fn test_sha3_256() {
|
||||
|
||||
for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() {
|
||||
let message: Vec<Byte> = message.iter().map(|a| Byte::new(*a)).collect();
|
||||
let result: Vec<u8> = sha3_256(&message).into_iter().map(|a| a.unwrap_constant()).collect();
|
||||
let result: Vec<u8> = sha3_256(&message)
|
||||
.chunks(8)
|
||||
.map(|a| Byte::from(a))
|
||||
.map(|a| a.unwrap_constant())
|
||||
.collect();
|
||||
|
||||
if &*result != expected {
|
||||
print!("Got: ");
|
||||
@ -306,6 +314,44 @@ fn test_sha3_256() {
|
||||
println!("--- HASH {} SUCCESS ---", i+1);
|
||||
}
|
||||
}
|
||||
|
||||
tinysnark::init();
|
||||
|
||||
for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() {
|
||||
fn into_bytes(a: &[Var]) -> Vec<Byte> {
|
||||
let a: Vec<_> = a.into_iter().map(|a| Bit::new(a)).collect();
|
||||
|
||||
a.chunks(8).map(|a| Byte::from(a)).collect()
|
||||
}
|
||||
|
||||
fn into_fieldt(a: &[u8], vars: &mut [FieldT]) {
|
||||
let mut counter = 0;
|
||||
|
||||
for byte in a {
|
||||
for bit in (0..8).map(|i| byte & (1 << i) != 0).rev() {
|
||||
if bit { vars[counter] = FieldT::one() } else { vars[counter] = FieldT::zero() }
|
||||
counter += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (public, private, mut circuit) = CircuitBuilder::new(expected.len() * 8, message.len() * 8);
|
||||
|
||||
let private = into_bytes(&private);
|
||||
|
||||
circuit.constrain(sha3_256(&private).must_equal(&public));
|
||||
|
||||
let circuit = circuit.finalize();
|
||||
|
||||
let mut input: Vec<FieldT> = (0..message.len() * 8).map(|_| FieldT::zero()).collect();
|
||||
let mut output: Vec<FieldT> = (0..expected.len() * 8).map(|_| FieldT::zero()).collect();
|
||||
|
||||
into_fieldt(message, &mut input);
|
||||
into_fieldt(expected, &mut output);
|
||||
|
||||
let proof = circuit.prove(&output, &input).unwrap();
|
||||
assert!(circuit.verify(&proof, &output));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
33
src/main.rs
33
src/main.rs
@ -5,42 +5,15 @@ extern crate rand;
|
||||
|
||||
use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem};
|
||||
use variable::*;
|
||||
use circuit::*;
|
||||
use keccak::*;
|
||||
use bit::*;
|
||||
|
||||
mod variable;
|
||||
mod keccak;
|
||||
mod bit;
|
||||
mod circuit;
|
||||
|
||||
fn main() {
|
||||
tinysnark::init();
|
||||
|
||||
let inbytes = 64;
|
||||
//for inbits in 0..1024 {
|
||||
let inbits = inbytes * 8;
|
||||
let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect();
|
||||
let input: Vec<Byte> = input.chunks(8).map(|c| Byte::from(c)).collect();
|
||||
|
||||
let output = sha3_256(&input);
|
||||
|
||||
let mut counter = 1 + (8*input.len());
|
||||
let mut constraints = vec![];
|
||||
let mut witness_map = WitnessMap::new();
|
||||
|
||||
for o in output.iter().flat_map(|e| e.bits().into_iter()) {
|
||||
o.walk(&mut counter, &mut constraints, &mut witness_map);
|
||||
}
|
||||
|
||||
let mut vars: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect();
|
||||
vars[0] = FieldT::one();
|
||||
|
||||
witness_field_elements(&mut vars, &witness_map);
|
||||
|
||||
for b in output.iter().flat_map(|e| e.bits()) {
|
||||
print!("{}", if b.val(&vars) { 1 } else { 0 });
|
||||
}
|
||||
println!("");
|
||||
|
||||
println!("{}: {} constraints", inbits, constraints.len());
|
||||
//}
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@ use tinysnark::FieldT;
|
||||
use std::cell::Cell;
|
||||
use std::rc::Rc;
|
||||
use std::collections::BTreeMap;
|
||||
use super::circuit::ConstraintWalker;
|
||||
|
||||
pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&mut VariableView) + 'static>)>>;
|
||||
|
||||
@ -21,6 +22,14 @@ impl<'a> VariableView<'a> {
|
||||
pub fn get_input(&self, index: usize) -> FieldT {
|
||||
self.vars[self.inputs[index]]
|
||||
}
|
||||
|
||||
/// Sets the value of an input variable. This is unsafe
|
||||
/// because theoretically this should not be necessary,
|
||||
/// and could cause soundness problems, but I've temporarily
|
||||
/// done this to make testing easier.
|
||||
pub fn set_input(&mut self, index: usize, to: FieldT) {
|
||||
self.vars[self.inputs[index]] = to;
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::Bound::Unbounded;
|
||||
@ -102,9 +111,8 @@ impl Var {
|
||||
}
|
||||
}
|
||||
|
||||
// make this not public or unsafe too
|
||||
pub fn index(&self) -> Rc<Cell<usize>> {
|
||||
self.index.clone()
|
||||
pub fn index(&self) -> usize {
|
||||
self.index.get()
|
||||
}
|
||||
|
||||
pub fn val(&self, map: &[FieldT]) -> FieldT {
|
||||
@ -119,8 +127,10 @@ impl Var {
|
||||
Some(ref g) => g.group
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
|
||||
impl ConstraintWalker for Var {
|
||||
fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
|
||||
match self.gadget {
|
||||
None => {},
|
||||
Some(ref g) => g.walk(counter, constraints, witness_map)
|
||||
|
@ -88,6 +88,15 @@ impl Keypair {
|
||||
aux_size: constraint_system.aux_size
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_satisfied(&self, primary: &[FieldT], aux: &[FieldT]) -> bool {
|
||||
assert_eq!(primary.len(), self.primary_size);
|
||||
assert_eq!(aux.len(), self.aux_size);
|
||||
|
||||
unsafe {
|
||||
tinysnark_keypair_satisfies_test(self.kp, primary.get_unchecked(0), aux.get_unchecked(0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Keypair {
|
||||
@ -99,6 +108,7 @@ impl Drop for Keypair {
|
||||
extern "C" {
|
||||
fn tinysnark_gen_keypair(cs: *mut R1ConstraintSystem) -> *mut R1CSKeypair;
|
||||
fn tinysnark_drop_keypair(cs: *mut R1CSKeypair);
|
||||
fn tinysnark_keypair_satisfies_test(kp: *mut R1CSKeypair, primary: *const FieldT, aux: *const FieldT) -> bool;
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
|
@ -78,6 +78,18 @@ extern "C" void tinysnark_drop_r1cs(void * ics) {
|
||||
delete cs;
|
||||
}
|
||||
|
||||
extern "C" bool tinysnark_keypair_satisfies_test(void * kp, FieldT* primary, FieldT* aux)
|
||||
{
|
||||
r1cs_ppzksnark_keypair<default_r1cs_ppzksnark_pp>* keypair = static_cast<r1cs_ppzksnark_keypair<default_r1cs_ppzksnark_pp>*>(kp);
|
||||
|
||||
r1cs_constraint_system<FieldT>* cs = &keypair->pk.constraint_system;
|
||||
|
||||
r1cs_primary_input<FieldT> primary_input(primary, primary+(cs->primary_input_size));
|
||||
r1cs_auxiliary_input<FieldT> aux_input(aux, aux+(cs->auxiliary_input_size));
|
||||
|
||||
return cs->is_valid() && cs->is_satisfied(primary_input, aux_input);
|
||||
}
|
||||
|
||||
extern "C" bool tinysnark_satisfy_test(void * ics, FieldT* primary, FieldT* aux) {
|
||||
r1cs_constraint_system<FieldT>* cs = static_cast<r1cs_constraint_system<FieldT>*>(ics);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user