diff --git a/src/keccak.rs b/src/keccak.rs index 8a30466..0574809 100644 --- a/src/keccak.rs +++ b/src/keccak.rs @@ -198,7 +198,7 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8 { use std::cmp::min; - let mut st: Vec = Some(Bit::byte(0)).into_iter().cycle().take(200).collect(); + let mut st: Vec = Some(Byte::new(0)).into_iter().cycle().take(200).collect(); let rate_in_bytes = rate / 8; let mut input_byte_len = input.len(); @@ -224,13 +224,13 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8 } } - st[block_size] = st[block_size].xor(&Bit::byte(delimited_suffix)); + st[block_size] = st[block_size].xor(&Byte::new(delimited_suffix)); if ((delimited_suffix & 0x80) != 0) && (block_size == (rate_in_bytes-1)) { keccakf(&mut st, num_rounds); } - st[rate_in_bytes-1] = st[rate_in_bytes-1].xor(&Bit::byte(0x80)); + st[rate_in_bytes-1] = st[rate_in_bytes-1].xor(&Byte::new(0x80)); keccakf(&mut st, num_rounds); @@ -260,6 +260,14 @@ struct Byte { } impl Byte { + fn new(byte: u8) -> Byte { + Byte { + bits: (0..8).map(|i| Bit::constant(byte & (1 << i) != 0)) + .rev() + .collect() + } + } + fn unwrap_constant(&self) -> u8 { let mut cur = 7; let mut acc = 0; @@ -289,14 +297,6 @@ impl Byte { } impl Bit { - fn byte(byte: u8) -> Byte { - Byte { - bits: (0..8).map(|i| Bit::constant(byte & (1 << i) != 0)) - .rev() - .collect() - } - } - fn constant(num: bool) -> Bit { Bit::Constant(num) } @@ -324,32 +324,33 @@ impl Bit { #[test] fn test_sha3_256() { - let test_vector: Vec<(Vec, [u8; 32])> = vec![ - (vec![Bit::byte(0x30), Bit::byte(0x31), Bit::byte(0x30), Bit::byte(0x31)], + let test_vector: Vec<(Vec, [u8; 32])> = vec![ + (vec![0x30, 0x31, 0x30, 0x31], [0xe5,0xbf,0x4a,0xd7,0xda,0x2b,0x4d,0x64,0x0d,0x2b,0x8d,0xd3,0xae,0x9b,0x6e,0x71,0xb3,0x6e,0x0f,0x3d,0xb7,0x6a,0x1e,0xc0,0xad,0x6b,0x87,0x2f,0x3e,0xcc,0x2e,0xbc] ), - (vec![Bit::byte(0x30)], + (vec![0x30], [0xf9,0xe2,0xea,0xaa,0x42,0xd9,0xfe,0x9e,0x55,0x8a,0x9b,0x8e,0xf1,0xbf,0x36,0x6f,0x19,0x0a,0xac,0xaa,0x83,0xba,0xd2,0x64,0x1e,0xe1,0x06,0xe9,0x04,0x10,0x96,0xe4] ), - (vec![Bit::byte(0x30),Bit::byte(0x30)], + (vec![0x30,0x30], [0x2e,0x16,0xaa,0xb4,0x83,0xcb,0x95,0x57,0x7c,0x50,0xd3,0x8c,0x8d,0x0d,0x70,0x40,0xf4,0x67,0x26,0x83,0x23,0x84,0x46,0xc9,0x90,0xba,0xbb,0xca,0x5a,0xe1,0x33,0xc8] ), - ((0..64).map(|_| Bit::byte(0x30)).collect::>(), + ((0..64).map(|_| 0x30).collect::>(), [0xc6,0xfd,0xd7,0xa7,0xf7,0x08,0x62,0xb3,0x6a,0x26,0xcc,0xd1,0x47,0x52,0x26,0x80,0x61,0xe9,0x81,0x03,0x29,0x9b,0x28,0xfe,0x77,0x63,0xbd,0x96,0x29,0x92,0x6f,0x4b] ), - ((0..128).map(|_| Bit::byte(0x30)).collect::>(), + ((0..128).map(|_| 0x30).collect::>(), [0x99,0x9d,0xb4,0xd4,0x28,0x7b,0x52,0x15,0x20,0x8d,0x11,0xe4,0x0a,0x27,0xca,0x54,0xac,0xa0,0x09,0xb2,0x5c,0x4f,0x7a,0xb9,0x1a,0xd8,0xaa,0x93,0x60,0xf0,0x63,0x71] ), - ((0..256).map(|_| Bit::byte(0x30)).collect::>(), + ((0..256).map(|_| 0x30).collect::>(), [0x11,0xea,0x74,0x37,0x7b,0x74,0xf1,0x53,0x9f,0x2e,0xd9,0x0a,0xb8,0xca,0x9e,0xb1,0xe0,0x70,0x8a,0x4b,0xfb,0xad,0x4e,0x81,0xcc,0x77,0xd9,0xa1,0x61,0x9a,0x10,0xdb] ), - ((0..512).map(|_| Bit::byte(0x30)).collect::>(), + ((0..512).map(|_| 0x30).collect::>(), [0x1c,0x80,0x1b,0x16,0x3a,0x2a,0xbe,0xd0,0xe8,0x07,0x1e,0x7f,0xf2,0x60,0x4e,0x98,0x11,0x22,0x80,0x54,0x14,0xf3,0xc8,0xfd,0x96,0x59,0x5d,0x7e,0xe1,0xd6,0x54,0xe2] ), ]; for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() { - let result: Vec = sha3_256(message).into_iter().map(|a| a.unwrap_constant()).collect(); + let message: Vec = message.iter().map(|a| Byte::new(*a)).collect(); + let result: Vec = sha3_256(&message).into_iter().map(|a| a.unwrap_constant()).collect(); if &*result != expected { print!("Got: ");