diff --git a/.nushell/error.nu b/.nushell/error.nu index 34e30b9b1b8465f85645414dda6a451f4b2ea991..9e6a4aeb278e3035ebde3f34149b980e138ffb8e 100644 --- a/.nushell/error.nu +++ b/.nushell/error.nu @@ -1,9 +1,15 @@ -export def "error throw" [err: record<err: string, label: string, span: record<start: int, end: int>>] { +export def "error throw" [err: record< + err: string, + label: string, + span: record<start: int, end: int>, + # help: string?, +>] { error make { msg: $"(ansi red_bold)($err.err)(ansi reset)", label: { text: $err.label, span: $err.span, }, + help: $err.help?, } } diff --git a/bins/inbreeding/Cargo.toml b/bins/inbreeding/Cargo.toml index c8982f389c03592e9c7cae1ce70f4fc42e4b415b..5fbcc6973bda64d8ceafa4d6d7aa506a46208afc 100644 --- a/bins/inbreeding/Cargo.toml +++ b/bins/inbreeding/Cargo.toml @@ -11,3 +11,4 @@ clap = { version = "4.5.4", features = ["derive"] } rand = "0.8.5" indicatif = "0.17.8" ark-ff = "0.4.2" +hex = "0.4.3" diff --git a/bins/inbreeding/README.md b/bins/inbreeding/README.md index dc77e786db6f26a0228b1dafb953373448b8a443..1b27f8c034d76eee5c04e55a9bdaa3a12b623d34 100644 --- a/bins/inbreeding/README.md +++ b/bins/inbreeding/README.md @@ -7,7 +7,7 @@ use ./bins/inbreeding ``` ```bash -const PRNG_SEED = 123 +const PRNG_SEED = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" const OPTS = { nb_bytes: (10 * 1_024), k: 10, diff --git a/bins/inbreeding/inspect.nu b/bins/inbreeding/inspect.nu index 0df2ea82e6c0665aea322ad8147f3aacc22933a3..96a908a56e0ac25da5da29ce1d983c0153fb3544 100644 --- a/bins/inbreeding/inspect.nu +++ b/bins/inbreeding/inspect.nu @@ -5,7 +5,7 @@ def get-seeds [] [ nothing -> list<string> ] { $consts.CACHE | path join '*' | into glob | ls $in | get name | each { path split | last } } -export def main [seed: int@get-seeds]: [ +export def main [seed: string@get-seeds]: [ nothing -> table< seed: string, timestamp: string, diff --git a/bins/inbreeding/run.nu b/bins/inbreeding/run.nu index 5b01dfce6a0d5535322e6db825a9beb746ab96bc..e27cdc8599b7d33e5a43e79fe1a5f090ddfcf855 100644 --- a/bins/inbreeding/run.nu +++ b/bins/inbreeding/run.nu @@ -1,5 +1,42 @@ use consts.nu -use ../../.nushell cargo "cargo bin" +use ../../.nushell error "error throw" + +const VALID_HEX_CHARS = "abcdefABCDEF0123456789" + +def check-hex [-n: int]: [ + string -> record< + ok: bool, + err: record<msg: string, label: string, help: string>, + > +] { + let s = $in + + if ($s | str length) != $n { + return { + ok: false, + err: { + msg: "invalid HEX length" + label : $"length is ($s | str length)", + help: "length should be 64", + }, + } + } + + for c in ($s | split chars | enumerate) { + if not ($VALID_HEX_CHARS | str contains $c.item) { + return { + ok: false, + err: { + msg: "bad HEX character", + label: $"found '($c.item)' at ($c.index)", + help: $"expected one of '($VALID_HEX_CHARS)'", + }, + } + } + } + + { ok: true, err: {} } +} export def main [ --options: record< @@ -14,7 +51,7 @@ export def main [ strategies: list<string>, environment: string, >, - --prng-seed: int = 0, + --prng-seed: string = "0000000000000000000000000000000000000000000000000000000000000000", ] { if $options.measurement_schedule_start > $options.max_t { error make --unspanned { @@ -22,6 +59,16 @@ export def main [ } } + let res = $prng_seed | check-hex -n 64 + if not $res.ok { + error throw { + err: $res.err.msg, + label: $res.err.label, + span: (metadata $prng_seed).span, + help: $res.err.help, + } + } + let now = date now | format date "%s%f" for s in $options.strategies { @@ -33,19 +80,6 @@ export def main [ mkdir $output_dir print $"data will be dumped to `($output_dir)`" - # compute a unique seed for that strategy and global seed - let seed = $s + $"($prng_seed)" - | hash sha256 - | split chars - | last 2 - | str join - | $"0x($in)" - | into int - # compute all the seeds for that strategy, one per scenario - let seeds = cargo bin rng ...[ -n $options.nb_scenarii --prng-seed $prng_seed ] - | lines - | into int - for i in 1..$options.nb_scenarii { let output = [ $output_dir, $"($i)" ] | path join @@ -60,7 +94,7 @@ export def main [ --test-case recoding --strategy $s --environment $options.environment - --prng-seed ($seeds | get ($i - 1)) + --prng-seed ([$prng_seed, $s, $i] | str join | hash sha256) ] out> $output } diff --git a/bins/inbreeding/src/main.rs b/bins/inbreeding/src/main.rs index 51f2d08b22af02b937b1760744b89177aac6c33e..47ace20fc3aa55b5e53d01358d65ef3acf58831d 100644 --- a/bins/inbreeding/src/main.rs +++ b/bins/inbreeding/src/main.rs @@ -149,6 +149,19 @@ where Ok(()) } +fn parse_hex_string(s: &str) -> Result<[u8; 32], String> { + if s.len() != 64 { + return Err("Input string must be exactly 64 characters long".to_string()); + } + + match hex::decode(s) { + // `bytes` will be a `Vec<u8>` of size `32`, so it's safe to `unwrap` + // the conversion to `[u8: 32]` + Ok(bytes) => Ok(bytes.try_into().unwrap()), + Err(e) => Err(format!("Failed to decode hex string: {}", e)), + } +} + #[derive(ValueEnum, Clone)] enum TestCase { EndToEnd, @@ -190,8 +203,8 @@ struct Cli { #[arg(long)] measurement_schedule_start: usize, - #[arg(long)] - prng_seed: u8, + #[arg(long, value_parser = parse_hex_string)] + prng_seed: [u8; 32], } fn main() { @@ -205,9 +218,7 @@ fn main() { exit(1); } - let mut seed: [u8; 32] = [0; 32]; - seed[0] = cli.prng_seed; - let mut rng = StdRng::from_seed(seed); + let mut rng = StdRng::from_seed(cli.prng_seed); let bytes = random_bytes(cli.nb_bytes, &mut rng); diff --git a/bins/rng/Cargo.toml b/bins/rng/Cargo.toml index 39f70a07a8bc58a0012463ea4b4643578d459ecc..f209c7e3afb139cdfd2d20a0b0f4de4b64814c55 100644 --- a/bins/rng/Cargo.toml +++ b/bins/rng/Cargo.toml @@ -8,4 +8,5 @@ description = "Generate random numbers from a seed." [dependencies] clap = { version = "4.5.4", features = ["derive"] } +hex = "0.4.3" rand = "0.8.5" diff --git a/bins/rng/src/main.rs b/bins/rng/src/main.rs index 04f7e40aba1db36692df93be38921663e41b0429..32223c9fca320339413fdc94d923219f24bd8152 100644 --- a/bins/rng/src/main.rs +++ b/bins/rng/src/main.rs @@ -1,22 +1,33 @@ use clap::Parser; use rand::{rngs::StdRng, Rng, SeedableRng}; +fn parse_hex_string(s: &str) -> Result<[u8; 32], String> { + if s.len() != 64 { + return Err("Input string must be exactly 64 characters long".to_string()); + } + + match hex::decode(s) { + // `bytes` will be a `Vec<u8>` of size `32`, so it's safe to `unwrap` + // the conversion to `[u8: 32]` + Ok(bytes) => Ok(bytes.try_into().unwrap()), + Err(e) => Err(format!("Failed to decode hex string: {}", e)), + } +} + #[derive(Parser)] #[command(version, about, long_about = None)] struct Cli { #[arg(short)] n: usize, - #[arg(long)] - prng_seed: u8, + #[arg(long, value_parser = parse_hex_string)] + prng_seed: [u8; 32], } fn main() { let cli = Cli::parse(); - let mut seed: [u8; 32] = [0; 32]; - seed[0] = cli.prng_seed; - let mut rng = StdRng::from_seed(seed); + let mut rng = StdRng::from_seed(cli.prng_seed); for _ in 0..cli.n { println!("{}", rng.gen::<u8>());