From 8108b6d3d0815243caf1be5ce34f053023fc5989 Mon Sep 17 00:00:00 2001 From: Jordan Rose Date: Tue, 11 Jul 2023 17:46:57 -0700 Subject: [PATCH] zkgroup: Add support for encoding ServiceIds in UidStructs This does mean that the 'bytes' field in a UidStruct isn't as useful anymore, because it can't distinguish different kinds of ServiceIds without extra work. Unfortunately, it was serialized inside a client-stored AuthCredential, so we can't just change it or take it out. Fortunately, nothing actually reads this field anyway except when decrypting, so it's okay to change how decryption works and ignore the 'bytes' field going forward. --- Cargo.lock | 2 + rust/zkgroup/Cargo.toml | 4 +- rust/zkgroup/src/api/call_links/params.rs | 2 +- rust/zkgroup/src/api/groups/group_params.rs | 2 +- rust/zkgroup/src/crypto/uid_encryption.rs | 73 ++++++++++++++++----- rust/zkgroup/src/crypto/uid_struct.rs | 43 +++++++----- 6 files changed, 89 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 99cd86e3..618c55fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2912,11 +2912,13 @@ dependencies = [ "displaydoc", "hex", "lazy_static", + "libsignal-protocol", "poksho", "rand 0.7.3", "serde", "sha2 0.9.9", "signal-crypto", "subtle", + "uuid", "zkcredential", ] diff --git a/rust/zkgroup/Cargo.toml b/rust/zkgroup/Cargo.toml index e2132296..38ce6271 100644 --- a/rust/zkgroup/Cargo.toml +++ b/rust/zkgroup/Cargo.toml @@ -12,9 +12,10 @@ description = "A zero-knowledge group library" license = "AGPL-3.0-only" [dependencies] +libsignal-protocol = { path = "../protocol" } poksho = { path = "../poksho" } -zkcredential = { path = "../zkcredential" } signal-crypto = { path = "../crypto" } +zkcredential = { path = "../zkcredential" } bincode = "1.2.1" serde = { version = "1.0.106", features = ["derive"] } @@ -25,6 +26,7 @@ aes-gcm-siv = "0.10.0" displaydoc = "0.2" lazy_static = "1.4.0" subtle = "2.3" +uuid = "1.1.2" # For generation base64 = { version = "0.13.0", optional = true } diff --git a/rust/zkgroup/src/api/call_links/params.rs b/rust/zkgroup/src/api/call_links/params.rs index b88410b9..ba457ef7 100644 --- a/rust/zkgroup/src/api/call_links/params.rs +++ b/rust/zkgroup/src/api/call_links/params.rs @@ -63,6 +63,6 @@ impl CallLinkSecretParams { ciphertext: api::groups::UuidCiphertext, ) -> Result { let uid = self.uid_enc_key_pair.decrypt(ciphertext.ciphertext)?; - Ok(uid.to_bytes()) + Ok(uid.raw_uuid().into_bytes()) } } diff --git a/rust/zkgroup/src/api/groups/group_params.rs b/rust/zkgroup/src/api/groups/group_params.rs index 6e1a70a2..8de79ec0 100644 --- a/rust/zkgroup/src/api/groups/group_params.rs +++ b/rust/zkgroup/src/api/groups/group_params.rs @@ -120,7 +120,7 @@ impl GroupSecretParams { ciphertext: api::groups::UuidCiphertext, ) -> Result { let uid = self.uid_enc_key_pair.decrypt(ciphertext.ciphertext)?; - Ok(uid.to_bytes()) + Ok(uid.raw_uuid().into_bytes()) } pub fn encrypt_profile_key( diff --git a/rust/zkgroup/src/crypto/uid_encryption.rs b/rust/zkgroup/src/crypto/uid_encryption.rs index 60bdda1b..21e9608b 100644 --- a/rust/zkgroup/src/crypto/uid_encryption.rs +++ b/rust/zkgroup/src/crypto/uid_encryption.rs @@ -8,12 +8,13 @@ use crate::common::errors::*; use crate::common::sho::*; use crate::crypto::uid_struct; + use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT; use curve25519_dalek::ristretto::RistrettoPoint; use curve25519_dalek::scalar::Scalar; -use serde::{Deserialize, Serialize}; - use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use subtle::{ConditionallySelectable, ConstantTimeEq}; lazy_static! { static ref SYSTEM_PARAMS: SystemParams = @@ -80,35 +81,50 @@ impl KeyPair { } pub fn encrypt(&self, uid: uid_struct::UidStruct) -> Ciphertext { - let E_A1 = self.calc_E_A1(uid); + let E_A1 = self.a1 * uid.M1; let E_A2 = (self.a2 * E_A1) + uid.M2; Ciphertext { E_A1, E_A2 } } - // Might return VerificationFailure pub fn decrypt( &self, ciphertext: Ciphertext, - ) -> Result { + ) -> Result { if ciphertext.E_A1 == RISTRETTO_BASEPOINT_POINT { return Err(ZkGroupVerificationFailure); } - match uid_struct::UidStruct::from_M2(ciphertext.E_A2 - (self.a2 * ciphertext.E_A1)) { - Err(_) => Err(ZkGroupVerificationFailure), - Ok(decrypted_uid) => { - if ciphertext.E_A1 == self.calc_E_A1(decrypted_uid) { - Ok(decrypted_uid) - } else { - Err(ZkGroupVerificationFailure) - } + let M2 = ciphertext.E_A2 - (self.a2 * ciphertext.E_A1); + match M2.lizard_decode::() { + None => Err(ZkGroupVerificationFailure), + Some(bytes) => { + // We want to do a constant-time choice between the ACI and the PNI possibilities. + // Only at the end do we do a normal branch to see if decryption succeeded, + // and even then we don't want to expose whether we picked the ACI or the PNI. + // So we store them both in an array, and index into it at the very end. + // This isn't fully "data-oblivious"; only one service ID gets loaded from memory at + // the end, and which one is data-dependent. But it is constant-time. + let decoded_uuid = uuid::Uuid::from_bytes(bytes); + let decoded_service_ids = [ + libsignal_protocol::Aci::from(decoded_uuid).into(), + libsignal_protocol::Pni::from(decoded_uuid).into(), + ]; + let decoded_aci = &decoded_service_ids[0]; + let decoded_pni = &decoded_service_ids[1]; + let aci_M1 = uid_struct::UidStruct::calc_M1(*decoded_aci); + let pni_M1 = uid_struct::UidStruct::calc_M1(*decoded_pni); + debug_assert!(aci_M1 != pni_M1); + let decrypted_M1 = self.a1.invert() * ciphertext.E_A1; + let mut index = u8::MAX; + index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1)); + index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1)); + decoded_service_ids + .get(index as usize) + .copied() + .ok_or(ZkGroupVerificationFailure) } } } - fn calc_E_A1(&self, uid: uid_struct::UidStruct) -> RistrettoPoint { - self.a1 * uid.M1 - } - pub fn get_public_key(&self) -> PublicKey { PublicKey { A: self.A } } @@ -204,7 +220,28 @@ mod tests { ); let plaintext = key_pair.decrypt(ciphertext2).unwrap(); + assert!(matches!(plaintext, libsignal_protocol::ServiceId::Aci(_))); + assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid); + } - assert!(plaintext == uid); + #[test] + fn test_pni_encryption() { + let mut sho = Sho::new(b"Test_Pni_Encryption", &[]); + let key_pair = KeyPair::derive_from(&mut sho); + + let uid = uid_struct::UidStruct::from_service_id( + libsignal_protocol::Pni::from(uuid::Uuid::from_bytes(TEST_ARRAY_16)).into(), + ); + let ciphertext = key_pair.encrypt(uid); + + // Test serialize / deserialize of Ciphertext + let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap(); + assert!(ciphertext_bytes.len() == 64); + let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap(); + assert!(ciphertext == ciphertext2); + + let plaintext = key_pair.decrypt(ciphertext2).unwrap(); + assert!(matches!(plaintext, libsignal_protocol::ServiceId::Pni(_))); + assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid); } } diff --git a/rust/zkgroup/src/crypto/uid_struct.rs b/rust/zkgroup/src/crypto/uid_struct.rs index 7b5d74d3..f448b6eb 100644 --- a/rust/zkgroup/src/crypto/uid_struct.rs +++ b/rust/zkgroup/src/crypto/uid_struct.rs @@ -8,39 +8,50 @@ use crate::common::sho::*; use crate::common::simple_types::*; use curve25519_dalek::ristretto::RistrettoPoint; +use libsignal_protocol::ServiceId; use serde::{Deserialize, Serialize}; use sha2::Sha256; #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UidStruct { - pub(crate) bytes: UidBytes, + // Currently unused. It would be possible to convert this back to the correct kind of ServiceId + // using the same technique as decryption: comparing possible M1 points and seeing which one + // matches. But we don't have a need for that, and therefore it's better if that operation + // remains part of decryption, so that you're guaranteed to get a valid result or an error in + // one step. + // + // At the same time, we can't just remove the field: it's serialized as part of AuthCredential + // and AuthCredentialWithPni, which clients store locally. + #[serde(rename = "bytes")] + raw_uuid_bytes: UidBytes, pub(crate) M1: RistrettoPoint, pub(crate) M2: RistrettoPoint, } -pub struct PointDecodeFailure; - impl UidStruct { pub fn new(uid_bytes: UidBytes) -> Self { - let mut sho = Sho::new(b"Signal_ZKGroup_20200424_UID_CalcM1", &uid_bytes); - let M1 = sho.get_point(); - let M2 = RistrettoPoint::lizard_encode::(&uid_bytes); + Self::from_service_id( + libsignal_protocol::Aci::from(uuid::Uuid::from_bytes(uid_bytes)).into(), + ) + } + + pub fn from_service_id(service_id: ServiceId) -> Self { + let M1 = Self::calc_M1(service_id); + let raw_uuid_bytes = service_id.raw_uuid().into_bytes(); + let M2 = RistrettoPoint::lizard_encode::(&raw_uuid_bytes); UidStruct { - bytes: uid_bytes, + raw_uuid_bytes, M1, M2, } } - pub fn from_M2(M2: RistrettoPoint) -> Result { - match M2.lizard_decode::() { - None => Err(PointDecodeFailure), - Some(bytes) => Ok(Self::new(bytes)), - } - } - - pub fn to_bytes(&self) -> UidBytes { - self.bytes + pub fn calc_M1(service_id: ServiceId) -> RistrettoPoint { + let mut sho = Sho::new( + b"Signal_ZKGroup_20200424_UID_CalcM1", + &service_id.service_id_binary(), + ); + sho.get_point() } }