Skip to content
Snippets Groups Projects
Commit 0c48f632 authored by STEVAN Antoine's avatar STEVAN Antoine :crab:
Browse files

allow passing any matrix as parameter to encoding process (dragoon/komodo!27)

## changelog
- add `--encoding-method` to `komodo prove`
- pass the encoding matrix to `encode` and `fec::encode` instead of `k` and `n`, these two parameters can be extracted without further check by looking at the shape of the encoding matrix
- the global recoding vector is now extracted from the encoding matrix instead of recomputing it (see new `Matrix::get_col` implementation)
- `linalg` and `Matrix::{random, vandermonde}` have been made public (see new `Matrix::random` implementation)
- the computation of `Matrix::vandermonde` has been optimized
parent fbe6fbab
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,13 @@ def "nu-complete log-levels" []: nothing -> list<string> {
]
}
def "nu-complete encoding-methods" []: nothing -> list<string> {
[
"vandermonde"
"random",
]
}
def run-komodo [
--input: path = "",
--nb-bytes: int = 0,
......@@ -31,6 +38,7 @@ def run-komodo [
--verify,
--combine,
--inspect,
--encoding-method: string = "",
--log-level: string,
...block_hashes: string,
]: nothing -> any {
......@@ -56,6 +64,7 @@ def run-komodo [
($combine | into string)
($inspect | into string)
$nb_bytes
$encoding_method
] | append $block_hashes)
} | complete
......@@ -94,6 +103,7 @@ export def "komodo setup" [
export def "komodo prove" [
input: path,
--fec-params: record<k: int, n: int>,
--encoding-method: string@"nu-complete encoding-methods" = "random",
--log-level: string@"nu-complete log-levels" = $DEFAULT_LOG_LEVEL
]: nothing -> list<string> {
(
......@@ -102,6 +112,7 @@ export def "komodo prove" [
--input $input
-k $fec_params.k
-n $fec_params.n
--encoding-method $encoding_method
)
}
......
use std::ops::{Add, Mul};
use ark_ec::pairing::Pairing;
use ark_ff::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{One, Zero};
use ark_std::Zero;
use rs_merkle::algorithms::Sha256;
use rs_merkle::Hasher;
......@@ -48,13 +47,13 @@ impl<E: Pairing> Shard<E> {
}
}
pub fn encode<E: Pairing>(data: &[u8], k: usize, n: usize) -> Result<Vec<Shard<E>>, KomodoError> {
pub fn encode<E: Pairing>(
data: &[u8],
encoding_mat: &Matrix<E::ScalarField>,
) -> Result<Vec<Shard<E>>, KomodoError> {
let hash = Sha256::hash(data).to_vec();
let points: Vec<E::ScalarField> = (0..n)
.map(|i| E::ScalarField::from_le_bytes_mod_order(&i.to_le_bytes()))
.collect();
let encoding = Matrix::vandermonde(&points, k);
let k = encoding_mat.height;
let source_shards = Matrix::from_vec_vec(
field::split_data_into_field_elements::<E>(data, k)
......@@ -64,26 +63,17 @@ pub fn encode<E: Pairing>(data: &[u8], k: usize, n: usize) -> Result<Vec<Shard<E
)?;
Ok(source_shards
.mul(&encoding)?
.mul(encoding_mat)?
.transpose()
.elements
.chunks(source_shards.height)
.enumerate()
.map(|(i, s)| {
let alpha = E::ScalarField::from_le_bytes_mod_order(&i.to_le_bytes());
let mut linear_combination = Vec::new();
linear_combination.push(E::ScalarField::one());
for i in 1..k {
linear_combination.push(linear_combination[i - 1].mul(alpha));
}
Shard {
k: k as u32,
linear_combination,
hash: hash.clone(),
bytes: s.to_vec(),
size: data.len(),
}
.map(|(j, s)| Shard {
k: k as u32,
linear_combination: encoding_mat.get_col(j).unwrap(),
hash: hash.clone(),
bytes: s.to_vec(),
size: data.len(),
})
.collect())
}
......@@ -129,6 +119,7 @@ mod tests {
use crate::{
fec::{decode, encode, Shard},
field,
linalg::Matrix,
};
fn bytes() -> Vec<u8> {
......@@ -143,7 +134,7 @@ mod tests {
let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", data.len(), k, n);
assert_eq!(
data,
decode::<E>(encode(data, k, n).unwrap()).unwrap(),
decode::<E>(encode(data, &Matrix::random(k, n)).unwrap()).unwrap(),
"{test_case}"
);
}
......@@ -161,7 +152,7 @@ mod tests {
}
fn decoding_with_recoding_template<E: Pairing>(data: &[u8], k: usize, n: usize) {
let mut shards = encode(data, k, n).unwrap();
let mut shards = encode(data, &Matrix::random(k, n)).unwrap();
shards[1] = shards[2].combine(to_curve::<E>(7), &shards[4], to_curve::<E>(6));
shards[2] = shards[1].combine(to_curve::<E>(5), &shards[3], to_curve::<E>(4));
assert_eq!(
......
......@@ -11,11 +11,13 @@ use tracing::{debug, info};
mod error;
pub mod fec;
mod field;
mod linalg;
pub mod linalg;
pub mod setup;
use error::KomodoError;
use crate::linalg::Matrix;
#[derive(Debug, Default, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct Block<E: Pairing> {
pub shard: fec::Shard<E>,
......@@ -97,8 +99,7 @@ where
pub fn encode<E, P>(
bytes: &[u8],
k: usize,
n: usize,
encoding_mat: &Matrix<E::ScalarField>,
powers: &Powers<E>,
) -> Result<Vec<Block<E>>, ark_poly_commit::Error>
where
......@@ -108,6 +109,8 @@ where
{
info!("encoding and proving {} bytes", bytes.len());
let k = encoding_mat.height;
debug!("splitting bytes into polynomials");
let elements = field::split_data_into_field_elements::<E>(bytes, k);
let polynomials = elements
......@@ -128,7 +131,7 @@ where
debug!("committing the polynomials");
let (commits, _) = commit(powers, &polynomials_to_commit)?;
Ok(fec::encode(bytes, k, n)
Ok(fec::encode(bytes, encoding_mat)
.unwrap() // TODO: don't unwrap here
.iter()
.map(|s| Block {
......@@ -237,6 +240,7 @@ mod tests {
use crate::{
batch_verify, encode,
fec::{decode, Shard},
linalg::Matrix,
recode, setup, verify, Block,
};
......@@ -248,8 +252,7 @@ mod tests {
fn verify_template<E, P>(
bytes: &[u8],
k: usize,
n: usize,
encoding_mat: &Matrix<E::ScalarField>,
batch: &[usize],
) -> Result<(), ark_poly_commit::Error>
where
......@@ -258,7 +261,7 @@ mod tests {
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let powers = setup::random(bytes.len())?;
let blocks = encode::<E, P>(bytes, k, n, &powers)?;
let blocks = encode::<E, P>(bytes, encoding_mat, &powers)?;
for block in &blocks {
assert!(verify::<E, P>(block, &powers)?);
......@@ -289,14 +292,16 @@ mod tests {
let batch = [1, 2, 3];
let bytes = bytes();
let encoding_mat = Matrix::random(k, n);
let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", bytes.len(), k, n);
verify_template::<E, P>(&bytes, k, n, &batch)
verify_template::<E, P>(&bytes, &encoding_mat, &batch)
.unwrap_or_else(|_| panic!("verification failed for bls12-381\n{test_case}"));
verify_template::<E, P>(&bytes[0..(bytes.len() - 10)], k, n, &batch).unwrap_or_else(|_| {
panic!("verification failed for bls12-381 with padding\n{test_case}")
});
verify_template::<E, P>(&bytes[0..(bytes.len() - 10)], &encoding_mat, &batch)
.unwrap_or_else(|_| {
panic!("verification failed for bls12-381 with padding\n{test_case}")
});
}
#[ignore = "Semi-AVID-PR does not support large padding"]
......@@ -309,26 +314,29 @@ mod tests {
let batch = [1, 2, 3];
let bytes = bytes();
let encoding_mat = Matrix::random(k, n);
let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", bytes.len(), k, n);
verify_template::<E, P>(&bytes[0..(bytes.len() - 33)], k, n, &batch).unwrap_or_else(|_| {
panic!("verification failed for bls12-381 with padding\n{test_case}")
});
verify_template::<E, P>(&bytes[0..(bytes.len() - 33)], &encoding_mat, &batch)
.unwrap_or_else(|_| {
panic!("verification failed for bls12-381 with padding\n{test_case}")
});
}
fn verify_with_errors_template<E, P>(
bytes: &[u8],
k: usize,
n: usize,
encoding_mat: &Matrix<E::ScalarField>,
) -> Result<(), ark_poly_commit::Error>
where
E: Pairing,
P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>,
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let k = encoding_mat.height;
let powers = setup::random(bytes.len())?;
let blocks = encode::<E, P>(bytes, k, n, &powers)?;
let blocks = encode::<E, P>(bytes, encoding_mat, &powers)?;
for block in &blocks {
assert!(verify::<E, P>(block, &powers)?);
......@@ -368,20 +376,21 @@ mod tests {
let (k, n) = (4, 6);
let bytes = bytes();
let encoding_mat = Matrix::random(k, n);
let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", bytes.len(), k, n);
verify_with_errors_template::<E, P>(&bytes, k, n)
verify_with_errors_template::<E, P>(&bytes, &encoding_mat)
.unwrap_or_else(|_| panic!("verification failed for bls12-381\n{test_case}"));
verify_with_errors_template::<E, P>(&bytes[0..(bytes.len() - 10)], k, n).unwrap_or_else(
|_| panic!("verification failed for bls12-381 with padding\n{test_case}"),
);
verify_with_errors_template::<E, P>(&bytes[0..(bytes.len() - 10)], &encoding_mat)
.unwrap_or_else(|_| {
panic!("verification failed for bls12-381 with padding\n{test_case}")
});
}
fn verify_recoding_template<E, P>(
bytes: &[u8],
k: usize,
n: usize,
encoding_mat: &Matrix<E::ScalarField>,
) -> Result<(), ark_poly_commit::Error>
where
E: Pairing,
......@@ -389,7 +398,7 @@ mod tests {
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let powers = setup::random(bytes.len())?;
let blocks = encode::<E, P>(bytes, k, n, &powers)?;
let blocks = encode::<E, P>(bytes, encoding_mat, &powers)?;
assert!(verify::<E, P>(
&recode(&blocks[2], &blocks[3]).unwrap(),
......@@ -411,20 +420,21 @@ mod tests {
let (k, n) = (4, 6);
let bytes = bytes();
let encoding_mat = Matrix::random(k, n);
let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", bytes.len(), k, n);
verify_recoding_template::<E, P>(&bytes, k, n)
verify_recoding_template::<E, P>(&bytes, &encoding_mat)
.unwrap_or_else(|_| panic!("verification failed for bls12-381\n{test_case}"));
verify_recoding_template::<E, P>(&bytes[0..(bytes.len() - 10)], k, n).unwrap_or_else(
|_| panic!("verification failed for bls12-381 with padding\n{test_case}"),
);
verify_recoding_template::<E, P>(&bytes[0..(bytes.len() - 10)], &encoding_mat)
.unwrap_or_else(|_| {
panic!("verification failed for bls12-381 with padding\n{test_case}")
});
}
fn end_to_end_template<E, P>(
bytes: &[u8],
k: usize,
n: usize,
encoding_mat: &Matrix<E::ScalarField>,
) -> Result<(), ark_poly_commit::Error>
where
E: Pairing,
......@@ -432,7 +442,7 @@ mod tests {
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let powers = setup::random(bytes.len())?;
let blocks: Vec<Shard<E>> = encode::<E, P>(bytes, k, n, &powers)?
let blocks: Vec<Shard<E>> = encode::<E, P>(bytes, encoding_mat, &powers)?
.iter()
.map(|b| b.shard.clone())
.collect();
......@@ -450,14 +460,15 @@ mod tests {
let (k, n) = (4, 6);
let bytes = bytes();
let encoding_mat = Matrix::random(k, n);
let test_case = format!("TEST | data: {} bytes, k: {}, n: {}", bytes.len(), k, n);
end_to_end_template::<E, P>(&bytes, k, n)
end_to_end_template::<E, P>(&bytes, &encoding_mat)
.unwrap_or_else(|_| panic!("end to end failed for bls12-381\n{test_case}"));
end_to_end_template::<E, P>(&bytes[0..(bytes.len() - 10)], k, n).unwrap_or_else(|_| {
panic!("end to end failed for bls12-381 with padding\n{test_case}")
});
end_to_end_template::<E, P>(&bytes[0..(bytes.len() - 10)], &encoding_mat).unwrap_or_else(
|_| panic!("end to end failed for bls12-381 with padding\n{test_case}"),
);
}
fn end_to_end_with_recoding_template<E, P>(bytes: &[u8]) -> Result<(), ark_poly_commit::Error>
......@@ -467,7 +478,7 @@ mod tests {
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let powers = setup::random(bytes.len())?;
let blocks = encode::<E, P>(bytes, 3, 5, &powers)?;
let blocks = encode::<E, P>(bytes, &Matrix::random(3, 5), &powers)?;
let b_0_1 = recode(&blocks[0], &blocks[1]).unwrap();
let shards = vec![
......
use ark_ff::Field;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use rand::Rng;
use crate::error::KomodoError;
#[derive(Clone, PartialEq, Default, Debug)]
pub(super) struct Matrix<T: Field> {
#[derive(Clone, PartialEq, Default, Debug, CanonicalSerialize, CanonicalDeserialize)]
pub struct Matrix<T: Field> {
pub elements: Vec<T>,
pub height: usize,
width: usize,
pub width: usize,
}
impl<T: Field> Matrix<T> {
......@@ -30,15 +32,17 @@ impl<T: Field> Matrix<T> {
Self::from_diagonal(vec![T::one(); size])
}
pub(super) fn vandermonde(points: &[T], height: usize) -> Self {
pub fn vandermonde(points: &[T], height: usize) -> Self {
let width = points.len();
let mut elements = Vec::new();
elements.resize(height * width, T::zero());
for (j, pj) in points.iter().enumerate() {
let mut el = T::one();
for i in 0..height {
elements[i * width + j] = pj.pow([i as u64]);
elements[i * width + j] = el;
el *= pj;
}
}
......@@ -49,6 +53,24 @@ impl<T: Field> Matrix<T> {
}
}
pub fn random(n: usize, m: usize) -> Self {
let mut rng = rand::thread_rng();
Matrix::from_vec_vec(
(0..n)
.map(|_| {
(0..m)
.map(|_| {
let element: u128 = rng.gen();
T::from(element)
})
.collect()
})
.collect::<Vec<Vec<T>>>(),
)
.unwrap()
}
pub(super) fn from_vec_vec(matrix: Vec<Vec<T>>) -> Result<Self, KomodoError> {
let height = matrix.len();
let width = matrix[0].len();
......@@ -87,6 +109,14 @@ impl<T: Field> Matrix<T> {
self.elements[i * self.width + j] = value;
}
pub(super) fn get_col(&self, j: usize) -> Option<Vec<T>> {
if j >= self.width {
return None;
}
Some((0..self.height).map(|i| self.get(i, j)).collect())
}
// compute _row / value_
fn divide_row_by(&mut self, row: usize, value: T) {
for j in 0..self.width {
......@@ -217,18 +247,10 @@ impl<T: Field> Matrix<T> {
#[cfg(test)]
mod tests {
use ark_bls12_381::Fr;
use ark_ff::Field;
use ark_std::{One, Zero};
use rand::Rng;
use super::{KomodoError, Matrix};
fn random_field_element<T: Field>() -> T {
let mut rng = rand::thread_rng();
let element: u128 = rng.gen();
T::from(element)
}
#[test]
fn from_vec_vec() {
let actual = Matrix::from_vec_vec(vec![
......@@ -332,12 +354,7 @@ mod tests {
assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(3));
let n = 20;
let matrix = Matrix::from_vec_vec(
(0..n)
.map(|_| (0..n).map(|_| random_field_element()).collect())
.collect::<Vec<Vec<Fr>>>(),
)
.unwrap();
let matrix = Matrix::random(n, n);
let inverse = matrix.invert().unwrap();
assert_eq!(matrix.mul(&inverse).unwrap(), Matrix::<Fr>::identity(n));
assert_eq!(inverse.mul(&matrix).unwrap(), Matrix::<Fr>::identity(n));
......@@ -439,4 +456,25 @@ mod tests {
.unwrap();
assert_eq!(matrix.truncate(Some(1), Some(2)), truncated);
}
#[test]
fn get_cols() {
let matrix = Matrix::from_vec_vec(vec![
vec![Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(10)],
vec![Fr::from(4), Fr::from(5), Fr::from(6), Fr::from(11)],
vec![Fr::from(7), Fr::from(8), Fr::from(9), Fr::from(12)],
])
.unwrap();
assert!(matrix.get_col(10).is_none());
assert_eq!(
matrix.get_col(0),
Some(vec![Fr::from(1), Fr::from(4), Fr::from(7)])
);
assert_eq!(
matrix.get_col(3),
Some(vec![Fr::from(10), Fr::from(11), Fr::from(12)])
);
}
}
......@@ -6,10 +6,12 @@ use std::{fs::File, path::PathBuf};
use ark_bls12_381::Bls12_381;
use ark_ec::pairing::Pairing;
use ark_ff::PrimeField;
use ark_poly::univariate::DensePolynomial;
use ark_poly::DenseUVPolynomial;
use ark_poly_commit::kzg10::Powers;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use komodo::linalg::Matrix;
use komodo::recode;
use rs_merkle::algorithms::Sha256;
use rs_merkle::Hasher;
......@@ -38,6 +40,7 @@ fn parse_args() -> (
bool,
bool,
usize,
String,
Vec<String>,
) {
let bytes_path = std::env::args()
......@@ -91,7 +94,10 @@ fn parse_args() -> (
.expect("expected nb_bytes as 10th positional argument")
.parse()
.expect("could not parse nb_bytes as a usize");
let block_hashes = std::env::args().skip(11).collect::<Vec<_>>();
let encoding_method = std::env::args()
.nth(11)
.expect("expected encoding_method as 11th positional argument");
let block_hashes = std::env::args().skip(12).collect::<Vec<_>>();
(
bytes,
......@@ -104,6 +110,7 @@ fn parse_args() -> (
do_combine_blocks,
do_inspect_blocks,
nb_bytes,
encoding_method,
block_hashes,
)
}
......@@ -232,6 +239,7 @@ fn main() {
do_combine_blocks,
do_inspect_blocks,
nb_bytes,
encoding_method,
block_hashes,
) = parse_args();
......@@ -324,8 +332,24 @@ fn main() {
exit(0);
}
let encoding_mat = match encoding_method.as_str() {
"vandermonde" => {
let points: Vec<<Bls12_381 as Pairing>::ScalarField> = (0..n)
.map(|i| {
<Bls12_381 as Pairing>::ScalarField::from_le_bytes_mod_order(&i.to_le_bytes())
})
.collect();
Matrix::vandermonde(&points, k)
}
"random" => Matrix::random(k, n),
m => {
throw_error(1, &format!("invalid encoding method: {}", m));
unreachable!()
}
};
dump_blocks(
&encode::<Bls12_381, UniPoly12_381>(&bytes, k, n, &powers).unwrap_or_else(|e| {
&encode::<Bls12_381, UniPoly12_381>(&bytes, &encoding_mat, &powers).unwrap_or_else(|e| {
throw_error(1, &format!("could not encode: {}", e));
unreachable!()
}),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment