Reorganize and remove (temporary) unsafe witnessing
This commit is contained in:
parent
7415d5ff3c
commit
b82a2f60f7
44
src/bit.rs
44
src/bit.rs
@ -11,11 +11,11 @@ pub enum Bit {
|
||||
}
|
||||
|
||||
fn resolve_not(v: &Var) -> Var {
|
||||
gadget(&[v], 1, |i, o| {
|
||||
if *i[0] == FieldT::zero() {
|
||||
*o[0] = FieldT::one();
|
||||
gadget(&[v], 1, |vars| {
|
||||
if vars.get_input(0) == FieldT::zero() {
|
||||
vars.set_output(0, FieldT::one());
|
||||
} else {
|
||||
*o[0] = FieldT::zero();
|
||||
vars.set_output(0, FieldT::zero());
|
||||
}
|
||||
}, |i, o, cs| {
|
||||
// (1 - a) * 1 = b
|
||||
@ -47,7 +47,7 @@ impl Bit {
|
||||
}
|
||||
|
||||
pub fn new(v: &Var) -> Bit {
|
||||
Is(gadget(&[v], 0, |_, _| {}, |i, o, cs| {
|
||||
Is(gadget(&[v], 0, |_| {}, |i, o, cs| {
|
||||
cs.push(Constraint);
|
||||
|
||||
vec![i[0]]
|
||||
@ -74,11 +74,11 @@ impl Bit {
|
||||
}
|
||||
},
|
||||
(&Is(ref a), &Is(ref b)) => {
|
||||
Is(gadget(&[a, b], 1, |i, o| {
|
||||
if *i[0] != *i[1] {
|
||||
*o[0] = FieldT::one();
|
||||
Is(gadget(&[a, b], 1, |vars| {
|
||||
if vars.get_input(0) != vars.get_input(1) {
|
||||
vars.set_output(0, FieldT::one());
|
||||
} else {
|
||||
*o[0] = FieldT::zero();
|
||||
vars.set_output(0, FieldT::zero());
|
||||
}
|
||||
}, |i, o, cs| {
|
||||
// (2*b) * c = b+c - a
|
||||
@ -119,11 +119,11 @@ impl Bit {
|
||||
}
|
||||
},
|
||||
(&Is(ref a), &Is(ref b)) => {
|
||||
Is(gadget(&[a, b], 1, |i, o| {
|
||||
if *i[0] == FieldT::one() && *i[1] == FieldT::one() {
|
||||
*o[0] = FieldT::one();
|
||||
Is(gadget(&[a, b], 1, |vars| {
|
||||
if vars.get_input(0) == FieldT::one() && vars.get_input(1) == FieldT::one() {
|
||||
vars.set_output(0, FieldT::one());
|
||||
} else {
|
||||
*o[0] = FieldT::zero();
|
||||
vars.set_output(0, FieldT::zero());
|
||||
}
|
||||
}, |i, o, cs| {
|
||||
// a * b = c
|
||||
@ -142,11 +142,11 @@ impl Bit {
|
||||
},
|
||||
(&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => {
|
||||
//Is(i.clone()).and(&Is(resolve_not(n)))
|
||||
Is(gadget(&[n, i], 1, |i, o| {
|
||||
if *i[0] == FieldT::zero() && *i[1] == FieldT::one() {
|
||||
*o[0] = FieldT::one();
|
||||
Is(gadget(&[n, i], 1, |vars| {
|
||||
if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::one() {
|
||||
vars.set_output(0, FieldT::one());
|
||||
} else {
|
||||
*o[0] = FieldT::zero();
|
||||
vars.set_output(0, FieldT::zero());
|
||||
}
|
||||
}, |i, o, cs| {
|
||||
// (1-a) * b = c
|
||||
@ -157,11 +157,11 @@ impl Bit {
|
||||
},
|
||||
(&Not(ref a), &Not(ref b)) => {
|
||||
//Is(resolve_not(a)).and(&Is(resolve_not(b)))
|
||||
Is(gadget(&[a, b], 1, |i, o| {
|
||||
if *i[0] == FieldT::zero() && *i[1] == FieldT::zero() {
|
||||
*o[0] = FieldT::one();
|
||||
Is(gadget(&[a, b], 1, |vars| {
|
||||
if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::zero() {
|
||||
vars.set_output(0, FieldT::one());
|
||||
} else {
|
||||
*o[0] = FieldT::zero();
|
||||
vars.set_output(0, FieldT::zero());
|
||||
}
|
||||
}, |i, o, cs| {
|
||||
// (1 - a) * (1 - b) = c
|
||||
@ -203,7 +203,7 @@ fn test_binary_op<F: Fn(&Bit, &Bit) -> Bit>(op: F, a_in: i64, b_in: i64, c_out:
|
||||
f[1] = FieldT::from(a_in);
|
||||
f[2] = FieldT::from(b_in);
|
||||
|
||||
satisfy_field_elements(&mut f, &witness_map);
|
||||
witness_field_elements(&mut f, &witness_map);
|
||||
|
||||
assert_eq!(f[3], FieldT::from(c_out));
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ mod bit;
|
||||
fn main() {
|
||||
tinysnark::init();
|
||||
|
||||
let inbytes = 64;
|
||||
let inbytes = 1;
|
||||
//for inbits in 0..1024 {
|
||||
let inbits = inbytes * 8;
|
||||
let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect();
|
||||
@ -34,7 +34,7 @@ fn main() {
|
||||
let mut vars: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect();
|
||||
vars[0] = FieldT::one();
|
||||
|
||||
satisfy_field_elements(&mut vars, &witness_map);
|
||||
witness_field_elements(&mut vars, &witness_map);
|
||||
|
||||
for b in output.iter().flat_map(|e| e.bits()) {
|
||||
print!("{}", if b.val(&vars) { 1 } else { 0 });
|
||||
|
@ -4,23 +4,38 @@ use std::rc::Rc;
|
||||
use std::fmt;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&[&FieldT], &mut [&mut FieldT]) + 'static>)>>;
|
||||
pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&mut VariableView) + 'static>)>>;
|
||||
|
||||
struct VariableView<'a> {
|
||||
vars: &'a mut [FieldT],
|
||||
inputs: &'a [usize],
|
||||
outputs: &'a [usize]
|
||||
}
|
||||
|
||||
impl<'a> VariableView<'a> {
|
||||
/// Sets an output variable at `index` to value `to`.
|
||||
pub fn set_output(&mut self, index: usize, to: FieldT) {
|
||||
self.vars[self.outputs[index]] = to;
|
||||
}
|
||||
|
||||
/// Gets the value of an input variable at `index`.
|
||||
pub fn get_input(&self, index: usize) -> FieldT {
|
||||
self.vars[self.inputs[index]]
|
||||
}
|
||||
}
|
||||
|
||||
use std::collections::Bound::Unbounded;
|
||||
|
||||
pub fn satisfy_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) {
|
||||
pub fn witness_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) {
|
||||
for (n, group) in witness_map.range(Unbounded, Unbounded) {
|
||||
for &(ref i, ref o, ref f) in group.iter() {
|
||||
let i: Vec<&FieldT> = i.iter().map(|i| &vars[*i]).collect();
|
||||
let o: Vec<&FieldT> = o.iter().map(|o| &vars[*o]).collect();
|
||||
|
||||
let mut o: Vec<&mut FieldT> = unsafe {
|
||||
use std::mem::transmute;
|
||||
|
||||
transmute(o)
|
||||
let mut vars = VariableView {
|
||||
vars: vars,
|
||||
inputs: &*i,
|
||||
outputs: &*o
|
||||
};
|
||||
|
||||
f(&i, &mut o);
|
||||
f(&mut vars);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -31,7 +46,7 @@ pub struct Constraint;
|
||||
struct Gadget {
|
||||
inputs: Vec<Var>,
|
||||
aux: Vec<Var>,
|
||||
witness: Rc<Fn(&[&FieldT], &mut [&mut FieldT]) + 'static>,
|
||||
witness: Rc<Fn(&mut VariableView) + 'static>,
|
||||
constraints: Vec<Constraint>,
|
||||
group: usize,
|
||||
visited: Cell<bool>
|
||||
@ -109,7 +124,7 @@ pub fn gadget<W, C>(
|
||||
constrain: C
|
||||
) -> Vec<Var>
|
||||
where C: for<'a> Fn(&[&'a Var], &[&'a Var], &mut Vec<Constraint>) -> Vec<&'a Var>,
|
||||
W: Fn(&[&FieldT], &mut [&mut FieldT]) + 'static
|
||||
W: Fn(&mut VariableView) + 'static
|
||||
{
|
||||
let this_group = inputs.iter().map(|i| i.group()).max().map(|a| a+1).unwrap_or(0);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user