mirror of
https://github.com/signalapp/libsignal.git
synced 2024-09-19 19:42:19 +02:00
ffi: Expose cancellation to Swift
This commit is contained in:
parent
6d3c192208
commit
7dc63b99af
@ -635,6 +635,7 @@ public final class Native {
|
||||
public static native CompletableFuture<Integer> TESTING_FutureSuccess(long asyncRuntime, int input);
|
||||
public static native CompletableFuture<Void> TESTING_FutureThrowsCustomErrorType(long asyncRuntime);
|
||||
public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle);
|
||||
public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime);
|
||||
public static native String TESTING_OtherTestingHandleType_getValue(long handle);
|
||||
public static native void TESTING_PanicInBodyAsync(Object input);
|
||||
public static native CompletableFuture TESTING_PanicInBodyIo(long asyncRuntime, Object input);
|
||||
@ -655,6 +656,7 @@ public final class Native {
|
||||
public static native void TestingHandleType_Destroy(long handle);
|
||||
|
||||
public static native void TokioAsyncContext_Destroy(long handle);
|
||||
public static native void TokioAsyncContext_cancel(long context, long rawCancellationId);
|
||||
public static native long TokioAsyncContext_new();
|
||||
|
||||
public static native long UnidentifiedSenderMessageContent_Deserialize(byte[] data) throws Exception;
|
||||
|
2
node/Native.d.ts
vendored
2
node/Native.d.ts
vendored
@ -490,6 +490,7 @@ export function TESTING_FutureProducesOtherPointerType(asyncRuntime: Wrapper<Non
|
||||
export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<TestingHandleType>;
|
||||
export function TESTING_FutureSuccess(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<number>;
|
||||
export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime;
|
||||
export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper<TokioAsyncContext>): Promise<void>;
|
||||
export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper<OtherTestingHandleType>): string;
|
||||
export function TESTING_PanicInBodyAsync(_input: null): Promise<void>;
|
||||
export function TESTING_PanicInBodyIo(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, _input: null): Promise<void>;
|
||||
@ -506,6 +507,7 @@ export function TESTING_PanicOnReturnSync(_needsCleanup: null): null;
|
||||
export function TESTING_ProcessBytestringArray(input: Buffer[]): Buffer[];
|
||||
export function TESTING_ReturnStringArray(): string[];
|
||||
export function TESTING_TestingHandleType_getValue(handle: Wrapper<TestingHandleType>): number;
|
||||
export function TokioAsyncContext_cancel(context: Wrapper<TokioAsyncContext>, rawCancellationId: bigint): void;
|
||||
export function TokioAsyncContext_new(): TokioAsyncContext;
|
||||
export function UnidentifiedSenderMessageContent_Deserialize(data: Buffer): UnidentifiedSenderMessageContent;
|
||||
export function UnidentifiedSenderMessageContent_GetContentHint(m: Wrapper<UnidentifiedSenderMessageContent>): number;
|
||||
|
@ -64,6 +64,8 @@ renaming_overrides_prefixing = true
|
||||
"FfiOptionalServiceIdFixedWidthBinaryBytes" = "SignalOptionalServiceIdFixedWidthBinaryBytes"
|
||||
"CPromisec_void" = "SignalCPromiseRawPointer"
|
||||
|
||||
"RawCancellationId" = "SignalCancellationId"
|
||||
|
||||
# Avoid double-prefixing these
|
||||
"SignalFfiError" = "SignalFfiError"
|
||||
"SignalErrorCode" = "SignalErrorCode"
|
||||
|
@ -25,6 +25,7 @@ pub enum SignalErrorCode {
|
||||
InvalidArgument = 5,
|
||||
InvalidType = 6,
|
||||
InvalidUtf8String = 7,
|
||||
Cancelled = 8,
|
||||
|
||||
ProtobufError = 10,
|
||||
|
||||
@ -105,6 +106,8 @@ impl From<&SignalFfiError> for SignalErrorCode {
|
||||
|
||||
SignalFfiError::InvalidUtf8String => SignalErrorCode::InvalidUtf8String,
|
||||
|
||||
SignalFfiError::Cancelled => SignalErrorCode::Cancelled,
|
||||
|
||||
SignalFfiError::Signal(SignalProtocolError::InvalidProtobufEncoding) => {
|
||||
SignalErrorCode::ProtobufError
|
||||
}
|
||||
|
@ -167,10 +167,16 @@ fn bridge_io_body(
|
||||
|__cancel| async move {
|
||||
let __future = ffi::catch_unwind(std::panic::AssertUnwindSafe(async move {
|
||||
#(#input_loading)*
|
||||
let __result = #orig_name(#(#input_names),*).await;
|
||||
// If the original function can't fail, wrap the result in Ok for uniformity.
|
||||
// See TransformHelper::ok_if_needed.
|
||||
Ok(TransformHelper(__result).ok_if_needed()?.0)
|
||||
::tokio::select! {
|
||||
__result = #orig_name(#(#input_names),*) => {
|
||||
// If the original function can't fail, wrap the result in Ok for uniformity.
|
||||
// See TransformHelper::ok_if_needed.
|
||||
Ok(TransformHelper(__result).ok_if_needed()?.0)
|
||||
}
|
||||
_ = __cancel => {
|
||||
Err(ffi::SignalFfiError::Cancelled)
|
||||
}
|
||||
}
|
||||
}));
|
||||
ffi::FutureResultReporter::new(__future.await)
|
||||
}
|
||||
|
@ -54,6 +54,7 @@ pub enum SignalFfiError {
|
||||
NullPointer,
|
||||
InvalidUtf8String,
|
||||
InvalidArgument(String),
|
||||
Cancelled,
|
||||
InternalError(String),
|
||||
UnexpectedPanic(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
|
||||
}
|
||||
@ -102,6 +103,7 @@ impl fmt::Display for SignalFfiError {
|
||||
SignalFfiError::NullPointer => write!(f, "null pointer"),
|
||||
SignalFfiError::InvalidUtf8String => write!(f, "invalid UTF8 string"),
|
||||
SignalFfiError::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
|
||||
SignalFfiError::Cancelled => write!(f, "cancelled"),
|
||||
SignalFfiError::InternalError(msg) => write!(f, "internal error: {msg}"),
|
||||
SignalFfiError::UnexpectedPanic(e) => {
|
||||
write!(f, "unexpected panic: {}", describe_panic(e))
|
||||
|
@ -10,6 +10,8 @@ use futures_util::{FutureExt, TryFutureExt};
|
||||
|
||||
use std::future::Future;
|
||||
|
||||
pub type RawCancellationId = u64;
|
||||
|
||||
/// A C callback used to report the results of Rust futures.
|
||||
///
|
||||
/// cbindgen will produce independent C types like `SignalCPromisei32` and
|
||||
@ -26,6 +28,7 @@ pub struct CPromise<T> {
|
||||
context: *const std::ffi::c_void,
|
||||
),
|
||||
context: *const std::ffi::c_void,
|
||||
cancellation_id: RawCancellationId,
|
||||
}
|
||||
|
||||
/// Keeps track of the information necessary to report a promise result back to C.
|
||||
@ -121,7 +124,8 @@ pub fn run_future_on_runtime<R, F, O>(
|
||||
O: ResultTypeInfo + 'static,
|
||||
{
|
||||
let completion = PromiseCompleter { promise: *promise };
|
||||
runtime.run_future(future, completion);
|
||||
let cancellation_id = runtime.run_future(future, completion);
|
||||
promise.cancellation_id = cancellation_id.into();
|
||||
}
|
||||
|
||||
/// Catches panics that occur in `future` and converts them to [`SignalFfiError::UnexpectedPanic`].
|
||||
|
@ -38,6 +38,11 @@ fn TokioAsyncContext_new() -> TokioAsyncContext {
|
||||
}
|
||||
}
|
||||
|
||||
#[bridge_fn]
|
||||
fn TokioAsyncContext_cancel(context: &TokioAsyncContext, raw_cancellation_id: u64) {
|
||||
context.cancel(raw_cancellation_id.into())
|
||||
}
|
||||
|
||||
pub struct TokioContextCancellation(tokio::sync::oneshot::Receiver<()>);
|
||||
|
||||
impl Future for TokioContextCancellation {
|
||||
@ -52,6 +57,12 @@ impl Future for TokioContextCancellation {
|
||||
}
|
||||
}
|
||||
|
||||
// Not ideal! tokio doesn't promise that a oneshot::Receiver is in fact panic-safe.
|
||||
// But its interior mutable state is only modified by the Receiver while it's being polled,
|
||||
// and that means a panic would have to happen inside Receiver itself to cause a problem.
|
||||
// Combined with our payload type being (), it's unlikely this can happen in practice.
|
||||
impl std::panic::UnwindSafe for TokioContextCancellation {}
|
||||
|
||||
impl AsyncRuntimeBase for TokioAsyncContext {
|
||||
fn cancel(&self, cancellation_token: CancellationId) {
|
||||
if cancellation_token == CancellationId::NotSupported {
|
||||
|
@ -49,6 +49,11 @@ async fn TESTING_CdsiLookupResponseConvert() -> LookupResponse {
|
||||
}
|
||||
}
|
||||
|
||||
#[bridge_io(TokioAsyncContext)]
|
||||
async fn TESTING_OnlyCompletesByCancellation() {
|
||||
std::future::pending::<()>().await
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Copy, Clone, strum::EnumString)]
|
||||
enum TestingCdsiLookupError {
|
||||
|
@ -21,5 +21,7 @@ file_header:
|
||||
inclusive_language:
|
||||
override_allowed_terms:
|
||||
- master
|
||||
nesting:
|
||||
type_level: 2
|
||||
excluded:
|
||||
- .build/**
|
||||
|
@ -3,6 +3,7 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import SignalFfi
|
||||
|
||||
/// Used to check types for values produced asynchronously by Rust.
|
||||
@ -15,6 +16,12 @@ import SignalFfi
|
||||
/// Note that implementing this is **unchecked;** make sure you match up the types correctly!
|
||||
internal protocol PromiseStruct {
|
||||
associatedtype Result
|
||||
|
||||
// We can't declare the 'complete' callback without an associated type,
|
||||
// and that associated type won't get inferred for an imported struct (not sure why).
|
||||
// So we'd have to write out the callback type for every conformer.
|
||||
var context: UnsafeRawPointer! { get }
|
||||
var cancellation_id: SignalCancellationId { get }
|
||||
}
|
||||
|
||||
extension SignalCPromisebool: PromiseStruct {
|
||||
@ -94,7 +101,7 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
||||
///
|
||||
/// This retains `self`, to be used as the C context pointer for the callback.
|
||||
/// You must ensure that either the callback is called, or the result is passed to
|
||||
/// ``destroyUncompletedPromiseStruct(_:)``.
|
||||
/// ``cleanUpUncompletedPromiseStruct(_:)``.
|
||||
func makePromiseStruct() -> Promise {
|
||||
typealias RawPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafeRawPointer?, _ context: UnsafeRawPointer?) -> Void
|
||||
let completeOpaque: RawPromiseCallback = { error, value, context in
|
||||
@ -108,7 +115,7 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
||||
// because of how `self.completeUnsafe` is initialized.
|
||||
// So first we build a promise struct---it doesn't matter which one---by reinterpreting the callback...
|
||||
typealias RawPointerPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafePointer<UnsafeRawPointer?>?, _ context: UnsafeRawPointer?) -> Void
|
||||
let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque())
|
||||
let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque(), cancellation_id: 0)
|
||||
|
||||
// ...And then we reinterpret the entire struct, because all promise structs *also* have the same layout.
|
||||
// (Which we at least check a little bit here.)
|
||||
@ -119,11 +126,8 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
||||
return unsafeBitCast(rawPromiseStruct, to: Promise.self)
|
||||
}
|
||||
|
||||
func destroyUncompletedPromiseStruct(_ promiseStruct: Promise) {
|
||||
// Double-check that all promise structs have the same layout, then reverse what we did above.
|
||||
precondition(MemoryLayout<SignalCPromiseRawPointer>.size == MemoryLayout<Promise>.size)
|
||||
let rawPromiseStruct = unsafeBitCast(promiseStruct, to: SignalCPromiseRawPointer.self)
|
||||
Unmanaged<CompleterBase>.fromOpaque(rawPromiseStruct.context!).release()
|
||||
func cleanUpUncompletedPromiseStruct(_ promiseStruct: Promise) {
|
||||
Unmanaged<CompleterBase>.fromOpaque(promiseStruct.context!).release()
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,8 +140,12 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
||||
/// signal_do_async_work($0, someInput, someOtherInput)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Prefer ``TokioAsyncContext/invokeAsyncFunction(_:)`` if using a TokioAsyncContext;
|
||||
/// that method supports cancellation.
|
||||
internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
||||
_ body: (UnsafeMutablePointer<Promise>) -> SignalFfiErrorRef?
|
||||
_ body: (UnsafeMutablePointer<Promise>) -> SignalFfiErrorRef?,
|
||||
saveCancellationId: (SignalCancellationId) -> Void = { _ in }
|
||||
) async throws -> Promise.Result {
|
||||
try await withCheckedThrowingContinuation { continuation in
|
||||
let completer = Completer<Promise>(continuation: continuation)
|
||||
@ -145,9 +153,10 @@ internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
||||
let startResult = body(&promiseStruct)
|
||||
if let error = startResult {
|
||||
// Our completion callback is never going to get called, so we need to balance the `passRetained` above.
|
||||
completer.destroyUncompletedPromiseStruct(promiseStruct)
|
||||
completer.cleanUpUncompletedPromiseStruct(promiseStruct)
|
||||
completer.completeUnsafe(error, nil)
|
||||
return
|
||||
}
|
||||
saveCancellationId(promiseStruct.cancellation_id)
|
||||
}
|
||||
}
|
||||
|
@ -139,11 +139,9 @@ public class ChatService: NativeHandleOwner {
|
||||
/// Calling this method will result in starting to accept incoming requests from the Chat Service.
|
||||
@discardableResult
|
||||
public func connectAuthenticated() async throws -> DebugInfo {
|
||||
let rawDebugInfo = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService)
|
||||
}
|
||||
let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService)
|
||||
}
|
||||
}
|
||||
return DebugInfo(consuming: rawDebugInfo)
|
||||
@ -155,11 +153,9 @@ public class ChatService: NativeHandleOwner {
|
||||
/// reconnect attempt will be made.
|
||||
@discardableResult
|
||||
public func connectUnauthenticated() async throws -> DebugInfo {
|
||||
let rawDebugInfo = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService)
|
||||
}
|
||||
let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService)
|
||||
}
|
||||
}
|
||||
return DebugInfo(consuming: rawDebugInfo)
|
||||
@ -174,11 +170,9 @@ public class ChatService: NativeHandleOwner {
|
||||
///
|
||||
/// Returns when the disconnection is complete.
|
||||
public func disconnect() async throws {
|
||||
_ = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
signal_chat_service_disconnect(promise, tokioAsyncContext, chatService)
|
||||
}
|
||||
_ = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
signal_chat_service_disconnect(promise, tokioAsyncContext, chatService)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -190,12 +184,10 @@ public class ChatService: NativeHandleOwner {
|
||||
public func unauthenticatedSend(_ request: Request) async throws -> Response {
|
||||
let internalRequest = try InternalRequest(request)
|
||||
let timeoutMillis = request.timeoutMillis
|
||||
let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -212,12 +204,10 @@ public class ChatService: NativeHandleOwner {
|
||||
public func unauthenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
|
||||
let internalRequest = try InternalRequest(request)
|
||||
let timeoutMillis = request.timeoutMillis
|
||||
let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -231,12 +221,10 @@ public class ChatService: NativeHandleOwner {
|
||||
public func authenticatedSend(_ request: Request) async throws -> Response {
|
||||
let internalRequest = try InternalRequest(request)
|
||||
let timeoutMillis = request.timeoutMillis
|
||||
let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -253,12 +241,10 @@ public class ChatService: NativeHandleOwner {
|
||||
public func authenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
|
||||
let internalRequest = try InternalRequest(request)
|
||||
let timeoutMillis = request.timeoutMillis
|
||||
let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in
|
||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||
withNativeHandle { chatService in
|
||||
internalRequest.withNativeHandle { request in
|
||||
signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -80,6 +80,9 @@ internal func checkError(_ error: SignalFfiErrorRef?) throws {
|
||||
defer { signal_error_free(error) }
|
||||
|
||||
switch SignalErrorCode(errType) {
|
||||
case SignalErrorCodeCancelled:
|
||||
// Special case: don't use SignalError for this one.
|
||||
throw CancellationError()
|
||||
case SignalErrorCodeInvalidState:
|
||||
throw SignalError.invalidState(errStr)
|
||||
case SignalErrorCodeInternalError:
|
||||
|
@ -108,12 +108,10 @@ public class Net {
|
||||
auth: Auth,
|
||||
request: CdsiLookupRequest
|
||||
) async throws -> CdsiLookup {
|
||||
let handle: OpaquePointer = try await invokeAsyncFunction { promise in
|
||||
self.asyncContext.withNativeHandle { asyncContext in
|
||||
self.connectionManager.withNativeHandle { connectionManager in
|
||||
request.withNativeHandle { request in
|
||||
signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request)
|
||||
}
|
||||
let handle: OpaquePointer = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||
self.connectionManager.withNativeHandle { connectionManager in
|
||||
request.withNativeHandle { request in
|
||||
signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -264,11 +262,9 @@ public class CdsiLookup {
|
||||
/// `SignalError.networkError` for a network-level connectivity issue,
|
||||
/// `SignalError.networkProtocolError` for a CDSI or attested connection protocol issue.
|
||||
public func complete() async throws -> CdsiLookupResponse {
|
||||
let response: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in
|
||||
self.asyncContext.withNativeHandle { asyncContext in
|
||||
self.native.withNativeHandle { handle in
|
||||
signal_cdsi_lookup_complete(promise, asyncContext, handle)
|
||||
}
|
||||
let response: SignalFfiCdsiLookupResponse = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||
self.native.withNativeHandle { handle in
|
||||
signal_cdsi_lookup_complete(promise, asyncContext, handle)
|
||||
}
|
||||
}
|
||||
|
||||
@ -362,18 +358,6 @@ extension CdsiLookupResponseEntry: Equatable {
|
||||
}
|
||||
}
|
||||
|
||||
internal class TokioAsyncContext: NativeHandleOwner {
|
||||
convenience init() {
|
||||
var handle: OpaquePointer?
|
||||
failOnError(signal_tokio_async_context_new(&handle))
|
||||
self.init(owned: handle!)
|
||||
}
|
||||
|
||||
override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? {
|
||||
signal_tokio_async_context_destroy(handle)
|
||||
}
|
||||
}
|
||||
|
||||
internal class ConnectionManager: NativeHandleOwner {
|
||||
convenience init(env: Net.Environment, userAgent: String) {
|
||||
var handle: OpaquePointer?
|
||||
|
@ -93,21 +93,19 @@ public class Svr3Client {
|
||||
maxTries: UInt32,
|
||||
auth: Auth
|
||||
) async throws -> [UInt8] {
|
||||
let output = try await invokeAsyncFunction { promise in
|
||||
self.asyncContext.withNativeHandle { asyncContext in
|
||||
self.connectionManager.withNativeHandle { connectionManager in
|
||||
secret.withUnsafeBorrowedBuffer { secretBuffer in
|
||||
signal_svr3_backup(
|
||||
promise,
|
||||
asyncContext,
|
||||
connectionManager,
|
||||
secretBuffer,
|
||||
password,
|
||||
maxTries,
|
||||
auth.username,
|
||||
auth.password
|
||||
)
|
||||
}
|
||||
let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||
self.connectionManager.withNativeHandle { connectionManager in
|
||||
secret.withUnsafeBorrowedBuffer { secretBuffer in
|
||||
signal_svr3_backup(
|
||||
promise,
|
||||
asyncContext,
|
||||
connectionManager,
|
||||
secretBuffer,
|
||||
password,
|
||||
maxTries,
|
||||
auth.username,
|
||||
auth.password
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -159,20 +157,18 @@ public class Svr3Client {
|
||||
shareSet: some ContiguousBytes,
|
||||
auth: Auth
|
||||
) async throws -> [UInt8] {
|
||||
let output = try await invokeAsyncFunction { promise in
|
||||
self.asyncContext.withNativeHandle { asyncContext in
|
||||
self.connectionManager.withNativeHandle { connectionManager in
|
||||
shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in
|
||||
signal_svr3_restore(
|
||||
promise,
|
||||
asyncContext,
|
||||
connectionManager,
|
||||
password,
|
||||
shareSetBuffer,
|
||||
auth.username,
|
||||
auth.password
|
||||
)
|
||||
}
|
||||
let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||
self.connectionManager.withNativeHandle { connectionManager in
|
||||
shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in
|
||||
signal_svr3_restore(
|
||||
promise,
|
||||
asyncContext,
|
||||
connectionManager,
|
||||
password,
|
||||
shareSetBuffer,
|
||||
auth.username,
|
||||
auth.password
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
127
swift/Sources/LibSignalClient/TokioAsyncContext.swift
Normal file
127
swift/Sources/LibSignalClient/TokioAsyncContext.swift
Normal file
@ -0,0 +1,127 @@
|
||||
//
|
||||
// Copyright 2024 Signal Messenger, LLC.
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import SignalFfi
|
||||
|
||||
#if canImport(SignalCoreKit)
|
||||
import SignalCoreKit
|
||||
#endif
|
||||
|
||||
internal class TokioAsyncContext: NativeHandleOwner {
|
||||
convenience init() {
|
||||
var handle: OpaquePointer?
|
||||
failOnError(signal_tokio_async_context_new(&handle))
|
||||
self.init(owned: handle!)
|
||||
}
|
||||
|
||||
override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? {
|
||||
signal_tokio_async_context_destroy(handle)
|
||||
}
|
||||
|
||||
/// A thread-safe helper for translating Swift task cancellations into calls to
|
||||
/// `signal_tokio_async_context_cancel`.
|
||||
private class CancellationHandoffHelper {
|
||||
enum State {
|
||||
case initial
|
||||
case started(SignalCancellationId)
|
||||
case cancelled
|
||||
}
|
||||
|
||||
// Emulates Rust's `Mutex<State>` (and the containing class is providing an `Arc`)
|
||||
// Unfortunately, doing this in Swift requires a separate allocation for the lock today.
|
||||
var state: State = .initial
|
||||
var lock = NSLock()
|
||||
|
||||
let context: TokioAsyncContext
|
||||
|
||||
init(context: TokioAsyncContext) {
|
||||
self.context = context
|
||||
}
|
||||
|
||||
func setCancellationId(_ id: SignalCancellationId) {
|
||||
// Ideally we would use NSLock.withLock here, but that's not available on Linux,
|
||||
// which we still support for development and CI.
|
||||
do {
|
||||
self.lock.lock()
|
||||
defer { self.lock.unlock() }
|
||||
|
||||
switch self.state {
|
||||
case .initial:
|
||||
self.state = .started(id)
|
||||
fallthrough
|
||||
case .started(_):
|
||||
return
|
||||
case .cancelled:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't early-exit, we're already cancelled.
|
||||
self.cancel(id)
|
||||
}
|
||||
|
||||
func cancel() {
|
||||
let cancelId: SignalCancellationId
|
||||
// Ideally we would use NSLock.withLock here, but that's not available on Linux,
|
||||
// which we still support for development and CI.
|
||||
do {
|
||||
self.lock.lock()
|
||||
defer { self.lock.unlock() }
|
||||
|
||||
defer { state = .cancelled }
|
||||
switch self.state {
|
||||
case .started(let id):
|
||||
cancelId = id
|
||||
case .initial, .cancelled:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't early-exit, the task has already started and we need to cancel it.
|
||||
self.cancel(cancelId)
|
||||
}
|
||||
|
||||
func cancel(_ id: SignalCancellationId) {
|
||||
do {
|
||||
try self.context.withNativeHandle {
|
||||
try checkError(signal_tokio_async_context_cancel($0, id))
|
||||
}
|
||||
} catch {
|
||||
#if canImport(SignalCoreKit)
|
||||
Logger.warn("failed to cancel libsignal task \(id): \(error)")
|
||||
#else
|
||||
NSLog("failed to cancel libsignal task %ld: %@", id, "\(error)")
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides a callback and context for calling Promise-based libsignal\_ffi functions, with cancellation supported.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// let result = try await asyncContext.invokeAsyncFunction { promise, runtime in
|
||||
/// signal_do_async_work(promise, runtime, someInput, someOtherInput)
|
||||
/// }
|
||||
/// ```
|
||||
internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
||||
_ body: (UnsafeMutablePointer<Promise>, OpaquePointer?) -> SignalFfiErrorRef?
|
||||
) async throws -> Promise.Result {
|
||||
let cancellationHelper = CancellationHandoffHelper(context: self)
|
||||
return try await withTaskCancellationHandler(operation: {
|
||||
try await LibSignalClient.invokeAsyncFunction({ promise in
|
||||
withNativeHandle { handle in
|
||||
body(promise, handle)
|
||||
}
|
||||
}, saveCancellationId: {
|
||||
cancellationHelper.setCancellationId($0)
|
||||
})
|
||||
}, onCancel: {
|
||||
cancellationHelper.cancel()
|
||||
})
|
||||
}
|
||||
}
|
@ -147,6 +147,7 @@ typedef enum {
|
||||
SignalErrorCodeInvalidArgument = 5,
|
||||
SignalErrorCodeInvalidType = 6,
|
||||
SignalErrorCodeInvalidUtf8String = 7,
|
||||
SignalErrorCodeCancelled = 8,
|
||||
SignalErrorCodeProtobufError = 10,
|
||||
SignalErrorCodeLegacyCiphertextVersion = 21,
|
||||
SignalErrorCodeUnknownCiphertextVersion = 22,
|
||||
@ -493,6 +494,8 @@ typedef struct {
|
||||
size_t length;
|
||||
} SignalBorrowedSliceOfBuffers;
|
||||
|
||||
typedef uint64_t SignalCancellationId;
|
||||
|
||||
/**
|
||||
* A C callback used to report the results of Rust futures.
|
||||
*
|
||||
@ -505,6 +508,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const SignalOwnedBuffer *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseOwnedBufferOfc_uchar;
|
||||
|
||||
/**
|
||||
@ -519,6 +523,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const bool *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromisebool;
|
||||
|
||||
typedef struct {
|
||||
@ -540,6 +545,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const SignalFfiChatServiceDebugInfo *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseFfiChatServiceDebugInfo;
|
||||
|
||||
typedef struct {
|
||||
@ -561,6 +567,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const SignalFfiChatResponse *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseFfiChatResponse;
|
||||
|
||||
typedef struct {
|
||||
@ -580,6 +587,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const SignalFfiResponseAndDebugInfo *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseFfiResponseAndDebugInfo;
|
||||
|
||||
/**
|
||||
@ -594,6 +602,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, SignalCdsiLookup *const *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseCdsiLookup;
|
||||
|
||||
typedef struct {
|
||||
@ -613,6 +622,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const SignalFfiCdsiLookupResponse *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseFfiCdsiLookupResponse;
|
||||
|
||||
typedef SignalBytestringArray SignalStringArray;
|
||||
@ -641,6 +651,7 @@ typedef SignalInputStream SignalSyncInputStream;
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const int32_t *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromisei32;
|
||||
|
||||
/**
|
||||
@ -655,6 +666,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, SignalTestingHandleType *const *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseTestingHandleType;
|
||||
|
||||
/**
|
||||
@ -669,6 +681,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, SignalOtherTestingHandleType *const *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseOtherTestingHandleType;
|
||||
|
||||
/**
|
||||
@ -683,6 +696,7 @@ typedef struct {
|
||||
typedef struct {
|
||||
void (*complete)(SignalFfiError *error, const void *const *result, const void *context);
|
||||
const void *context;
|
||||
SignalCancellationId cancellation_id;
|
||||
} SignalCPromiseRawPointer;
|
||||
|
||||
typedef uint8_t SignalRandomnessBytes[SignalRANDOMNESS_LEN];
|
||||
@ -1527,6 +1541,8 @@ SignalFfiError *signal_tokio_async_context_destroy(SignalTokioAsyncContext *p);
|
||||
|
||||
SignalFfiError *signal_tokio_async_context_new(SignalTokioAsyncContext **out);
|
||||
|
||||
SignalFfiError *signal_tokio_async_context_cancel(const SignalTokioAsyncContext *context, uint64_t raw_cancellation_id);
|
||||
|
||||
SignalFfiError *signal_pin_hash_destroy(SignalPinHash *p);
|
||||
|
||||
SignalFfiError *signal_pin_hash_clone(SignalPinHash **new_obj, const SignalPinHash *obj);
|
||||
@ -1685,6 +1701,8 @@ SignalFfiError *signal_testing_process_bytestring_array(SignalBytestringArray *o
|
||||
|
||||
SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime);
|
||||
|
||||
SignalFfiError *signal_testing_only_completes_by_cancellation(SignalCPromisebool *promise, const SignalTokioAsyncContext *async_runtime);
|
||||
|
||||
SignalFfiError *signal_testing_cdsi_lookup_error_convert(const char *error_description);
|
||||
|
||||
SignalFfiError *signal_testing_chat_service_error_convert(void);
|
||||
|
@ -18,7 +18,7 @@ extension SignalCPromiseOtherTestingHandleType: PromiseStruct {
|
||||
public typealias Result = OpaquePointer
|
||||
}
|
||||
|
||||
final class AsyncTests: XCTestCase {
|
||||
final class AsyncTests: TestCaseBase {
|
||||
func testSuccess() async throws {
|
||||
let result: Int32 = try await invokeAsyncFunction {
|
||||
signal_testing_future_success($0, OpaquePointer(bitPattern: -1), 21)
|
||||
@ -65,6 +65,48 @@ final class AsyncTests: XCTestCase {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func testTokioCancellation() async throws {
|
||||
let asyncContext = TokioAsyncContext()
|
||||
|
||||
// We can replace this with AsyncStream.makeStream(...) when we update our builder.
|
||||
var _continuation: AsyncStream<Int>.Continuation!
|
||||
let completionStream = AsyncStream<Int> { _continuation = $0 }
|
||||
let continuation = _continuation!
|
||||
|
||||
let makeTask = { (id: Int) in
|
||||
Task {
|
||||
defer {
|
||||
// Do this unconditionally so that the outer test procedure doesn't get stuck.
|
||||
continuation.yield(id)
|
||||
}
|
||||
do {
|
||||
_ = try await asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||
signal_testing_only_completes_by_cancellation(promise, asyncContext)
|
||||
}
|
||||
} catch is CancellationError {
|
||||
// Okay, expected.
|
||||
} catch {
|
||||
XCTFail("incorrect error: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
let task1 = makeTask(1)
|
||||
let task2 = makeTask(2)
|
||||
|
||||
var completionIter = completionStream.makeAsyncIterator()
|
||||
|
||||
// Complete the tasks in opposite order of starting them,
|
||||
// to make it less likely to get this result by accident.
|
||||
// This is not a rigorous test, only a simple exercise of the feature.
|
||||
task2.cancel()
|
||||
let firstCompletionId = await completionIter.next()
|
||||
XCTAssertEqual(firstCompletionId, 2)
|
||||
|
||||
task1.cancel()
|
||||
let secondCompletionId = await completionIter.next()
|
||||
XCTAssertEqual(secondCompletionId, 1)
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -23,10 +23,8 @@ final class NetTests: XCTestCase {
|
||||
|
||||
let asyncContext = TokioAsyncContext()
|
||||
|
||||
let output: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in
|
||||
asyncContext.withNativeHandle { asyncContext in
|
||||
signal_testing_cdsi_lookup_response_convert(promise, asyncContext)
|
||||
}
|
||||
let output: SignalFfiCdsiLookupResponse = try await asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||
signal_testing_cdsi_lookup_response_convert(promise, asyncContext)
|
||||
}
|
||||
XCTAssertEqual(output.debug_permits_used, 123)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user