Reorganize and remove (temporary) unsafe witnessing

This commit is contained in:
Sean Bowe 2016-01-03 03:45:20 -07:00
parent 7415d5ff3c
commit b82a2f60f7
3 changed files with 51 additions and 36 deletions

@ -11,11 +11,11 @@ pub enum Bit {
} }
fn resolve_not(v: &Var) -> Var { fn resolve_not(v: &Var) -> Var {
gadget(&[v], 1, |i, o| { gadget(&[v], 1, |vars| {
if *i[0] == FieldT::zero() { if vars.get_input(0) == FieldT::zero() {
*o[0] = FieldT::one(); vars.set_output(0, FieldT::one());
} else { } else {
*o[0] = FieldT::zero(); vars.set_output(0, FieldT::zero());
} }
}, |i, o, cs| { }, |i, o, cs| {
// (1 - a) * 1 = b // (1 - a) * 1 = b
@ -47,7 +47,7 @@ impl Bit {
} }
pub fn new(v: &Var) -> Bit { pub fn new(v: &Var) -> Bit {
Is(gadget(&[v], 0, |_, _| {}, |i, o, cs| { Is(gadget(&[v], 0, |_| {}, |i, o, cs| {
cs.push(Constraint); cs.push(Constraint);
vec![i[0]] vec![i[0]]
@ -74,11 +74,11 @@ impl Bit {
} }
}, },
(&Is(ref a), &Is(ref b)) => { (&Is(ref a), &Is(ref b)) => {
Is(gadget(&[a, b], 1, |i, o| { Is(gadget(&[a, b], 1, |vars| {
if *i[0] != *i[1] { if vars.get_input(0) != vars.get_input(1) {
*o[0] = FieldT::one(); vars.set_output(0, FieldT::one());
} else { } else {
*o[0] = FieldT::zero(); vars.set_output(0, FieldT::zero());
} }
}, |i, o, cs| { }, |i, o, cs| {
// (2*b) * c = b+c - a // (2*b) * c = b+c - a
@ -119,11 +119,11 @@ impl Bit {
} }
}, },
(&Is(ref a), &Is(ref b)) => { (&Is(ref a), &Is(ref b)) => {
Is(gadget(&[a, b], 1, |i, o| { Is(gadget(&[a, b], 1, |vars| {
if *i[0] == FieldT::one() && *i[1] == FieldT::one() { if vars.get_input(0) == FieldT::one() && vars.get_input(1) == FieldT::one() {
*o[0] = FieldT::one(); vars.set_output(0, FieldT::one());
} else { } else {
*o[0] = FieldT::zero(); vars.set_output(0, FieldT::zero());
} }
}, |i, o, cs| { }, |i, o, cs| {
// a * b = c // a * b = c
@ -142,11 +142,11 @@ impl Bit {
}, },
(&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => { (&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => {
//Is(i.clone()).and(&Is(resolve_not(n))) //Is(i.clone()).and(&Is(resolve_not(n)))
Is(gadget(&[n, i], 1, |i, o| { Is(gadget(&[n, i], 1, |vars| {
if *i[0] == FieldT::zero() && *i[1] == FieldT::one() { if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::one() {
*o[0] = FieldT::one(); vars.set_output(0, FieldT::one());
} else { } else {
*o[0] = FieldT::zero(); vars.set_output(0, FieldT::zero());
} }
}, |i, o, cs| { }, |i, o, cs| {
// (1-a) * b = c // (1-a) * b = c
@ -157,11 +157,11 @@ impl Bit {
}, },
(&Not(ref a), &Not(ref b)) => { (&Not(ref a), &Not(ref b)) => {
//Is(resolve_not(a)).and(&Is(resolve_not(b))) //Is(resolve_not(a)).and(&Is(resolve_not(b)))
Is(gadget(&[a, b], 1, |i, o| { Is(gadget(&[a, b], 1, |vars| {
if *i[0] == FieldT::zero() && *i[1] == FieldT::zero() { if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::zero() {
*o[0] = FieldT::one(); vars.set_output(0, FieldT::one());
} else { } else {
*o[0] = FieldT::zero(); vars.set_output(0, FieldT::zero());
} }
}, |i, o, cs| { }, |i, o, cs| {
// (1 - a) * (1 - b) = c // (1 - a) * (1 - b) = c
@ -203,7 +203,7 @@ fn test_binary_op<F: Fn(&Bit, &Bit) -> Bit>(op: F, a_in: i64, b_in: i64, c_out:
f[1] = FieldT::from(a_in); f[1] = FieldT::from(a_in);
f[2] = FieldT::from(b_in); f[2] = FieldT::from(b_in);
satisfy_field_elements(&mut f, &witness_map); witness_field_elements(&mut f, &witness_map);
assert_eq!(f[3], FieldT::from(c_out)); assert_eq!(f[3], FieldT::from(c_out));
} }

@ -15,7 +15,7 @@ mod bit;
fn main() { fn main() {
tinysnark::init(); tinysnark::init();
let inbytes = 64; let inbytes = 1;
//for inbits in 0..1024 { //for inbits in 0..1024 {
let inbits = inbytes * 8; let inbits = inbytes * 8;
let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect(); let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect();
@ -34,7 +34,7 @@ fn main() {
let mut vars: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect(); let mut vars: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect();
vars[0] = FieldT::one(); vars[0] = FieldT::one();
satisfy_field_elements(&mut vars, &witness_map); witness_field_elements(&mut vars, &witness_map);
for b in output.iter().flat_map(|e| e.bits()) { for b in output.iter().flat_map(|e| e.bits()) {
print!("{}", if b.val(&vars) { 1 } else { 0 }); print!("{}", if b.val(&vars) { 1 } else { 0 });

@ -4,23 +4,38 @@ use std::rc::Rc;
use std::fmt; use std::fmt;
use std::collections::BTreeMap; use std::collections::BTreeMap;
pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&[&FieldT], &mut [&mut FieldT]) + 'static>)>>; pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&mut VariableView) + 'static>)>>;
struct VariableView<'a> {
vars: &'a mut [FieldT],
inputs: &'a [usize],
outputs: &'a [usize]
}
impl<'a> VariableView<'a> {
/// Sets an output variable at `index` to value `to`.
pub fn set_output(&mut self, index: usize, to: FieldT) {
self.vars[self.outputs[index]] = to;
}
/// Gets the value of an input variable at `index`.
pub fn get_input(&self, index: usize) -> FieldT {
self.vars[self.inputs[index]]
}
}
use std::collections::Bound::Unbounded; use std::collections::Bound::Unbounded;
pub fn satisfy_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) { pub fn witness_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) {
for (n, group) in witness_map.range(Unbounded, Unbounded) { for (n, group) in witness_map.range(Unbounded, Unbounded) {
for &(ref i, ref o, ref f) in group.iter() { for &(ref i, ref o, ref f) in group.iter() {
let i: Vec<&FieldT> = i.iter().map(|i| &vars[*i]).collect(); let mut vars = VariableView {
let o: Vec<&FieldT> = o.iter().map(|o| &vars[*o]).collect(); vars: vars,
inputs: &*i,
let mut o: Vec<&mut FieldT> = unsafe { outputs: &*o
use std::mem::transmute;
transmute(o)
}; };
f(&i, &mut o); f(&mut vars);
} }
} }
} }
@ -31,7 +46,7 @@ pub struct Constraint;
struct Gadget { struct Gadget {
inputs: Vec<Var>, inputs: Vec<Var>,
aux: Vec<Var>, aux: Vec<Var>,
witness: Rc<Fn(&[&FieldT], &mut [&mut FieldT]) + 'static>, witness: Rc<Fn(&mut VariableView) + 'static>,
constraints: Vec<Constraint>, constraints: Vec<Constraint>,
group: usize, group: usize,
visited: Cell<bool> visited: Cell<bool>
@ -109,7 +124,7 @@ pub fn gadget<W, C>(
constrain: C constrain: C
) -> Vec<Var> ) -> Vec<Var>
where C: for<'a> Fn(&[&'a Var], &[&'a Var], &mut Vec<Constraint>) -> Vec<&'a Var>, where C: for<'a> Fn(&[&'a Var], &[&'a Var], &mut Vec<Constraint>) -> Vec<&'a Var>,
W: Fn(&[&FieldT], &mut [&mut FieldT]) + 'static W: Fn(&mut VariableView) + 'static
{ {
let this_group = inputs.iter().map(|i| i.group()).max().map(|a| a+1).unwrap_or(0); let this_group = inputs.iter().map(|i| i.group()).max().map(|a| a+1).unwrap_or(0);