From 8c07220b52f070839a7312db2c899a58679c9424 Mon Sep 17 00:00:00 2001
From: STEVAN Antoine <antoine.stevan@isae-supaero.fr>
Date: Tue, 6 Aug 2024 10:38:45 +0000
Subject: [PATCH] refactor algebra module (dragoon/komodo!165)

`algebra`, `field` and `linalg` were doing extremely similar things before...

this MR merges them into a single module `algebra`
- old `algebra` and `field` are at the root of the new `algebra`
- old `linalg` is now `algebra::linalg`

all references to these have been fixed in the rest of the codebase and the features have been tuned to work fine.
---
 benchmarks/src/bin/fec.rs          |   2 +-
 benchmarks/src/bin/linalg.rs       |   2 +-
 benchmarks/src/bin/recoding.rs     |   4 +-
 bins/rank/src/main.rs              |   2 +-
 bins/saclin/src/main.rs            |   2 +-
 src/{ => algebra}/linalg.rs        |   8 +-
 src/{algebra.rs => algebra/mod.rs} | 127 ++++++++++++++++++++++++++++-
 src/aplonk/mod.rs                  |   6 +-
 src/fec.rs                         |  12 +--
 src/field.rs                       | 111 -------------------------
 src/kzg.rs                         |   6 +-
 src/lib.rs                         |   5 +-
 src/semi_avid.rs                   |   6 +-
 13 files changed, 152 insertions(+), 141 deletions(-)
 rename src/{ => algebra}/linalg.rs (99%)
 rename src/{algebra.rs => algebra/mod.rs} (58%)
 delete mode 100644 src/field.rs

diff --git a/benchmarks/src/bin/fec.rs b/benchmarks/src/bin/fec.rs
index 90a16dd3..173438d2 100644
--- a/benchmarks/src/bin/fec.rs
+++ b/benchmarks/src/bin/fec.rs
@@ -2,7 +2,7 @@
 use ark_ff::PrimeField;
 
 use clap::{arg, command, Parser, ValueEnum};
-use komodo::{fec, linalg::Matrix};
+use komodo::{algebra::linalg::Matrix, fec};
 use plnk::Bencher;
 use rand::{rngs::ThreadRng, thread_rng, Rng, RngCore};
 
diff --git a/benchmarks/src/bin/linalg.rs b/benchmarks/src/bin/linalg.rs
index 022c4e20..be8e3d28 100644
--- a/benchmarks/src/bin/linalg.rs
+++ b/benchmarks/src/bin/linalg.rs
@@ -2,7 +2,7 @@
 use ark_ff::PrimeField;
 
 use clap::{arg, command, Parser};
-use komodo::linalg::Matrix;
+use komodo::algebra::linalg::Matrix;
 use plnk::Bencher;
 
 fn inverse_template<F: PrimeField>(b: &Bencher, n: usize) {
diff --git a/benchmarks/src/bin/recoding.rs b/benchmarks/src/bin/recoding.rs
index b7bce615..69c55de0 100644
--- a/benchmarks/src/bin/recoding.rs
+++ b/benchmarks/src/bin/recoding.rs
@@ -4,8 +4,8 @@ use ark_std::rand::Rng;
 
 use clap::{arg, command, Parser, ValueEnum};
 use komodo::{
+    algebra,
     fec::{recode_with_coeffs, Shard},
-    field,
 };
 use plnk::Bencher;
 
@@ -23,7 +23,7 @@ fn create_fake_shard<F: PrimeField>(nb_bytes: usize, k: usize) -> Shard<F> {
         k: k as u32,
         linear_combination,
         hash: vec![],
-        data: field::split_data_into_field_elements::<F>(&bytes, 1),
+        data: algebra::split_data_into_field_elements::<F>(&bytes, 1),
         size: 0,
     }
 }
diff --git a/bins/rank/src/main.rs b/bins/rank/src/main.rs
index 0739a85d..9393e801 100644
--- a/bins/rank/src/main.rs
+++ b/bins/rank/src/main.rs
@@ -2,7 +2,7 @@ use ark_bls12_381::Fr;
 use ark_ff::Field;
 use ark_std::rand::{Rng, RngCore};
 
