0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-19 19:42:19 +02:00

ffi: Expose cancellation to Swift

This commit is contained in:
Jordan Rose 2024-05-03 09:52:29 -07:00
parent 6d3c192208
commit 7dc63b99af
19 changed files with 312 additions and 110 deletions

View File

@ -635,6 +635,7 @@ public final class Native {
public static native CompletableFuture<Integer> TESTING_FutureSuccess(long asyncRuntime, int input);
public static native CompletableFuture<Void> TESTING_FutureThrowsCustomErrorType(long asyncRuntime);
public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle);
public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime);
public static native String TESTING_OtherTestingHandleType_getValue(long handle);
public static native void TESTING_PanicInBodyAsync(Object input);
public static native CompletableFuture TESTING_PanicInBodyIo(long asyncRuntime, Object input);
@ -655,6 +656,7 @@ public final class Native {
public static native void TestingHandleType_Destroy(long handle);
public static native void TokioAsyncContext_Destroy(long handle);
public static native void TokioAsyncContext_cancel(long context, long rawCancellationId);
public static native long TokioAsyncContext_new();
public static native long UnidentifiedSenderMessageContent_Deserialize(byte[] data) throws Exception;

2
node/Native.d.ts vendored
View File

@ -490,6 +490,7 @@ export function TESTING_FutureProducesOtherPointerType(asyncRuntime: Wrapper<Non
export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<TestingHandleType>;
export function TESTING_FutureSuccess(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<number>;
export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime;
export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper<TokioAsyncContext>): Promise<void>;
export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper<OtherTestingHandleType>): string;
export function TESTING_PanicInBodyAsync(_input: null): Promise<void>;
export function TESTING_PanicInBodyIo(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, _input: null): Promise<void>;
@ -506,6 +507,7 @@ export function TESTING_PanicOnReturnSync(_needsCleanup: null): null;
export function TESTING_ProcessBytestringArray(input: Buffer[]): Buffer[];
export function TESTING_ReturnStringArray(): string[];
export function TESTING_TestingHandleType_getValue(handle: Wrapper<TestingHandleType>): number;
export function TokioAsyncContext_cancel(context: Wrapper<TokioAsyncContext>, rawCancellationId: bigint): void;
export function TokioAsyncContext_new(): TokioAsyncContext;
export function UnidentifiedSenderMessageContent_Deserialize(data: Buffer): UnidentifiedSenderMessageContent;
export function UnidentifiedSenderMessageContent_GetContentHint(m: Wrapper<UnidentifiedSenderMessageContent>): number;

View File

@ -64,6 +64,8 @@ renaming_overrides_prefixing = true
"FfiOptionalServiceIdFixedWidthBinaryBytes" = "SignalOptionalServiceIdFixedWidthBinaryBytes"
"CPromisec_void" = "SignalCPromiseRawPointer"
"RawCancellationId" = "SignalCancellationId"
# Avoid double-prefixing these
"SignalFfiError" = "SignalFfiError"
"SignalErrorCode" = "SignalErrorCode"

View File

@ -25,6 +25,7 @@ pub enum SignalErrorCode {
InvalidArgument = 5,
InvalidType = 6,
InvalidUtf8String = 7,
Cancelled = 8,
ProtobufError = 10,
@ -105,6 +106,8 @@ impl From<&SignalFfiError> for SignalErrorCode {
SignalFfiError::InvalidUtf8String => SignalErrorCode::InvalidUtf8String,
SignalFfiError::Cancelled => SignalErrorCode::Cancelled,
SignalFfiError::Signal(SignalProtocolError::InvalidProtobufEncoding) => {
SignalErrorCode::ProtobufError
}

View File

@ -167,10 +167,16 @@ fn bridge_io_body(
|__cancel| async move {
let __future = ffi::catch_unwind(std::panic::AssertUnwindSafe(async move {
#(#input_loading)*
let __result = #orig_name(#(#input_names),*).await;
// If the original function can't fail, wrap the result in Ok for uniformity.
// See TransformHelper::ok_if_needed.
Ok(TransformHelper(__result).ok_if_needed()?.0)
::tokio::select! {
__result = #orig_name(#(#input_names),*) => {
// If the original function can't fail, wrap the result in Ok for uniformity.
// See TransformHelper::ok_if_needed.
Ok(TransformHelper(__result).ok_if_needed()?.0)
}
_ = __cancel => {
Err(ffi::SignalFfiError::Cancelled)
}
}
}));
ffi::FutureResultReporter::new(__future.await)
}

View File

@ -54,6 +54,7 @@ pub enum SignalFfiError {
NullPointer,
InvalidUtf8String,
InvalidArgument(String),
Cancelled,
InternalError(String),
UnexpectedPanic(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
}
@ -102,6 +103,7 @@ impl fmt::Display for SignalFfiError {
SignalFfiError::NullPointer => write!(f, "null pointer"),
SignalFfiError::InvalidUtf8String => write!(f, "invalid UTF8 string"),
SignalFfiError::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
SignalFfiError::Cancelled => write!(f, "cancelled"),
SignalFfiError::InternalError(msg) => write!(f, "internal error: {msg}"),
SignalFfiError::UnexpectedPanic(e) => {
write!(f, "unexpected panic: {}", describe_panic(e))

View File

@ -10,6 +10,8 @@ use futures_util::{FutureExt, TryFutureExt};
use std::future::Future;
pub type RawCancellationId = u64;
/// A C callback used to report the results of Rust futures.
///
/// cbindgen will produce independent C types like `SignalCPromisei32` and
@ -26,6 +28,7 @@ pub struct CPromise<T> {
context: *const std::ffi::c_void,
),
context: *const std::ffi::c_void,
cancellation_id: RawCancellationId,
}
/// Keeps track of the information necessary to report a promise result back to C.
@ -121,7 +124,8 @@ pub fn run_future_on_runtime<R, F, O>(
O: ResultTypeInfo + 'static,
{
let completion = PromiseCompleter { promise: *promise };
runtime.run_future(future, completion);
let cancellation_id = runtime.run_future(future, completion);
promise.cancellation_id = cancellation_id.into();
}
/// Catches panics that occur in `future` and converts them to [`SignalFfiError::UnexpectedPanic`].

View File

@ -38,6 +38,11 @@ fn TokioAsyncContext_new() -> TokioAsyncContext {
}
}
#[bridge_fn]
fn TokioAsyncContext_cancel(context: &TokioAsyncContext, raw_cancellation_id: u64) {
context.cancel(raw_cancellation_id.into())
}
pub struct TokioContextCancellation(tokio::sync::oneshot::Receiver<()>);
impl Future for TokioContextCancellation {
@ -52,6 +57,12 @@ impl Future for TokioContextCancellation {
}
}
// Not ideal! tokio doesn't promise that a oneshot::Receiver is in fact panic-safe.
// But its interior mutable state is only modified by the Receiver while it's being polled,
// and that means a panic would have to happen inside Receiver itself to cause a problem.
// Combined with our payload type being (), it's unlikely this can happen in practice.
impl std::panic::UnwindSafe for TokioContextCancellation {}
impl AsyncRuntimeBase for TokioAsyncContext {
fn cancel(&self, cancellation_token: CancellationId) {
if cancellation_token == CancellationId::NotSupported {

View File

@ -49,6 +49,11 @@ async fn TESTING_CdsiLookupResponseConvert() -> LookupResponse {
}
}
#[bridge_io(TokioAsyncContext)]
async fn TESTING_OnlyCompletesByCancellation() {
std::future::pending::<()>().await
}
#[repr(u8)]
#[derive(Copy, Clone, strum::EnumString)]
enum TestingCdsiLookupError {

View File

@ -21,5 +21,7 @@ file_header:
inclusive_language:
override_allowed_terms:
- master
nesting:
type_level: 2
excluded:
- .build/**

View File

@ -3,6 +3,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import SignalFfi
/// Used to check types for values produced asynchronously by Rust.
@ -15,6 +16,12 @@ import SignalFfi
/// Note that implementing this is **unchecked;** make sure you match up the types correctly!
internal protocol PromiseStruct {
associatedtype Result
// We can't declare the 'complete' callback without an associated type,
// and that associated type won't get inferred for an imported struct (not sure why).
// So we'd have to write out the callback type for every conformer.
var context: UnsafeRawPointer! { get }
var cancellation_id: SignalCancellationId { get }
}
extension SignalCPromisebool: PromiseStruct {
@ -94,7 +101,7 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
///
/// This retains `self`, to be used as the C context pointer for the callback.
/// You must ensure that either the callback is called, or the result is passed to
/// ``destroyUncompletedPromiseStruct(_:)``.
/// ``cleanUpUncompletedPromiseStruct(_:)``.
func makePromiseStruct() -> Promise {
typealias RawPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafeRawPointer?, _ context: UnsafeRawPointer?) -> Void
let completeOpaque: RawPromiseCallback = { error, value, context in
@ -108,7 +115,7 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
// because of how `self.completeUnsafe` is initialized.
// So first we build a promise struct---it doesn't matter which one---by reinterpreting the callback...
typealias RawPointerPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafePointer<UnsafeRawPointer?>?, _ context: UnsafeRawPointer?) -> Void
let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque())
let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque(), cancellation_id: 0)
// ...And then we reinterpret the entire struct, because all promise structs *also* have the same layout.
// (Which we at least check a little bit here.)
@ -119,11 +126,8 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
return unsafeBitCast(rawPromiseStruct, to: Promise.self)
}
func destroyUncompletedPromiseStruct(_ promiseStruct: Promise) {
// Double-check that all promise structs have the same layout, then reverse what we did above.
precondition(MemoryLayout<SignalCPromiseRawPointer>.size == MemoryLayout<Promise>.size)
let rawPromiseStruct = unsafeBitCast(promiseStruct, to: SignalCPromiseRawPointer.self)
Unmanaged<CompleterBase>.fromOpaque(rawPromiseStruct.context!).release()
func cleanUpUncompletedPromiseStruct(_ promiseStruct: Promise) {
Unmanaged<CompleterBase>.fromOpaque(promiseStruct.context!).release()
}
}
@ -136,8 +140,12 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
/// signal_do_async_work($0, someInput, someOtherInput)
/// }
/// ```
///
/// Prefer ``TokioAsyncContext/invokeAsyncFunction(_:)`` if using a TokioAsyncContext;
/// that method supports cancellation.
internal func invokeAsyncFunction<Promise: PromiseStruct>(
_ body: (UnsafeMutablePointer<Promise>) -> SignalFfiErrorRef?
_ body: (UnsafeMutablePointer<Promise>) -> SignalFfiErrorRef?,
saveCancellationId: (SignalCancellationId) -> Void = { _ in }
) async throws -> Promise.Result {
try await withCheckedThrowingContinuation { continuation in
let completer = Completer<Promise>(continuation: continuation)
@ -145,9 +153,10 @@ internal func invokeAsyncFunction<Promise: PromiseStruct>(
let startResult = body(&promiseStruct)
if let error = startResult {
// Our completion callback is never going to get called, so we need to balance the `passRetained` above.
completer.destroyUncompletedPromiseStruct(promiseStruct)
completer.cleanUpUncompletedPromiseStruct(promiseStruct)
completer.completeUnsafe(error, nil)
return
}
saveCancellationId(promiseStruct.cancellation_id)
}
}

View File

@ -139,11 +139,9 @@ public class ChatService: NativeHandleOwner {
/// Calling this method will result in starting to accept incoming requests from the Chat Service.
@discardableResult
public func connectAuthenticated() async throws -> DebugInfo {
let rawDebugInfo = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService)
}
let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService)
}
}
return DebugInfo(consuming: rawDebugInfo)
@ -155,11 +153,9 @@ public class ChatService: NativeHandleOwner {
/// reconnect attempt will be made.
@discardableResult
public func connectUnauthenticated() async throws -> DebugInfo {
let rawDebugInfo = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService)
}
let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService)
}
}
return DebugInfo(consuming: rawDebugInfo)
@ -174,11 +170,9 @@ public class ChatService: NativeHandleOwner {
///
/// Returns when the disconnection is complete.
public func disconnect() async throws {
_ = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
signal_chat_service_disconnect(promise, tokioAsyncContext, chatService)
}
_ = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
signal_chat_service_disconnect(promise, tokioAsyncContext, chatService)
}
}
}
@ -190,12 +184,10 @@ public class ChatService: NativeHandleOwner {
public func unauthenticatedSend(_ request: Request) async throws -> Response {
let internalRequest = try InternalRequest(request)
let timeoutMillis = request.timeoutMillis
let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
}
}
@ -212,12 +204,10 @@ public class ChatService: NativeHandleOwner {
public func unauthenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
let internalRequest = try InternalRequest(request)
let timeoutMillis = request.timeoutMillis
let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
}
}
@ -231,12 +221,10 @@ public class ChatService: NativeHandleOwner {
public func authenticatedSend(_ request: Request) async throws -> Response {
let internalRequest = try InternalRequest(request)
let timeoutMillis = request.timeoutMillis
let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
}
}
@ -253,12 +241,10 @@ public class ChatService: NativeHandleOwner {
public func authenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
let internalRequest = try InternalRequest(request)
let timeoutMillis = request.timeoutMillis
let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
withNativeHandle { chatService in
internalRequest.withNativeHandle { request in
signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
}
}
}

View File

@ -80,6 +80,9 @@ internal func checkError(_ error: SignalFfiErrorRef?) throws {
defer { signal_error_free(error) }
switch SignalErrorCode(errType) {
case SignalErrorCodeCancelled:
// Special case: don't use SignalError for this one.
throw CancellationError()
case SignalErrorCodeInvalidState:
throw SignalError.invalidState(errStr)
case SignalErrorCodeInternalError:

View File

@ -108,12 +108,10 @@ public class Net {
auth: Auth,
request: CdsiLookupRequest
) async throws -> CdsiLookup {
let handle: OpaquePointer = try await invokeAsyncFunction { promise in
self.asyncContext.withNativeHandle { asyncContext in
self.connectionManager.withNativeHandle { connectionManager in
request.withNativeHandle { request in
signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request)
}
let handle: OpaquePointer = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
self.connectionManager.withNativeHandle { connectionManager in
request.withNativeHandle { request in
signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request)
}
}
}
@ -264,11 +262,9 @@ public class CdsiLookup {
/// `SignalError.networkError` for a network-level connectivity issue,
/// `SignalError.networkProtocolError` for a CDSI or attested connection protocol issue.
public func complete() async throws -> CdsiLookupResponse {
let response: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in
self.asyncContext.withNativeHandle { asyncContext in
self.native.withNativeHandle { handle in
signal_cdsi_lookup_complete(promise, asyncContext, handle)
}
let response: SignalFfiCdsiLookupResponse = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
self.native.withNativeHandle { handle in
signal_cdsi_lookup_complete(promise, asyncContext, handle)
}
}
@ -362,18 +358,6 @@ extension CdsiLookupResponseEntry: Equatable {
}
}
internal class TokioAsyncContext: NativeHandleOwner {
convenience init() {
var handle: OpaquePointer?
failOnError(signal_tokio_async_context_new(&handle))
self.init(owned: handle!)
}
override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? {
signal_tokio_async_context_destroy(handle)
}
}
internal class ConnectionManager: NativeHandleOwner {
convenience init(env: Net.Environment, userAgent: String) {
var handle: OpaquePointer?

View File

@ -93,21 +93,19 @@ public class Svr3Client {
maxTries: UInt32,
auth: Auth
) async throws -> [UInt8] {
let output = try await invokeAsyncFunction { promise in
self.asyncContext.withNativeHandle { asyncContext in
self.connectionManager.withNativeHandle { connectionManager in
secret.withUnsafeBorrowedBuffer { secretBuffer in
signal_svr3_backup(
promise,
asyncContext,
connectionManager,
secretBuffer,
password,
maxTries,
auth.username,
auth.password
)
}
let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
self.connectionManager.withNativeHandle { connectionManager in
secret.withUnsafeBorrowedBuffer { secretBuffer in
signal_svr3_backup(
promise,
asyncContext,
connectionManager,
secretBuffer,
password,
maxTries,
auth.username,
auth.password
)
}
}
}
@ -159,20 +157,18 @@ public class Svr3Client {
shareSet: some ContiguousBytes,
auth: Auth
) async throws -> [UInt8] {
let output = try await invokeAsyncFunction { promise in
self.asyncContext.withNativeHandle { asyncContext in
self.connectionManager.withNativeHandle { connectionManager in
shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in
signal_svr3_restore(
promise,
asyncContext,
connectionManager,
password,
shareSetBuffer,
auth.username,
auth.password
)
}
let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
self.connectionManager.withNativeHandle { connectionManager in
shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in
signal_svr3_restore(
promise,
asyncContext,
connectionManager,
password,
shareSetBuffer,
auth.username,
auth.password
)
}
}
}

View File

@ -0,0 +1,127 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import SignalFfi
#if canImport(SignalCoreKit)
import SignalCoreKit
#endif
internal class TokioAsyncContext: NativeHandleOwner {
convenience init() {
var handle: OpaquePointer?
failOnError(signal_tokio_async_context_new(&handle))
self.init(owned: handle!)
}
override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? {
signal_tokio_async_context_destroy(handle)
}
/// A thread-safe helper for translating Swift task cancellations into calls to
/// `signal_tokio_async_context_cancel`.
private class CancellationHandoffHelper {
enum State {
case initial
case started(SignalCancellationId)
case cancelled
}
// Emulates Rust's `Mutex<State>` (and the containing class is providing an `Arc`)
// Unfortunately, doing this in Swift requires a separate allocation for the lock today.
var state: State = .initial
var lock = NSLock()
let context: TokioAsyncContext
init(context: TokioAsyncContext) {
self.context = context
}
func setCancellationId(_ id: SignalCancellationId) {
// Ideally we would use NSLock.withLock here, but that's not available on Linux,
// which we still support for development and CI.
do {
self.lock.lock()
defer { self.lock.unlock() }
switch self.state {
case .initial:
self.state = .started(id)
fallthrough
case .started(_):
return
case .cancelled:
break
}
}
// If we didn't early-exit, we're already cancelled.
self.cancel(id)
}
func cancel() {
let cancelId: SignalCancellationId
// Ideally we would use NSLock.withLock here, but that's not available on Linux,
// which we still support for development and CI.
do {
self.lock.lock()
defer { self.lock.unlock() }
defer { state = .cancelled }
switch self.state {
case .started(let id):
cancelId = id
case .initial, .cancelled:
return
}
}
// If we didn't early-exit, the task has already started and we need to cancel it.
self.cancel(cancelId)
}
func cancel(_ id: SignalCancellationId) {
do {
try self.context.withNativeHandle {
try checkError(signal_tokio_async_context_cancel($0, id))
}
} catch {
#if canImport(SignalCoreKit)
Logger.warn("failed to cancel libsignal task \(id): \(error)")
#else
NSLog("failed to cancel libsignal task %ld: %@", id, "\(error)")
#endif
}
}
}
/// Provides a callback and context for calling Promise-based libsignal\_ffi functions, with cancellation supported.
///
/// Example:
///
/// ```
/// let result = try await asyncContext.invokeAsyncFunction { promise, runtime in
/// signal_do_async_work(promise, runtime, someInput, someOtherInput)
/// }
/// ```
internal func invokeAsyncFunction<Promise: PromiseStruct>(
_ body: (UnsafeMutablePointer<Promise>, OpaquePointer?) -> SignalFfiErrorRef?
) async throws -> Promise.Result {
let cancellationHelper = CancellationHandoffHelper(context: self)
return try await withTaskCancellationHandler(operation: {
try await LibSignalClient.invokeAsyncFunction({ promise in
withNativeHandle { handle in
body(promise, handle)
}
}, saveCancellationId: {
cancellationHelper.setCancellationId($0)
})
}, onCancel: {
cancellationHelper.cancel()
})
}
}

View File

@ -147,6 +147,7 @@ typedef enum {
SignalErrorCodeInvalidArgument = 5,
SignalErrorCodeInvalidType = 6,
SignalErrorCodeInvalidUtf8String = 7,
SignalErrorCodeCancelled = 8,
SignalErrorCodeProtobufError = 10,
SignalErrorCodeLegacyCiphertextVersion = 21,
SignalErrorCodeUnknownCiphertextVersion = 22,
@ -493,6 +494,8 @@ typedef struct {
size_t length;
} SignalBorrowedSliceOfBuffers;
typedef uint64_t SignalCancellationId;
/**
* A C callback used to report the results of Rust futures.
*
@ -505,6 +508,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const SignalOwnedBuffer *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseOwnedBufferOfc_uchar;
/**
@ -519,6 +523,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const bool *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromisebool;
typedef struct {
@ -540,6 +545,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const SignalFfiChatServiceDebugInfo *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseFfiChatServiceDebugInfo;
typedef struct {
@ -561,6 +567,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const SignalFfiChatResponse *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseFfiChatResponse;
typedef struct {
@ -580,6 +587,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const SignalFfiResponseAndDebugInfo *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseFfiResponseAndDebugInfo;
/**
@ -594,6 +602,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, SignalCdsiLookup *const *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseCdsiLookup;
typedef struct {
@ -613,6 +622,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const SignalFfiCdsiLookupResponse *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseFfiCdsiLookupResponse;
typedef SignalBytestringArray SignalStringArray;
@ -641,6 +651,7 @@ typedef SignalInputStream SignalSyncInputStream;
typedef struct {
void (*complete)(SignalFfiError *error, const int32_t *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromisei32;
/**
@ -655,6 +666,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, SignalTestingHandleType *const *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseTestingHandleType;
/**
@ -669,6 +681,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, SignalOtherTestingHandleType *const *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseOtherTestingHandleType;
/**
@ -683,6 +696,7 @@ typedef struct {
typedef struct {
void (*complete)(SignalFfiError *error, const void *const *result, const void *context);
const void *context;
SignalCancellationId cancellation_id;
} SignalCPromiseRawPointer;
typedef uint8_t SignalRandomnessBytes[SignalRANDOMNESS_LEN];
@ -1527,6 +1541,8 @@ SignalFfiError *signal_tokio_async_context_destroy(SignalTokioAsyncContext *p);
SignalFfiError *signal_tokio_async_context_new(SignalTokioAsyncContext **out);
SignalFfiError *signal_tokio_async_context_cancel(const SignalTokioAsyncContext *context, uint64_t raw_cancellation_id);
SignalFfiError *signal_pin_hash_destroy(SignalPinHash *p);
SignalFfiError *signal_pin_hash_clone(SignalPinHash **new_obj, const SignalPinHash *obj);
@ -1685,6 +1701,8 @@ SignalFfiError *signal_testing_process_bytestring_array(SignalBytestringArray *o
SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime);
SignalFfiError *signal_testing_only_completes_by_cancellation(SignalCPromisebool *promise, const SignalTokioAsyncContext *async_runtime);
SignalFfiError *signal_testing_cdsi_lookup_error_convert(const char *error_description);
SignalFfiError *signal_testing_chat_service_error_convert(void);

View File

@ -18,7 +18,7 @@ extension SignalCPromiseOtherTestingHandleType: PromiseStruct {
public typealias Result = OpaquePointer
}
final class AsyncTests: XCTestCase {
final class AsyncTests: TestCaseBase {
func testSuccess() async throws {
let result: Int32 = try await invokeAsyncFunction {
signal_testing_future_success($0, OpaquePointer(bitPattern: -1), 21)
@ -65,6 +65,48 @@ final class AsyncTests: XCTestCase {
)
}
}
func testTokioCancellation() async throws {
let asyncContext = TokioAsyncContext()
// We can replace this with AsyncStream.makeStream(...) when we update our builder.
var _continuation: AsyncStream<Int>.Continuation!
let completionStream = AsyncStream<Int> { _continuation = $0 }
let continuation = _continuation!
let makeTask = { (id: Int) in
Task {
defer {
// Do this unconditionally so that the outer test procedure doesn't get stuck.
continuation.yield(id)
}
do {
_ = try await asyncContext.invokeAsyncFunction { promise, asyncContext in
signal_testing_only_completes_by_cancellation(promise, asyncContext)
}
} catch is CancellationError {
// Okay, expected.
} catch {
XCTFail("incorrect error: \(error)")
}
}
}
let task1 = makeTask(1)
let task2 = makeTask(2)
var completionIter = completionStream.makeAsyncIterator()
// Complete the tasks in opposite order of starting them,
// to make it less likely to get this result by accident.
// This is not a rigorous test, only a simple exercise of the feature.
task2.cancel()
let firstCompletionId = await completionIter.next()
XCTAssertEqual(firstCompletionId, 2)
task1.cancel()
let secondCompletionId = await completionIter.next()
XCTAssertEqual(secondCompletionId, 1)
}
}
#endif

View File

@ -23,10 +23,8 @@ final class NetTests: XCTestCase {
let asyncContext = TokioAsyncContext()
let output: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in
asyncContext.withNativeHandle { asyncContext in
signal_testing_cdsi_lookup_response_convert(promise, asyncContext)
}
let output: SignalFfiCdsiLookupResponse = try await asyncContext.invokeAsyncFunction { promise, asyncContext in
signal_testing_cdsi_lookup_response_convert(promise, asyncContext)
}
XCTAssertEqual(output.debug_permits_used, 123)