From 8c2962ac11334ccad016ff8bfb2363c804c5342d Mon Sep 17 00:00:00 2001 From: moiseev-signal <122060238+moiseev-signal@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:15:12 -0800 Subject: [PATCH] Use NonZeroU32 for SVR3 Backup max_tries parameter --- Cargo.lock | 2 ++ rust/net/Cargo.toml | 1 + rust/net/examples/svr3.rs | 14 +++++--- rust/net/examples/svr3_2xsgx.rs | 14 +++++--- rust/net/src/svr3.rs | 5 +-- rust/svr3/Cargo.toml | 1 + rust/svr3/src/lib.rs | 64 ++++++++++++++++++--------------- 7 files changed, 63 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5b658746..c0231bdc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1861,6 +1861,7 @@ dependencies = [ "libsignal-core", "libsignal-svr3", "log", + "nonzero_ext", "pin-project-lite", "prost", "prost-build", @@ -1953,6 +1954,7 @@ dependencies = [ "hex-literal", "hkdf", "http 1.0.0", + "nonzero_ext", "prost", "prost-build", "rand_core", diff --git a/rust/net/Cargo.toml b/rust/net/Cargo.toml index 7b3669ed..d75ea0c2 100644 --- a/rust/net/Cargo.toml +++ b/rust/net/Cargo.toml @@ -48,6 +48,7 @@ prost-build = "0.12.1" assert_matches = "1.5.0" clap = { version = "4.4.11", features = ["derive"] } env_logger = "0.10.0" +nonzero_ext = "0.3.0" snow = "0.9.3" tokio = { version = "1", features = ["test-util", "rt-multi-thread"] } tokio-stream = "0.1.14" diff --git a/rust/net/examples/svr3.rs b/rust/net/examples/svr3.rs index ec3d022a..a3766120 100644 --- a/rust/net/examples/svr3.rs +++ b/rust/net/examples/svr3.rs @@ -12,6 +12,7 @@ use std::time::Duration; use base64::prelude::{Engine, BASE64_STANDARD}; use clap::Parser; +use nonzero_ext::nonzero; use rand_core::{CryptoRngCore, OsRng, RngCore}; use attest::svr2::RaftConfig; @@ -92,10 +93,15 @@ async fn main() { println!("Secret to be stored: {}", hex::encode(secret)); let share_set_bytes = { - let opaque_share_set = - Svr3Env::backup(&mut connect().await, &args.password, secret, 10, &mut rng) - .await - .expect("can multi backup"); + let opaque_share_set = Svr3Env::backup( + &mut connect().await, + &args.password, + secret, + nonzero!(10u32), + &mut rng, + ) + .await + .expect("can multi backup"); opaque_share_set.serialize().expect("can serialize") }; println!("Share set: {}", hex::encode(&share_set_bytes)); diff --git a/rust/net/examples/svr3_2xsgx.rs b/rust/net/examples/svr3_2xsgx.rs index 8e70be7d..292d0208 100644 --- a/rust/net/examples/svr3_2xsgx.rs +++ b/rust/net/examples/svr3_2xsgx.rs @@ -14,6 +14,7 @@ use std::time::Duration; use base64::prelude::{Engine, BASE64_STANDARD}; use clap::Parser; use hex_literal::hex; +use nonzero_ext::nonzero; use rand_core::{CryptoRngCore, OsRng, RngCore}; use attest::svr2::RaftConfig; @@ -128,10 +129,15 @@ async fn main() { println!("Secret to be stored: {}", hex::encode(secret)); let share_set_bytes = { - let opaque_share_set = - TwoForTwoEnv::backup(&mut connect().await, &args.password, secret, 10, &mut rng) - .await - .expect("can multi backup"); + let opaque_share_set = TwoForTwoEnv::backup( + &mut connect().await, + &args.password, + secret, + nonzero!(10u32), + &mut rng, + ) + .await + .expect("can multi backup"); opaque_share_set.serialize().expect("can serialize") }; println!("Share set: {}", hex::encode(&share_set_bytes)); diff --git a/rust/net/src/svr3.rs b/rust/net/src/svr3.rs index 0b514a90..dac622d4 100644 --- a/rust/net/src/svr3.rs +++ b/rust/net/src/svr3.rs @@ -12,6 +12,7 @@ use futures_util::future::try_join_all; use libsignal_svr3::{Backup, MaskedShareSet, Restore}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; +use std::num::NonZeroU32; const MASKED_SHARE_SET_FORMAT: u8 = 0; @@ -136,7 +137,7 @@ pub trait PpssOps: PpssSetup { connections: &mut Self::Connections, password: &str, secret: [u8; 32], - max_tries: u32, + max_tries: NonZeroU32, rng: &mut impl CryptoRngCore, ) -> Result; @@ -154,7 +155,7 @@ impl PpssOps for Env { connections: &mut Self::Connections, password: &str, secret: [u8; 32], - max_tries: u32, + max_tries: NonZeroU32, rng: &mut impl CryptoRngCore, ) -> Result { let server_ids = Self::server_ids().as_mut().to_owned(); diff --git a/rust/svr3/Cargo.toml b/rust/svr3/Cargo.toml index 775c4e5c..e644aa39 100644 --- a/rust/svr3/Cargo.toml +++ b/rust/svr3/Cargo.toml @@ -27,6 +27,7 @@ hex = "0.4" hex-literal = "0.4.1" criterion = "0.5" bytemuck = "1.13.0" +nonzero_ext = "0.3.0" test-case = "3.2.1" [build-dependencies] diff --git a/rust/svr3/src/lib.rs b/rust/svr3/src/lib.rs index f7d69d17..2425ccac 100644 --- a/rust/svr3/src/lib.rs +++ b/rust/svr3/src/lib.rs @@ -2,19 +2,20 @@ // Copyright 2023 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // -mod oprf; -mod ppss; -pub use ppss::MaskedShareSet; -mod errors; -mod proto; -pub use errors::{Error, OPRFError, PPSSError}; +use std::num::NonZeroU32; -use crate::ppss::OPRFSession; use prost::Message; use rand_core::CryptoRngCore; -use crate::proto::svr3; -use crate::proto::svr3::{create_response, evaluate_response}; +mod oprf; +mod ppss; +pub use ppss::{MaskedShareSet, OPRFSession}; + +mod errors; +pub use errors::{Error, OPRFError, PPSSError}; +mod proto; +use proto::svr3; +use proto::svr3::{create_response, evaluate_response}; const CONTEXT: &str = "Signal_SVR3_20231121_PPSS_Context"; @@ -31,14 +32,13 @@ impl<'a> Backup<'a> { server_ids: &[u64], password: &'a str, secret: [u8; 32], - max_tries: u32, + max_tries: NonZeroU32, rng: &mut R, ) -> Result { - assert_ne!(0, max_tries); let oprfs = ppss::begin_oprfs(CONTEXT, server_ids, password, rng)?; let requests = oprfs .iter() - .map(|oprf| crate::make_create_request(max_tries, &oprf.blinded_elt_bytes)) + .map(|oprf| crate::make_create_request(max_tries.into(), &oprf.blinded_elt_bytes)) .map(|request| request.encode_to_vec()) .collect(); Ok(Self { @@ -170,14 +170,16 @@ fn decode_evaluate_response(bytes: &[u8]) -> Result<[u8; 32], Error> { #[cfg(test)] mod test { - use super::*; - - use crate::oprf::ciphersuite::hash_to_group; - use crate::proto::svr3; + use nonzero_ext::nonzero; use prost::Message; use rand_core::{OsRng, RngCore}; use test_case::test_case; + use crate::oprf::ciphersuite::hash_to_group; + use crate::proto::svr3; + + use super::*; + fn make_secret() -> [u8; 32] { let mut rng = OsRng; let mut secret = [0; 32]; @@ -185,18 +187,12 @@ mod test { secret } - #[test] - #[should_panic] - fn zero_max_tries() { - let _ = Backup::new(&[], "", [0; 32], 0, &mut OsRng); - } - #[test] fn backup_request_basic_checks() { let mut rng = OsRng; let secret = make_secret(); - let backup = - Backup::new(&[1, 2, 3], "password", secret, 1, &mut rng).expect("can create backup"); + let backup = Backup::new(&[1, 2, 3], "password", secret, nonzero!(1u32), &mut rng) + .expect("can create backup"); assert_eq!(3, backup.requests.len()); for request_bytes in backup.requests.into_iter() { let decode_result = svr3::Request::decode(&*request_bytes); @@ -228,8 +224,14 @@ mod test { #[test_case(svr3::create_response::Status::InvalidRequest, false ; "status_invalid_request")] #[test_case(svr3::create_response::Status::Error, false ; "status_error")] fn backup_finalize_checks_status(status: svr3::create_response::Status, should_succeed: bool) { - let backup = Backup::new(&[1, 2, 3], "password", make_secret(), 1, &mut OsRng) - .expect("can create backup"); + let backup = Backup::new( + &[1, 2, 3], + "password", + make_secret(), + nonzero!(1u32), + &mut OsRng, + ) + .expect("can create backup"); let responses: Vec<_> = std::iter::repeat(make_create_response(status).encode_to_vec()) .take(3) .collect(); @@ -241,8 +243,14 @@ mod test { #[test_case(vec![1, 2, 3]; "bad_protobuf")] #[test_case(make_evaluate_response(svr3::evaluate_response::Status::Ok).encode_to_vec(); "wrong_response_type")] fn backup_invalid_response(response: Vec) { - let backup = Backup::new(&[1, 2, 3], "password", make_secret(), 1, &mut OsRng) - .expect("can create backup"); + let backup = Backup::new( + &[1, 2, 3], + "password", + make_secret(), + nonzero!(1u32), + &mut OsRng, + ) + .expect("can create backup"); let mut rng = OsRng; let result = backup.finalize(&mut rng, &[response]); assert!(matches!(result, Err(Error::Protocol(_))));