From 4273d8691b5db6a71130bb76431bba5f49db561a Mon Sep 17 00:00:00 2001 From: STEVAN Antoine <antoine.stevan@isae-supaero.fr> Date: Tue, 2 Apr 2024 15:04:42 +0000 Subject: [PATCH] add a benchmark for the _trusted setup_ (dragoon/komodo!52) as per title ## changelog - add a `setup.rs` benchmark which measures - the creation of a random setup - the serialization of a setup - the deserialization of a setup - refactor `plot.py` a bit to - use `argparse` - take `--bench` to plot either _linalg_ or _setup_ results - write a complete `plot_setup` function - add a bit of documentation here and there ## example results  --- Cargo.toml | 4 + benches/README.md | 8 +- benches/plot.py | 146 ++++++++++++++++++++++++++--------- benches/setup.rs | 119 ++++++++++++++++++++++++++++ examples/bench_setup_size.rs | 52 +++++++++++++ 5 files changed, 293 insertions(+), 36 deletions(-) create mode 100644 benches/setup.rs create mode 100644 examples/bench_setup_size.rs diff --git a/Cargo.toml b/Cargo.toml index df807bf1..af1c6f88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,7 @@ harness = false [[bench]] name = "linalg" harness = false + +[[bench]] +name = "setup" +harness = false diff --git a/benches/README.md b/benches/README.md index 9867ccd1..d085dee9 100644 --- a/benches/README.md +++ b/benches/README.md @@ -3,7 +3,13 @@ nushell> cargo criterion --output-format verbose --message-format json out> results.ndjson ``` +## add the _trusted setup_ sizes +```shell +nushell> cargo run --example bench_setup_size out>> results.ndjson +``` + ## plot the results ```shell -python benches/plot.py results.ndjson +python benches/plot.py results.ndjson --bench linalg +python benches/plot.py results.ndjson --bench setup ``` diff --git a/benches/plot.py b/benches/plot.py index 08b17c60..ff967d67 100644 --- a/benches/plot.py +++ b/benches/plot.py @@ -2,47 +2,28 @@ import matplotlib.pyplot as plt import json import sys import os +import argparse from typing import Any, Dict, List NB_NS_IN_MS = 1e6 +NB_BYTES_IN_KB = 1_024 +# represents a full NDJSON dataset, i.e. directly generated by `cargo criterion`, +# filtered to remove invalid lines, e.g. whose `$.reason` is not +# `benchmark-complete` Data = List[Dict[str, Any]] -def extract(data: Data, k1: str, k2: str) -> List[float]: - return [line[k1][k2] / NB_NS_IN_MS for line in data] - - -def plot(data: Data, key: str, ax): - filtered_data = list(filter(lambda line: line["id"].startswith(key), data)) - - sizes = [ - int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data - ] - - means = extract(filtered_data, "mean", "estimate") - up = extract(filtered_data, "mean", "upper_bound") - down = extract(filtered_data, "mean", "lower_bound") - - ax.plot(sizes, means, label="mean", color="blue") - ax.fill_between(sizes, down, up, color="blue", alpha=0.3, label="mean bounds") - - medians = extract(filtered_data, "median", "estimate") - up = extract(filtered_data, "median", "upper_bound") - down = extract(filtered_data, "median", "lower_bound") - - ax.plot(sizes, medians, label="median", color="orange") - ax.fill_between(sizes, down, up, color="orange", alpha=0.3, label="median bounds") - - -def parse_args(): - if len(sys.argv) == 1: - print("please give a filename as first positional argument") - exit(1) - - return sys.argv[1] +# k1: namely `mean` or `median` +# k2: namely `estimation`, `upper_bound` or `lower_bound` +def extract_time(data: Data, k1: str, k2: str) -> List[float]: + return [line[k1][k2] if k2 is not None else line[k1] / NB_NS_IN_MS for line in data] +# read a result dataset from an NDJSON file and filter out invalid lines +# +# here, invalid lines are all the lines with `$.reason` not equal to +# `benchmark-complete` that are generated by `cargo criterion` but useless. def read_data(data_file: str) -> Data: if not os.path.exists(data_file): print(f"no such file: `{data_file}`") @@ -60,9 +41,28 @@ def read_data(data_file: str) -> Data: return data -if __name__ == "__main__": - results_file = parse_args() - data = read_data(results_file) +def plot_linalg(data: Data): + # key: the start of the `$.id` field + def plot(data: Data, key: str, ax): + filtered_data = list(filter(lambda line: line["id"].startswith(key), data)) + + sizes = [ + int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data + ] + + means = extract_time(filtered_data, "mean", "estimate") + up = extract_time(filtered_data, "mean", "upper_bound") + down = extract_time(filtered_data, "mean", "lower_bound") + + ax.plot(sizes, means, label="mean", color="blue") + ax.fill_between(sizes, down, up, color="blue", alpha=0.3, label="mean bounds") + + medians = extract_time(filtered_data, "median", "estimate") + up = extract_time(filtered_data, "median", "upper_bound") + down = extract_time(filtered_data, "median", "lower_bound") + + ax.plot(sizes, medians, label="median", color="orange") + ax.fill_between(sizes, down, up, color="orange", alpha=0.3, label="median bounds") labels = ["transpose", "mul", "inverse"] @@ -77,3 +77,79 @@ if __name__ == "__main__": ax.grid() plt.show() + + +def plot_setup(data: Data): + fig, axs = plt.subplots(4, 1, sharex=True) + + # key: the start of the `$.id` field + # i: the index where the size of the input data needs to be extracted from + def plot(data: Data, key: str, i: int, label: str, color: str, error_bar: bool, ax): + filtered_data = list(filter(lambda line: line["id"].startswith(key), data)) + sizes = [int(line["id"].split(' ')[i]) / NB_BYTES_IN_KB for line in filtered_data] + + if error_bar: + means = extract_time(filtered_data, "mean", "estimate") + up = extract_time(filtered_data, "mean", "upper_bound") + down = extract_time(filtered_data, "mean", "lower_bound") + else: + means = [x * NB_NS_IN_MS / NB_BYTES_IN_KB for x in extract_time(filtered_data, "mean", None)] + + ax.plot(sizes, means, label=label, color=color) + + if error_bar: + ax.fill_between(sizes, down, up, color=color, alpha=0.3, label="mean bounds") + + # setup + plot(data, "setup/setup", 1, "mean", "orange", True, axs[0]) + axs[0].set_title("time to generate a random trusted setup") + axs[0].set_ylabel("time (in ms)") + axs[0].legend() + axs[0].grid() + + # serialization + plot(data, "setup/serializing with compression", -3, "mean compressed", "orange", True, axs[1]) + plot(data, "setup/serializing with no compression", -3, "mean uncompressed", "blue", True, axs[1]) + axs[1].set_title("serialization") + axs[1].set_ylabel("time (in ms)") + axs[1].legend() + axs[1].grid() + + # deserialization + plot(data, "setup/deserializing with no compression and no validation", -3, "mean uncompressed unvalidated", "red", True, axs[2]) + plot(data, "setup/deserializing with compression and no validation", -3, "mean compressed unvalidated", "orange", True, axs[2]) + plot(data, "setup/deserializing with no compression and validation", -3, "mean uncompressed validated", "blue", True, axs[2]) + plot(data, "setup/deserializing with compression and validation", -3, "mean compressed validated", "green", True, axs[2]) + axs[2].set_title("deserialization") + axs[2].set_ylabel("time (in ms)") + axs[2].legend() + axs[2].grid() + + plot(data, "serialized size with no compression and no validation", -3, "mean uncompressed unvalidated", "red", False, axs[3]) + plot(data, "serialized size with compression and no validation", -3, "mean compressed unvalidated", "orange", False, axs[3]) + plot(data, "serialized size with no compression and validation", -3, "mean uncompressed validated", "blue", False, axs[3]) + plot(data, "serialized size with compression and validation", -3, "mean compressed validated", "green", False, axs[3]) + axs[3].set_title("size") + axs[3].set_xlabel("number of expected bytes (in kb)") + axs[3].set_ylabel("size (in kb)") + axs[3].legend() + axs[3].grid() + + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("filename", type=str) + parser.add_argument( + "--bench", "-b", type=str, choices=["linalg", "setup"], required=True + ) + args = parser.parse_args() + + data = read_data(args.filename) + + match args.bench: + case "linalg": + plot_linalg(data) + case "setup": + plot_setup(data) diff --git a/benches/setup.rs b/benches/setup.rs new file mode 100644 index 00000000..3df93d88 --- /dev/null +++ b/benches/setup.rs @@ -0,0 +1,119 @@ +use std::ops::Div; + +use ark_bls12_381::Bls12_381; +use ark_ec::pairing::Pairing; +use ark_poly::univariate::DensePolynomial; +use ark_poly::DenseUVPolynomial; + +use ark_poly_commit::kzg10::Powers; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +type UniPoly12_381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>; + +fn setup_template<E, P>(c: &mut Criterion, nb_bytes: usize) +where + E: Pairing, + P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>, + for<'a, 'b> &'a P: Div<&'b P, Output = P>, +{ + let mut group = c.benchmark_group("setup"); + + group.bench_function( + &format!("setup {} on {}", nb_bytes, std::any::type_name::<E>()), + |b| b.iter(|| komodo::setup::random::<E, P>(nb_bytes).unwrap()), + ); + + let setup = komodo::setup::random::<E, P>(nb_bytes).unwrap(); + + group.bench_function( + &format!( + "serializing with compression {} on {}", + nb_bytes, + std::any::type_name::<E>() + ), + |b| { + b.iter(|| { + let mut serialized = vec![0; setup.serialized_size(Compress::Yes)]; + setup + .serialize_with_mode(&mut serialized[..], Compress::Yes) + .unwrap(); + }) + }, + ); + + group.bench_function( + &format!( + "serializing with no compression {} on {}", + nb_bytes, + std::any::type_name::<E>() + ), + |b| { + b.iter(|| { + let mut serialized = vec![0; setup.serialized_size(Compress::No)]; + setup + .serialize_with_mode(&mut serialized[..], Compress::No) + .unwrap(); + }) + }, + ); + + for (compress, validate) in [ + (Compress::Yes, Validate::Yes), + (Compress::Yes, Validate::No), + (Compress::No, Validate::Yes), + (Compress::No, Validate::No), + ] { + let mut serialized = vec![0; setup.serialized_size(compress)]; + setup + .serialize_with_mode(&mut serialized[..], compress) + .unwrap(); + + println!( + r#"["id": "{} bytes serialized with {} and {} on {}", "size": {}"#, + nb_bytes, + match compress { + Compress::Yes => "compression", + Compress::No => "no compression", + }, + match validate { + Validate::Yes => "validation", + Validate::No => "no validation", + }, + std::any::type_name::<E>(), + serialized.len(), + ); + + group.bench_function( + &format!( + "deserializing with {} and {} {} on {}", + match compress { + Compress::Yes => "compression", + Compress::No => "no compression", + }, + match validate { + Validate::Yes => "validation", + Validate::No => "no validation", + }, + nb_bytes, + std::any::type_name::<E>() + ), + |b| { + b.iter(|| { + Powers::<Bls12_381>::deserialize_with_mode(&serialized[..], compress, validate) + }) + }, + ); + } + + group.finish(); +} + +fn setup(c: &mut Criterion) { + for n in [1, 2, 4, 8, 16] { + setup_template::<Bls12_381, UniPoly12_381>(c, black_box(n * 1024)); + } +} + +criterion_group!(benches, setup); +criterion_main!(benches); diff --git a/examples/bench_setup_size.rs b/examples/bench_setup_size.rs new file mode 100644 index 00000000..1c9321bd --- /dev/null +++ b/examples/bench_setup_size.rs @@ -0,0 +1,52 @@ +use std::ops::Div; + +use ark_bls12_381::Bls12_381; +use ark_ec::pairing::Pairing; +use ark_poly::univariate::DensePolynomial; +use ark_poly::DenseUVPolynomial; + +use ark_serialize::{CanonicalSerialize, Compress, Validate}; + +type UniPoly12_381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>; + +fn setup_template<E, P>(nb_bytes: usize) +where + E: Pairing, + P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>, + for<'a, 'b> &'a P: Div<&'b P, Output = P>, +{ + let setup = komodo::setup::random::<E, P>(nb_bytes).unwrap(); + + for (compress, validate) in [ + (Compress::Yes, Validate::Yes), + (Compress::Yes, Validate::No), + (Compress::No, Validate::Yes), + (Compress::No, Validate::No), + ] { + let mut serialized = vec![0; setup.serialized_size(compress)]; + setup + .serialize_with_mode(&mut serialized[..], compress) + .unwrap(); + + println!( + r#"{{"reason": "benchmark-complete", "id": "serialized size with {} and {} {} on {}", "mean": {}}}"#, + match compress { + Compress::Yes => "compression", + Compress::No => "no compression", + }, + match validate { + Validate::Yes => "validation", + Validate::No => "no validation", + }, + nb_bytes, + std::any::type_name::<E>(), + serialized.len(), + ); + } +} + +fn main() { + for n in [1, 2, 4, 8, 16] { + setup_template::<Bls12_381, UniPoly12_381>(n * 1024); + } +} -- GitLab