From 4273d8691b5db6a71130bb76431bba5f49db561a Mon Sep 17 00:00:00 2001
From: STEVAN Antoine <antoine.stevan@isae-supaero.fr>
Date: Tue, 2 Apr 2024 15:04:42 +0000
Subject: [PATCH] add a benchmark for the _trusted setup_ (dragoon/komodo!52)

as per title

## changelog
- add a `setup.rs` benchmark which measures
  - the creation of a random setup
  - the serialization of a setup
  - the deserialization of a setup
- refactor `plot.py` a bit to
  - use `argparse`
  - take `--bench` to plot either _linalg_ or _setup_ results
  - write a complete `plot_setup` function
  - add a bit of documentation here and there

## example results
![Figure_1](/uploads/ea4bddc5c0c426d0824bad55e2e2e5aa/Figure_1.png)
---
 Cargo.toml                   |   4 +
 benches/README.md            |   8 +-
 benches/plot.py              | 146 ++++++++++++++++++++++++++---------
 benches/setup.rs             | 119 ++++++++++++++++++++++++++++
 examples/bench_setup_size.rs |  52 +++++++++++++
 5 files changed, 293 insertions(+), 36 deletions(-)
 create mode 100644 benches/setup.rs
 create mode 100644 examples/bench_setup_size.rs

diff --git a/Cargo.toml b/Cargo.toml
index df807bf1..af1c6f88 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,3 +31,7 @@ harness = false
 [[bench]]
 name = "linalg"
 harness = false
+
+[[bench]]
+name = "setup"
+harness = false
diff --git a/benches/README.md b/benches/README.md
index 9867ccd1..d085dee9 100644
--- a/benches/README.md
+++ b/benches/README.md
@@ -3,7 +3,13 @@
 nushell> cargo criterion --output-format verbose --message-format json out> results.ndjson
 ```
 
+## add the _trusted setup_ sizes
+```shell
+nushell> cargo run --example bench_setup_size out>> results.ndjson
+```
+
 ## plot the results
 ```shell
