0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-19 19:42:19 +02:00

Add PreKeySignalMessage struct implentation

This commit is contained in:
Ehren Kret 2020-05-07 06:01:19 -07:00
parent ae92413766
commit a551b45c67
2 changed files with 220 additions and 20 deletions

View File

@ -19,7 +19,7 @@ pub struct IdentityKey {
impl IdentityKey {
#[inline]
pub fn public_key(&self) -> &dyn curve::PublicKey {
pub fn public_key(&self) -> &(dyn curve::PublicKey + 'static) {
self.public_key.as_ref()
}

View File

@ -132,7 +132,7 @@ impl SignalMessage {
}
#[inline]
pub fn sender_ratchet_key(&self) -> &dyn curve::PublicKey {
pub fn sender_ratchet_key(&self) -> &(dyn curve::PublicKey + 'static) {
&*self.sender_ratchet_key
}
@ -184,17 +184,17 @@ impl AsRef<[u8]> for SignalMessage {
}
}
impl TryInto<SignalMessage> for &[u8] {
impl TryFrom<&[u8]> for SignalMessage {
type Error = CiphertextMessageDeserializationError;
fn try_into(self) -> Result<SignalMessage, Self::Error> {
if self.len() < SignalMessage::MAC_LENGTH + 1 {
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() < SignalMessage::MAC_LENGTH + 1 {
return Err(CiphertextMessageDeserializationError::MessageTooShort(
self.len(),
value.len(),
));
}
let message_version = self[0] >> 4;
let ciphertext_version = self[0] & 0x0F;
let message_version = value[0] >> 4;
let ciphertext_version = value[0] & 0x0F;
if ciphertext_version < CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(CiphertextMessageDeserializationError::LegacyVersion(
ciphertext_version,
@ -207,7 +207,7 @@ impl TryInto<SignalMessage> for &[u8] {
}
let proto_structure =
proto::wire::SignalMessage::decode(&self[1..self.len() - SignalMessage::MAC_LENGTH])?;
proto::wire::SignalMessage::decode(&value[1..value.len() - SignalMessage::MAC_LENGTH])?;
if proto_structure.ciphertext.is_none()
|| proto_structure.counter.is_none()
|| proto_structure.ratchet_key.is_none()
@ -222,12 +222,142 @@ impl TryInto<SignalMessage> for &[u8] {
counter: proto_structure.counter.unwrap(),
previous_counter: proto_structure.previous_counter.unwrap_or(0),
ciphertext: proto_structure.ciphertext.unwrap().into_boxed_slice(),
serialized: Box::from(self),
serialized: Box::from(value),
})
}
}
pub struct PreKeySignalMessage {
message_version: u8,
registration_id: u32,
pre_key_id: Option<u32>,
signed_pre_key_id: u32,
base_key: Box<dyn curve::PublicKey>,
identity_key: IdentityKey,
message: SignalMessage,
serialized: Box<[u8]>,
}
impl PreKeySignalMessage {
pub fn new(
message_version: u8,
registration_id: u32,
pre_key_id: Option<u32>,
signed_pre_key_id: u32,
base_key: Box<dyn curve::PublicKey>,
identity_key: IdentityKey,
message: SignalMessage,
) -> Self {
let proto_message = proto::wire::PreKeySignalMessage {
registration_id: Some(registration_id),
pre_key_id,
signed_pre_key_id: Some(signed_pre_key_id),
base_key: Some(base_key.serialize().into_vec()),
identity_key: Some(identity_key.serialize().into_vec()),
message: Some(Vec::from(message.as_ref())),
};
let mut serialized = vec![0u8; 1 + proto_message.encoded_len()];
serialized[0] = ((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
proto_message.encode(&mut &mut serialized[1..]).unwrap();
Self {
message_version,
registration_id,
pre_key_id,
signed_pre_key_id,
base_key,
identity_key,
message,
serialized: serialized.into_boxed_slice(),
}
}
#[inline]
pub fn message_version(&self) -> u8 {
self.message_version
}
#[inline]
pub fn registration_id(&self) -> u32 {
self.registration_id
}
#[inline]
pub fn pre_key_id(&self) -> Option<u32> {
self.pre_key_id
}
#[inline]
pub fn signed_pre_key_id(&self) -> u32 {
self.signed_pre_key_id
}
#[inline]
pub fn base_key(&self) -> &(dyn curve::PublicKey + 'static) {
&*self.base_key
}
#[inline]
pub fn identity_key(&self) -> &IdentityKey {
&self.identity_key
}
#[inline]
pub fn message(&self) -> &SignalMessage {
&self.message
}
}
impl AsRef<[u8]> for PreKeySignalMessage {
fn as_ref(&self) -> &[u8] {
&*self.serialized
}
}
impl TryFrom<&[u8]> for PreKeySignalMessage {
type Error = CiphertextMessageDeserializationError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.is_empty() {
return Err(CiphertextMessageDeserializationError::MessageTooShort(
value.len(),
));
}
let message_version = value[0] >> 4;
let ciphertext_version = value[0] & 0x0F;
if ciphertext_version < CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(CiphertextMessageDeserializationError::LegacyVersion(
ciphertext_version,
));
}
if ciphertext_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(CiphertextMessageDeserializationError::UnrecognizedVersion(
ciphertext_version,
));
}
let proto_structure = proto::wire::PreKeySignalMessage::decode(&value[1..])?;
if proto_structure.signed_pre_key_id.is_none()
|| proto_structure.base_key.is_none()
|| proto_structure.identity_key.is_none()
|| proto_structure.message.is_none()
{
return Err(CiphertextMessageDeserializationError::InvalidMessage(None));
}
let base_key = curve::decode_point(proto_structure.base_key.unwrap().as_ref())?;
Ok(PreKeySignalMessage {
message_version,
registration_id: proto_structure.registration_id.unwrap_or(0),
pre_key_id: proto_structure.pre_key_id,
signed_pre_key_id: proto_structure.signed_pre_key_id.unwrap(),
base_key,
identity_key: IdentityKey::try_from(proto_structure.identity_key.unwrap().as_ref())?,
message: SignalMessage::try_from(proto_structure.message.unwrap().as_ref())?,
serialized: Box::from(value),
})
}
}
pub struct PreKeySignalMessage {}
pub struct SenderKeyMessage {}
pub struct SenderKeyDistributionMessage {}
@ -236,12 +366,12 @@ mod tests {
use super::*;
use rand::rngs::OsRng;
use rand::RngCore;
#[test]
fn test_signal_message_serialize_deserialize() {
let mut csprng = OsRng;
use rand::{CryptoRng, Rng, RngCore};
fn create_signal_message<T>(csprng: &mut T) -> SignalMessage
where
T: Rng + CryptoRng,
{
let mut mac_key = [0u8; 32];
csprng.fill_bytes(&mut mac_key);
let mac_key = mac_key;
@ -250,11 +380,11 @@ mod tests {
csprng.fill_bytes(&mut ciphertext);
let ciphertext = ciphertext;
let sender_ratchet_key_pair = curve::KeyPair::new(&mut csprng);
let sender_identity_key_pair = curve::KeyPair::new(&mut csprng);
let receiver_identity_key_pair = curve::KeyPair::new(&mut csprng);
let sender_ratchet_key_pair = curve::KeyPair::new(csprng);
let sender_identity_key_pair = curve::KeyPair::new(csprng);
let receiver_identity_key_pair = curve::KeyPair::new(csprng);
let msg = SignalMessage::new(
SignalMessage::new(
3,
&mac_key,
sender_ratchet_key_pair.public_key,
@ -263,6 +393,76 @@ mod tests {
Box::new(ciphertext),
&sender_identity_key_pair.public_key.into(),
&receiver_identity_key_pair.public_key.into(),
)
}
fn assert_signal_message_equals(m1: &SignalMessage, m2: &SignalMessage) {
assert_eq!(m1.message_version, m2.message_version);
assert_eq!(*m1.sender_ratchet_key, *m2.sender_ratchet_key);
assert_eq!(m1.counter, m2.counter);
assert_eq!(m1.previous_counter, m2.previous_counter);
assert_eq!(m1.ciphertext, m2.ciphertext);
assert_eq!(m1.serialized, m2.serialized);
}
#[test]
fn test_signal_message_serialize_deserialize() {
let mut csprng = OsRng;
let message = create_signal_message(&mut csprng);
let deser_message =
SignalMessage::try_from(message.as_ref()).expect("should deserialize without error");
assert_signal_message_equals(&message, &deser_message);
}
#[test]
fn test_pre_key_signal_message_serialize_deserialize() {
let mut csprng = OsRng;
let identity_key_pair = curve::KeyPair::new(&mut csprng);
let base_key_pair = curve::KeyPair::new(&mut csprng);
let message = create_signal_message(&mut csprng);
let pre_key_signal_message = PreKeySignalMessage::new(
3,
365,
None,
97,
base_key_pair.public_key,
identity_key_pair.public_key.into(),
message,
);
let deser_pre_key_signal_message =
PreKeySignalMessage::try_from(pre_key_signal_message.as_ref())
.expect("should deserialized without error");
assert_eq!(
pre_key_signal_message.message_version,
deser_pre_key_signal_message.message_version
);
assert_eq!(
pre_key_signal_message.registration_id,
deser_pre_key_signal_message.registration_id
);
assert_eq!(
pre_key_signal_message.pre_key_id,
deser_pre_key_signal_message.pre_key_id
);
assert_eq!(
pre_key_signal_message.signed_pre_key_id,
deser_pre_key_signal_message.signed_pre_key_id
);
assert_eq!(
*pre_key_signal_message.base_key,
*deser_pre_key_signal_message.base_key
);
assert_eq!(
pre_key_signal_message.identity_key.public_key(),
deser_pre_key_signal_message.identity_key.public_key()
);
assert_signal_message_equals(
&pre_key_signal_message.message,
&deser_pre_key_signal_message.message,
);
assert_eq!(
pre_key_signal_message.serialized,
deser_pre_key_signal_message.serialized
);
}
}