From bb55005fdf731d686c6287b73bc47ec2fac5fd03 Mon Sep 17 00:00:00 2001
From: STEVAN Antoine <antoine.stevan@isae-supaero.fr>
Date: Thu, 30 May 2024 07:12:49 +0000
Subject: [PATCH] support proper 32-byte RNG seeds (dragoon/komodo!126)

- add optional `$.help` to argument `err` of `error throw`
- parse `prng_seed: [u8; 32]` in `rng` and `inbreeding`
- compute the "_local_" seed by hashing the "_global_" seed, the strategy and the iteration index
- pass `--prng-seed: string`, a 64-char long seed to `inbreeding run`
---
 .nushell/error.nu           |  8 ++++-
 bins/inbreeding/Cargo.toml  |  1 +
 bins/inbreeding/README.md   |  2 +-
 bins/inbreeding/inspect.nu  |  2 +-
 bins/inbreeding/run.nu      | 66 ++++++++++++++++++++++++++++---------
 bins/inbreeding/src/main.rs | 21 +++++++++---
 bins/rng/Cargo.toml         |  1 +
 bins/rng/src/main.rs        | 21 +++++++++---
 8 files changed, 93 insertions(+), 29 deletions(-)

diff --git a/.nushell/error.nu b/.nushell/error.nu
index 34e30b9b..9e6a4aeb 100644
--- a/.nushell/error.nu
+++ b/.nushell/error.nu
@@ -1,9 +1,15 @@
-export def "error throw" [err: record<err: string, label: string, span: record<start: int, end: int>>] {
+export def "error throw" [err: record<
+    err: string,
+    label: string,
+    span: record<start: int, end: int>,
+    # help: string?,
+>] {
     error make {
         msg: $"(ansi red_bold)($err.err)(ansi reset)",
         label: {
             text: $err.label,
             span: $err.span,
         },
+        help: $err.help?,
     }
 }
diff --git a/bins/inbreeding/Cargo.toml b/bins/inbreeding/Cargo.toml
index c8982f38..5fbcc697 100644
--- a/bins/inbreeding/Cargo.toml
+++ b/bins/inbreeding/Cargo.toml
@@ -11,3 +11,4 @@ clap = { version = "4.5.4", features = ["derive"] }
 rand = "0.8.5"
 indicatif = "0.17.8"
 ark-ff = "0.4.2"
+hex = "0.4.3"
diff --git a/bins/inbreeding/README.md b/bins/inbreeding/README.md
index dc77e786..1b27f8c0 100644
--- a/bins/inbreeding/README.md
+++ b/bins/inbreeding/README.md
@@ -7,7 +7,7 @@
 use ./bins/inbreeding
 ```
 ```bash
-const PRNG_SEED = 123
+const PRNG_SEED = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
 const OPTS = {
     nb_bytes: (10 * 1_024),
     k: 10,
diff --git a/bins/inbreeding/inspect.nu b/bins/inbreeding/inspect.nu
index 0df2ea82..96a908a5 100644
--- a/bins/inbreeding/inspect.nu
+++ b/bins/inbreeding/inspect.nu
@@ -5,7 +5,7 @@ def get-seeds [] [ nothing -> list<string> ] {
     $consts.CACHE | path join '*' | into glob | ls $in | get name | each { path split | last }
 }
 
-export def main [seed: int@get-seeds]: [
+export def main [seed: string@get-seeds]: [
     nothing -> table<
         seed: string,
         timestamp: string,
diff --git a/bins/inbreeding/run.nu b/bins/inbreeding/run.nu
index 5b01dfce..e27cdc85 100644
--- a/bins/inbreeding/run.nu
+++ b/bins/inbreeding/run.nu
@@ -1,5 +1,42 @@
 use consts.nu
-use ../../.nushell cargo "cargo bin"
+use ../../.nushell error "error throw"
+
+const VALID_HEX_CHARS = "abcdefABCDEF0123456789"
+
+def check-hex [-n: int]: [
+    string -> record<
+        ok: bool,
+        err: record<msg: string, label: string, help: string>,
+    >
+] {
+    let s = $in
+
+    if ($s | str length) != $n {
+        return {
+            ok: false,
+            err: {
+                msg: "invalid HEX length"
+                label : $"length is ($s | str length)",
+                help: "length should be 64",
+            },
+        }
+    }
+
+    for c in ($s | split chars | enumerate) {
+        if not ($VALID_HEX_CHARS | str contains $c.item) {
+            return {
+                ok: false,
+                err: {
+                    msg: "bad HEX character",
+                    label: $"found '($c.item)' at ($c.index)",
+                    help: $"expected one of '($VALID_HEX_CHARS)'",
+                },
+            }
+        }
+    }
+
+    { ok: true, err: {} }
+}
 
 export def main [
     --options: record<
@@ -14,7 +51,7 @@ export def main [
         strategies: list<string>,
         environment: string,
     >,
-    --prng-seed: int = 0,
+    --prng-seed: string = "0000000000000000000000000000000000000000000000000000000000000000",
 ] {
     if $options.measurement_schedule_start > $options.max_t {
         error make --unspanned {
@@ -22,6 +59,16 @@ export def main [
         }
     }
 
+    let res = $prng_seed | check-hex -n 64
+    if not $res.ok {
+        error throw {
+            err: $res.err.msg,
+            label: $res.err.label,
+            span: (metadata $prng_seed).span,
+            help: $res.err.help,
+        }
+    }
+
     let now = date now | format date "%s%f"
 
     for s in $options.strategies {
@@ -33,19 +80,6 @@ export def main [
         mkdir $output_dir
         print $"data will be dumped to `($output_dir)`"
 
-        # compute a unique seed for that strategy and global seed
-        let seed = $s + $"($prng_seed)"
-            | hash sha256
-            | split chars
-            | last 2
-            | str join
-            | $"0x($in)"
-            | into int
-        # compute all the seeds for that strategy, one per scenario
-        let seeds = cargo bin rng ...[ -n $options.nb_scenarii --prng-seed $prng_seed ]
-            | lines
-            | into int
-
         for i in 1..$options.nb_scenarii {
             let output = [ $output_dir, $"($i)" ] | path join
 
@@ -60,7 +94,7 @@ export def main [
                 --test-case recoding
                 --strategy $s
                 --environment $options.environment
-                --prng-seed ($seeds | get ($i - 1))
+                --prng-seed ([$prng_seed, $s, $i] | str join | hash sha256)
             ] out> $output
         }
 
diff --git a/bins/inbreeding/src/main.rs b/bins/inbreeding/src/main.rs
index 51f2d08b..47ace20f 100644
--- a/bins/inbreeding/src/main.rs
+++ b/bins/inbreeding/src/main.rs
@@ -149,6 +149,19 @@ where
     Ok(())
 }
 
+fn parse_hex_string(s: &str) -> Result<[u8; 32], String> {
+    if s.len() != 64 {
+        return Err("Input string must be exactly 64 characters long".to_string());
+    }
+
+    match hex::decode(s) {
+        // `bytes` will be a `Vec<u8>` of size `32`, so it's safe to `unwrap`
+        // the conversion to `[u8: 32]`
+        Ok(bytes) => Ok(bytes.try_into().unwrap()),
+        Err(e) => Err(format!("Failed to decode hex string: {}", e)),
+    }
+}
+
 #[derive(ValueEnum, Clone)]
 enum TestCase {
     EndToEnd,
@@ -190,8 +203,8 @@ struct Cli {
     #[arg(long)]
     measurement_schedule_start: usize,
 
-    #[arg(long)]
-    prng_seed: u8,
+    #[arg(long, value_parser = parse_hex_string)]
+    prng_seed: [u8; 32],
 }
 
 fn main() {
@@ -205,9 +218,7 @@ fn main() {
         exit(1);
     }
 
-    let mut seed: [u8; 32] = [0; 32];
-    seed[0] = cli.prng_seed;
-    let mut rng = StdRng::from_seed(seed);
+    let mut rng = StdRng::from_seed(cli.prng_seed);
 
     let bytes = random_bytes(cli.nb_bytes, &mut rng);
 
diff --git a/bins/rng/Cargo.toml b/bins/rng/Cargo.toml
index 39f70a07..f209c7e3 100644
--- a/bins/rng/Cargo.toml
+++ b/bins/rng/Cargo.toml
@@ -8,4 +8,5 @@ description = "Generate random numbers from a seed."
 
 [dependencies]
 clap = { version = "4.5.4", features = ["derive"] }
+hex = "0.4.3"
 rand = "0.8.5"
diff --git a/bins/rng/src/main.rs b/bins/rng/src/main.rs
index 04f7e40a..32223c9f 100644
--- a/bins/rng/src/main.rs
+++ b/bins/rng/src/main.rs
@@ -1,22 +1,33 @@
 use clap::Parser;
 use rand::{rngs::StdRng, Rng, SeedableRng};
 
+fn parse_hex_string(s: &str) -> Result<[u8; 32], String> {
+    if s.len() != 64 {
+        return Err("Input string must be exactly 64 characters long".to_string());
+    }
+
+    match hex::decode(s) {
+        // `bytes` will be a `Vec<u8>` of size `32`, so it's safe to `unwrap`
+        // the conversion to `[u8: 32]`
+        Ok(bytes) => Ok(bytes.try_into().unwrap()),
+        Err(e) => Err(format!("Failed to decode hex string: {}", e)),
+    }
+}
+
 #[derive(Parser)]
 #[command(version, about, long_about = None)]
 struct Cli {
     #[arg(short)]
     n: usize,
 
-    #[arg(long)]
-    prng_seed: u8,
+    #[arg(long, value_parser = parse_hex_string)]
+    prng_seed: [u8; 32],
 }
 
 fn main() {
     let cli = Cli::parse();
 
-    let mut seed: [u8; 32] = [0; 32];
-    seed[0] = cli.prng_seed;
-    let mut rng = StdRng::from_seed(seed);
+    let mut rng = StdRng::from_seed(cli.prng_seed);
 
     for _ in 0..cli.n {
         println!("{}", rng.gen::<u8>());
-- 
GitLab