-python benches/plot.py results.ndjson
+python benches/plot.py results.ndjson --bench linalg
+python benches/plot.py results.ndjson --bench setup
 ```
diff --git a/benches/plot.py b/benches/plot.py
index 08b17c60..ff967d67 100644
--- a/benches/plot.py
+++ b/benches/plot.py
@@ -2,47 +2,28 @@ import matplotlib.pyplot as plt
 import json
 import sys
 import os
+import argparse
 from typing import Any, Dict, List
 
 NB_NS_IN_MS = 1e6
+NB_BYTES_IN_KB = 1_024
 
+# represents a full NDJSON dataset, i.e. directly generated by `cargo criterion`,
+# filtered to remove invalid lines, e.g. whose `$.reason` is not
+# `benchmark-complete`
 Data = List[Dict[str, Any]]
 
 
-def extract(data: Data, k1: str, k2: str) -> List[float]:
-    return [line[k1][k2] / NB_NS_IN_MS for line in data]
-
-
-def plot(data: Data, key: str, ax):
-    filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
-
-    sizes = [
-        int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data
-    ]
-
-    means = extract(filtered_data, "mean", "estimate")
-    up = extract(filtered_data, "mean", "upper_bound")
-    down = extract(filtered_data, "mean", "lower_bound")
-
-    ax.plot(sizes, means, label="mean", color="blue")
-    ax.fill_between(sizes, down, up, color="blue", alpha=0.3, label="mean bounds")
-
-    medians = extract(filtered_data, "median", "estimate")
-    up = extract(filtered_data, "median", "upper_bound")
-    down = extract(filtered_data, "median", "lower_bound")
-
-    ax.plot(sizes, medians, label="median", color="orange")
-    ax.fill_between(sizes, down, up, color="orange", alpha=0.3, label="median bounds")
-
-
-def parse_args():
-    if len(sys.argv) == 1:
-        print("please give a filename as first positional argument")
-        exit(1)
-
-    return sys.argv[1]
+# k1: namely `mean` or `median`
+# k2: namely `estimation`, `upper_bound` or `lower_bound`
+def extract_time(data: Data, k1: str, k2: str) -> List[float]:
+    return [line[k1][k2] if k2 is not None else line[k1] / NB_NS_IN_MS for line in data]
 
 
+# read a result dataset from an NDJSON file and filter out invalid lines
+#
+# here, invalid lines are all the lines with `$.reason` not equal to
+# `benchmark-complete` that are generated by `cargo criterion` but useless.
 def read_data(data_file: str) -> Data:
     if not os.path.exists(data_file):
         print(f"no such file: `{data_file}`")
@@ -60,9 +41,28 @@ def read_data(data_file: str) -> Data:
     return data
 
 
-if __name__ == "__main__":
-    results_file = parse_args()
-    data = read_data(results_file)
+def plot_linalg(data: Data):
+    # key: the start of the `$.id` field
+    def plot(data: Data, key: str, ax):
+        filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
+
+        sizes = [
+            int(line["id"].split(' ')[1].split('x')[0]) for line in filtered_data
+        ]
+
+        means = extract_time(filtered_data, "mean", "estimate")
+        up = extract_time(filtered_data, "mean", "upper_bound")
+        down = extract_time(filtered_data, "mean", "lower_bound")
+
+        ax.plot(sizes, means, label="mean", color="blue")
+        ax.fill_between(sizes, down, up, color="blue", alpha=0.3, label="mean bounds")
+
+        medians = extract_time(filtered_data, "median", "estimate")
+        up = extract_time(filtered_data, "median", "upper_bound")
+        down = extract_time(filtered_data, "median", "lower_bound")
+
+        ax.plot(sizes, medians, label="median", color="orange")
+        ax.fill_between(sizes, down, up, color="orange", alpha=0.3, label="median bounds")
 
     labels = ["transpose", "mul", "inverse"]
 
@@ -77,3 +77,79 @@ if __name__ == "__main__":
         ax.grid()
 
     plt.show()
+
+
+def plot_setup(data: Data):
+    fig, axs = plt.subplots(4, 1, sharex=True)
+
+    # key: the start of the `$.id` field
+    # i: the index where the size of the input data needs to be extracted from
+    def plot(data: Data, key: str, i: int, label: str, color: str, error_bar: bool, ax):
+        filtered_data = list(filter(lambda line: line["id"].startswith(key), data))
+        sizes = [int(line["id"].split(' ')[i]) / NB_BYTES_IN_KB for line in filtered_data]
+
+        if error_bar:
+            means = extract_time(filtered_data, "mean", "estimate")
+            up = extract_time(filtered_data, "mean", "upper_bound")
+            down = extract_time(filtered_data, "mean", "lower_bound")
+        else:
+            means = [x * NB_NS_IN_MS / NB_BYTES_IN_KB for x in extract_time(filtered_data, "mean", None)]
+
+        ax.plot(sizes, means, label=label, color=color)
+
+        if error_bar:
+            ax.fill_between(sizes, down, up, color=color, alpha=0.3, label="mean bounds")
+
+    # setup
+    plot(data, "setup/setup", 1, "mean", "orange", True, axs[0])
+    axs[0].set_title("time to generate a random trusted setup")
+    axs[0].set_ylabel("time (in ms)")
+    axs[0].legend()
+    axs[0].grid()
+
+    # serialization
+    plot(data, "setup/serializing with compression", -3, "mean compressed", "orange", True, axs[1])
+    plot(data, "setup/serializing with no compression", -3, "mean uncompressed", "blue", True, axs[1])
+    axs[1].set_title("serialization")
+    axs[1].set_ylabel("time (in ms)")
+    axs[1].legend()
+    axs[1].grid()
+
+    # deserialization
+    plot(data, "setup/deserializing with no compression and no validation", -3, "mean uncompressed unvalidated", "red", True, axs[2])
+    plot(data, "setup/deserializing with compression and no validation", -3, "mean compressed unvalidated", "orange", True, axs[2])
+    plot(data, "setup/deserializing with no compression and validation", -3, "mean uncompressed validated", "blue", True, axs[2])
+    plot(data, "setup/deserializing with compression and validation", -3, "mean compressed validated", "green", True, axs[2])
+    axs[2].set_title("deserialization")
+    axs[2].set_ylabel("time (in ms)")
+    axs[2].legend()
+    axs[2].grid()
+
+    plot(data, "serialized size with no compression and no validation", -3, "mean uncompressed unvalidated", "red", False, axs[3])
+    plot(data, "serialized size with compression and no validation", -3, "mean compressed unvalidated", "orange", False, axs[3])
+    plot(data, "serialized size with no compression and validation", -3, "mean uncompressed validated", "blue", False, axs[3])
+    plot(data, "serialized size with compression and validation", -3, "mean compressed validated", "green", False, axs[3])
+    axs[3].set_title("size")
+    axs[3].set_xlabel("number of expected bytes (in kb)")
+    axs[3].set_ylabel("size (in kb)")
+    axs[3].legend()
+    axs[3].grid()
+
+    plt.show()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("filename", type=str)
+    parser.add_argument(
+        "--bench", "-b", type=str, choices=["linalg", "setup"], required=True
+    )
+    args = parser.parse_args()
+
+    data = read_data(args.filename)
+
+    match args.bench:
+        case "linalg":
+            plot_linalg(data)
+        case "setup":
+            plot_setup(data)
diff --git a/benches/setup.rs b/benches/setup.rs
new file mode 100644
index 00000000..3df93d88
--- /dev/null
+++ b/benches/setup.rs
@@ -0,0 +1,119 @@
+use std::ops::Div;
+
+use ark_bls12_381::Bls12_381;
+use ark_ec::pairing::Pairing;
+use ark_poly::univariate::DensePolynomial;
+use ark_poly::DenseUVPolynomial;
+
+use ark_poly_commit::kzg10::Powers;
+use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+
+type UniPoly12_381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>;
+
+fn setup_template<E, P>(c: &mut Criterion, nb_bytes: usize)
+where
+    E: Pairing,
+    P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>,
+    for<'a, 'b> &'a P: Div<&'b P, Output = P>,
+{
+    let mut group = c.benchmark_group("setup");
+
+    group.bench_function(
+        &format!("setup {} on {}", nb_bytes, std::any::type_name::<E>()),
+        |b| b.iter(|| komodo::setup::random::<E, P>(nb_bytes).unwrap()),
+    );
+
+    let setup = komodo::setup::random::<E, P>(nb_bytes).unwrap();
+
+    group.bench_function(
+        &format!(
+            "serializing with compression {} on {}",
+            nb_bytes,
+            std::any::type_name::<E>()
+        ),
+        |b| {
+            b.iter(|| {
+                let mut serialized = vec![0; setup.serialized_size(Compress::Yes)];
+                setup
+                    .serialize_with_mode(&mut serialized[..], Compress::Yes)
+                    .unwrap();
+            })
+        },
+    );
+
+    group.bench_function(
+        &format!(
+            "serializing with no compression {} on {}",
+            nb_bytes,
+            std::any::type_name::<E>()
+        ),
+        |b| {
+            b.iter(|| {
+                let mut serialized = vec![0; setup.serialized_size(Compress::No)];
+                setup
+                    .serialize_with_mode(&mut serialized[..], Compress::No)
+                    .unwrap();
+            })
+        },
+    );
+
+    for (compress, validate) in [
+        (Compress::Yes, Validate::Yes),
+        (Compress::Yes, Validate::No),
+        (Compress::No, Validate::Yes),
+        (Compress::No, Validate::No),
+    ] {
+        let mut serialized = vec![0; setup.serialized_size(compress)];
+        setup
+            .serialize_with_mode(&mut serialized[..], compress)
+            .unwrap();
+
+        println!(
+            r#"["id": "{} bytes serialized with {} and {} on {}", "size": {}"#,
+            nb_bytes,
+            match compress {
+                Compress::Yes => "compression",
+                Compress::No => "no compression",
+            },
+            match validate {
+                Validate::Yes => "validation",
+                Validate::No => "no validation",
+            },
+            std::any::type_name::<E>(),
+            serialized.len(),
+        );
+
+        group.bench_function(
+            &format!(
+                "deserializing with {} and {} {} on {}",
+                match compress {
+                    Compress::Yes => "compression",
+                    Compress::No => "no compression",
+                },
+                match validate {
+                    Validate::Yes => "validation",
+                    Validate::No => "no validation",
+                },
+                nb_bytes,
+                std::any::type_name::<E>()
+            ),
+            |b| {
+                b.iter(|| {
+                    Powers::<Bls12_381>::deserialize_with_mode(&serialized[..], compress, validate)
+                })
+            },
+        );
+    }
+
+    group.finish();
+}
+
+fn setup(c: &mut Criterion) {
+    for n in [1, 2, 4, 8, 16] {
+        setup_template::<Bls12_381, UniPoly12_381>(c, black_box(n * 1024));
+    }
+}
+
+criterion_group!(benches, setup);
+criterion_main!(benches);
diff --git a/examples/bench_setup_size.rs b/examples/bench_setup_size.rs
new file mode 100644
index 00000000..1c9321bd
--- /dev/null
+++ b/examples/bench_setup_size.rs
@@ -0,0 +1,52 @@
+use std::ops::Div;
+
+use ark_bls12_381::Bls12_381;
+use ark_ec::pairing::Pairing;
+use ark_poly::univariate::DensePolynomial;
+use ark_poly::DenseUVPolynomial;
+
+use ark_serialize::{CanonicalSerialize, Compress, Validate};
+
+type UniPoly12_381 = DensePolynomial<<Bls12_381 as Pairing>::ScalarField>;
+
+fn setup_template<E, P>(nb_bytes: usize)
+where
+    E: Pairing,
+    P: DenseUVPolynomial<E::ScalarField, Point = E::ScalarField>,
+    for<'a, 'b> &'a P: Div<&'b P, Output = P>,
+{
+    let setup = komodo::setup::random::<E, P>(nb_bytes).unwrap();
+
+    for (compress, validate) in [
+        (Compress::Yes, Validate::Yes),
+        (Compress::Yes, Validate::No),
+        (Compress::No, Validate::Yes),
+        (Compress::No, Validate::No),
+    ] {
+        let mut serialized = vec![0; setup.serialized_size(compress)];
+        setup
+            .serialize_with_mode(&mut serialized[..], compress)
+            .unwrap();
+
+        println!(
+            r#"{{"reason": "benchmark-complete", "id": "serialized size with {} and {} {} on {}", "mean": {}}}"#,
+            match compress {
+                Compress::Yes => "compression",
+                Compress::No => "no compression",
+            },
+            match validate {
+                Validate::Yes => "validation",
+                Validate::No => "no validation",
+            },
+            nb_bytes,
+            std::any::type_name::<E>(),
+            serialized.len(),
+        );
+    }
+}
+
+fn main() {
+    for n in [1, 2, 4, 8, 16] {
+        setup_template::<Bls12_381, UniPoly12_381>(n * 1024);
+    }
+}
-- 
GitLab