diff --git a/Cargo.toml b/Cargo.toml index 5d0d67b6959816e287b56ab9eb3a773606566c91..f3332ea10078cccf78cfd50c4b7113ba6fbbe38e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,10 @@ ark-poly = "0.4.2" ark-poly-commit = "0.4.0" ark-serialize = "0.4.2" ark-std = "0.4.0" -reed-solomon-erasure = { git = "https://github.com/jdetchart/reed-solomon-erasure", branch = "master" } rs_merkle = "1.4.1" +thiserror = "1.0.50" tracing = "0.1.40" tracing-subscriber = "0.3.17" + +[dev-dependencies] +rand = "0.8.5" diff --git a/src/fec.rs b/src/fec.rs index 0dddaaba44d0e19026b69401a9a5570d507337b5..fe848f87db2118f5699adc5c0ee92275699e857a 100644 --- a/src/fec.rs +++ b/src/fec.rs @@ -1,11 +1,12 @@ use std::ops::{Add, Mul}; use ark_ec::pairing::Pairing; +use ark_ff::PrimeField; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{One, Zero}; -use reed_solomon_erasure::{Error, Field as GF, ReedSolomonNonSystematic}; use crate::field; +use crate::linalg::{LinalgError, Matrix}; #[derive(Debug, Default, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)] pub struct LinearCombinationElement<E: Pairing> { @@ -34,7 +35,7 @@ impl<E: Pairing> Shard<E> { .map(|e| e.mul(alpha)) .collect::<Vec<_>>(); - field::merge_elements_into_bytes::<E>(&elements) + field::merge_elements_into_bytes::<E>(&elements, false) }; Shard { @@ -89,39 +90,47 @@ impl<E: Pairing> Shard<E> { k: self.k, linear_combination, hash: self.hash.clone(), - bytes: field::merge_elements_into_bytes::<E>(&elements), + bytes: field::merge_elements_into_bytes::<E>(&elements, false), size: self.size, } } } -pub fn decode<F: GF, E: Pairing>(blocks: Vec<Shard<E>>) -> Result<Vec<u8>, Error> { +pub fn decode<E: Pairing>(blocks: Vec<Shard<E>>) -> Result<Vec<u8>, LinalgError> { let k = blocks[0].k; - let n = blocks - .iter() - // FIXME: this is incorrect - .map(|b| b.linear_combination[0].index) - .max() - .unwrap_or(0) - + 1; if blocks.len() < k as usize { - return Err(Error::TooFewShards); - } - - let mut shards: Vec<Option<Vec<F::Elem>>> = Vec::with_capacity(n as usize); - shards.resize(n as usize, None); - for block in &blocks { - // FIXME: this is incorrect - shards[block.linear_combination[0].index as usize] = Some(F::deserialize(&block.bytes)); + return Err(LinalgError::Other("too few shards".to_string())); } - ReedSolomonNonSystematic::<F>::vandermonde(k as usize, n as usize)?.reconstruct(&mut shards)?; - let elements: Vec<_> = shards.iter().filter_map(|x| x.clone()).flatten().collect(); + let points: Vec<_> = blocks + .iter() + .take(k as usize) + .map(|b| { + E::ScalarField::from_le_bytes_mod_order( + // TODO: use the real linear combination + &(b.linear_combination[0].index as u64).to_le_bytes(), + ) + }) + .collect(); - let mut data = F::into_data(elements.as_slice()); - data.resize(blocks[0].size, 0); - Ok(data) + let shards = Matrix::from_vec_vec( + blocks + .iter() + .take(k as usize) + .map(|b| field::split_data_into_field_elements::<E>(&b.bytes, 1, true)) + .collect(), + )? + .transpose(); + + let source_shards = shards + .mul(&Matrix::vandermonde(&points, k as usize).invert()?)? + .transpose() + .elements; + + let mut bytes = field::merge_elements_into_bytes::<E>(&source_shards, true); + bytes.resize(blocks[0].size, 0); + Ok(bytes) } #[cfg(test)] @@ -130,7 +139,6 @@ mod tests { use ark_ec::pairing::Pairing; use ark_ff::PrimeField; use ark_std::One; - use reed_solomon_erasure::galois_prime::Field as GF; use rs_merkle::algorithms::Sha256; use rs_merkle::Hasher; @@ -152,7 +160,7 @@ mod tests { [3042u32, 3021u32, 3731u32], [4218u32, 4185u32, 5187u32], ]; - const LOST_SHARDS: [usize; 3] = [1, 3, 6]; + const LOST_SHARDS: [usize; 2] = [4, 0]; fn decoding_template<E: Pairing>() { let hash = Sha256::hash(DATA).to_vec(); @@ -181,13 +189,13 @@ mod tests { weight: E::ScalarField::one(), }], hash: hash.clone(), - bytes: field::merge_elements_into_bytes::<E>(bytes), + bytes: field::merge_elements_into_bytes::<E>(bytes, false), size: DATA.len(), }); } } - assert_eq!(DATA, decode::<GF, E>(blocks).unwrap()) + assert_eq!(DATA, decode::<E>(blocks).unwrap()) } #[test] diff --git a/src/field.rs b/src/field.rs index 9cd6f965390a6e847be05ccdf8837b75921fac3b..716b5d3777d1c863b87279fbfbb668f052155434 100644 --- a/src/field.rs +++ b/src/field.rs @@ -34,10 +34,17 @@ pub(crate) fn split_data_into_field_elements<E: Pairing>( elements } -pub(crate) fn merge_elements_into_bytes<E: Pairing>(elements: &[E::ScalarField]) -> Vec<u8> { +pub(crate) fn merge_elements_into_bytes<E: Pairing>( + elements: &[E::ScalarField], + one_less: bool, +) -> Vec<u8> { let mut bytes = vec![]; for e in elements { - bytes.append(&mut e.into_bigint().to_bytes_le()); + let mut b = e.into_bigint().to_bytes_le(); + if one_less { + b.pop(); + } + bytes.append(&mut b); } bytes diff --git a/src/lib.rs b/src/lib.rs index c43c3a3c33ffd680868994d828045b89bac9d807..852019c26029325f8a630f2a0598e0d756f8aaee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use tracing::{debug, info}; pub mod fec; mod field; +mod linalg; pub mod setup; #[derive(Debug, Default, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)] @@ -72,7 +73,7 @@ where weight: E::ScalarField::one(), }], hash: hash.to_vec(), - bytes: field::merge_elements_into_bytes::<E>(row), + bytes: field::merge_elements_into_bytes::<E>(row, false), size: nb_bytes, }, commit: commits.clone(), diff --git a/src/linalg.rs b/src/linalg.rs new file mode 100644 index 0000000000000000000000000000000000000000..00ee65f16a8ff278e7172e2718160e0b51223489 --- /dev/null +++ b/src/linalg.rs @@ -0,0 +1,404 @@ +use ark_ff::Field; +use thiserror::Error; + +#[derive(Clone, PartialEq, Default, Debug)] +pub(super) struct Matrix<T: Field> { + pub elements: Vec<T>, + pub height: usize, + width: usize, +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum LinalgError { + #[error("Invalid matrix elements: {0}")] + InvalidMatrixElements(String), + #[error("Matrix is not a square")] + NonSquareMatrix(usize, usize), + #[error("Matrix is not invertible at row {0}")] + NonInvertibleMatrix(usize), + #[error("Matrices don't have compatible shapes: ({0} x {1}) and ({2} x {3})")] + IncompatibleMatrixShapes(usize, usize, usize, usize), + #[error("Another error: {0}")] + Other(String), +} + +impl<T: Field> Matrix<T> { + fn from_diagonal(diagonal: Vec<T>) -> Self { + let size = diagonal.len(); + + let mut elements = Vec::new(); + elements.resize(size * size, T::zero()); + for i in 0..size { + elements[i * size + i] = diagonal[i]; + } + + Self { + elements, + height: size, + width: size, + } + } + + fn identity(size: usize) -> Self { + Self::from_diagonal(vec![T::one(); size]) + } + + pub(super) fn vandermonde(points: &[T], height: usize) -> Self { + let width = points.len(); + + let mut elements = Vec::new(); + elements.resize(height * width, T::zero()); + + for (j, pj) in points.iter().enumerate() { + for i in 0..height { + elements[i * width + j] = pj.pow([i as u64]); + } + } + + Self { + elements, + height, + width, + } + } + + pub(super) fn from_vec_vec(matrix: Vec<Vec<T>>) -> Result<Self, LinalgError> { + let height = matrix.len(); + let width = matrix[0].len(); + + for (i, row) in matrix.iter().enumerate() { + if row.len() != width { + return Err(LinalgError::InvalidMatrixElements(format!( + "expected rows to be of same length {}, found {} at row {}", + width, + row.len(), + i + ))); + } + } + + let mut elements = Vec::new(); + elements.resize(height * width, T::zero()); + for i in 0..height { + for j in 0..width { + elements[i * width + j] = matrix[i][j]; + } + } + + Ok(Self { + elements, + height, + width, + }) + } + + fn get(&self, i: usize, j: usize) -> T { + self.elements[i * self.width + j] + } + + fn set(&mut self, i: usize, j: usize, value: T) { + self.elements[i * self.width + j] = value; + } + + // compute _row / value_ + fn divide_row_by(&mut self, row: usize, value: T) { + for j in 0..self.width { + self.set(row, j, self.get(row, j) / value); + } + } + + // compute _destination = destination + source * value_ + fn multiply_row_by_and_add_to_row(&mut self, source: usize, value: T, destination: usize) { + for j in 0..self.width { + self.set( + destination, + j, + self.get(destination, j) + self.get(source, j) * value, + ); + } + } + + pub(super) fn invert(&self) -> Result<Self, LinalgError> { + if self.height != self.width { + return Err(LinalgError::NonSquareMatrix(self.height, self.width)); + } + + let mut inverse = Self::identity(self.height); + let mut matrix = self.clone(); + + for i in 0..matrix.height { + let pivot = matrix.get(i, i); + if pivot.is_zero() { + return Err(LinalgError::NonInvertibleMatrix(i)); + } + + inverse.divide_row_by(i, pivot); + matrix.divide_row_by(i, pivot); + + for k in 0..matrix.height { + if k != i { + let factor = matrix.get(k, i); + inverse.multiply_row_by_and_add_to_row(i, -factor, k); + matrix.multiply_row_by_and_add_to_row(i, -factor, k); + } + } + } + + Ok(inverse) + } + + pub(super) fn mul(&self, rhs: &Self) -> Result<Self, LinalgError> { + if self.width != rhs.height { + return Err(LinalgError::IncompatibleMatrixShapes( + self.height, + self.width, + rhs.height, + rhs.width, + )); + } + + let height = self.height; + let width = rhs.width; + let common = self.width; + + let mut elements = Vec::new(); + elements.resize(height * width, T::zero()); + + for i in 0..height { + for j in 0..width { + elements[i * width + j] = (0..common).map(|k| self.get(i, k) * rhs.get(k, j)).sum(); + } + } + + Ok(Self { + elements, + height, + width, + }) + } + + pub(super) fn transpose(&self) -> Self { + let height = self.width; + let width = self.height; + + let mut elements = Vec::new(); + elements.resize(height * width, T::zero()); + + for i in 0..height { + for j in 0..width { + elements[i * width + j] = self.get(j, i); + } + } + + Self { + elements, + height, + width, + } + } +} + +#[cfg(test)] +mod tests { + use ark_bls12_381::Fr; + use ark_ff::Field; + use ark_std::{One, Zero}; + use rand::Rng; + + use super::{LinalgError, Matrix}; + + fn random_field_element<T: Field>() -> T { + let mut rng = rand::thread_rng(); + let element: u128 = rng.gen(); + T::from(element) + } + + #[test] + fn from_vec_vec() { + let actual = Matrix::from_vec_vec(vec![ + vec![Fr::from(2), Fr::zero(), Fr::zero()], + vec![Fr::zero(), Fr::from(3), Fr::zero()], + vec![Fr::zero(), Fr::zero(), Fr::from(4)], + vec![Fr::from(2), Fr::from(3), Fr::from(4)], + ]) + .unwrap(); + let expected = Matrix { + elements: vec![ + Fr::from(2), + Fr::zero(), + Fr::zero(), + Fr::zero(), + Fr::from(3), + Fr::zero(), + Fr::zero(), + Fr::zero(), + Fr::from(4), + Fr::from(2), + Fr::from(3), + Fr::from(4), + ], + height: 4, + width: 3, + }; + assert_eq!(actual, expected); + + let matrix = Matrix::from_vec_vec(vec![vec![Fr::zero()], vec![Fr::zero(), Fr::zero()]]); + assert!(matrix.is_err()); + assert!(matches!( + matrix.err().unwrap(), + LinalgError::InvalidMatrixElements(..) + )); + } + + #[test] + fn diagonal() { + let actual = Matrix::<Fr>::from_diagonal(vec![Fr::from(2), Fr::from(3), Fr::from(4)]); + let expected = Matrix::from_vec_vec(vec![ + vec![Fr::from(2), Fr::zero(), Fr::zero()], + vec![Fr::zero(), Fr::from(3), Fr::zero()], + vec![Fr::zero(), Fr::zero(), Fr::from(4)], + ]) + .unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn identity() { + let actual = Matrix::<Fr>::identity(3); + let expected = Matrix::from_vec_vec(vec![ + vec![Fr::one(), Fr::zero(), Fr::zero()], + vec![Fr::zero(), Fr::one(), Fr::zero()], + vec![Fr::zero(), Fr::zero(), Fr::one()], + ]) + .unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn multiplication() { + let a = Matrix::from_vec_vec(vec![ + vec![Fr::from(9), Fr::from(4), Fr::from(3)], + vec![Fr::from(8), Fr::from(5), Fr::from(2)], + vec![Fr::from(7), Fr::from(6), Fr::from(1)], + ]) + .unwrap(); + let b = Matrix::from_vec_vec(vec![ + vec![Fr::from(1), Fr::from(2), Fr::from(3)], + vec![Fr::from(4), Fr::from(5), Fr::from(6)], + vec![Fr::from(7), Fr::from(8), Fr::from(9)], + ]) + .unwrap(); + + assert!(matches!( + a.mul(&Matrix::from_vec_vec(vec![vec![Fr::from(1), Fr::from(2)]]).unwrap()), + Err(LinalgError::IncompatibleMatrixShapes(3, 3, 1, 2)) + )); + + let product = a.mul(&b).unwrap(); + let expected = Matrix::from_vec_vec(vec![ + vec![Fr::from(46), Fr::from(62), Fr::from(78)], + vec![Fr::from(42), Fr::from(57), Fr::from(72)], + vec![Fr::from(38), Fr::from(52), Fr::from(66)], + ]) + .unwrap(); + assert_eq!(product, expected); + } + + #[test] + fn inverse() { + let matrix = Matrix::<Fr>::identity(3); + let inverse = matrix.invert().unwrap(); + assert_eq!(Matrix::<Fr>::identity(3), inverse); + + let matrix = Matrix::<Fr>::from_diagonal(vec![Fr::from(2), Fr::from(3), Fr::from(4)]); + let inverse = matrix.invert().unwrap(); + assert_eq!(matrix.mul(&inverse).unwrap(), Matrix::<Fr>::identity(3)); + assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(3)); + + let n = 20; + let matrix = Matrix::from_vec_vec( + (0..n) + .map(|_| (0..n).map(|_| random_field_element()).collect()) + .collect::<Vec<Vec<Fr>>>(), + ) + .unwrap(); + let inverse = matrix.invert().unwrap(); + assert_eq!(matrix.mul(&inverse).unwrap(), Matrix::<Fr>::identity(n)); + assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(n)); + + let inverse = Matrix::from_vec_vec(vec![ + vec![Fr::one(), Fr::zero(), Fr::zero()], + vec![Fr::zero(), Fr::one(), Fr::zero()], + ]) + .unwrap() + .invert(); + assert!(inverse.is_err()); + assert!(matches!( + inverse.err().unwrap(), + LinalgError::NonSquareMatrix(..) + )); + + let inverse = + Matrix::<Fr>::from_diagonal(vec![Fr::zero(), Fr::from(3), Fr::from(4)]).invert(); + assert!(inverse.is_err()); + assert!(matches!( + inverse.err().unwrap(), + LinalgError::NonInvertibleMatrix(0) + )); + + let inverse = Matrix::from_vec_vec(vec![ + vec![Fr::one(), Fr::one(), Fr::zero()], + vec![Fr::zero(), Fr::zero(), Fr::zero()], + vec![Fr::zero(), Fr::zero(), Fr::one()], + ]) + .unwrap() + .invert(); + assert!(inverse.is_err()); + assert!(matches!( + inverse.err().unwrap(), + LinalgError::NonInvertibleMatrix(1) + )); + } + + #[test] + fn vandermonde() { + let actual = Matrix::vandermonde( + &[ + Fr::from(0), + Fr::from(1), + Fr::from(2), + Fr::from(3), + Fr::from(4), + ], + 4, + ); + #[rustfmt::skip] + let expected = Matrix::from_vec_vec(vec![ + vec![Fr::from(1), Fr::from(1), Fr::from(1), Fr::from(1), Fr::from(1)], + vec![Fr::from(0), Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)], + vec![Fr::from(0), Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)], + vec![Fr::from(0), Fr::from(1), Fr::from(8), Fr::from(27), Fr::from(64)], + ]) + .unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn transpose() { + let matrix = Matrix::from_vec_vec(vec![ + vec![Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(10)], + vec![Fr::from(4), Fr::from(5), Fr::from(6), Fr::from(11)], + vec![Fr::from(7), Fr::from(8), Fr::from(9), Fr::from(12)], + ]) + .unwrap(); + let transpose = Matrix::from_vec_vec(vec![ + vec![Fr::from(1), Fr::from(4), Fr::from(7)], + vec![Fr::from(2), Fr::from(5), Fr::from(8)], + vec![Fr::from(3), Fr::from(6), Fr::from(9)], + vec![Fr::from(10), Fr::from(11), Fr::from(12)], + ]) + .unwrap(); + + assert_eq!(matrix.transpose(), transpose); + } +} diff --git a/src/main.rs b/src/main.rs index 95f23064f57b1945e5237328ed059d0cb58857b0..9a0b0301734a9800ffc69122e1cc2fc53eb1fa31 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,6 @@ use ark_poly::univariate::DensePolynomial; use ark_poly::DenseUVPolynomial; use ark_poly_commit::kzg10::Powers; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; -use reed_solomon_erasure::galois_prime::Field as GF; use tracing::{debug, info, warn}; use komodo::{ @@ -182,7 +181,7 @@ fn main() { .cloned() .map(|b| b.1.shard) .collect(); - eprintln!("{:?}", decode::<GF, Bls12_381>(blocks).unwrap()); + eprintln!("{:?}", decode::<Bls12_381>(blocks).unwrap()); exit(0); }