diff --git a/java/shared/java/org/signal/libsignal/internal/Native.java b/java/shared/java/org/signal/libsignal/internal/Native.java index ec2b234b..3b5aa8b7 100644 --- a/java/shared/java/org/signal/libsignal/internal/Native.java +++ b/java/shared/java/org/signal/libsignal/internal/Native.java @@ -635,6 +635,7 @@ public final class Native { public static native CompletableFuture TESTING_FutureSuccess(long asyncRuntime, int input); public static native CompletableFuture TESTING_FutureThrowsCustomErrorType(long asyncRuntime); public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle); + public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime); public static native String TESTING_OtherTestingHandleType_getValue(long handle); public static native void TESTING_PanicInBodyAsync(Object input); public static native CompletableFuture TESTING_PanicInBodyIo(long asyncRuntime, Object input); @@ -655,6 +656,7 @@ public final class Native { public static native void TestingHandleType_Destroy(long handle); public static native void TokioAsyncContext_Destroy(long handle); + public static native void TokioAsyncContext_cancel(long context, long rawCancellationId); public static native long TokioAsyncContext_new(); public static native long UnidentifiedSenderMessageContent_Deserialize(byte[] data) throws Exception; diff --git a/node/Native.d.ts b/node/Native.d.ts index 5ab31ba6..bb32c168 100644 --- a/node/Native.d.ts +++ b/node/Native.d.ts @@ -490,6 +490,7 @@ export function TESTING_FutureProducesOtherPointerType(asyncRuntime: Wrapper, input: number): Promise; export function TESTING_FutureSuccess(asyncRuntime: Wrapper, input: number): Promise; export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime; +export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper): Promise; export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper): string; export function TESTING_PanicInBodyAsync(_input: null): Promise; export function TESTING_PanicInBodyIo(asyncRuntime: Wrapper, _input: null): Promise; @@ -506,6 +507,7 @@ export function TESTING_PanicOnReturnSync(_needsCleanup: null): null; export function TESTING_ProcessBytestringArray(input: Buffer[]): Buffer[]; export function TESTING_ReturnStringArray(): string[]; export function TESTING_TestingHandleType_getValue(handle: Wrapper): number; +export function TokioAsyncContext_cancel(context: Wrapper, rawCancellationId: bigint): void; export function TokioAsyncContext_new(): TokioAsyncContext; export function UnidentifiedSenderMessageContent_Deserialize(data: Buffer): UnidentifiedSenderMessageContent; export function UnidentifiedSenderMessageContent_GetContentHint(m: Wrapper): number; diff --git a/rust/bridge/ffi/cbindgen.toml b/rust/bridge/ffi/cbindgen.toml index 78e30bf9..c10bb65a 100644 --- a/rust/bridge/ffi/cbindgen.toml +++ b/rust/bridge/ffi/cbindgen.toml @@ -64,6 +64,8 @@ renaming_overrides_prefixing = true "FfiOptionalServiceIdFixedWidthBinaryBytes" = "SignalOptionalServiceIdFixedWidthBinaryBytes" "CPromisec_void" = "SignalCPromiseRawPointer" +"RawCancellationId" = "SignalCancellationId" + # Avoid double-prefixing these "SignalFfiError" = "SignalFfiError" "SignalErrorCode" = "SignalErrorCode" diff --git a/rust/bridge/ffi/src/util.rs b/rust/bridge/ffi/src/util.rs index 226c4dc4..23ab5a9a 100644 --- a/rust/bridge/ffi/src/util.rs +++ b/rust/bridge/ffi/src/util.rs @@ -25,6 +25,7 @@ pub enum SignalErrorCode { InvalidArgument = 5, InvalidType = 6, InvalidUtf8String = 7, + Cancelled = 8, ProtobufError = 10, @@ -105,6 +106,8 @@ impl From<&SignalFfiError> for SignalErrorCode { SignalFfiError::InvalidUtf8String => SignalErrorCode::InvalidUtf8String, + SignalFfiError::Cancelled => SignalErrorCode::Cancelled, + SignalFfiError::Signal(SignalProtocolError::InvalidProtobufEncoding) => { SignalErrorCode::ProtobufError } diff --git a/rust/bridge/shared/macros/src/ffi.rs b/rust/bridge/shared/macros/src/ffi.rs index b0705478..cf5ae8c1 100644 --- a/rust/bridge/shared/macros/src/ffi.rs +++ b/rust/bridge/shared/macros/src/ffi.rs @@ -167,10 +167,16 @@ fn bridge_io_body( |__cancel| async move { let __future = ffi::catch_unwind(std::panic::AssertUnwindSafe(async move { #(#input_loading)* - let __result = #orig_name(#(#input_names),*).await; - // If the original function can't fail, wrap the result in Ok for uniformity. - // See TransformHelper::ok_if_needed. - Ok(TransformHelper(__result).ok_if_needed()?.0) + ::tokio::select! { + __result = #orig_name(#(#input_names),*) => { + // If the original function can't fail, wrap the result in Ok for uniformity. + // See TransformHelper::ok_if_needed. + Ok(TransformHelper(__result).ok_if_needed()?.0) + } + _ = __cancel => { + Err(ffi::SignalFfiError::Cancelled) + } + } })); ffi::FutureResultReporter::new(__future.await) } diff --git a/rust/bridge/shared/src/ffi/error.rs b/rust/bridge/shared/src/ffi/error.rs index 66d83d90..953b8ab6 100644 --- a/rust/bridge/shared/src/ffi/error.rs +++ b/rust/bridge/shared/src/ffi/error.rs @@ -54,6 +54,7 @@ pub enum SignalFfiError { NullPointer, InvalidUtf8String, InvalidArgument(String), + Cancelled, InternalError(String), UnexpectedPanic(std::boxed::Box), } @@ -102,6 +103,7 @@ impl fmt::Display for SignalFfiError { SignalFfiError::NullPointer => write!(f, "null pointer"), SignalFfiError::InvalidUtf8String => write!(f, "invalid UTF8 string"), SignalFfiError::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"), + SignalFfiError::Cancelled => write!(f, "cancelled"), SignalFfiError::InternalError(msg) => write!(f, "internal error: {msg}"), SignalFfiError::UnexpectedPanic(e) => { write!(f, "unexpected panic: {}", describe_panic(e)) diff --git a/rust/bridge/shared/src/ffi/futures.rs b/rust/bridge/shared/src/ffi/futures.rs index ad3722de..d9e079de 100644 --- a/rust/bridge/shared/src/ffi/futures.rs +++ b/rust/bridge/shared/src/ffi/futures.rs @@ -10,6 +10,8 @@ use futures_util::{FutureExt, TryFutureExt}; use std::future::Future; +pub type RawCancellationId = u64; + /// A C callback used to report the results of Rust futures. /// /// cbindgen will produce independent C types like `SignalCPromisei32` and @@ -26,6 +28,7 @@ pub struct CPromise { context: *const std::ffi::c_void, ), context: *const std::ffi::c_void, + cancellation_id: RawCancellationId, } /// Keeps track of the information necessary to report a promise result back to C. @@ -121,7 +124,8 @@ pub fn run_future_on_runtime( O: ResultTypeInfo + 'static, { let completion = PromiseCompleter { promise: *promise }; - runtime.run_future(future, completion); + let cancellation_id = runtime.run_future(future, completion); + promise.cancellation_id = cancellation_id.into(); } /// Catches panics that occur in `future` and converts them to [`SignalFfiError::UnexpectedPanic`]. diff --git a/rust/bridge/shared/src/net/tokio.rs b/rust/bridge/shared/src/net/tokio.rs index cb117adf..08572793 100644 --- a/rust/bridge/shared/src/net/tokio.rs +++ b/rust/bridge/shared/src/net/tokio.rs @@ -38,6 +38,11 @@ fn TokioAsyncContext_new() -> TokioAsyncContext { } } +#[bridge_fn] +fn TokioAsyncContext_cancel(context: &TokioAsyncContext, raw_cancellation_id: u64) { + context.cancel(raw_cancellation_id.into()) +} + pub struct TokioContextCancellation(tokio::sync::oneshot::Receiver<()>); impl Future for TokioContextCancellation { @@ -52,6 +57,12 @@ impl Future for TokioContextCancellation { } } +// Not ideal! tokio doesn't promise that a oneshot::Receiver is in fact panic-safe. +// But its interior mutable state is only modified by the Receiver while it's being polled, +// and that means a panic would have to happen inside Receiver itself to cause a problem. +// Combined with our payload type being (), it's unlikely this can happen in practice. +impl std::panic::UnwindSafe for TokioContextCancellation {} + impl AsyncRuntimeBase for TokioAsyncContext { fn cancel(&self, cancellation_token: CancellationId) { if cancellation_token == CancellationId::NotSupported { diff --git a/rust/bridge/shared/src/testing/net.rs b/rust/bridge/shared/src/testing/net.rs index 9bf64bc7..44629498 100644 --- a/rust/bridge/shared/src/testing/net.rs +++ b/rust/bridge/shared/src/testing/net.rs @@ -49,6 +49,11 @@ async fn TESTING_CdsiLookupResponseConvert() -> LookupResponse { } } +#[bridge_io(TokioAsyncContext)] +async fn TESTING_OnlyCompletesByCancellation() { + std::future::pending::<()>().await +} + #[repr(u8)] #[derive(Copy, Clone, strum::EnumString)] enum TestingCdsiLookupError { diff --git a/swift/.swiftlint.yml b/swift/.swiftlint.yml index 6dbbe3af..1310620f 100644 --- a/swift/.swiftlint.yml +++ b/swift/.swiftlint.yml @@ -21,5 +21,7 @@ file_header: inclusive_language: override_allowed_terms: - master +nesting: + type_level: 2 excluded: - .build/** diff --git a/swift/Sources/LibSignalClient/AsyncUtils.swift b/swift/Sources/LibSignalClient/AsyncUtils.swift index 1a743760..dd000d87 100644 --- a/swift/Sources/LibSignalClient/AsyncUtils.swift +++ b/swift/Sources/LibSignalClient/AsyncUtils.swift @@ -3,6 +3,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // +import Foundation import SignalFfi /// Used to check types for values produced asynchronously by Rust. @@ -15,6 +16,12 @@ import SignalFfi /// Note that implementing this is **unchecked;** make sure you match up the types correctly! internal protocol PromiseStruct { associatedtype Result + + // We can't declare the 'complete' callback without an associated type, + // and that associated type won't get inferred for an imported struct (not sure why). + // So we'd have to write out the callback type for every conformer. + var context: UnsafeRawPointer! { get } + var cancellation_id: SignalCancellationId { get } } extension SignalCPromisebool: PromiseStruct { @@ -94,7 +101,7 @@ private class Completer: CompleterBase { /// /// This retains `self`, to be used as the C context pointer for the callback. /// You must ensure that either the callback is called, or the result is passed to - /// ``destroyUncompletedPromiseStruct(_:)``. + /// ``cleanUpUncompletedPromiseStruct(_:)``. func makePromiseStruct() -> Promise { typealias RawPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafeRawPointer?, _ context: UnsafeRawPointer?) -> Void let completeOpaque: RawPromiseCallback = { error, value, context in @@ -108,7 +115,7 @@ private class Completer: CompleterBase { // because of how `self.completeUnsafe` is initialized. // So first we build a promise struct---it doesn't matter which one---by reinterpreting the callback... typealias RawPointerPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafePointer?, _ context: UnsafeRawPointer?) -> Void - let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque()) + let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque(), cancellation_id: 0) // ...And then we reinterpret the entire struct, because all promise structs *also* have the same layout. // (Which we at least check a little bit here.) @@ -119,11 +126,8 @@ private class Completer: CompleterBase { return unsafeBitCast(rawPromiseStruct, to: Promise.self) } - func destroyUncompletedPromiseStruct(_ promiseStruct: Promise) { - // Double-check that all promise structs have the same layout, then reverse what we did above. - precondition(MemoryLayout.size == MemoryLayout.size) - let rawPromiseStruct = unsafeBitCast(promiseStruct, to: SignalCPromiseRawPointer.self) - Unmanaged.fromOpaque(rawPromiseStruct.context!).release() + func cleanUpUncompletedPromiseStruct(_ promiseStruct: Promise) { + Unmanaged.fromOpaque(promiseStruct.context!).release() } } @@ -136,8 +140,12 @@ private class Completer: CompleterBase { /// signal_do_async_work($0, someInput, someOtherInput) /// } /// ``` +/// +/// Prefer ``TokioAsyncContext/invokeAsyncFunction(_:)`` if using a TokioAsyncContext; +/// that method supports cancellation. internal func invokeAsyncFunction( - _ body: (UnsafeMutablePointer) -> SignalFfiErrorRef? + _ body: (UnsafeMutablePointer) -> SignalFfiErrorRef?, + saveCancellationId: (SignalCancellationId) -> Void = { _ in } ) async throws -> Promise.Result { try await withCheckedThrowingContinuation { continuation in let completer = Completer(continuation: continuation) @@ -145,9 +153,10 @@ internal func invokeAsyncFunction( let startResult = body(&promiseStruct) if let error = startResult { // Our completion callback is never going to get called, so we need to balance the `passRetained` above. - completer.destroyUncompletedPromiseStruct(promiseStruct) + completer.cleanUpUncompletedPromiseStruct(promiseStruct) completer.completeUnsafe(error, nil) return } + saveCancellationId(promiseStruct.cancellation_id) } } diff --git a/swift/Sources/LibSignalClient/ChatService.swift b/swift/Sources/LibSignalClient/ChatService.swift index 73fd4a24..70341e61 100644 --- a/swift/Sources/LibSignalClient/ChatService.swift +++ b/swift/Sources/LibSignalClient/ChatService.swift @@ -139,11 +139,9 @@ public class ChatService: NativeHandleOwner { /// Calling this method will result in starting to accept incoming requests from the Chat Service. @discardableResult public func connectAuthenticated() async throws -> DebugInfo { - let rawDebugInfo = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService) - } + let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService) } } return DebugInfo(consuming: rawDebugInfo) @@ -155,11 +153,9 @@ public class ChatService: NativeHandleOwner { /// reconnect attempt will be made. @discardableResult public func connectUnauthenticated() async throws -> DebugInfo { - let rawDebugInfo = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService) - } + let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService) } } return DebugInfo(consuming: rawDebugInfo) @@ -174,11 +170,9 @@ public class ChatService: NativeHandleOwner { /// /// Returns when the disconnection is complete. public func disconnect() async throws { - _ = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - signal_chat_service_disconnect(promise, tokioAsyncContext, chatService) - } + _ = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + signal_chat_service_disconnect(promise, tokioAsyncContext, chatService) } } } @@ -190,12 +184,10 @@ public class ChatService: NativeHandleOwner { public func unauthenticatedSend(_ request: Request) async throws -> Response { let internalRequest = try InternalRequest(request) let timeoutMillis = request.timeoutMillis - let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - internalRequest.withNativeHandle { request in - signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis) - } + let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + internalRequest.withNativeHandle { request in + signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis) } } } @@ -212,12 +204,10 @@ public class ChatService: NativeHandleOwner { public func unauthenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) { let internalRequest = try InternalRequest(request) let timeoutMillis = request.timeoutMillis - let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - internalRequest.withNativeHandle { request in - signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis) - } + let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + internalRequest.withNativeHandle { request in + signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis) } } } @@ -231,12 +221,10 @@ public class ChatService: NativeHandleOwner { public func authenticatedSend(_ request: Request) async throws -> Response { let internalRequest = try InternalRequest(request) let timeoutMillis = request.timeoutMillis - let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - internalRequest.withNativeHandle { request in - signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis) - } + let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + internalRequest.withNativeHandle { request in + signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis) } } } @@ -253,12 +241,10 @@ public class ChatService: NativeHandleOwner { public func authenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) { let internalRequest = try InternalRequest(request) let timeoutMillis = request.timeoutMillis - let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in - self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in - withNativeHandle { chatService in - internalRequest.withNativeHandle { request in - signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis) - } + let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in + withNativeHandle { chatService in + internalRequest.withNativeHandle { request in + signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis) } } } diff --git a/swift/Sources/LibSignalClient/Error.swift b/swift/Sources/LibSignalClient/Error.swift index 3f445ce5..c350032f 100644 --- a/swift/Sources/LibSignalClient/Error.swift +++ b/swift/Sources/LibSignalClient/Error.swift @@ -80,6 +80,9 @@ internal func checkError(_ error: SignalFfiErrorRef?) throws { defer { signal_error_free(error) } switch SignalErrorCode(errType) { + case SignalErrorCodeCancelled: + // Special case: don't use SignalError for this one. + throw CancellationError() case SignalErrorCodeInvalidState: throw SignalError.invalidState(errStr) case SignalErrorCodeInternalError: diff --git a/swift/Sources/LibSignalClient/Net.swift b/swift/Sources/LibSignalClient/Net.swift index a449f830..6d570213 100644 --- a/swift/Sources/LibSignalClient/Net.swift +++ b/swift/Sources/LibSignalClient/Net.swift @@ -108,12 +108,10 @@ public class Net { auth: Auth, request: CdsiLookupRequest ) async throws -> CdsiLookup { - let handle: OpaquePointer = try await invokeAsyncFunction { promise in - self.asyncContext.withNativeHandle { asyncContext in - self.connectionManager.withNativeHandle { connectionManager in - request.withNativeHandle { request in - signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request) - } + let handle: OpaquePointer = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in + self.connectionManager.withNativeHandle { connectionManager in + request.withNativeHandle { request in + signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request) } } } @@ -264,11 +262,9 @@ public class CdsiLookup { /// `SignalError.networkError` for a network-level connectivity issue, /// `SignalError.networkProtocolError` for a CDSI or attested connection protocol issue. public func complete() async throws -> CdsiLookupResponse { - let response: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in - self.asyncContext.withNativeHandle { asyncContext in - self.native.withNativeHandle { handle in - signal_cdsi_lookup_complete(promise, asyncContext, handle) - } + let response: SignalFfiCdsiLookupResponse = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in + self.native.withNativeHandle { handle in + signal_cdsi_lookup_complete(promise, asyncContext, handle) } } @@ -362,18 +358,6 @@ extension CdsiLookupResponseEntry: Equatable { } } -internal class TokioAsyncContext: NativeHandleOwner { - convenience init() { - var handle: OpaquePointer? - failOnError(signal_tokio_async_context_new(&handle)) - self.init(owned: handle!) - } - - override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? { - signal_tokio_async_context_destroy(handle) - } -} - internal class ConnectionManager: NativeHandleOwner { convenience init(env: Net.Environment, userAgent: String) { var handle: OpaquePointer? diff --git a/swift/Sources/LibSignalClient/Svr3.swift b/swift/Sources/LibSignalClient/Svr3.swift index b55ca0f4..2ccd30bf 100644 --- a/swift/Sources/LibSignalClient/Svr3.swift +++ b/swift/Sources/LibSignalClient/Svr3.swift @@ -93,21 +93,19 @@ public class Svr3Client { maxTries: UInt32, auth: Auth ) async throws -> [UInt8] { - let output = try await invokeAsyncFunction { promise in - self.asyncContext.withNativeHandle { asyncContext in - self.connectionManager.withNativeHandle { connectionManager in - secret.withUnsafeBorrowedBuffer { secretBuffer in - signal_svr3_backup( - promise, - asyncContext, - connectionManager, - secretBuffer, - password, - maxTries, - auth.username, - auth.password - ) - } + let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in + self.connectionManager.withNativeHandle { connectionManager in + secret.withUnsafeBorrowedBuffer { secretBuffer in + signal_svr3_backup( + promise, + asyncContext, + connectionManager, + secretBuffer, + password, + maxTries, + auth.username, + auth.password + ) } } } @@ -159,20 +157,18 @@ public class Svr3Client { shareSet: some ContiguousBytes, auth: Auth ) async throws -> [UInt8] { - let output = try await invokeAsyncFunction { promise in - self.asyncContext.withNativeHandle { asyncContext in - self.connectionManager.withNativeHandle { connectionManager in - shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in - signal_svr3_restore( - promise, - asyncContext, - connectionManager, - password, - shareSetBuffer, - auth.username, - auth.password - ) - } + let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in + self.connectionManager.withNativeHandle { connectionManager in + shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in + signal_svr3_restore( + promise, + asyncContext, + connectionManager, + password, + shareSetBuffer, + auth.username, + auth.password + ) } } } diff --git a/swift/Sources/LibSignalClient/TokioAsyncContext.swift b/swift/Sources/LibSignalClient/TokioAsyncContext.swift new file mode 100644 index 00000000..5d24e25b --- /dev/null +++ b/swift/Sources/LibSignalClient/TokioAsyncContext.swift @@ -0,0 +1,127 @@ +// +// Copyright 2024 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +import Foundation +import SignalFfi + +#if canImport(SignalCoreKit) +import SignalCoreKit +#endif + +internal class TokioAsyncContext: NativeHandleOwner { + convenience init() { + var handle: OpaquePointer? + failOnError(signal_tokio_async_context_new(&handle)) + self.init(owned: handle!) + } + + override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? { + signal_tokio_async_context_destroy(handle) + } + + /// A thread-safe helper for translating Swift task cancellations into calls to + /// `signal_tokio_async_context_cancel`. + private class CancellationHandoffHelper { + enum State { + case initial + case started(SignalCancellationId) + case cancelled + } + + // Emulates Rust's `Mutex` (and the containing class is providing an `Arc`) + // Unfortunately, doing this in Swift requires a separate allocation for the lock today. + var state: State = .initial + var lock = NSLock() + + let context: TokioAsyncContext + + init(context: TokioAsyncContext) { + self.context = context + } + + func setCancellationId(_ id: SignalCancellationId) { + // Ideally we would use NSLock.withLock here, but that's not available on Linux, + // which we still support for development and CI. + do { + self.lock.lock() + defer { self.lock.unlock() } + + switch self.state { + case .initial: + self.state = .started(id) + fallthrough + case .started(_): + return + case .cancelled: + break + } + } + + // If we didn't early-exit, we're already cancelled. + self.cancel(id) + } + + func cancel() { + let cancelId: SignalCancellationId + // Ideally we would use NSLock.withLock here, but that's not available on Linux, + // which we still support for development and CI. + do { + self.lock.lock() + defer { self.lock.unlock() } + + defer { state = .cancelled } + switch self.state { + case .started(let id): + cancelId = id + case .initial, .cancelled: + return + } + } + + // If we didn't early-exit, the task has already started and we need to cancel it. + self.cancel(cancelId) + } + + func cancel(_ id: SignalCancellationId) { + do { + try self.context.withNativeHandle { + try checkError(signal_tokio_async_context_cancel($0, id)) + } + } catch { +#if canImport(SignalCoreKit) + Logger.warn("failed to cancel libsignal task \(id): \(error)") +#else + NSLog("failed to cancel libsignal task %ld: %@", id, "\(error)") +#endif + } + } + } + + /// Provides a callback and context for calling Promise-based libsignal\_ffi functions, with cancellation supported. + /// + /// Example: + /// + /// ``` + /// let result = try await asyncContext.invokeAsyncFunction { promise, runtime in + /// signal_do_async_work(promise, runtime, someInput, someOtherInput) + /// } + /// ``` + internal func invokeAsyncFunction( + _ body: (UnsafeMutablePointer, OpaquePointer?) -> SignalFfiErrorRef? + ) async throws -> Promise.Result { + let cancellationHelper = CancellationHandoffHelper(context: self) + return try await withTaskCancellationHandler(operation: { + try await LibSignalClient.invokeAsyncFunction({ promise in + withNativeHandle { handle in + body(promise, handle) + } + }, saveCancellationId: { + cancellationHelper.setCancellationId($0) + }) + }, onCancel: { + cancellationHelper.cancel() + }) + } +} diff --git a/swift/Sources/SignalFfi/signal_ffi.h b/swift/Sources/SignalFfi/signal_ffi.h index 1b29cee4..66343e6d 100644 --- a/swift/Sources/SignalFfi/signal_ffi.h +++ b/swift/Sources/SignalFfi/signal_ffi.h @@ -147,6 +147,7 @@ typedef enum { SignalErrorCodeInvalidArgument = 5, SignalErrorCodeInvalidType = 6, SignalErrorCodeInvalidUtf8String = 7, + SignalErrorCodeCancelled = 8, SignalErrorCodeProtobufError = 10, SignalErrorCodeLegacyCiphertextVersion = 21, SignalErrorCodeUnknownCiphertextVersion = 22, @@ -493,6 +494,8 @@ typedef struct { size_t length; } SignalBorrowedSliceOfBuffers; +typedef uint64_t SignalCancellationId; + /** * A C callback used to report the results of Rust futures. * @@ -505,6 +508,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const SignalOwnedBuffer *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseOwnedBufferOfc_uchar; /** @@ -519,6 +523,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const bool *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromisebool; typedef struct { @@ -540,6 +545,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const SignalFfiChatServiceDebugInfo *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseFfiChatServiceDebugInfo; typedef struct { @@ -561,6 +567,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const SignalFfiChatResponse *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseFfiChatResponse; typedef struct { @@ -580,6 +587,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const SignalFfiResponseAndDebugInfo *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseFfiResponseAndDebugInfo; /** @@ -594,6 +602,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, SignalCdsiLookup *const *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseCdsiLookup; typedef struct { @@ -613,6 +622,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const SignalFfiCdsiLookupResponse *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseFfiCdsiLookupResponse; typedef SignalBytestringArray SignalStringArray; @@ -641,6 +651,7 @@ typedef SignalInputStream SignalSyncInputStream; typedef struct { void (*complete)(SignalFfiError *error, const int32_t *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromisei32; /** @@ -655,6 +666,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, SignalTestingHandleType *const *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseTestingHandleType; /** @@ -669,6 +681,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, SignalOtherTestingHandleType *const *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseOtherTestingHandleType; /** @@ -683,6 +696,7 @@ typedef struct { typedef struct { void (*complete)(SignalFfiError *error, const void *const *result, const void *context); const void *context; + SignalCancellationId cancellation_id; } SignalCPromiseRawPointer; typedef uint8_t SignalRandomnessBytes[SignalRANDOMNESS_LEN]; @@ -1527,6 +1541,8 @@ SignalFfiError *signal_tokio_async_context_destroy(SignalTokioAsyncContext *p); SignalFfiError *signal_tokio_async_context_new(SignalTokioAsyncContext **out); +SignalFfiError *signal_tokio_async_context_cancel(const SignalTokioAsyncContext *context, uint64_t raw_cancellation_id); + SignalFfiError *signal_pin_hash_destroy(SignalPinHash *p); SignalFfiError *signal_pin_hash_clone(SignalPinHash **new_obj, const SignalPinHash *obj); @@ -1685,6 +1701,8 @@ SignalFfiError *signal_testing_process_bytestring_array(SignalBytestringArray *o SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime); +SignalFfiError *signal_testing_only_completes_by_cancellation(SignalCPromisebool *promise, const SignalTokioAsyncContext *async_runtime); + SignalFfiError *signal_testing_cdsi_lookup_error_convert(const char *error_description); SignalFfiError *signal_testing_chat_service_error_convert(void); diff --git a/swift/Tests/LibSignalClientTests/AsyncTests.swift b/swift/Tests/LibSignalClientTests/AsyncTests.swift index f63b7ace..08daae2f 100644 --- a/swift/Tests/LibSignalClientTests/AsyncTests.swift +++ b/swift/Tests/LibSignalClientTests/AsyncTests.swift @@ -18,7 +18,7 @@ extension SignalCPromiseOtherTestingHandleType: PromiseStruct { public typealias Result = OpaquePointer } -final class AsyncTests: XCTestCase { +final class AsyncTests: TestCaseBase { func testSuccess() async throws { let result: Int32 = try await invokeAsyncFunction { signal_testing_future_success($0, OpaquePointer(bitPattern: -1), 21) @@ -65,6 +65,48 @@ final class AsyncTests: XCTestCase { ) } } + + func testTokioCancellation() async throws { + let asyncContext = TokioAsyncContext() + + // We can replace this with AsyncStream.makeStream(...) when we update our builder. + var _continuation: AsyncStream.Continuation! + let completionStream = AsyncStream { _continuation = $0 } + let continuation = _continuation! + + let makeTask = { (id: Int) in + Task { + defer { + // Do this unconditionally so that the outer test procedure doesn't get stuck. + continuation.yield(id) + } + do { + _ = try await asyncContext.invokeAsyncFunction { promise, asyncContext in + signal_testing_only_completes_by_cancellation(promise, asyncContext) + } + } catch is CancellationError { + // Okay, expected. + } catch { + XCTFail("incorrect error: \(error)") + } + } + } + let task1 = makeTask(1) + let task2 = makeTask(2) + + var completionIter = completionStream.makeAsyncIterator() + + // Complete the tasks in opposite order of starting them, + // to make it less likely to get this result by accident. + // This is not a rigorous test, only a simple exercise of the feature. + task2.cancel() + let firstCompletionId = await completionIter.next() + XCTAssertEqual(firstCompletionId, 2) + + task1.cancel() + let secondCompletionId = await completionIter.next() + XCTAssertEqual(secondCompletionId, 1) + } } #endif diff --git a/swift/Tests/LibSignalClientTests/NetTests.swift b/swift/Tests/LibSignalClientTests/NetTests.swift index 74de45a5..38595022 100644 --- a/swift/Tests/LibSignalClientTests/NetTests.swift +++ b/swift/Tests/LibSignalClientTests/NetTests.swift @@ -23,10 +23,8 @@ final class NetTests: XCTestCase { let asyncContext = TokioAsyncContext() - let output: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in - asyncContext.withNativeHandle { asyncContext in - signal_testing_cdsi_lookup_response_convert(promise, asyncContext) - } + let output: SignalFfiCdsiLookupResponse = try await asyncContext.invokeAsyncFunction { promise, asyncContext in + signal_testing_cdsi_lookup_response_convert(promise, asyncContext) } XCTAssertEqual(output.debug_permits_used, 123)