diff --git a/src/fec.rs b/src/fec.rs index f1a9e7cb1ebfaa80d0c81194f34e7dc989b50392..a1fd2da61b99e39138be447af095367b4bdfa34c 100644 --- a/src/fec.rs +++ b/src/fec.rs @@ -209,6 +209,9 @@ mod tests { use super::recode_with_coeffs; + type LC = Vec<usize>; + type LCExclusion = Vec<usize>; + fn bytes() -> Vec<u8> { include_bytes!("../tests/dragoon_32x32.png").to_vec() } @@ -240,16 +243,20 @@ mod tests { &[3, 6, 8, 9, 10], &[vec![2, 4, 6], vec![1, 3, 7], vec![6, 7, 8], vec![3, 6, 8]], )); + assert!(contains_one_of(&[0, 4, 5], &[vec![2, 3, 5], vec![4, 5]])); } fn try_all_decoding_combinations<F: PrimeField>( data: &[u8], shards: &[Shard<F>], k: usize, + n: usize, test_case: &str, limit: Option<usize>, - should_not_be_decodable: Vec<Vec<usize>>, + should_not_be_decodable: Vec<LCExclusion>, ) { + let there_are_recoded_shards = shards.len() > n; + for c in shards .iter() .cloned() @@ -257,27 +264,52 @@ mod tests { .combinations(k) .take(limit.unwrap_or(usize::MAX)) { - let s = c.iter().map(|(_, s)| s).cloned().collect(); - let is: Vec<usize> = c.iter().map(|(i, _)| i).cloned().collect(); + let is: Vec<usize> = c.iter().map(|(i, _)| *i).collect(); + if there_are_recoded_shards { + let contains_recoded_shards = *is.iter().max().unwrap() < n; + if contains_recoded_shards { + continue; + } + } - let actual = decode::<F>(s); + let pretty_is = is + .iter() + .map(|&i| { + #[allow(clippy::comparison_chain)] + if i == n { + "(n)".into() + } else if i > n { + format!("(n + {})", i - n) + } else { + format!("{}", i) + } + }) + .collect::<Vec<_>>() + .join(", "); + let pretty_is = format!("[{pretty_is}]"); + + let actual = decode::<F>(c.iter().map(|(_, s)| s).cloned().collect()); if contains_one_of(&is, &should_not_be_decodable) { assert!( actual.is_err(), - "should not decode with {:?} {test_case}", - is + "should not decode with {} {test_case}", + pretty_is ); continue; } - assert!(actual.is_ok(), "could not decode with {:?} {test_case}", is); + assert!( + actual.is_ok(), + "could not decode with {} {test_case}", + pretty_is + ); assert_eq!( data, actual.unwrap(), - "bad decoded data with {:?} {test_case}", - is, + "bad decoded data with {} {test_case}", + pretty_is, ); } } @@ -289,37 +321,29 @@ mod tests { let shards = encode::<F>(data, &Matrix::random(k, n, &mut rng)) .unwrap_or_else(|_| panic!("could not encode {test_case}")); - try_all_decoding_combinations(data, &shards, k, &test_case, None, vec![]); + try_all_decoding_combinations(data, &shards, k, n, &test_case, None, vec![]); } - fn end_to_end_with_recoding_template<F: PrimeField>(data: &[u8], k: usize, n: usize) { - assert!(n >= 5, "n should be at least 5, found {}", n); - + fn end_to_end_with_recoding_template<F: PrimeField>( + data: &[u8], + k: usize, + n: usize, + recoding_steps: Vec<LC>, + should_not_be_decodable: Vec<LCExclusion>, + name: &str, + ) { let mut rng = ark_std::test_rng(); - let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", data.len(), k, n); + let test_case = format!( + "TEST | data: {} bytes, k: {}, n: {}, scenario: {}", + data.len(), + k, + n, + name + ); let mut shards = encode::<F>(data, &Matrix::random(k, n, &mut rng)) .unwrap_or_else(|_| panic!("could not encode {test_case}")); - let recoding_steps = [ - vec![2, 4], // = n - vec![1, 3], // = (n + 1) - vec![n, (n + 1)], // = (n + 2) = ((2, 4), (1, 3)) - vec![0], // = (n + 3) = (0) - vec![(n + 3)], // = (n + 4) = (0) - ]; - let should_not_be_decodable = vec![ - vec![2, 4, n], - vec![1, 3, (n + 1)], - vec![n, (n + 1), (n + 2)], - vec![1, 3, n, (n + 2)], - vec![2, 4, (n + 1), (n + 2)], - vec![1, 2, 3, 4, (n + 2)], - vec![0, (n + 3)], - vec![0, (n + 4)], - vec![(n + 3), (n + 4)], - ]; - for step in recoding_steps { let shards_to_recode: Vec<_> = shards .iter() @@ -330,18 +354,26 @@ mod tests { shards.push(recode_random(&shards_to_recode, &mut rng).unwrap().unwrap()); } - try_all_decoding_combinations(data, &shards, k, &test_case, None, should_not_be_decodable); + try_all_decoding_combinations( + data, + &shards, + k, + n, + &test_case, + None, + should_not_be_decodable, + ); } #[test] fn end_to_end() { let bytes = bytes(); - for k in [3, 5] { - for rho in [0.5, 0.33] { - let n = (k as f64 / rho) as usize; - end_to_end_template::<Fr>(&bytes, k, n); - } + let ks = [3, 5]; + let n = 5; + + for k in ks { + end_to_end_template::<Fr>(&bytes, k, n); } } @@ -349,10 +381,62 @@ mod tests { fn end_to_end_with_recoding() { let bytes = bytes(); - for k in [3, 5] { - for rho in [0.50, 0.33] { - let n = (k as f64 / rho) as usize; - end_to_end_with_recoding_template::<Fr>(&bytes, k, n); + fn get_scenarii(n: usize) -> Vec<(String, Vec<LC>, Vec<LCExclusion>)> { + vec![ + // ```mermaid + // graph TD; + // a[n+1]; b[n+2]; c[n+3]; + // + // 1; + // 3-->a; 5-->a; + // 2-->b; 4-->b; + // a-->c; b-->c; + // ``` + ( + "simple".into(), + vec![ + vec![2, 4], // = n + vec![1, 3], // = (n + 1) + vec![n, (n + 1)], // = (n + 2) = ((2, 4), (1, 3)) + ], + vec![ + vec![2, 4, n], + // + vec![1, 3, (n + 1)], + vec![n, (n + 1), (n + 2)], + vec![1, 3, n, (n + 2)], + vec![2, 4, (n + 1), (n + 2)], + vec![1, 2, 3, 4, (n + 2)], + ], + ), + // ```mermaid + // graph TD; + // a[n+1]; b[n+2]; + // + // 1-->a; a-->b; + // 2; 3; 4; 5; + // ``` + ( + "chain".into(), + vec![ + vec![0], // = (n) = (0) + vec![(n)], // = (n + 1) = (0) + ], + vec![vec![0, (n)], vec![0, (n + 1)], vec![(n), (n + 1)]], + ), + ] + } + + for (k, n) in [(3, 5), (5, 5), (8, 10)] { + for (name, steps, should_not_decode) in get_scenarii(n) { + end_to_end_with_recoding_template::<Fr>( + &bytes, + k, + n, + steps, + should_not_decode, + &name, + ); } } }