Skip to content
Snippets Groups Projects
Commit 4273d869 authored by STEVAN Antoine's avatar STEVAN Antoine :crab:
Browse files

add a benchmark for the _trusted setup_ (dragoon/komodo!52)

as per title

## changelog
- add a `setup.rs` benchmark which measures
  - the creation of a random setup
  - the serialization of a setup
  - the deserialization of a setup
- refactor `plot.py` a bit to
  - use `argparse`
  - take `--bench` to plot either _linalg_ or _setup_ results
  - write a complete `plot_setup` function
  - add a bit of documentation here and there

## example results
![Figure_1](/uploads/ea4bddc5c0c426d0824bad55e2e2e5aa/Figure_1.png)
parent e06a9b5d
No related branches found
No related tags found
No related merge requests found
......@@ -31,3 +31,7 @@ harness = false
[[bench]]
name = "linalg"
harness = false
[[bench]]
name = "setup"
harness = false
......@@ -3,7 +3,13 @@
nushell> cargo criterion --output-format verbose --message-format json out> results.ndjson
```
## add the _trusted setup_ sizes
```shell
nushell> cargo run --example bench_setup_size out>> results.ndjson
```
## plot the results
```shell
python benches/plot.py results.ndjson
python benches/plot.py results.ndjson --bench linalg
python benches/plot.py results.ndjson --bench setup
```
......@@ -2,47 +2,28 @@ import matplotlib.pyplot as plt
import json
import sys
import os
import argparse
from typing import Any, Dict, List
NB_NS_IN_MS = 1e6
NB_BYTES_IN_KB = 1_024
# represents a full NDJSON dataset, i.e. directly generated by `cargo criterion`,
# filtered to remove invalid lines, e.g. whose `$.reason` is not
# `benchmark-complete`
Data = List[Dict[str, Any]]
def extract(data: Data, k1: str, k2: str) -> List[float]:
return [line[k1][k2] / NB_NS_IN_MS for line in data]
def plot(data: Data, key: str, ax):
filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
sizes = [
int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data
]
means = extract(filtered_data, "mean", "estimate")
up = extract(filtered_data, "mean", "upper_bound")
down = extract(filtered_data, "mean", "lower_bound")
ax.plot(sizes, means, label="mean", color="blue")
ax.fill_between(sizes, down, up, color="blue", alpha=0.3, label="mean bounds")
medians = extract(filtered_data, "median", "estimate")
up = extract(filtered_data, "median", "upper_bound")
down = extract(filtered_data, "median", "lower_bound")
ax.plot(sizes, medians, label="median", color="orange")
ax.fill_between(sizes, down, up, color="orange", alpha=0.3, label="median bounds")
def parse_args():
if len(sys.argv) == 1:
print("please give a filename as first positional argument")
exit(1)
return sys.argv[1]
# k1: namely `mean` or `median`
# k2: namely `estimation`, `upper_bound` or `lower_bound`
def extract_time(data: Data, k1: str, k2: str) -> List[float]:
return [line[k1][k2] if k2 is not None else line[k1] / NB_NS_IN_MS for line in data]
# read a result dataset from an NDJSON file and filter out invalid lines
#
# here, invalid lines are all the lines with `$.reason` not equal to
# `benchmark-complete` that are generated by `cargo criterion` but useless.
def read_data(data_file: str) -> Data:
if not os.path.exists(data_file):
print(f"no such file: `{data_file}`")
......@@ -60,9 +41,28 @@ def read_data(data_file: str) -> Data:
return data
if __name__ == "__main__":
results_file = parse_args()
data = read_data(results_file)
def plot_linalg(data: Data):
# key: the start of the `$.id` field
def plot(data: Data, key: str, ax):
filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
sizes = [
int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data
]
means = extract_time(filtered_data, "mean", "estimate")
up = extract_time(filtered_data, "mean", "upper_bound")
down = extract_time(filtered_data, "mean", "lower_bound")
ax.plot(sizes, means, label="mean", color="blue")
ax.fill_between(sizes, down, up, color="blue", alpha=0.3, label="mean bounds")
medians = extract_time(filtered_data, "median", "estimate")
up = extract_time(filtered_data, "median", "upper_bound")
down = extract_time(filtered_data, "median", "lower_bound")
ax.plot(sizes, medians, label="median", color="orange")
ax.fill_between(sizes, down, up, color="orange", alpha=0.3, label="median bounds")
labels = ["transpose", "mul", "inverse"]
......@@ -77,3 +77,79 @@ if __name__ == "__main__":
ax.grid()
plt.show()
def plot_setup(data: Data):
fig, axs = plt.subplots(4, 1, sharex=True)
# key: the start of the `$.id` field
# i: the index where the size of the input data needs to be extracted from
def plot(data: Data, key: str, i: int, label: str, color: str, error_bar: bool, ax):
filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
sizes = [int(line["id"].split(' ')[i]) / NB_BYTES_IN_KB for line in filtered_data]
if error_bar:
means = extract_time(filtered_data, "mean", "estimate")
up = extract_time(filtered_data, "mean", "upper_bound")
down = extract_time(filtered_data, "mean", "lower_bound")
else:
means = [x * NB_NS_IN_MS / NB_BYTES_IN_KB for x in extract_time(filtered_data, "mean", None)]
ax.plot(sizes, means, label=label, color=color)
if error_bar:
ax.fill_between(sizes, down, up, color=color, alpha=0.3, label="mean bounds")
# setup
plot(data, "setup/setup", 1, "mean", "orange", True, axs[0])
axs[0].set_title("time to generate a random trusted setup")
axs[0].set_ylabel("time (in ms)")
axs[0].legend()
axs[0].grid()
# serialization
plot(data, "setup/serializing with compression", -3, "mean compressed", "orange", True, axs[1])
plot(data, "setup/serializing with no compression", -3, "mean uncompressed", "blue", True, axs[1])
axs[1].set_title("serialization")
axs[1].set_ylabel("time (in ms)")
axs[1].legend()
axs[1].grid()
# deserialization
plot(data, "setup/deserializing with no compression and no validation", -3, "mean uncompressed unvalidated", "red", True, axs[2])
plot(data, "setup/deserializing with compression and no validation", -3, "mean compressed unvalidated", "orange", True, axs[2])
plot(data, "setup/deserializing with no compression and validation", -3, "mean uncompressed validated", "blue", True, axs[2])
plot(data, "setup/deserializing with compression and validation", -3, "mean compressed validated", "green", True, axs[2])
axs[2].set_title("deserialization")
axs[2].set_ylabel("time (in ms)")
axs[2].legend()
axs[2].grid()
plot(data, "serialized size with no compression and no validation", -3, "mean uncompressed unvalidated", "red", False, axs[3])
plot(data, "serialized size with compression and no validation", -3, "mean compressed unvalidated", "orange", False, axs[3])
plot(data, "serialized size with no compression and validation", -3, "mean uncompressed validated", "blue", False, axs[3])
plot(data, "serialized size with compression and validation", -3, "mean compressed validated", "green", False, axs[3])
axs[3].set_title("size")
axs[3].set_xlabel("number of expected bytes (in kb)")
axs[3].set_ylabel("size (in kb)")
axs[3].legend()
axs[3].grid()
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("filename", type=str)
parser.add_argument(
"--bench", "-b", type=str, choices=["linalg", "setup"], required=True
)
args = parser.parse_args()
data = read_data(args.filename)
match args.bench:
case "linalg":
plot_linalg(data)
case "setup":
plot_setup(data)
use std::ops::Div;
use ark_bls12_381::Bls12_381;
use ark_ec::pairing::Pairing;
use ark_poly::univariate::DensePolynomial;
use ark_poly::DenseUVPolynomial;
use ark_poly_commit::kzg10::Powers;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
type UniPoly12_381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>;
fn setup_template<E, P>(c: &mut Criterion, nb_bytes: usize)
where
E: Pairing,
P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>,
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let mut group = c.benchmark_group("setup");
group.bench_function(
&format!("setup {} on {}", nb_bytes, std::any::type_name::<E>()),
|b| b.iter(|| komodo::setup::random::<E, P>(nb_bytes).unwrap()),
);
let setup = komodo::setup::random::<E, P>(nb_bytes).unwrap();
group.bench_function(
&format!(
"serializing with compression {} on {}",
nb_bytes,
std::any::type_name::<E>()
),
|b| {
b.iter(|| {
let mut serialized = vec![0; setup.serialized_size(Compress::Yes)];
setup
.serialize_with_mode(&mut serialized[..], Compress::Yes)
.unwrap();
})
},
);
group.bench_function(
&format!(
"serializing with no compression {} on {}",
nb_bytes,
std::any::type_name::<E>()
),
|b| {
b.iter(|| {
let mut serialized = vec![0; setup.serialized_size(Compress::No)];
setup
.serialize_with_mode(&mut serialized[..], Compress::No)
.unwrap();
})
},
);
for (compress, validate) in [
(Compress::Yes, Validate::Yes),
(Compress::Yes, Validate::No),
(Compress::No, Validate::Yes),
(Compress::No, Validate::No),
] {
let mut serialized = vec![0; setup.serialized_size(compress)];
setup
.serialize_with_mode(&mut serialized[..], compress)
.unwrap();
println!(
r#"["id": "{} bytes serialized with {} and {} on {}", "size": {}"#,
nb_bytes,
match compress {
Compress::Yes => "compression",
Compress::No => "no compression",
},
match validate {
Validate::Yes => "validation",
Validate::No => "no validation",
},
std::any::type_name::<E>(),
serialized.len(),
);
group.bench_function(
&format!(
"deserializing with {} and {} {} on {}",
match compress {
Compress::Yes => "compression",
Compress::No => "no compression",
},
match validate {
Validate::Yes => "validation",
Validate::No => "no validation",
},
nb_bytes,
std::any::type_name::<E>()
),
|b| {
b.iter(|| {
Powers::<Bls12_381>::deserialize_with_mode(&serialized[..], compress, validate)
})
},
);
}
group.finish();
}
fn setup(c: &mut Criterion) {
for n in [1, 2, 4, 8, 16] {
setup_template::<Bls12_381, UniPoly12_381>(c, black_box(n * 1024));
}
}
criterion_group!(benches, setup);
criterion_main!(benches);
use std::ops::Div;
use ark_bls12_381::Bls12_381;
use ark_ec::pairing::Pairing;
use ark_poly::univariate::DensePolynomial;
use ark_poly::DenseUVPolynomial;
use ark_serialize::{CanonicalSerialize, Compress, Validate};
type UniPoly12_381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>;
fn setup_template<E, P>(nb_bytes: usize)
where
E: Pairing,
P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>,
for<'a, 'b> &'a P: Div<&'b P, Output = P>,
{
let setup = komodo::setup::random::<E, P>(nb_bytes).unwrap();
for (compress, validate) in [
(Compress::Yes, Validate::Yes),
(Compress::Yes, Validate::No),
(Compress::No, Validate::Yes),
(Compress::No, Validate::No),
] {
let mut serialized = vec![0; setup.serialized_size(compress)];
setup
.serialize_with_mode(&mut serialized[..], compress)
.unwrap();
println!(
r#"{{"reason": "benchmark-complete", "id": "serialized size with {} and {} {} on {}", "mean": {}}}"#,
match compress {
Compress::Yes => "compression",
Compress::No => "no compression",
},
match validate {
Validate::Yes => "validation",
Validate::No => "no validation",
},
nb_bytes,
std::any::type_name::<E>(),
serialized.len(),
);
}
}
fn main() {
for n in [1, 2, 4, 8, 16] {
setup_template::<Bls12_381, UniPoly12_381>(n * 1024);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment