0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-20 12:02:18 +02:00

Make ffi::CallbackError::check return a Result

Use Result to signal success or failure instead of Option::Some
signalling an error. This makes it easy to use combinators like
Result::map_err to more succinctly express the same operations.
Introduce a helper for SignalProtocolError to more succinctly construct
ApplicationCallbackError instances.
This commit is contained in:
Alex Konradi 2024-01-18 09:46:29 -05:00 committed by GitHub
parent b2f6a791d3
commit 174919865d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 118 deletions

View File

@ -201,10 +201,12 @@ pub struct CallbackError {
}
impl CallbackError {
/// Returns `None` if `value` is zero; otherwise, wraps the value in `Self`.
pub fn check(value: i32) -> Option<Self> {
let value = std::num::NonZeroI32::try_from(value).ok()?;
Some(Self { value })
/// Returns `Ok(())` if `value` is zero; otherwise, wraps the value in `Self` as an error.
pub fn check(value: i32) -> Result<(), Self> {
match std::num::NonZeroI32::try_from(value).ok() {
None => Ok(()),
Some(value) => Err(Self { value }),
}
}
}

View File

@ -30,18 +30,13 @@ impl FfiInputStreamStruct {
fn do_read(&self, buf: &mut [u8]) -> io::Result<usize> {
let mut amount_read = 0;
let result = (self.read)(self.ctx, buf.as_mut_ptr(), buf.len(), &mut amount_read);
match CallbackError::check(result) {
Some(error) => Err(io::Error::new(io::ErrorKind::Other, error)),
None => Ok(amount_read),
}
CallbackError::check(result).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(amount_read)
}
fn do_skip(&self, amount: u64) -> io::Result<()> {
let result = (self.skip)(self.ctx, amount);
match CallbackError::check(result) {
Some(error) => Err(io::Error::new(io::ErrorKind::Other, error)),
None => Ok(()),
}
CallbackError::check(result).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}

View File

@ -54,12 +54,9 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
let mut key = std::ptr::null_mut();
let result = (self.get_identity_key_pair)(self.ctx, &mut key);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"get_identity_key_pair",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"get_identity_key_pair",
))?;
if key.is_null() {
return Err(SignalProtocolError::InvalidState(
@ -78,12 +75,9 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
let mut id = 0;
let result = (self.get_local_registration_id)(self.ctx, &mut id);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"get_local_registration_id",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"get_local_registration_id",
))?;
Ok(id)
}
@ -98,9 +92,10 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
match result {
0 => Ok(false),
1 => Ok(true),
r => Err(SignalProtocolError::ApplicationCallbackError(
r => Err(SignalProtocolError::for_application_callback(
"save_identity",
Box::new(CallbackError::check(r).expect("verified non-zero")),
)(
CallbackError::check(r).expect_err("verified non-zero")
)),
}
}
@ -121,9 +116,10 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
match result {
0 => Ok(false),
1 => Ok(true),
r => Err(SignalProtocolError::ApplicationCallbackError(
r => Err(SignalProtocolError::for_application_callback(
"is_trusted_identity",
Box::new(CallbackError::check(r).expect("verified non-zero")),
)(
CallbackError::check(r).expect_err("verified non-zero")
)),
}
}
@ -135,12 +131,9 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
let mut key = std::ptr::null_mut();
let result = (self.get_identity)(self.ctx, &mut key, address);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"get_identity",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"get_identity",
))?;
if key.is_null() {
return Ok(None);
@ -173,12 +166,9 @@ impl PreKeyStore for &FfiPreKeyStoreStruct {
let mut record = std::ptr::null_mut();
let result = (self.load_pre_key)(self.ctx, &mut record, prekey_id.into());
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"load_pre_key",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"load_pre_key",
))?;
if record.is_null() {
return Err(SignalProtocolError::InvalidPreKeyId);
@ -195,27 +185,17 @@ impl PreKeyStore for &FfiPreKeyStoreStruct {
) -> Result<(), SignalProtocolError> {
let result = (self.store_pre_key)(self.ctx, prekey_id.into(), record);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"store_pre_key",
Box::new(error),
));
}
Ok(())
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"store_pre_key",
))
}
async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> {
let result = (self.remove_pre_key)(self.ctx, prekey_id.into());
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"remove_pre_key",
Box::new(error),
));
}
Ok(())
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"remove_pre_key",
))
}
}
@ -241,12 +221,9 @@ impl SignedPreKeyStore for &FfiSignedPreKeyStoreStruct {
let mut record = std::ptr::null_mut();
let result = (self.load_signed_pre_key)(self.ctx, &mut record, prekey_id.into());
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"load_signed_pre_key",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"load_signed_pre_key",
))?;
if record.is_null() {
return Err(SignalProtocolError::InvalidSignedPreKeyId);
@ -264,12 +241,9 @@ impl SignedPreKeyStore for &FfiSignedPreKeyStoreStruct {
) -> Result<(), SignalProtocolError> {
let result = (self.store_signed_pre_key)(self.ctx, prekey_id.into(), record);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"store_signed_pre_key",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"store_signed_pre_key",
))?;
Ok(())
}
@ -299,12 +273,9 @@ impl KyberPreKeyStore for &FfiKyberPreKeyStoreStruct {
let mut record = std::ptr::null_mut();
let result = (self.load_kyber_pre_key)(self.ctx, &mut record, id.into());
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"load_kyber_pre_key",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"load_kyber_pre_key",
))?;
if record.is_null() {
return Err(SignalProtocolError::InvalidKyberPreKeyId);
@ -322,14 +293,9 @@ impl KyberPreKeyStore for &FfiKyberPreKeyStoreStruct {
) -> Result<(), SignalProtocolError> {
let result = (self.store_kyber_pre_key)(self.ctx, id.into(), record);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"store_kyber_pre_key",
Box::new(error),
));
}
Ok(())
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"store_kyber_pre_key",
))
}
async fn mark_kyber_pre_key_used(
@ -338,14 +304,9 @@ impl KyberPreKeyStore for &FfiKyberPreKeyStoreStruct {
) -> Result<(), SignalProtocolError> {
let result = (self.mark_kyber_pre_key_used)(self.ctx, id.into());
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"mark_kyber_pre_key_used",
Box::new(error),
));
}
Ok(())
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"mark_kyber_pre_key_used",
))
}
}
@ -377,12 +338,9 @@ impl SessionStore for &FfiSessionStoreStruct {
let mut record = std::ptr::null_mut();
let result = (self.load_session)(self.ctx, &mut record, address);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"load_session",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"load_session",
))?;
if record.is_null() {
return Ok(None);
@ -400,14 +358,9 @@ impl SessionStore for &FfiSessionStoreStruct {
) -> Result<(), SignalProtocolError> {
let result = (self.store_session)(self.ctx, address, record);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"store_session",
Box::new(error),
));
}
Ok(())
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"store_session",
))
}
}
@ -442,14 +395,9 @@ impl SenderKeyStore for &FfiSenderKeyStoreStruct {
) -> Result<(), SignalProtocolError> {
let result = (self.store_sender_key)(self.ctx, sender, distribution_id.as_bytes(), record);
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"store_sender_key",
Box::new(error),
));
}
Ok(())
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"store_sender_key",
))
}
async fn load_sender_key(
@ -461,12 +409,9 @@ impl SenderKeyStore for &FfiSenderKeyStoreStruct {
let result =
(self.load_sender_key)(self.ctx, &mut record, sender, distribution_id.as_bytes());
if let Some(error) = CallbackError::check(result) {
return Err(SignalProtocolError::ApplicationCallbackError(
"load_sender_key",
Box::new(error),
));
}
CallbackError::check(result).map_err(SignalProtocolError::for_application_callback(
"load_sender_key",
))?;
if record.is_null() {
return Ok(None);

View File

@ -102,3 +102,13 @@ pub enum SignalProtocolError {
/// bad KEM ciphertext length <{1}> for key with type <{0}>
BadKEMCiphertextLength(kem::KeyType, usize),
}
impl SignalProtocolError {
/// Convenience factory for [`SignalProtocolError::ApplicationCallbackError`].
#[inline]
pub fn for_application_callback<E: std::error::Error + Send + Sync + UnwindSafe + 'static>(
method: &'static str,
) -> impl FnOnce(E) -> Self {
move |error| Self::ApplicationCallbackError(method, Box::new(error))
}
}