0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-20 03:52:17 +02:00

Bridge message backup validation to node

Expose message backup at the bridge layer as a separate async function. Add a 
TS wrapper with the same interface as for the other app languages.
This commit is contained in:
Alex Konradi 2024-02-02 14:47:05 -05:00 committed by GitHub
parent 4888b9ba12
commit 11f7b0b231
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 345 additions and 123 deletions

6
node/Native.d.ts vendored
View File

@ -96,6 +96,11 @@ interface Wrapper<T> {
readonly _nativeHandle: T;
}
interface MessageBackupValidationOutcome {
errorMessage: string | null;
unknownFieldMessages: Array<string>;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
type Serialized<T> = Buffer;
@ -237,6 +242,7 @@ export function LookupRequest_new(): LookupRequest;
export function LookupRequest_setReturnAcisWithoutUaks(request: Wrapper<LookupRequest>, returnAcisWithoutUaks: boolean): void;
export function LookupRequest_setToken(request: Wrapper<LookupRequest>, token: Buffer): void;
export function MessageBackupKey_New(masterKey: Buffer, aci: Buffer): MessageBackupKey;
export function MessageBackupValidator_Validate(key: Wrapper<MessageBackupKey>, firstStream: InputStream, secondStream: InputStream, len: Buffer): Promise<MessageBackupValidationOutcome>;
export function Mp4Sanitizer_Sanitize(input: InputStream, len: Buffer): Promise<SanitizedMetadata>;
export function PlaintextContent_Deserialize(data: Buffer): PlaintextContent;
export function PlaintextContent_FromDecryptionErrorMessage(m: Wrapper<DecryptionErrorMessage>): PlaintextContent;

93
node/ts/MessageBackup.ts Normal file
View File

@ -0,0 +1,93 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
/**
* Message backup validation routines.
*
* @module MessageBackup
*/
import * as Native from '../Native';
import { Aci } from './Address';
import { InputStream } from './io';
import { bufferFromBigUInt64BE } from './zkgroup/internal/BigIntUtil';
export type InputStreamFactory = () => InputStream;
/**
* Result of validating a message backup bundle.
*/
export class ValidationOutcome {
/**
* A developer-facing message about the error encountered during validation,
* if any.
*/
public errorMessage: string | null;
/**
* Information about unknown fields encountered during validation.
*/
public unknownFieldMessages: string[];
/**
* `true` if the backup is valid, `false` otherwise.
*
* If this is `true`, there might still be messages about unknown fields.
*/
public get ok(): boolean {
return this.errorMessage == null;
}
constructor(outcome: Native.MessageBackupValidationOutcome) {
const { errorMessage, unknownFieldMessages } = outcome;
this.errorMessage = errorMessage;
this.unknownFieldMessages = unknownFieldMessages;
}
}
/**
* Key used to encrypt and decrypt a message backup bundle.
*/
export class MessageBackupKey {
readonly _nativeHandle: Native.MessageBackupKey;
/**
* Create a public key from the given master key and ACI.
*
* `masterKeyBytes` should contain exactly 32 bytes.
*/
public constructor(masterKeyBytes: Buffer, aci: Aci) {
this._nativeHandle = Native.MessageBackupKey_New(
masterKeyBytes,
aci.getServiceIdFixedWidthBinary()
);
}
}
/**
* Validate a backup file
*
* @param backupKey The key to use to decrypt the backup contents.
* @param inputFactory A function that returns new input streams that read the backup contents.
* @param length The exact length of the input stream.
* @returns The outcome of validation, including any errors and warnings.
* @throws IoError If an IO error on the input occurs.
*/
export async function validate(
backupKey: MessageBackupKey,
inputFactory: InputStreamFactory,
length: bigint
): Promise<ValidationOutcome> {
const firstStream = inputFactory();
const secondStream = inputFactory();
return new ValidationOutcome(
await Native.MessageBackupValidator_Validate(
backupKey,
firstStream,
secondStream,
bufferFromBigUInt64BE(length)
)
);
}

View File

@ -4,12 +4,12 @@
//
import { assert } from 'chai';
import { InputStream } from '../io';
import * as Mp4Sanitizer from '../Mp4Sanitizer';
import * as WebpSanitizer from '../WebpSanitizer';
import { SanitizedMetadata } from '../Mp4Sanitizer';
import * as util from './util';
import { ErrorCode, LibSignalErrorBase } from '../Errors';
import { ErrorInputStream, Uint8ArrayInputStream } from './ioutil';
util.initLogger();
@ -199,36 +199,3 @@ function assertSanitizedMetadataEqual(
assert.equal(sanitized.getDataOffset(), BigInt(dataOffset));
assert.equal(sanitized.getDataLen(), BigInt(dataLen));
}
class ErrorInputStream extends InputStream {
read(_amount: number): Promise<Buffer> {
throw new Error('test io error');
}
skip(_amount: number): Promise<void> {
throw new Error('test io error');
}
}
class Uint8ArrayInputStream extends InputStream {
data: Uint8Array;
constructor(data: Uint8Array) {
super();
this.data = data;
}
read(amount: number): Promise<Buffer> {
const read_amount = Math.min(amount, this.data.length);
const read_data = this.data.slice(0, read_amount);
this.data = this.data.slice(read_amount);
return Promise.resolve(Buffer.from(read_data));
}
skip(amount: number): Promise<void> {
if (amount > this.data.length) {
throw Error('skipped past end of data');
}
this.data = this.data.slice(amount);
return Promise.resolve();
}
}

View File

@ -0,0 +1,58 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
import { assert } from 'chai';
import * as MessageBackup from '../MessageBackup';
import * as util from './util';
import { Aci } from '../Address';
import { Uint8ArrayInputStream, ErrorInputStream } from './ioutil';
import * as fs from 'node:fs';
import * as path from 'node:path';
import { LogLevel } from '..';
util.initLogger(LogLevel.Trace);
describe('MessageBackup', () => {
const masterKey = Buffer.from(new Uint8Array(32).fill('M'.charCodeAt(0)));
const aci = Aci.fromUuidBytes(new Uint8Array(16).fill(0x11));
const testKey = new MessageBackup.MessageBackupKey(masterKey, aci);
describe('validate', () => {
it('successfully validates a minimal backup', async () => {
const input = fs.readFileSync(
path.join(__dirname, '../../ts/test/new_account.binproto.encrypted')
);
const outcome = await MessageBackup.validate(
testKey,
() => new Uint8ArrayInputStream(input),
BigInt(input.length)
);
assert.equal(outcome.errorMessage, null);
});
it('produces an error message on empty input', async () => {
const outcome = await MessageBackup.validate(
testKey,
() => new Uint8ArrayInputStream(new Uint8Array()),
0n
);
assert.equal(outcome.errorMessage, 'not enough bytes for an HMAC');
});
it('throws a raised IO error', async () => {
try {
await MessageBackup.validate(
testKey,
() => new ErrorInputStream(),
BigInt(234)
);
assert.fail('did not throw');
} catch (e) {
assert.instanceOf(e, ErrorInputStream.Error);
}
});
});
});

41
node/ts/test/ioutil.ts Normal file
View File

@ -0,0 +1,41 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
import { InputStream } from '../io';
export class ErrorInputStream extends InputStream {
public static Error = class extends Error {};
read(_amount: number): Promise<Buffer> {
throw new ErrorInputStream.Error();
}
skip(_amount: number): Promise<void> {
throw new ErrorInputStream.Error();
}
}
export class Uint8ArrayInputStream extends InputStream {
data: Uint8Array;
constructor(data: Uint8Array) {
super();
this.data = data;
}
read(amount: number): Promise<Buffer> {
const read_amount = Math.min(amount, this.data.length);
const read_data = this.data.slice(0, read_amount);
this.data = this.data.slice(read_amount);
return Promise.resolve(Buffer.from(read_data));
}
skip(amount: number): Promise<void> {
if (amount > this.data.length) {
throw Error('skipped past end of data');
}
this.data = this.data.slice(amount);
return Promise.resolve();
}
}

View File

@ -0,0 +1 @@
../../../rust/message-backup/tests/res/test-cases/valid/new_account.binproto.encrypted

View File

@ -96,6 +96,11 @@ interface Wrapper<T> {
readonly _nativeHandle: T;
}
interface MessageBackupValidationOutcome {
errorMessage: string | null;
unknownFieldMessages: Array<string>;
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
type Serialized<T> = Buffer;

View File

@ -160,7 +160,7 @@ impl From<signal_media::sanitize::mp4::Error> for SignalFfiError {
fn from(e: signal_media::sanitize::mp4::Error) -> SignalFfiError {
use signal_media::sanitize::mp4::Error;
match e {
Error::Io(e) => Self::Io(e.into()),
Error::Io(e) => Self::Io(e),
Error::Parse(e) => Self::Mp4SanitizeParse(e),
}
}
@ -171,7 +171,7 @@ impl From<signal_media::sanitize::webp::Error> for SignalFfiError {
fn from(e: signal_media::sanitize::webp::Error) -> SignalFfiError {
use signal_media::sanitize::webp::Error;
match e {
Error::Io(e) => Self::Io(e.into()),
Error::Io(e) => Self::Io(e),
Error::Parse(e) => Self::WebpSanitizeParse(e),
}
}

View File

@ -192,7 +192,7 @@ impl From<signal_media::sanitize::mp4::Error> for SignalJniError {
fn from(e: signal_media::sanitize::mp4::Error) -> Self {
use signal_media::sanitize::mp4::Error;
match e {
Error::Io(e) => Self::Io(e.into()),
Error::Io(e) => Self::Io(e),
Error::Parse(e) => Self::Mp4SanitizeParse(e),
}
}
@ -203,7 +203,7 @@ impl From<signal_media::sanitize::webp::Error> for SignalJniError {
fn from(e: signal_media::sanitize::webp::Error) -> Self {
use signal_media::sanitize::webp::Error;
match e {
Error::Io(e) => Self::Io(e.into()),
Error::Io(e) => Self::Io(e),
Error::Parse(e) => Self::WebpSanitizeParse(e),
}
}

View File

@ -3,21 +3,17 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
use libsignal_bridge_macros::*;
#[cfg(any(feature = "jni", feature = "ffi"))]
use futures_util::FutureExt as _;
use libsignal_bridge_macros::*;
use libsignal_message_backup::frame::{
LimitedReaderFactory, ValidationError as FrameValidationError,
};
use libsignal_message_backup::key::{BackupKey, MessageBackupKey as MessageBackupKeyInner};
#[cfg(any(feature = "jni", feature = "ffi"))]
use libsignal_message_backup::parse::ParseError;
#[cfg(any(feature = "jni", feature = "ffi"))]
use libsignal_message_backup::Error;
#[cfg(any(feature = "jni", feature = "ffi"))]
use libsignal_message_backup::{BackupReader, FoundUnknownField, ReadResult};
use libsignal_message_backup::{BackupReader, Error, FoundUnknownField, ReadResult};
use libsignal_protocol::Aci;
#[cfg(any(feature = "jni", feature = "ffi"))]
use crate::io::{AsyncInput, InputStream};
use crate::support::*;
use crate::*;
@ -34,13 +30,11 @@ fn MessageBackupKey_New(master_key: &[u8; 32], aci: Aci) -> MessageBackupKey {
}
#[derive(Debug)]
#[cfg(any(feature = "jni", feature = "ffi"))]
enum MessageBackupValidationError {
Io(std::io::Error),
String(String),
}
#[cfg(any(feature = "jni", feature = "ffi"))]
impl From<Error> for MessageBackupValidationError {
fn from(value: Error) -> Self {
match value {
@ -53,7 +47,6 @@ impl From<Error> for MessageBackupValidationError {
}
}
#[cfg(any(feature = "jni", feature = "ffi"))]
impl From<FrameValidationError> for MessageBackupValidationError {
fn from(value: FrameValidationError) -> Self {
match value {
@ -65,7 +58,6 @@ impl From<FrameValidationError> for MessageBackupValidationError {
}
}
#[cfg(any(feature = "jni", feature = "ffi"))]
pub struct MessageBackupValidationOutcome {
pub(crate) error_message: Option<String>,
pub(crate) found_unknown_fields: Vec<FoundUnknownField>,
@ -91,8 +83,8 @@ fn MessageBackupValidationOutcome_getUnknownFields(
.collect()
}
#[bridge_fn(node = false)]
fn MessageBackupValidator_Validate(
#[bridge_fn]
async fn MessageBackupValidator_Validate(
key: &MessageBackupKey,
first_stream: &mut dyn InputStream,
second_stream: &mut dyn InputStream,
@ -106,27 +98,19 @@ fn MessageBackupValidator_Validate(
];
let factory = LimitedReaderFactory::new(streams);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create runtime");
let (error, found_unknown_fields) =
match BackupReader::new_encrypted_compressed(key, factory).await {
Err(e) => (Some(e.into()), Vec::new()),
Ok(reader) => {
let ReadResult {
result,
found_unknown_fields,
} = reader.validate_all().await;
let (error, found_unknown_fields) = runtime.block_on(async move {
let reader = match BackupReader::new_encrypted_compressed(key, factory).await {
Ok(reader) => reader,
Err(e) => {
return (Some(e.into()), Vec::new());
(result.err().map(Into::into), found_unknown_fields)
}
};
let ReadResult {
result,
found_unknown_fields,
} = reader.validate_all().await;
(result.err().map(Into::into), found_unknown_fields)
});
let error_message = error
.map(|m| match m {
MessageBackupValidationError::Io(io) => Err(io),

View File

@ -811,6 +811,33 @@ impl<'a> ResultTypeInfo<'a> for () {
}
}
impl<'a> ResultTypeInfo<'a> for crate::message_backup::MessageBackupValidationOutcome {
type ResultType = JsObject;
fn convert_into(self, cx: &mut impl Context<'a>) -> JsResult<'a, Self::ResultType> {
let Self {
error_message,
found_unknown_fields,
} = self;
let error_message = error_message.convert_into(cx)?;
let unknown_field_messages = JsArray::new(
cx,
found_unknown_fields.len().try_into().expect("< u32::MAX"),
);
for (unknown, i) in found_unknown_fields.into_iter().zip(0..) {
let message = JsString::new(cx, unknown.to_string());
unknown_field_messages.set(cx, i, message)?;
}
let obj = JsObject::new(cx);
obj.set(cx, "errorMessage", error_message)?;
obj.set(cx, "unknownFieldMessages", unknown_field_messages)?;
Ok(obj)
}
}
impl<'a, T: Value> ResultTypeInfo<'a> for Handle<'a, T> {
type ResultType = T;
fn convert_into(self, _cx: &mut impl Context<'a>) -> NeonResult<Handle<'a, Self::ResultType>> {

View File

@ -2,13 +2,13 @@
// Copyright 2021 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
use super::*;
use std::fmt;
use paste::paste;
use signal_media::sanitize::mp4::{Error as Mp4Error, ParseError as Mp4ParseError};
use signal_media::sanitize::webp::{Error as WebpError, ParseError as WebpParseError};
use std::fmt;
use super::*;
const ERRORS_PROPERTY_NAME: &str = "Errors";
const ERROR_CLASS_NAME: &str = "LibSignalErrorBase";
@ -66,6 +66,56 @@ fn new_js_error<'a>(
}
}
/// [`std::error::Error`] implementer that wraps a thrown value.
#[derive(Debug)]
pub(crate) enum ThrownException {
Error(Root<JsError>),
String(String),
}
impl ThrownException {
pub(crate) fn from_value<'a>(
cx: &mut CallContext<'a, JsObject>,
error: Handle<'a, JsValue>,
) -> Self {
if let Ok(e) = error.downcast::<JsError, _>(cx) {
ThrownException::Error(e.root(cx))
} else if let Ok(e) = error.downcast::<JsString, _>(cx) {
ThrownException::String(e.value(cx))
} else {
ThrownException::String(
error
.to_string(cx)
.expect("can convert to string")
.value(cx),
)
}
}
}
impl Default for ThrownException {
fn default() -> Self {
Self::String(String::default())
}
}
impl From<&str> for ThrownException {
fn from(value: &str) -> Self {
Self::String(value.to_string())
}
}
impl std::fmt::Display for ThrownException {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Error(r) => write!(f, "{:?}", r),
Self::String(s) => write!(f, "{}", s),
}
}
}
impl std::error::Error for ThrownException {}
pub trait SignalNodeError: Sized + fmt::Display {
fn throw<'a>(
self,
@ -319,6 +369,32 @@ impl SignalNodeError for WebpError {
}
}
impl SignalNodeError for std::io::Error {
fn throw<'a>(
mut self,
cx: &mut impl Context<'a>,
_module: Handle<'a, JsObject>,
_operation_name: &str,
) -> JsResult<'a, JsValue> {
let exception = (self.kind() == std::io::ErrorKind::Other)
.then(|| {
self.get_mut()
.and_then(|e| e.downcast_mut::<ThrownException>())
})
.flatten()
.map(std::mem::take);
match exception {
Some(ThrownException::Error(e)) => {
let inner = e.into_inner(cx);
cx.throw(inner)
}
Some(ThrownException::String(s)) => cx.throw_error(s),
None => cx.throw_error(self.to_string()),
}
}
}
impl SignalNodeError for libsignal_net::cdsi::Error {
fn throw<'a>(
self,

View File

@ -26,7 +26,7 @@ pub struct PromiseSettler<T, E> {
impl<T, E> PromiseSettler<T, E>
where
T: for<'a> ResultTypeInfo<'a> + std::panic::UnwindSafe + Send + 'static,
E: SignalNodeError + std::panic::UnwindSafe + Send + 'static,
E: SignalNodeError + Send + 'static,
{
/// Stores the information necessary to complete a JavaScript Promise.
///
@ -68,7 +68,7 @@ impl<T, E, U: Finalize + Send + 'static> FutureResultReporter<T, E, U> {
impl<T, E, U> ResultReporter for FutureResultReporter<T, E, U>
where
T: for<'a> ResultTypeInfo<'a> + std::panic::UnwindSafe + Send + 'static,
E: SignalNodeError + std::panic::UnwindSafe + Send + 'static,
E: SignalNodeError + Send + 'static,
U: Finalize + Send + 'static,
{
type Receiver = PromiseSettler<T, E>;
@ -108,10 +108,12 @@ where
// But if the panic is in *our* code, the context will be fine.
// And if Neon panics, there's not much we can do about it.
let mut cx = std::panic::AssertUnwindSafe(&mut cx);
std::panic::catch_unwind(move || match result {
Ok(success) => Ok(success.convert_into(*cx)?.upcast()),
Err(failure) => failure.throw(*cx, error_module, node_function_name),
})
match result {
Ok(success) => std::panic::catch_unwind(move || {
Ok(success.convert_into(*cx)?.upcast())
}),
Err(failure) => Ok(failure.throw(*cx, error_module, node_function_name)),
}
});
settled_result.unwrap_or_else(|panic| {
@ -171,7 +173,7 @@ where
F: Future + std::panic::UnwindSafe + 'static,
F::Output: ResultReporter<Receiver = PromiseSettler<O, E>>,
O: for<'a> ResultTypeInfo<'a> + Send + std::panic::UnwindSafe + 'static,
E: SignalNodeError + Send + std::panic::UnwindSafe + 'static,
E: SignalNodeError + Send + 'static,
{
let (deferred, promise) = cx.promise();
let completer = PromiseSettler::new(cx, deferred, node_function_name);

View File

@ -34,7 +34,7 @@ impl NodeInputStream {
}
}
async fn do_read(&self, amount: u32) -> Result<Vec<u8>, String> {
async fn do_read(&self, amount: u32) -> Result<Vec<u8>, ThrownException> {
let stream_object_shared = self.stream_object.clone();
let read_data = JsFuture::get_promise(&self.js_channel, move |cx| {
let stream_object = stream_object_shared.to_inner(cx);
@ -49,10 +49,7 @@ impl NodeInputStream {
Ok(b) => Ok(b.as_slice(cx).to_vec()),
Err(_) => Err("unexpected result from _read".into()),
},
Err(error) => Err(error
.to_string(cx)
.expect("can convert to string")
.value(cx)),
Err(error) => Err(ThrownException::from_value(cx, error)),
})
.await?;
if read_data.is_empty() {
@ -61,7 +58,7 @@ impl NodeInputStream {
Ok(read_data)
}
async fn do_skip(&self, amount: u64) -> Result<(), String> {
async fn do_skip(&self, amount: u64) -> Result<(), ThrownException> {
let amount = amount as f64;
if amount > MAX_SAFE_JS_INTEGER {
return Err("skipped more than fits in JsInteger".into());
@ -81,10 +78,7 @@ impl NodeInputStream {
Ok(_) => Ok(()),
Err(_) => Err("unexpected result from _skip".into()),
},
Err(error) => Err(error
.to_string(cx)
.expect("can convert to string")
.value(cx)),
Err(error) => Err(ThrownException::from_value(cx, error)),
})
.await
}

View File

@ -4,25 +4,17 @@ use mediasan_common::error::ReportableError;
// cbindgen does not like this being called simply `Error`.
/// Error type returned by [`sanitize_*`].
#[derive(Clone, Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
pub enum SanitizerError<E> {
/// An IO error while reading the media.
#[error("{0}")]
Io(IoError),
Io(io::Error),
/// An error parsing the media stream.
#[error("{0}")]
Parse(ParseErrorReport<E>),
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("IO error: {kind}: {message}")]
/// A decomposed and stringified [`io::Error'].
pub struct IoError {
pub kind: io::ErrorKind,
pub message: String,
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("Parse error: {kind}\n{report}")]
/// A decomposed and stringified [`error_stack::Report<ParseError>`](mediasan_common::Error::Parse).
@ -37,13 +29,7 @@ pub struct ParseErrorReport<E> {
impl<E: Clone + ReportableError> From<mediasan_common::Error<E>> for SanitizerError<E> {
fn from(from: mediasan_common::Error<E>) -> Self {
match from {
mediasan_common::Error::Io(err) => Self::Io(IoError {
kind: err.kind(),
message: err
.into_inner()
.map(|err| format!("{err:?}"))
.unwrap_or_default(),
}),
mediasan_common::Error::Io(err) => Self::Io(err),
mediasan_common::Error::Parse(err) => Self::Parse(ParseErrorReport {
kind: err.get_ref().clone(),
report: format!("{err:?}"),
@ -51,21 +37,3 @@ impl<E: Clone + ReportableError> From<mediasan_common::Error<E>> for SanitizerEr
}
}
}
impl<E> From<io::Error> for SanitizerError<E> {
fn from(from: io::Error) -> Self {
Self::Io(IoError {
kind: from.kind(),
message: from
.into_inner()
.map(|err| format!("{err:?}"))
.unwrap_or_default(),
})
}
}
impl From<IoError> for io::Error {
fn from(from: IoError) -> Self {
io::Error::new(from.kind, from.message)
}
}