mirror of
https://github.com/signalapp/libsignal.git
synced 2024-09-19 19:42:19 +02:00
ffi: Expose cancellation to Swift
This commit is contained in:
parent
6d3c192208
commit
7dc63b99af
@ -635,6 +635,7 @@ public final class Native {
|
|||||||
public static native CompletableFuture<Integer> TESTING_FutureSuccess(long asyncRuntime, int input);
|
public static native CompletableFuture<Integer> TESTING_FutureSuccess(long asyncRuntime, int input);
|
||||||
public static native CompletableFuture<Void> TESTING_FutureThrowsCustomErrorType(long asyncRuntime);
|
public static native CompletableFuture<Void> TESTING_FutureThrowsCustomErrorType(long asyncRuntime);
|
||||||
public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle);
|
public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle);
|
||||||
|
public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime);
|
||||||
public static native String TESTING_OtherTestingHandleType_getValue(long handle);
|
public static native String TESTING_OtherTestingHandleType_getValue(long handle);
|
||||||
public static native void TESTING_PanicInBodyAsync(Object input);
|
public static native void TESTING_PanicInBodyAsync(Object input);
|
||||||
public static native CompletableFuture TESTING_PanicInBodyIo(long asyncRuntime, Object input);
|
public static native CompletableFuture TESTING_PanicInBodyIo(long asyncRuntime, Object input);
|
||||||
@ -655,6 +656,7 @@ public final class Native {
|
|||||||
public static native void TestingHandleType_Destroy(long handle);
|
public static native void TestingHandleType_Destroy(long handle);
|
||||||
|
|
||||||
public static native void TokioAsyncContext_Destroy(long handle);
|
public static native void TokioAsyncContext_Destroy(long handle);
|
||||||
|
public static native void TokioAsyncContext_cancel(long context, long rawCancellationId);
|
||||||
public static native long TokioAsyncContext_new();
|
public static native long TokioAsyncContext_new();
|
||||||
|
|
||||||
public static native long UnidentifiedSenderMessageContent_Deserialize(byte[] data) throws Exception;
|
public static native long UnidentifiedSenderMessageContent_Deserialize(byte[] data) throws Exception;
|
||||||
|
2
node/Native.d.ts
vendored
2
node/Native.d.ts
vendored
@ -490,6 +490,7 @@ export function TESTING_FutureProducesOtherPointerType(asyncRuntime: Wrapper<Non
|
|||||||
export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<TestingHandleType>;
|
export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<TestingHandleType>;
|
||||||
export function TESTING_FutureSuccess(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<number>;
|
export function TESTING_FutureSuccess(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<number>;
|
||||||
export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime;
|
export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime;
|
||||||
|
export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper<TokioAsyncContext>): Promise<void>;
|
||||||
export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper<OtherTestingHandleType>): string;
|
export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper<OtherTestingHandleType>): string;
|
||||||
export function TESTING_PanicInBodyAsync(_input: null): Promise<void>;
|
export function TESTING_PanicInBodyAsync(_input: null): Promise<void>;
|
||||||
export function TESTING_PanicInBodyIo(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, _input: null): Promise<void>;
|
export function TESTING_PanicInBodyIo(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, _input: null): Promise<void>;
|
||||||
@ -506,6 +507,7 @@ export function TESTING_PanicOnReturnSync(_needsCleanup: null): null;
|
|||||||
export function TESTING_ProcessBytestringArray(input: Buffer[]): Buffer[];
|
export function TESTING_ProcessBytestringArray(input: Buffer[]): Buffer[];
|
||||||
export function TESTING_ReturnStringArray(): string[];
|
export function TESTING_ReturnStringArray(): string[];
|
||||||
export function TESTING_TestingHandleType_getValue(handle: Wrapper<TestingHandleType>): number;
|
export function TESTING_TestingHandleType_getValue(handle: Wrapper<TestingHandleType>): number;
|
||||||
|
export function TokioAsyncContext_cancel(context: Wrapper<TokioAsyncContext>, rawCancellationId: bigint): void;
|
||||||
export function TokioAsyncContext_new(): TokioAsyncContext;
|
export function TokioAsyncContext_new(): TokioAsyncContext;
|
||||||
export function UnidentifiedSenderMessageContent_Deserialize(data: Buffer): UnidentifiedSenderMessageContent;
|
export function UnidentifiedSenderMessageContent_Deserialize(data: Buffer): UnidentifiedSenderMessageContent;
|
||||||
export function UnidentifiedSenderMessageContent_GetContentHint(m: Wrapper<UnidentifiedSenderMessageContent>): number;
|
export function UnidentifiedSenderMessageContent_GetContentHint(m: Wrapper<UnidentifiedSenderMessageContent>): number;
|
||||||
|
@ -64,6 +64,8 @@ renaming_overrides_prefixing = true
|
|||||||
"FfiOptionalServiceIdFixedWidthBinaryBytes" = "SignalOptionalServiceIdFixedWidthBinaryBytes"
|
"FfiOptionalServiceIdFixedWidthBinaryBytes" = "SignalOptionalServiceIdFixedWidthBinaryBytes"
|
||||||
"CPromisec_void" = "SignalCPromiseRawPointer"
|
"CPromisec_void" = "SignalCPromiseRawPointer"
|
||||||
|
|
||||||
|
"RawCancellationId" = "SignalCancellationId"
|
||||||
|
|
||||||
# Avoid double-prefixing these
|
# Avoid double-prefixing these
|
||||||
"SignalFfiError" = "SignalFfiError"
|
"SignalFfiError" = "SignalFfiError"
|
||||||
"SignalErrorCode" = "SignalErrorCode"
|
"SignalErrorCode" = "SignalErrorCode"
|
||||||
|
@ -25,6 +25,7 @@ pub enum SignalErrorCode {
|
|||||||
InvalidArgument = 5,
|
InvalidArgument = 5,
|
||||||
InvalidType = 6,
|
InvalidType = 6,
|
||||||
InvalidUtf8String = 7,
|
InvalidUtf8String = 7,
|
||||||
|
Cancelled = 8,
|
||||||
|
|
||||||
ProtobufError = 10,
|
ProtobufError = 10,
|
||||||
|
|
||||||
@ -105,6 +106,8 @@ impl From<&SignalFfiError> for SignalErrorCode {
|
|||||||
|
|
||||||
SignalFfiError::InvalidUtf8String => SignalErrorCode::InvalidUtf8String,
|
SignalFfiError::InvalidUtf8String => SignalErrorCode::InvalidUtf8String,
|
||||||
|
|
||||||
|
SignalFfiError::Cancelled => SignalErrorCode::Cancelled,
|
||||||
|
|
||||||
SignalFfiError::Signal(SignalProtocolError::InvalidProtobufEncoding) => {
|
SignalFfiError::Signal(SignalProtocolError::InvalidProtobufEncoding) => {
|
||||||
SignalErrorCode::ProtobufError
|
SignalErrorCode::ProtobufError
|
||||||
}
|
}
|
||||||
|
@ -167,10 +167,16 @@ fn bridge_io_body(
|
|||||||
|__cancel| async move {
|
|__cancel| async move {
|
||||||
let __future = ffi::catch_unwind(std::panic::AssertUnwindSafe(async move {
|
let __future = ffi::catch_unwind(std::panic::AssertUnwindSafe(async move {
|
||||||
#(#input_loading)*
|
#(#input_loading)*
|
||||||
let __result = #orig_name(#(#input_names),*).await;
|
::tokio::select! {
|
||||||
// If the original function can't fail, wrap the result in Ok for uniformity.
|
__result = #orig_name(#(#input_names),*) => {
|
||||||
// See TransformHelper::ok_if_needed.
|
// If the original function can't fail, wrap the result in Ok for uniformity.
|
||||||
Ok(TransformHelper(__result).ok_if_needed()?.0)
|
// See TransformHelper::ok_if_needed.
|
||||||
|
Ok(TransformHelper(__result).ok_if_needed()?.0)
|
||||||
|
}
|
||||||
|
_ = __cancel => {
|
||||||
|
Err(ffi::SignalFfiError::Cancelled)
|
||||||
|
}
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
ffi::FutureResultReporter::new(__future.await)
|
ffi::FutureResultReporter::new(__future.await)
|
||||||
}
|
}
|
||||||
|
@ -54,6 +54,7 @@ pub enum SignalFfiError {
|
|||||||
NullPointer,
|
NullPointer,
|
||||||
InvalidUtf8String,
|
InvalidUtf8String,
|
||||||
InvalidArgument(String),
|
InvalidArgument(String),
|
||||||
|
Cancelled,
|
||||||
InternalError(String),
|
InternalError(String),
|
||||||
UnexpectedPanic(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
|
UnexpectedPanic(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
|
||||||
}
|
}
|
||||||
@ -102,6 +103,7 @@ impl fmt::Display for SignalFfiError {
|
|||||||
SignalFfiError::NullPointer => write!(f, "null pointer"),
|
SignalFfiError::NullPointer => write!(f, "null pointer"),
|
||||||
SignalFfiError::InvalidUtf8String => write!(f, "invalid UTF8 string"),
|
SignalFfiError::InvalidUtf8String => write!(f, "invalid UTF8 string"),
|
||||||
SignalFfiError::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
|
SignalFfiError::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
|
||||||
|
SignalFfiError::Cancelled => write!(f, "cancelled"),
|
||||||
SignalFfiError::InternalError(msg) => write!(f, "internal error: {msg}"),
|
SignalFfiError::InternalError(msg) => write!(f, "internal error: {msg}"),
|
||||||
SignalFfiError::UnexpectedPanic(e) => {
|
SignalFfiError::UnexpectedPanic(e) => {
|
||||||
write!(f, "unexpected panic: {}", describe_panic(e))
|
write!(f, "unexpected panic: {}", describe_panic(e))
|
||||||
|
@ -10,6 +10,8 @@ use futures_util::{FutureExt, TryFutureExt};
|
|||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
|
||||||
|
pub type RawCancellationId = u64;
|
||||||
|
|
||||||
/// A C callback used to report the results of Rust futures.
|
/// A C callback used to report the results of Rust futures.
|
||||||
///
|
///
|
||||||
/// cbindgen will produce independent C types like `SignalCPromisei32` and
|
/// cbindgen will produce independent C types like `SignalCPromisei32` and
|
||||||
@ -26,6 +28,7 @@ pub struct CPromise<T> {
|
|||||||
context: *const std::ffi::c_void,
|
context: *const std::ffi::c_void,
|
||||||
),
|
),
|
||||||
context: *const std::ffi::c_void,
|
context: *const std::ffi::c_void,
|
||||||
|
cancellation_id: RawCancellationId,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Keeps track of the information necessary to report a promise result back to C.
|
/// Keeps track of the information necessary to report a promise result back to C.
|
||||||
@ -121,7 +124,8 @@ pub fn run_future_on_runtime<R, F, O>(
|
|||||||
O: ResultTypeInfo + 'static,
|
O: ResultTypeInfo + 'static,
|
||||||
{
|
{
|
||||||
let completion = PromiseCompleter { promise: *promise };
|
let completion = PromiseCompleter { promise: *promise };
|
||||||
runtime.run_future(future, completion);
|
let cancellation_id = runtime.run_future(future, completion);
|
||||||
|
promise.cancellation_id = cancellation_id.into();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Catches panics that occur in `future` and converts them to [`SignalFfiError::UnexpectedPanic`].
|
/// Catches panics that occur in `future` and converts them to [`SignalFfiError::UnexpectedPanic`].
|
||||||
|
@ -38,6 +38,11 @@ fn TokioAsyncContext_new() -> TokioAsyncContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[bridge_fn]
|
||||||
|
fn TokioAsyncContext_cancel(context: &TokioAsyncContext, raw_cancellation_id: u64) {
|
||||||
|
context.cancel(raw_cancellation_id.into())
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TokioContextCancellation(tokio::sync::oneshot::Receiver<()>);
|
pub struct TokioContextCancellation(tokio::sync::oneshot::Receiver<()>);
|
||||||
|
|
||||||
impl Future for TokioContextCancellation {
|
impl Future for TokioContextCancellation {
|
||||||
@ -52,6 +57,12 @@ impl Future for TokioContextCancellation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Not ideal! tokio doesn't promise that a oneshot::Receiver is in fact panic-safe.
|
||||||
|
// But its interior mutable state is only modified by the Receiver while it's being polled,
|
||||||
|
// and that means a panic would have to happen inside Receiver itself to cause a problem.
|
||||||
|
// Combined with our payload type being (), it's unlikely this can happen in practice.
|
||||||
|
impl std::panic::UnwindSafe for TokioContextCancellation {}
|
||||||
|
|
||||||
impl AsyncRuntimeBase for TokioAsyncContext {
|
impl AsyncRuntimeBase for TokioAsyncContext {
|
||||||
fn cancel(&self, cancellation_token: CancellationId) {
|
fn cancel(&self, cancellation_token: CancellationId) {
|
||||||
if cancellation_token == CancellationId::NotSupported {
|
if cancellation_token == CancellationId::NotSupported {
|
||||||
|
@ -49,6 +49,11 @@ async fn TESTING_CdsiLookupResponseConvert() -> LookupResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[bridge_io(TokioAsyncContext)]
|
||||||
|
async fn TESTING_OnlyCompletesByCancellation() {
|
||||||
|
std::future::pending::<()>().await
|
||||||
|
}
|
||||||
|
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
#[derive(Copy, Clone, strum::EnumString)]
|
#[derive(Copy, Clone, strum::EnumString)]
|
||||||
enum TestingCdsiLookupError {
|
enum TestingCdsiLookupError {
|
||||||
|
@ -21,5 +21,7 @@ file_header:
|
|||||||
inclusive_language:
|
inclusive_language:
|
||||||
override_allowed_terms:
|
override_allowed_terms:
|
||||||
- master
|
- master
|
||||||
|
nesting:
|
||||||
|
type_level: 2
|
||||||
excluded:
|
excluded:
|
||||||
- .build/**
|
- .build/**
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
// SPDX-License-Identifier: AGPL-3.0-only
|
// SPDX-License-Identifier: AGPL-3.0-only
|
||||||
//
|
//
|
||||||
|
|
||||||
|
import Foundation
|
||||||
import SignalFfi
|
import SignalFfi
|
||||||
|
|
||||||
/// Used to check types for values produced asynchronously by Rust.
|
/// Used to check types for values produced asynchronously by Rust.
|
||||||
@ -15,6 +16,12 @@ import SignalFfi
|
|||||||
/// Note that implementing this is **unchecked;** make sure you match up the types correctly!
|
/// Note that implementing this is **unchecked;** make sure you match up the types correctly!
|
||||||
internal protocol PromiseStruct {
|
internal protocol PromiseStruct {
|
||||||
associatedtype Result
|
associatedtype Result
|
||||||
|
|
||||||
|
// We can't declare the 'complete' callback without an associated type,
|
||||||
|
// and that associated type won't get inferred for an imported struct (not sure why).
|
||||||
|
// So we'd have to write out the callback type for every conformer.
|
||||||
|
var context: UnsafeRawPointer! { get }
|
||||||
|
var cancellation_id: SignalCancellationId { get }
|
||||||
}
|
}
|
||||||
|
|
||||||
extension SignalCPromisebool: PromiseStruct {
|
extension SignalCPromisebool: PromiseStruct {
|
||||||
@ -94,7 +101,7 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
|||||||
///
|
///
|
||||||
/// This retains `self`, to be used as the C context pointer for the callback.
|
/// This retains `self`, to be used as the C context pointer for the callback.
|
||||||
/// You must ensure that either the callback is called, or the result is passed to
|
/// You must ensure that either the callback is called, or the result is passed to
|
||||||
/// ``destroyUncompletedPromiseStruct(_:)``.
|
/// ``cleanUpUncompletedPromiseStruct(_:)``.
|
||||||
func makePromiseStruct() -> Promise {
|
func makePromiseStruct() -> Promise {
|
||||||
typealias RawPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafeRawPointer?, _ context: UnsafeRawPointer?) -> Void
|
typealias RawPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafeRawPointer?, _ context: UnsafeRawPointer?) -> Void
|
||||||
let completeOpaque: RawPromiseCallback = { error, value, context in
|
let completeOpaque: RawPromiseCallback = { error, value, context in
|
||||||
@ -108,7 +115,7 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
|||||||
// because of how `self.completeUnsafe` is initialized.
|
// because of how `self.completeUnsafe` is initialized.
|
||||||
// So first we build a promise struct---it doesn't matter which one---by reinterpreting the callback...
|
// So first we build a promise struct---it doesn't matter which one---by reinterpreting the callback...
|
||||||
typealias RawPointerPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafePointer<UnsafeRawPointer?>?, _ context: UnsafeRawPointer?) -> Void
|
typealias RawPointerPromiseCallback = @convention(c) (_ error: SignalFfiErrorRef?, _ value: UnsafePointer<UnsafeRawPointer?>?, _ context: UnsafeRawPointer?) -> Void
|
||||||
let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque())
|
let rawPromiseStruct = SignalCPromiseRawPointer(complete: unsafeBitCast(completeOpaque, to: RawPointerPromiseCallback.self), context: Unmanaged.passRetained(self).toOpaque(), cancellation_id: 0)
|
||||||
|
|
||||||
// ...And then we reinterpret the entire struct, because all promise structs *also* have the same layout.
|
// ...And then we reinterpret the entire struct, because all promise structs *also* have the same layout.
|
||||||
// (Which we at least check a little bit here.)
|
// (Which we at least check a little bit here.)
|
||||||
@ -119,11 +126,8 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
|||||||
return unsafeBitCast(rawPromiseStruct, to: Promise.self)
|
return unsafeBitCast(rawPromiseStruct, to: Promise.self)
|
||||||
}
|
}
|
||||||
|
|
||||||
func destroyUncompletedPromiseStruct(_ promiseStruct: Promise) {
|
func cleanUpUncompletedPromiseStruct(_ promiseStruct: Promise) {
|
||||||
// Double-check that all promise structs have the same layout, then reverse what we did above.
|
Unmanaged<CompleterBase>.fromOpaque(promiseStruct.context!).release()
|
||||||
precondition(MemoryLayout<SignalCPromiseRawPointer>.size == MemoryLayout<Promise>.size)
|
|
||||||
let rawPromiseStruct = unsafeBitCast(promiseStruct, to: SignalCPromiseRawPointer.self)
|
|
||||||
Unmanaged<CompleterBase>.fromOpaque(rawPromiseStruct.context!).release()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,8 +140,12 @@ private class Completer<Promise: PromiseStruct>: CompleterBase {
|
|||||||
/// signal_do_async_work($0, someInput, someOtherInput)
|
/// signal_do_async_work($0, someInput, someOtherInput)
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
|
///
|
||||||
|
/// Prefer ``TokioAsyncContext/invokeAsyncFunction(_:)`` if using a TokioAsyncContext;
|
||||||
|
/// that method supports cancellation.
|
||||||
internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
||||||
_ body: (UnsafeMutablePointer<Promise>) -> SignalFfiErrorRef?
|
_ body: (UnsafeMutablePointer<Promise>) -> SignalFfiErrorRef?,
|
||||||
|
saveCancellationId: (SignalCancellationId) -> Void = { _ in }
|
||||||
) async throws -> Promise.Result {
|
) async throws -> Promise.Result {
|
||||||
try await withCheckedThrowingContinuation { continuation in
|
try await withCheckedThrowingContinuation { continuation in
|
||||||
let completer = Completer<Promise>(continuation: continuation)
|
let completer = Completer<Promise>(continuation: continuation)
|
||||||
@ -145,9 +153,10 @@ internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
|||||||
let startResult = body(&promiseStruct)
|
let startResult = body(&promiseStruct)
|
||||||
if let error = startResult {
|
if let error = startResult {
|
||||||
// Our completion callback is never going to get called, so we need to balance the `passRetained` above.
|
// Our completion callback is never going to get called, so we need to balance the `passRetained` above.
|
||||||
completer.destroyUncompletedPromiseStruct(promiseStruct)
|
completer.cleanUpUncompletedPromiseStruct(promiseStruct)
|
||||||
completer.completeUnsafe(error, nil)
|
completer.completeUnsafe(error, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
saveCancellationId(promiseStruct.cancellation_id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -139,11 +139,9 @@ public class ChatService: NativeHandleOwner {
|
|||||||
/// Calling this method will result in starting to accept incoming requests from the Chat Service.
|
/// Calling this method will result in starting to accept incoming requests from the Chat Service.
|
||||||
@discardableResult
|
@discardableResult
|
||||||
public func connectAuthenticated() async throws -> DebugInfo {
|
public func connectAuthenticated() async throws -> DebugInfo {
|
||||||
let rawDebugInfo = try await invokeAsyncFunction { promise in
|
let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService)
|
||||||
signal_chat_service_connect_auth(promise, tokioAsyncContext, chatService)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return DebugInfo(consuming: rawDebugInfo)
|
return DebugInfo(consuming: rawDebugInfo)
|
||||||
@ -155,11 +153,9 @@ public class ChatService: NativeHandleOwner {
|
|||||||
/// reconnect attempt will be made.
|
/// reconnect attempt will be made.
|
||||||
@discardableResult
|
@discardableResult
|
||||||
public func connectUnauthenticated() async throws -> DebugInfo {
|
public func connectUnauthenticated() async throws -> DebugInfo {
|
||||||
let rawDebugInfo = try await invokeAsyncFunction { promise in
|
let rawDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService)
|
||||||
signal_chat_service_connect_unauth(promise, tokioAsyncContext, chatService)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return DebugInfo(consuming: rawDebugInfo)
|
return DebugInfo(consuming: rawDebugInfo)
|
||||||
@ -174,11 +170,9 @@ public class ChatService: NativeHandleOwner {
|
|||||||
///
|
///
|
||||||
/// Returns when the disconnection is complete.
|
/// Returns when the disconnection is complete.
|
||||||
public func disconnect() async throws {
|
public func disconnect() async throws {
|
||||||
_ = try await invokeAsyncFunction { promise in
|
_ = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
signal_chat_service_disconnect(promise, tokioAsyncContext, chatService)
|
||||||
signal_chat_service_disconnect(promise, tokioAsyncContext, chatService)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,12 +184,10 @@ public class ChatService: NativeHandleOwner {
|
|||||||
public func unauthenticatedSend(_ request: Request) async throws -> Response {
|
public func unauthenticatedSend(_ request: Request) async throws -> Response {
|
||||||
let internalRequest = try InternalRequest(request)
|
let internalRequest = try InternalRequest(request)
|
||||||
let timeoutMillis = request.timeoutMillis
|
let timeoutMillis = request.timeoutMillis
|
||||||
let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in
|
let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
internalRequest.withNativeHandle { request in
|
||||||
internalRequest.withNativeHandle { request in
|
signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||||
signal_chat_service_unauth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -212,12 +204,10 @@ public class ChatService: NativeHandleOwner {
|
|||||||
public func unauthenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
|
public func unauthenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
|
||||||
let internalRequest = try InternalRequest(request)
|
let internalRequest = try InternalRequest(request)
|
||||||
let timeoutMillis = request.timeoutMillis
|
let timeoutMillis = request.timeoutMillis
|
||||||
let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in
|
let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
internalRequest.withNativeHandle { request in
|
||||||
internalRequest.withNativeHandle { request in
|
signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||||
signal_chat_service_unauth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,12 +221,10 @@ public class ChatService: NativeHandleOwner {
|
|||||||
public func authenticatedSend(_ request: Request) async throws -> Response {
|
public func authenticatedSend(_ request: Request) async throws -> Response {
|
||||||
let internalRequest = try InternalRequest(request)
|
let internalRequest = try InternalRequest(request)
|
||||||
let timeoutMillis = request.timeoutMillis
|
let timeoutMillis = request.timeoutMillis
|
||||||
let rawResponse: SignalFfiChatResponse = try await invokeAsyncFunction { promise in
|
let rawResponse: SignalFfiChatResponse = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
internalRequest.withNativeHandle { request in
|
||||||
internalRequest.withNativeHandle { request in
|
signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||||
signal_chat_service_auth_send(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,12 +241,10 @@ public class ChatService: NativeHandleOwner {
|
|||||||
public func authenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
|
public func authenticatedSendAndDebug(_ request: Request) async throws -> (Response, DebugInfo) {
|
||||||
let internalRequest = try InternalRequest(request)
|
let internalRequest = try InternalRequest(request)
|
||||||
let timeoutMillis = request.timeoutMillis
|
let timeoutMillis = request.timeoutMillis
|
||||||
let rawResponse: SignalFfiResponseAndDebugInfo = try await invokeAsyncFunction { promise in
|
let rawResponse: SignalFfiResponseAndDebugInfo = try await self.tokioAsyncContext.invokeAsyncFunction { promise, tokioAsyncContext in
|
||||||
self.tokioAsyncContext.withNativeHandle { tokioAsyncContext in
|
withNativeHandle { chatService in
|
||||||
withNativeHandle { chatService in
|
internalRequest.withNativeHandle { request in
|
||||||
internalRequest.withNativeHandle { request in
|
signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
||||||
signal_chat_service_auth_send_and_debug(promise, tokioAsyncContext, chatService, request, timeoutMillis)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,6 +80,9 @@ internal func checkError(_ error: SignalFfiErrorRef?) throws {
|
|||||||
defer { signal_error_free(error) }
|
defer { signal_error_free(error) }
|
||||||
|
|
||||||
switch SignalErrorCode(errType) {
|
switch SignalErrorCode(errType) {
|
||||||
|
case SignalErrorCodeCancelled:
|
||||||
|
// Special case: don't use SignalError for this one.
|
||||||
|
throw CancellationError()
|
||||||
case SignalErrorCodeInvalidState:
|
case SignalErrorCodeInvalidState:
|
||||||
throw SignalError.invalidState(errStr)
|
throw SignalError.invalidState(errStr)
|
||||||
case SignalErrorCodeInternalError:
|
case SignalErrorCodeInternalError:
|
||||||
|
@ -108,12 +108,10 @@ public class Net {
|
|||||||
auth: Auth,
|
auth: Auth,
|
||||||
request: CdsiLookupRequest
|
request: CdsiLookupRequest
|
||||||
) async throws -> CdsiLookup {
|
) async throws -> CdsiLookup {
|
||||||
let handle: OpaquePointer = try await invokeAsyncFunction { promise in
|
let handle: OpaquePointer = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||||
self.asyncContext.withNativeHandle { asyncContext in
|
self.connectionManager.withNativeHandle { connectionManager in
|
||||||
self.connectionManager.withNativeHandle { connectionManager in
|
request.withNativeHandle { request in
|
||||||
request.withNativeHandle { request in
|
signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request)
|
||||||
signal_cdsi_lookup_new(promise, asyncContext, connectionManager, auth.username, auth.password, request)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -264,11 +262,9 @@ public class CdsiLookup {
|
|||||||
/// `SignalError.networkError` for a network-level connectivity issue,
|
/// `SignalError.networkError` for a network-level connectivity issue,
|
||||||
/// `SignalError.networkProtocolError` for a CDSI or attested connection protocol issue.
|
/// `SignalError.networkProtocolError` for a CDSI or attested connection protocol issue.
|
||||||
public func complete() async throws -> CdsiLookupResponse {
|
public func complete() async throws -> CdsiLookupResponse {
|
||||||
let response: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in
|
let response: SignalFfiCdsiLookupResponse = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||||
self.asyncContext.withNativeHandle { asyncContext in
|
self.native.withNativeHandle { handle in
|
||||||
self.native.withNativeHandle { handle in
|
signal_cdsi_lookup_complete(promise, asyncContext, handle)
|
||||||
signal_cdsi_lookup_complete(promise, asyncContext, handle)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,18 +358,6 @@ extension CdsiLookupResponseEntry: Equatable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class TokioAsyncContext: NativeHandleOwner {
|
|
||||||
convenience init() {
|
|
||||||
var handle: OpaquePointer?
|
|
||||||
failOnError(signal_tokio_async_context_new(&handle))
|
|
||||||
self.init(owned: handle!)
|
|
||||||
}
|
|
||||||
|
|
||||||
override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? {
|
|
||||||
signal_tokio_async_context_destroy(handle)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ConnectionManager: NativeHandleOwner {
|
internal class ConnectionManager: NativeHandleOwner {
|
||||||
convenience init(env: Net.Environment, userAgent: String) {
|
convenience init(env: Net.Environment, userAgent: String) {
|
||||||
var handle: OpaquePointer?
|
var handle: OpaquePointer?
|
||||||
|
@ -93,21 +93,19 @@ public class Svr3Client {
|
|||||||
maxTries: UInt32,
|
maxTries: UInt32,
|
||||||
auth: Auth
|
auth: Auth
|
||||||
) async throws -> [UInt8] {
|
) async throws -> [UInt8] {
|
||||||
let output = try await invokeAsyncFunction { promise in
|
let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||||
self.asyncContext.withNativeHandle { asyncContext in
|
self.connectionManager.withNativeHandle { connectionManager in
|
||||||
self.connectionManager.withNativeHandle { connectionManager in
|
secret.withUnsafeBorrowedBuffer { secretBuffer in
|
||||||
secret.withUnsafeBorrowedBuffer { secretBuffer in
|
signal_svr3_backup(
|
||||||
signal_svr3_backup(
|
promise,
|
||||||
promise,
|
asyncContext,
|
||||||
asyncContext,
|
connectionManager,
|
||||||
connectionManager,
|
secretBuffer,
|
||||||
secretBuffer,
|
password,
|
||||||
password,
|
maxTries,
|
||||||
maxTries,
|
auth.username,
|
||||||
auth.username,
|
auth.password
|
||||||
auth.password
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -159,20 +157,18 @@ public class Svr3Client {
|
|||||||
shareSet: some ContiguousBytes,
|
shareSet: some ContiguousBytes,
|
||||||
auth: Auth
|
auth: Auth
|
||||||
) async throws -> [UInt8] {
|
) async throws -> [UInt8] {
|
||||||
let output = try await invokeAsyncFunction { promise in
|
let output = try await self.asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||||
self.asyncContext.withNativeHandle { asyncContext in
|
self.connectionManager.withNativeHandle { connectionManager in
|
||||||
self.connectionManager.withNativeHandle { connectionManager in
|
shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in
|
||||||
shareSet.withUnsafeBorrowedBuffer { shareSetBuffer in
|
signal_svr3_restore(
|
||||||
signal_svr3_restore(
|
promise,
|
||||||
promise,
|
asyncContext,
|
||||||
asyncContext,
|
connectionManager,
|
||||||
connectionManager,
|
password,
|
||||||
password,
|
shareSetBuffer,
|
||||||
shareSetBuffer,
|
auth.username,
|
||||||
auth.username,
|
auth.password
|
||||||
auth.password
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
127
swift/Sources/LibSignalClient/TokioAsyncContext.swift
Normal file
127
swift/Sources/LibSignalClient/TokioAsyncContext.swift
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
//
|
||||||
|
// Copyright 2024 Signal Messenger, LLC.
|
||||||
|
// SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
//
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import SignalFfi
|
||||||
|
|
||||||
|
#if canImport(SignalCoreKit)
|
||||||
|
import SignalCoreKit
|
||||||
|
#endif
|
||||||
|
|
||||||
|
internal class TokioAsyncContext: NativeHandleOwner {
|
||||||
|
convenience init() {
|
||||||
|
var handle: OpaquePointer?
|
||||||
|
failOnError(signal_tokio_async_context_new(&handle))
|
||||||
|
self.init(owned: handle!)
|
||||||
|
}
|
||||||
|
|
||||||
|
override internal class func destroyNativeHandle(_ handle: OpaquePointer) -> SignalFfiErrorRef? {
|
||||||
|
signal_tokio_async_context_destroy(handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A thread-safe helper for translating Swift task cancellations into calls to
|
||||||
|
/// `signal_tokio_async_context_cancel`.
|
||||||
|
private class CancellationHandoffHelper {
|
||||||
|
enum State {
|
||||||
|
case initial
|
||||||
|
case started(SignalCancellationId)
|
||||||
|
case cancelled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emulates Rust's `Mutex<State>` (and the containing class is providing an `Arc`)
|
||||||
|
// Unfortunately, doing this in Swift requires a separate allocation for the lock today.
|
||||||
|
var state: State = .initial
|
||||||
|
var lock = NSLock()
|
||||||
|
|
||||||
|
let context: TokioAsyncContext
|
||||||
|
|
||||||
|
init(context: TokioAsyncContext) {
|
||||||
|
self.context = context
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCancellationId(_ id: SignalCancellationId) {
|
||||||
|
// Ideally we would use NSLock.withLock here, but that's not available on Linux,
|
||||||
|
// which we still support for development and CI.
|
||||||
|
do {
|
||||||
|
self.lock.lock()
|
||||||
|
defer { self.lock.unlock() }
|
||||||
|
|
||||||
|
switch self.state {
|
||||||
|
case .initial:
|
||||||
|
self.state = .started(id)
|
||||||
|
fallthrough
|
||||||
|
case .started(_):
|
||||||
|
return
|
||||||
|
case .cancelled:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we didn't early-exit, we're already cancelled.
|
||||||
|
self.cancel(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cancel() {
|
||||||
|
let cancelId: SignalCancellationId
|
||||||
|
// Ideally we would use NSLock.withLock here, but that's not available on Linux,
|
||||||
|
// which we still support for development and CI.
|
||||||
|
do {
|
||||||
|
self.lock.lock()
|
||||||
|
defer { self.lock.unlock() }
|
||||||
|
|
||||||
|
defer { state = .cancelled }
|
||||||
|
switch self.state {
|
||||||
|
case .started(let id):
|
||||||
|
cancelId = id
|
||||||
|
case .initial, .cancelled:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we didn't early-exit, the task has already started and we need to cancel it.
|
||||||
|
self.cancel(cancelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cancel(_ id: SignalCancellationId) {
|
||||||
|
do {
|
||||||
|
try self.context.withNativeHandle {
|
||||||
|
try checkError(signal_tokio_async_context_cancel($0, id))
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
#if canImport(SignalCoreKit)
|
||||||
|
Logger.warn("failed to cancel libsignal task \(id): \(error)")
|
||||||
|
#else
|
||||||
|
NSLog("failed to cancel libsignal task %ld: %@", id, "\(error)")
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides a callback and context for calling Promise-based libsignal\_ffi functions, with cancellation supported.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// let result = try await asyncContext.invokeAsyncFunction { promise, runtime in
|
||||||
|
/// signal_do_async_work(promise, runtime, someInput, someOtherInput)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
internal func invokeAsyncFunction<Promise: PromiseStruct>(
|
||||||
|
_ body: (UnsafeMutablePointer<Promise>, OpaquePointer?) -> SignalFfiErrorRef?
|
||||||
|
) async throws -> Promise.Result {
|
||||||
|
let cancellationHelper = CancellationHandoffHelper(context: self)
|
||||||
|
return try await withTaskCancellationHandler(operation: {
|
||||||
|
try await LibSignalClient.invokeAsyncFunction({ promise in
|
||||||
|
withNativeHandle { handle in
|
||||||
|
body(promise, handle)
|
||||||
|
}
|
||||||
|
}, saveCancellationId: {
|
||||||
|
cancellationHelper.setCancellationId($0)
|
||||||
|
})
|
||||||
|
}, onCancel: {
|
||||||
|
cancellationHelper.cancel()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -147,6 +147,7 @@ typedef enum {
|
|||||||
SignalErrorCodeInvalidArgument = 5,
|
SignalErrorCodeInvalidArgument = 5,
|
||||||
SignalErrorCodeInvalidType = 6,
|
SignalErrorCodeInvalidType = 6,
|
||||||
SignalErrorCodeInvalidUtf8String = 7,
|
SignalErrorCodeInvalidUtf8String = 7,
|
||||||
|
SignalErrorCodeCancelled = 8,
|
||||||
SignalErrorCodeProtobufError = 10,
|
SignalErrorCodeProtobufError = 10,
|
||||||
SignalErrorCodeLegacyCiphertextVersion = 21,
|
SignalErrorCodeLegacyCiphertextVersion = 21,
|
||||||
SignalErrorCodeUnknownCiphertextVersion = 22,
|
SignalErrorCodeUnknownCiphertextVersion = 22,
|
||||||
@ -493,6 +494,8 @@ typedef struct {
|
|||||||
size_t length;
|
size_t length;
|
||||||
} SignalBorrowedSliceOfBuffers;
|
} SignalBorrowedSliceOfBuffers;
|
||||||
|
|
||||||
|
typedef uint64_t SignalCancellationId;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A C callback used to report the results of Rust futures.
|
* A C callback used to report the results of Rust futures.
|
||||||
*
|
*
|
||||||
@ -505,6 +508,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const SignalOwnedBuffer *result, const void *context);
|
void (*complete)(SignalFfiError *error, const SignalOwnedBuffer *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseOwnedBufferOfc_uchar;
|
} SignalCPromiseOwnedBufferOfc_uchar;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -519,6 +523,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const bool *result, const void *context);
|
void (*complete)(SignalFfiError *error, const bool *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromisebool;
|
} SignalCPromisebool;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -540,6 +545,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const SignalFfiChatServiceDebugInfo *result, const void *context);
|
void (*complete)(SignalFfiError *error, const SignalFfiChatServiceDebugInfo *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseFfiChatServiceDebugInfo;
|
} SignalCPromiseFfiChatServiceDebugInfo;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -561,6 +567,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const SignalFfiChatResponse *result, const void *context);
|
void (*complete)(SignalFfiError *error, const SignalFfiChatResponse *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseFfiChatResponse;
|
} SignalCPromiseFfiChatResponse;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -580,6 +587,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const SignalFfiResponseAndDebugInfo *result, const void *context);
|
void (*complete)(SignalFfiError *error, const SignalFfiResponseAndDebugInfo *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseFfiResponseAndDebugInfo;
|
} SignalCPromiseFfiResponseAndDebugInfo;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -594,6 +602,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, SignalCdsiLookup *const *result, const void *context);
|
void (*complete)(SignalFfiError *error, SignalCdsiLookup *const *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseCdsiLookup;
|
} SignalCPromiseCdsiLookup;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -613,6 +622,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const SignalFfiCdsiLookupResponse *result, const void *context);
|
void (*complete)(SignalFfiError *error, const SignalFfiCdsiLookupResponse *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseFfiCdsiLookupResponse;
|
} SignalCPromiseFfiCdsiLookupResponse;
|
||||||
|
|
||||||
typedef SignalBytestringArray SignalStringArray;
|
typedef SignalBytestringArray SignalStringArray;
|
||||||
@ -641,6 +651,7 @@ typedef SignalInputStream SignalSyncInputStream;
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const int32_t *result, const void *context);
|
void (*complete)(SignalFfiError *error, const int32_t *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromisei32;
|
} SignalCPromisei32;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -655,6 +666,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, SignalTestingHandleType *const *result, const void *context);
|
void (*complete)(SignalFfiError *error, SignalTestingHandleType *const *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseTestingHandleType;
|
} SignalCPromiseTestingHandleType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -669,6 +681,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, SignalOtherTestingHandleType *const *result, const void *context);
|
void (*complete)(SignalFfiError *error, SignalOtherTestingHandleType *const *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseOtherTestingHandleType;
|
} SignalCPromiseOtherTestingHandleType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -683,6 +696,7 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
void (*complete)(SignalFfiError *error, const void *const *result, const void *context);
|
void (*complete)(SignalFfiError *error, const void *const *result, const void *context);
|
||||||
const void *context;
|
const void *context;
|
||||||
|
SignalCancellationId cancellation_id;
|
||||||
} SignalCPromiseRawPointer;
|
} SignalCPromiseRawPointer;
|
||||||
|
|
||||||
typedef uint8_t SignalRandomnessBytes[SignalRANDOMNESS_LEN];
|
typedef uint8_t SignalRandomnessBytes[SignalRANDOMNESS_LEN];
|
||||||
@ -1527,6 +1541,8 @@ SignalFfiError *signal_tokio_async_context_destroy(SignalTokioAsyncContext *p);
|
|||||||
|
|
||||||
SignalFfiError *signal_tokio_async_context_new(SignalTokioAsyncContext **out);
|
SignalFfiError *signal_tokio_async_context_new(SignalTokioAsyncContext **out);
|
||||||
|
|
||||||
|
SignalFfiError *signal_tokio_async_context_cancel(const SignalTokioAsyncContext *context, uint64_t raw_cancellation_id);
|
||||||
|
|
||||||
SignalFfiError *signal_pin_hash_destroy(SignalPinHash *p);
|
SignalFfiError *signal_pin_hash_destroy(SignalPinHash *p);
|
||||||
|
|
||||||
SignalFfiError *signal_pin_hash_clone(SignalPinHash **new_obj, const SignalPinHash *obj);
|
SignalFfiError *signal_pin_hash_clone(SignalPinHash **new_obj, const SignalPinHash *obj);
|
||||||
@ -1685,6 +1701,8 @@ SignalFfiError *signal_testing_process_bytestring_array(SignalBytestringArray *o
|
|||||||
|
|
||||||
SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime);
|
SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime);
|
||||||
|
|
||||||
|
SignalFfiError *signal_testing_only_completes_by_cancellation(SignalCPromisebool *promise, const SignalTokioAsyncContext *async_runtime);
|
||||||
|
|
||||||
SignalFfiError *signal_testing_cdsi_lookup_error_convert(const char *error_description);
|
SignalFfiError *signal_testing_cdsi_lookup_error_convert(const char *error_description);
|
||||||
|
|
||||||
SignalFfiError *signal_testing_chat_service_error_convert(void);
|
SignalFfiError *signal_testing_chat_service_error_convert(void);
|
||||||
|
@ -18,7 +18,7 @@ extension SignalCPromiseOtherTestingHandleType: PromiseStruct {
|
|||||||
public typealias Result = OpaquePointer
|
public typealias Result = OpaquePointer
|
||||||
}
|
}
|
||||||
|
|
||||||
final class AsyncTests: XCTestCase {
|
final class AsyncTests: TestCaseBase {
|
||||||
func testSuccess() async throws {
|
func testSuccess() async throws {
|
||||||
let result: Int32 = try await invokeAsyncFunction {
|
let result: Int32 = try await invokeAsyncFunction {
|
||||||
signal_testing_future_success($0, OpaquePointer(bitPattern: -1), 21)
|
signal_testing_future_success($0, OpaquePointer(bitPattern: -1), 21)
|
||||||
@ -65,6 +65,48 @@ final class AsyncTests: XCTestCase {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testTokioCancellation() async throws {
|
||||||
|
let asyncContext = TokioAsyncContext()
|
||||||
|
|
||||||
|
// We can replace this with AsyncStream.makeStream(...) when we update our builder.
|
||||||
|
var _continuation: AsyncStream<Int>.Continuation!
|
||||||
|
let completionStream = AsyncStream<Int> { _continuation = $0 }
|
||||||
|
let continuation = _continuation!
|
||||||
|
|
||||||
|
let makeTask = { (id: Int) in
|
||||||
|
Task {
|
||||||
|
defer {
|
||||||
|
// Do this unconditionally so that the outer test procedure doesn't get stuck.
|
||||||
|
continuation.yield(id)
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
_ = try await asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||||
|
signal_testing_only_completes_by_cancellation(promise, asyncContext)
|
||||||
|
}
|
||||||
|
} catch is CancellationError {
|
||||||
|
// Okay, expected.
|
||||||
|
} catch {
|
||||||
|
XCTFail("incorrect error: \(error)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let task1 = makeTask(1)
|
||||||
|
let task2 = makeTask(2)
|
||||||
|
|
||||||
|
var completionIter = completionStream.makeAsyncIterator()
|
||||||
|
|
||||||
|
// Complete the tasks in opposite order of starting them,
|
||||||
|
// to make it less likely to get this result by accident.
|
||||||
|
// This is not a rigorous test, only a simple exercise of the feature.
|
||||||
|
task2.cancel()
|
||||||
|
let firstCompletionId = await completionIter.next()
|
||||||
|
XCTAssertEqual(firstCompletionId, 2)
|
||||||
|
|
||||||
|
task1.cancel()
|
||||||
|
let secondCompletionId = await completionIter.next()
|
||||||
|
XCTAssertEqual(secondCompletionId, 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -23,10 +23,8 @@ final class NetTests: XCTestCase {
|
|||||||
|
|
||||||
let asyncContext = TokioAsyncContext()
|
let asyncContext = TokioAsyncContext()
|
||||||
|
|
||||||
let output: SignalFfiCdsiLookupResponse = try await invokeAsyncFunction { promise in
|
let output: SignalFfiCdsiLookupResponse = try await asyncContext.invokeAsyncFunction { promise, asyncContext in
|
||||||
asyncContext.withNativeHandle { asyncContext in
|
signal_testing_cdsi_lookup_response_convert(promise, asyncContext)
|
||||||
signal_testing_cdsi_lookup_response_convert(promise, asyncContext)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
XCTAssertEqual(output.debug_permits_used, 123)
|
XCTAssertEqual(output.debug_permits_used, 123)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user