diff --git a/Cargo.toml b/Cargo.toml index 17ad977fcb932acc52823c713fa7291b7cb02d6f..404f080bf421ce0213e693d17cf42e3b8280519d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,6 @@ members = [ "bins/inbreeding", "bins/rank", ] + +[dev-dependencies] +itertools = "0.13.0" diff --git a/src/fec.rs b/src/fec.rs index 9284d326093f961c5babe9ae37bd87feb4386c50..f1a9e7cb1ebfaa80d0c81194f34e7dc989b50392 100644 --- a/src/fec.rs +++ b/src/fec.rs @@ -200,11 +200,13 @@ mod tests { use ark_ff::PrimeField; use crate::{ - fec::{decode, encode, Shard}, + fec::{decode, encode, recode_random, Shard}, field, linalg::Matrix, }; + use itertools::Itertools; + use super::recode_with_coeffs; fn bytes() -> Vec<u8> { @@ -215,59 +217,144 @@ mod tests { F::from_le_bytes_mod_order(&n.to_le_bytes()) } - fn end_to_end_template<F: PrimeField>(data: &[u8], k: usize, n: usize) { - let mut rng = ark_std::test_rng(); + /// `contains_one_of(x, set)` is true iif `x` fully contains one of the lists from `set` + /// + /// > **Note** + /// > see [`containment`] for some example + fn contains_one_of(x: &[usize], set: &[Vec<usize>]) -> bool { + set.iter().any(|y| y.iter().all(|z| x.contains(z))) + } - let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", data.len(), k, n); - assert_eq!( - data, - decode::<F>(encode(data, &Matrix::random(k, n, &mut rng)).unwrap()).unwrap(), - "{test_case}" - ); + #[test] + fn containment() { + assert!(contains_one_of(&[1, 2, 3], &[vec![1, 2, 3]])); + assert!(contains_one_of( + &[3, 6, 8], + &[vec![2, 4, 6], vec![1, 3, 7], vec![6, 7, 8], vec![3, 6, 8]], + )); + assert!(!contains_one_of( + &[1, 6, 8], + &[vec![2, 4, 6], vec![1, 3, 7], vec![6, 7, 8], vec![3, 6, 8]], + )); + assert!(contains_one_of( + &[3, 6, 8, 9, 10], + &[vec![2, 4, 6], vec![1, 3, 7], vec![6, 7, 8], vec![3, 6, 8]], + )); } - /// k should be at least 5 - fn end_to_end_with_recoding_template<F: PrimeField>(data: &[u8], k: usize, n: usize) { + fn try_all_decoding_combinations<F: PrimeField>( + data: &[u8], + shards: &[Shard<F>], + k: usize, + test_case: &str, + limit: Option<usize>, + should_not_be_decodable: Vec<Vec<usize>>, + ) { + for c in shards + .iter() + .cloned() + .enumerate() + .combinations(k) + .take(limit.unwrap_or(usize::MAX)) + { + let s = c.iter().map(|(_, s)| s).cloned().collect(); + let is: Vec<usize> = c.iter().map(|(i, _)| i).cloned().collect(); + + let actual = decode::<F>(s); + + if contains_one_of(&is, &should_not_be_decodable) { + assert!( + actual.is_err(), + "should not decode with {:?} {test_case}", + is + ); + continue; + } + + assert!(actual.is_ok(), "could not decode with {:?} {test_case}", is); + + assert_eq!( + data, + actual.unwrap(), + "bad decoded data with {:?} {test_case}", + is, + ); + } + } + + fn end_to_end_template<F: PrimeField>(data: &[u8], k: usize, n: usize) { let mut rng = ark_std::test_rng(); + let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", data.len(), k, n); - let mut shards = encode(data, &Matrix::random(k, n, &mut rng)).unwrap(); - shards[1] = shards[2].recode_with(to_curve(7), &shards[4], to_curve(6)); - shards[2] = shards[1].recode_with(to_curve(5), &shards[3], to_curve(4)); - assert_eq!( - data, - decode::<F>(shards).unwrap(), - "TEST | data: {} bytes, k: {}, n: {}", - data.len(), - k, - n - ); + let shards = encode::<F>(data, &Matrix::random(k, n, &mut rng)) + .unwrap_or_else(|_| panic!("could not encode {test_case}")); + + try_all_decoding_combinations(data, &shards, k, &test_case, None, vec![]); } - // NOTE: this is part of an experiment, to be honest, to be able to see how - // much these tests could be refactored and simplified - fn run_template<F, Fun>(test: Fun) - where - F: PrimeField, - Fun: Fn(&[u8], usize, usize), - { - let bytes = bytes(); - let (k, n) = (3, 5); + fn end_to_end_with_recoding_template<F: PrimeField>(data: &[u8], k: usize, n: usize) { + assert!(n >= 5, "n should be at least 5, found {}", n); - let modulus_byte_size = F::MODULUS_BIT_SIZE as usize / 8; - // NOTE: starting at `modulus_byte_size * (k - 1) + 1` to include at least _k_ elements - for b in (modulus_byte_size * (k - 1) + 1)..bytes.len() { - test(&bytes[..b], k, n); + let mut rng = ark_std::test_rng(); + let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", data.len(), k, n); + + let mut shards = encode::<F>(data, &Matrix::random(k, n, &mut rng)) + .unwrap_or_else(|_| panic!("could not encode {test_case}")); + + let recoding_steps = [ + vec![2, 4], // = n + vec![1, 3], // = (n + 1) + vec![n, (n + 1)], // = (n + 2) = ((2, 4), (1, 3)) + vec![0], // = (n + 3) = (0) + vec![(n + 3)], // = (n + 4) = (0) + ]; + let should_not_be_decodable = vec![ + vec![2, 4, n], + vec![1, 3, (n + 1)], + vec![n, (n + 1), (n + 2)], + vec![1, 3, n, (n + 2)], + vec![2, 4, (n + 1), (n + 2)], + vec![1, 2, 3, 4, (n + 2)], + vec![0, (n + 3)], + vec![0, (n + 4)], + vec![(n + 3), (n + 4)], + ]; + + for step in recoding_steps { + let shards_to_recode: Vec<_> = shards + .iter() + .cloned() + .enumerate() + .filter_map(|(i, s)| if step.contains(&i) { Some(s) } else { None }) + .collect(); + shards.push(recode_random(&shards_to_recode, &mut rng).unwrap().unwrap()); } + + try_all_decoding_combinations(data, &shards, k, &test_case, None, should_not_be_decodable); } #[test] fn end_to_end() { - run_template::<Fr, _>(end_to_end_template::<Fr>); + let bytes = bytes(); + + for k in [3, 5] { + for rho in [0.5, 0.33] { + let n = (k as f64 / rho) as usize; + end_to_end_template::<Fr>(&bytes, k, n); + } + } } #[test] fn end_to_end_with_recoding() { - run_template::<Fr, _>(end_to_end_with_recoding_template::<Fr>); + let bytes = bytes(); + + for k in [3, 5] { + for rho in [0.50, 0.33] { + let n = (k as f64 / rho) as usize; + end_to_end_with_recoding_template::<Fr>(&bytes, k, n); + } + } } fn create_fake_shard<F: PrimeField>(linear_combination: &[F], bytes: &[u8]) -> Shard<F> { diff --git a/src/linalg.rs b/src/linalg.rs index 86deb4046278e619916c8b7563105260385efba1..12ece96d08c926a3589452a706a7ccfa910758c5 100644 --- a/src/linalg.rs +++ b/src/linalg.rs @@ -666,6 +666,20 @@ mod tests { assert_eq!(product, expected); } + #[test] + fn random() { + let mut rng = ark_std::test_rng(); + + for n in 0..10 { + for m in 0..10 { + let mat = Matrix::<Fr>::random(n, m, &mut rng); + assert_eq!(mat.elements.len(), n * m); + assert_eq!(mat.width, m); + assert_eq!(mat.height, n); + } + } + } + #[test] fn inverse() { let mut rng = ark_std::test_rng(); @@ -679,11 +693,12 @@ mod tests { assert_eq!(matrix.mul(&inverse).unwrap(), Matrix::<Fr>::identity(3)); assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(3)); - let n = 20; - let matrix = Matrix::random(n, n, &mut rng); - let inverse = matrix.invert().unwrap(); - assert_eq!(matrix.mul(&inverse).unwrap(), Matrix::<Fr>::identity(n)); - assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(n)); + for n in 1..20 { + let matrix = Matrix::random(n, n, &mut rng); + let inverse = matrix.invert().unwrap(); + assert_eq!(matrix.mul(&inverse).unwrap(), Matrix::<Fr>::identity(n)); + assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(n)); + } let inverse = Matrix::<Fr>::from_vec_vec(mat_to_elements(vec![vec![1, 0, 0], vec![0, 1, 0]]))