-use komodo::linalg::Matrix;
+use komodo::algebra::linalg::Matrix;
 
 fn rand<T: Field, R: RngCore>(rng: &mut R) -> T {
     let element: u128 = rng.gen();
diff --git a/bins/saclin/src/main.rs b/bins/saclin/src/main.rs
index cc262431..2f6b1ff2 100644
--- a/bins/saclin/src/main.rs
+++ b/bins/saclin/src/main.rs
@@ -13,10 +13,10 @@ use ark_std::rand::RngCore;
 use tracing::{info, warn};
 
 use komodo::{
+    algebra::linalg::Matrix,
     error::KomodoError,
     fec::{self, decode, Shard},
     fs,
-    linalg::Matrix,
     semi_avid::{build, prove, recode, verify, Block},
     zk::{self, Powers},
 };
diff --git a/src/linalg.rs b/src/algebra/linalg.rs
similarity index 99%
rename from src/linalg.rs
rename to src/algebra/linalg.rs
index 12ece96d..5e6d6275 100644
--- a/src/linalg.rs
+++ b/src/algebra/linalg.rs
@@ -79,7 +79,7 @@ impl<T: Field> Matrix<T> {
     /// # Example
     /// ```rust
     /// # use ark_ff::Field;
-    /// # use komodo::linalg::Matrix;
+    /// # use komodo::algebra::linalg::Matrix;
     /// // helper to convert integers to field elements
     /// fn vec_to_elements<T: Field>(elements: Vec<u128>) -> Vec<T>
     /// # {
@@ -160,7 +160,7 @@ impl<T: Field> Matrix<T> {
     ///
     /// # Example
     /// ```rust
-    /// # use komodo::linalg::Matrix;
+    /// # use komodo::algebra::linalg::Matrix;
     /// # use ark_ff::Field;
     /// // helper to convert integers to field elements
     /// fn vec_to_elements<T: Field>(elements: Vec<u128>) -> Vec<T>
@@ -250,7 +250,7 @@ impl<T: Field> Matrix<T> {
     ///
     /// > **Note**
     /// > returns `None` if the provided index is out of bounds
-    pub(super) fn get_col(&self, j: usize) -> Option<Vec<T>> {
+    pub(crate) fn get_col(&self, j: usize) -> Option<Vec<T>> {
         if j >= self.width {
             return None;
         }
@@ -438,7 +438,7 @@ impl<T: Field> Matrix<T> {
     /// # Example
     /// if a matrix has shape `(10, 11)` and is truncated to `(5, 7)`, the 5
     /// bottom rows and 4 right columns will be removed.
-    pub(super) fn truncate(&self, rows: Option<usize>, cols: Option<usize>) -> Self {
+    pub(crate) fn truncate(&self, rows: Option<usize>, cols: Option<usize>) -> Self {
         let width = if let Some(w) = cols {
             self.width - w
         } else {
diff --git a/src/algebra.rs b/src/algebra/mod.rs
similarity index 58%
rename from src/algebra.rs
rename to src/algebra/mod.rs
index 63c7c025..a34ab0b9 100644
--- a/src/algebra.rs
+++ b/src/algebra/mod.rs
@@ -1,10 +1,53 @@
+//! manipulate finite field elements
+//!
+#[cfg(any(feature = "kzg", feature = "aplonk"))]
 use ark_ec::pairing::Pairing;
 #[cfg(feature = "aplonk")]
 use ark_ec::pairing::PairingOutput;
+use ark_ff::{BigInteger, PrimeField};
+#[cfg(any(feature = "kzg", feature = "aplonk"))]
 use ark_poly::DenseUVPolynomial;
+#[cfg(any(feature = "kzg", feature = "aplonk"))]
 use ark_std::One;
+#[cfg(any(feature = "kzg", feature = "aplonk"))]
 use std::ops::{Div, Mul};
 
+pub mod linalg;
+
+/// split a sequence of raw bytes into valid field elements
+///
+/// [`split_data_into_field_elements`] supports padding the output vector of
+/// elements by giving a number that needs to divide the length of the vector.
+pub fn split_data_into_field_elements<F: PrimeField>(bytes: &[u8], modulus: usize) -> Vec<F> {
+    let bytes_per_element = (F::MODULUS_BIT_SIZE as usize) / 8;
+
+    let mut elements = Vec::new();
+    for chunk in bytes.chunks(bytes_per_element) {
+        elements.push(F::from_le_bytes_mod_order(chunk));
+    }
+
+    if elements.len() % modulus != 0 {
+        elements.resize((elements.len() / modulus + 1) * modulus, F::one());
+    }
+
+    elements
+}
+
+/// merges elliptic curve elements back into a sequence of bytes
+///
+/// this is the inverse operation of [`split_data_into_field_elements`].
+pub(crate) fn merge_elements_into_bytes<F: PrimeField>(elements: &[F]) -> Vec<u8> {
+    let mut bytes = vec![];
+    for e in elements {
+        let mut b = e.into_bigint().to_bytes_le();
+        b.pop();
+        bytes.append(&mut b);
+    }
+
+    bytes
+}
+
+#[cfg(any(feature = "kzg", feature = "aplonk"))]
 pub(crate) fn scalar_product_polynomial<E, P>(lhs: &[E::ScalarField], rhs: &[P]) -> P
 where
     E: Pairing,
@@ -69,6 +112,7 @@ pub(super) mod vector {
 /// following vector:
 ///         [1, r, r^2, ..., r^(n-1)]
 /// where *n* is the number of powers
+#[cfg(any(feature = "kzg", feature = "aplonk"))]
 pub(crate) fn powers_of<E: Pairing>(step: E::ScalarField, nb_powers: usize) -> Vec<E::ScalarField> {
     let mut powers = Vec::with_capacity(nb_powers);
     powers.push(E::ScalarField::one());
@@ -81,12 +125,86 @@ pub(crate) fn powers_of<E: Pairing>(step: E::ScalarField, nb_powers: usize) -> V
 
 #[cfg(test)]
 mod tests {
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
     use ark_bls12_381::Bls12_381;
+    use ark_bls12_381::Fr;
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
     use ark_ec::pairing::Pairing;
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
     use ark_ff::Field;
-    use ark_std::test_rng;
-    use ark_std::UniformRand;
+    use ark_ff::PrimeField;
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
+    use ark_std::{test_rng, UniformRand};
+
+    fn bytes() -> Vec<u8> {
+        include_bytes!("../../assets/dragoon_32x32.png").to_vec()
+    }
+
+    fn split_data_template<F: PrimeField>(
+        bytes: &[u8],
+        modulus: usize,
+        exact_length: Option<usize>,
+    ) {
+        let test_case = format!(
+            "TEST | modulus: {}, exact_length: {:?}",
+            modulus, exact_length
+        );
+
+        let elements = super::split_data_into_field_elements::<F>(bytes, modulus);
+        assert!(
+            elements.len() % modulus == 0,
+            "number of elements should be divisible by {}, found {}\n{test_case}",
+            modulus,
+            elements.len(),
+        );
+
+        if let Some(length) = exact_length {
+            assert!(
+                elements.len() == length,
+                "number of elements should be exactly {}, found {}\n{test_case}",
+                length,
+                elements.len(),
+            );
+        }
+
+        assert!(
+            !elements.iter().any(|&e| e == F::zero()),
+            "elements should not contain any 0\n{test_case}"
+        );
+    }
+
+    #[test]
+    fn split_data() {
+        split_data_template::<Fr>(&bytes(), 1, None);
+        split_data_template::<Fr>(&bytes(), 8, None);
+        split_data_template::<Fr>(&[], 1, None);
+        split_data_template::<Fr>(&[], 8, None);
+
+        let nb_bytes = 11 * (Fr::MODULUS_BIT_SIZE as usize / 8);
+        split_data_template::<Fr>(&bytes()[..nb_bytes], 1, Some(11));
+        split_data_template::<Fr>(&bytes()[..nb_bytes], 8, Some(16));
+
+        let nb_bytes = 11 * (Fr::MODULUS_BIT_SIZE as usize / 8) - 10;
+        split_data_template::<Fr>(&bytes()[..nb_bytes], 1, Some(11));
+        split_data_template::<Fr>(&bytes()[..nb_bytes], 8, Some(16));
+    }
+
+    fn split_and_merge_template<F: PrimeField>(bytes: &[u8], modulus: usize) {
+        let elements: Vec<F> = super::split_data_into_field_elements(bytes, modulus);
+        let mut actual = super::merge_elements_into_bytes(&elements);
+        actual.resize(bytes.len(), 0);
+        assert_eq!(bytes, actual, "TEST | modulus: {modulus}");
+    }
+
+    #[test]
+    fn split_and_merge() {
+        split_and_merge_template::<Fr>(&bytes(), 1);
+        split_and_merge_template::<Fr>(&bytes(), 8);
+        split_and_merge_template::<Fr>(&bytes(), 64);
+        split_and_merge_template::<Fr>(&bytes(), 4096);
+    }
 
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
     fn powers_of_template<E: Pairing>() {
         let rng = &mut test_rng();
 
@@ -99,16 +217,19 @@ mod tests {
         );
     }
 
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
     #[test]
     fn powers_of() {
         powers_of_template::<Bls12_381>();
     }
 
+    #[cfg(any(feature = "kzg", feature = "aplonk"))]
     mod scalar_product {
         use ark_bls12_381::Bls12_381;
         use ark_ec::pairing::Pairing;
         use ark_ff::PrimeField;
-        use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial};
+        use ark_poly::univariate::DensePolynomial;
+        use ark_poly::DenseUVPolynomial;
         #[cfg(feature = "aplonk")]
         use ark_std::test_rng;
         #[cfg(feature = "aplonk")]
diff --git a/src/aplonk/mod.rs b/src/aplonk/mod.rs
index 86400edb..5b488347 100644
--- a/src/aplonk/mod.rs
+++ b/src/aplonk/mod.rs
@@ -387,7 +387,9 @@ where
 #[cfg(test)]
 mod tests {
     use super::{commit, prove, setup, Block};
-    use crate::{conversions::u32_to_u8_vec, fec::encode, field, linalg::Matrix, zk::trim};
+    use crate::{
+        algebra, algebra::linalg::Matrix, conversions::u32_to_u8_vec, fec::encode, zk::trim,
+    };
 
     use ark_bls12_381::Bls12_381;
     use ark_ec::{pairing::Pairing, AffineRepr};
@@ -428,7 +430,7 @@ mod tests {
         let params = setup::<E, P>(degree, vector_length_bound)?;
         let (_, vk_psi) = trim(params.kzg.clone(), degree);
 
-        let elements = field::split_data_into_field_elements::<E::ScalarField>(bytes, k);
+        let elements = algebra::split_data_into_field_elements::<E::ScalarField>(bytes, k);
         let mut polynomials = Vec::new();
         for chunk in elements.chunks(k) {
             polynomials.push(P::from_coefficients_vec(chunk.to_vec()))
diff --git a/src/fec.rs b/src/fec.rs
index dbe603c3..7a53afbe 100644
--- a/src/fec.rs
+++ b/src/fec.rs
@@ -6,7 +6,7 @@ use ark_std::rand::RngCore;
 
 use rs_merkle::{algorithms::Sha256, Hasher};
 
-use crate::{error::KomodoError, field, linalg::Matrix};
+use crate::{algebra, algebra::linalg::Matrix, error::KomodoError};
 
 /// representation of a FEC shard of data.
 #[derive(Debug, Default, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
@@ -151,7 +151,7 @@ pub fn encode<F: PrimeField>(
     let k = encoding_mat.height;
 
     let source_shards = Matrix::from_vec_vec(
-        field::split_data_into_field_elements(data, k)
+        algebra::split_data_into_field_elements(data, k)
             .chunks(k)
             .map(|c| c.to_vec())
             .collect(),
@@ -210,7 +210,7 @@ pub fn decode<F: PrimeField>(shards: Vec<Shard<F>>) -> Result<Vec<u8>, KomodoErr
 
     let source_shards = encoding_mat.invert()?.mul(&shard_mat)?.transpose().elements;
 
-    let mut bytes = field::merge_elements_into_bytes(&source_shards);
+    let mut bytes = algebra::merge_elements_into_bytes(&source_shards);
     bytes.resize(shards[0].size, 0);
     Ok(bytes)
 }
@@ -221,9 +221,9 @@ mod tests {
     use ark_ff::PrimeField;
 
     use crate::{
+        algebra,
+        algebra::linalg::Matrix,
         fec::{decode, encode, recode_random, Shard},
-        field,
-        linalg::Matrix,
     };
 
     use itertools::Itertools;
@@ -512,7 +512,7 @@ mod tests {
             k: 2,
             linear_combination: linear_combination.to_vec(),
             hash: vec![],
-            data: field::split_data_into_field_elements(bytes, 1),
+            data: algebra::split_data_into_field_elements(bytes, 1),
             size: 0,
         }
     }
diff --git a/src/field.rs b/src/field.rs
deleted file mode 100644
index 0970b260..00000000
--- a/src/field.rs
+++ /dev/null
@@ -1,111 +0,0 @@
-//! manipulate finite field elements
-use ark_ff::{BigInteger, PrimeField};
-
-/// split a sequence of raw bytes into valid field elements
-///
-/// [`split_data_into_field_elements`] supports padding the output vector of
-/// elements by giving a number that needs to divide the length of the vector.
-pub fn split_data_into_field_elements<F: PrimeField>(bytes: &[u8], modulus: usize) -> Vec<F> {
-    let bytes_per_element = (F::MODULUS_BIT_SIZE as usize) / 8;
-
-    let mut elements = Vec::new();
-    for chunk in bytes.chunks(bytes_per_element) {
-        elements.push(F::from_le_bytes_mod_order(chunk));
-    }
-
-    if elements.len() % modulus != 0 {
-        elements.resize((elements.len() / modulus + 1) * modulus, F::one());
-    }
-
-    elements
-}
-
-/// merges elliptic curve elements back into a sequence of bytes
-///
-/// this is the inverse operation of [`split_data_into_field_elements`].
-pub(crate) fn merge_elements_into_bytes<F: PrimeField>(elements: &[F]) -> Vec<u8> {
-    let mut bytes = vec![];
-    for e in elements {
-        let mut b = e.into_bigint().to_bytes_le();
-        b.pop();
-        bytes.append(&mut b);
-    }
-
-    bytes
-}
-
-#[cfg(test)]
-mod tests {
-    use ark_bls12_381::Fr;
-    use ark_ff::PrimeField;
-
-    use crate::field::{self, merge_elements_into_bytes};
-
-    fn bytes() -> Vec<u8> {
-        include_bytes!("../assets/dragoon_32x32.png").to_vec()
-    }
-
-    fn split_data_template<F: PrimeField>(
-        bytes: &[u8],
-        modulus: usize,
-        exact_length: Option<usize>,
-    ) {
-        let test_case = format!(
-            "TEST | modulus: {}, exact_length: {:?}",
-            modulus, exact_length
-        );
-
-        let elements = field::split_data_into_field_elements::<F>(bytes, modulus);
-        assert!(
-            elements.len() % modulus == 0,
-            "number of elements should be divisible by {}, found {}\n{test_case}",
-            modulus,
-            elements.len(),
-        );
-
-        if let Some(length) = exact_length {
-            assert!(
-                elements.len() == length,
-                "number of elements should be exactly {}, found {}\n{test_case}",
-                length,
-                elements.len(),
-            );
-        }
-
-        assert!(
-            !elements.iter().any(|&e| e == F::zero()),
-            "elements should not contain any 0\n{test_case}"
-        );
-    }
-
-    #[test]
-    fn split_data() {
-        split_data_template::<Fr>(&bytes(), 1, None);
-        split_data_template::<Fr>(&bytes(), 8, None);
-        split_data_template::<Fr>(&[], 1, None);
-        split_data_template::<Fr>(&[], 8, None);
-
-        let nb_bytes = 11 * (Fr::MODULUS_BIT_SIZE as usize / 8);
-        split_data_template::<Fr>(&bytes()[..nb_bytes], 1, Some(11));
-        split_data_template::<Fr>(&bytes()[..nb_bytes], 8, Some(16));
-
-        let nb_bytes = 11 * (Fr::MODULUS_BIT_SIZE as usize / 8) - 10;
-        split_data_template::<Fr>(&bytes()[..nb_bytes], 1, Some(11));
-        split_data_template::<Fr>(&bytes()[..nb_bytes], 8, Some(16));
-    }
-
-    fn split_and_merge_template<F: PrimeField>(bytes: &[u8], modulus: usize) {
-        let elements: Vec<F> = field::split_data_into_field_elements(bytes, modulus);
-        let mut actual = merge_elements_into_bytes(&elements);
-        actual.resize(bytes.len(), 0);
-        assert_eq!(bytes, actual, "TEST | modulus: {modulus}");
-    }
-
-    #[test]
-    fn split_and_merge() {
-        split_and_merge_template::<Fr>(&bytes(), 1);
-        split_and_merge_template::<Fr>(&bytes(), 8);
-        split_and_merge_template::<Fr>(&bytes(), 64);
-        split_and_merge_template::<Fr>(&bytes(), 4096);
-    }
-}
diff --git a/src/kzg.rs b/src/kzg.rs
index e065b7c8..72baa083 100644
--- a/src/kzg.rs
+++ b/src/kzg.rs
@@ -211,7 +211,9 @@ mod tests {
     use ark_std::test_rng;
     use std::ops::{Div, Mul};
 
-    use crate::{conversions::u32_to_u8_vec, fec::encode, field, linalg::Matrix, zk::trim};
+    use crate::{
+        algebra, algebra::linalg::Matrix, conversions::u32_to_u8_vec, fec::encode, zk::trim,
+    };
 
     type UniPoly381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>;
 
@@ -237,7 +239,7 @@ mod tests {
         let params = KZG10::<E, P>::setup(degree, false, rng)?;
         let (powers, verifier_key) = trim(params, degree);
 
-        let elements = field::split_data_into_field_elements::<E::ScalarField>(bytes, k);
+        let elements = algebra::split_data_into_field_elements::<E::ScalarField>(bytes, k);
         let mut polynomials = Vec::new();
         for chunk in elements.chunks(k) {
             polynomials.push(P::from_coefficients_vec(chunk.to_vec()))
diff --git a/src/lib.rs b/src/lib.rs
index 1e6f077a..78afa610 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,16 +1,13 @@
 //! Komodo: Cryptographically-proven Erasure Coding
-#[cfg(any(feature = "kzg", feature = "aplonk"))]
-mod algebra;
+pub mod algebra;
 #[cfg(feature = "aplonk")]
 pub mod aplonk;
 #[cfg(any(feature = "kzg", feature = "aplonk"))]
 mod conversions;
 pub mod error;
 pub mod fec;
-pub mod field;
 pub mod fs;
 #[cfg(feature = "kzg")]
 pub mod kzg;
-pub mod linalg;
 pub mod semi_avid;
 pub mod zk;
diff --git a/src/semi_avid.rs b/src/semi_avid.rs
index a2bc5314..2371b570 100644
--- a/src/semi_avid.rs
+++ b/src/semi_avid.rs
@@ -8,9 +8,9 @@ use ark_std::rand::RngCore;
 use tracing::{debug, info};
 
 use crate::{
+    algebra,
     error::KomodoError,
     fec::{self, Shard},
-    field,
     zk::{self, Commitment, Powers},
 };
 
@@ -123,7 +123,7 @@ where
     info!("encoding and proving {} bytes", bytes.len());
 
     debug!("splitting bytes into polynomials");
-    let elements = field::split_data_into_field_elements(bytes, k);
+    let elements = algebra::split_data_into_field_elements(bytes, k);
     let polynomials = elements
         .chunks(k)
         .map(|c| P::from_coefficients_vec(c.to_vec()))
@@ -197,9 +197,9 @@ mod tests {
     use ark_std::{ops::Div, test_rng};
 
     use crate::{
+        algebra::linalg::Matrix,
         error::KomodoError,
         fec::{decode, encode, Shard},
-        linalg::Matrix,
         zk::{setup, Commitment},
     };
 
-- 
GitLab