mirror of
https://github.com/signalapp/libsignal.git
synced 2024-09-20 03:52:17 +02:00
java: add async class load method
Add a method to allow Java code to attempt to load a class on a Tokio worker thread like libsignal does internally. This will be used for testing both in libsignal and in dependents. Fix a bug where exceptions raised during conversion from Rust result values to Java values weren't being correctly propagated to the Java Future that would report the result.
This commit is contained in:
parent
f5c2c50047
commit
6edd0540fb
@ -76,6 +76,33 @@ public class Network {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to load several libsignal classes asynchronously, using the same mechanism as native (Rust)
|
||||
* code.
|
||||
*
|
||||
* <p>This should only be called in tests, and can be used to ensure at test time that libsignal
|
||||
* async code won't fail to load exceptions.
|
||||
*/
|
||||
public static void checkClassesCanBeLoadedAsyncForTest() {
|
||||
// This doesn't need to be comprehensive, just check a few classes.
|
||||
final String[] classesToLoad = {
|
||||
"org.signal.libsignal.net.CdsiLookupResponse$Entry",
|
||||
"org.signal.libsignal.net.NetworkException",
|
||||
"org.signal.libsignal.net.ChatServiceException",
|
||||
"org.signal.libsignal.protocol.ServiceId",
|
||||
};
|
||||
TokioAsyncContext context = new TokioAsyncContext();
|
||||
|
||||
for (String className : classesToLoad) {
|
||||
// No need to do anything with the result; if it doesn't throw, it succeeded.
|
||||
try {
|
||||
context.loadClassAsync(className).get();
|
||||
} catch (ExecutionException | InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TokioAsyncContext getAsyncContext() {
|
||||
return this.tokioAsyncContext;
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
package org.signal.libsignal.net;
|
||||
|
||||
import org.signal.libsignal.internal.CompletableFuture;
|
||||
import org.signal.libsignal.internal.Native;
|
||||
import org.signal.libsignal.internal.NativeHandleGuard;
|
||||
|
||||
@ -13,6 +14,12 @@ class TokioAsyncContext extends NativeHandleGuard.SimpleOwner {
|
||||
super(Native.TokioAsyncContext_new());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
CompletableFuture<Class<Object>> loadClassAsync(String className) {
|
||||
className = className.replace('.', '/');
|
||||
return (CompletableFuture<Class<Object>>) Native.AsyncLoadClass(this, className);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void release(final long nativeHandle) {
|
||||
Native.TokioAsyncContext_Destroy(nativeHandle);
|
||||
|
@ -0,0 +1,67 @@
|
||||
//
|
||||
// Copyright 2024 Signal Messenger, LLC.
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
//
|
||||
|
||||
package org.signal.libsignal.net;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TokioAsyncContextTest {
|
||||
@Test
|
||||
public void loadExceptionClasses() throws ExecutionException, InterruptedException {
|
||||
TokioAsyncContext context = new TokioAsyncContext();
|
||||
assertCanLoadClass(context, "org.signal.libsignal.net.CdsiProtocolException");
|
||||
assertCanLoadClass(context, "org.signal.libsignal.net.NetworkException");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadNonexistentClasses() throws ExecutionException, InterruptedException {
|
||||
TokioAsyncContext context = new TokioAsyncContext();
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist1");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist2");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist3");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist4");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist5");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist6");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist7");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist8");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist9");
|
||||
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist10");
|
||||
}
|
||||
|
||||
/** Assert that the class with the given name can be loaded on a Tokio worker thread. */
|
||||
private static void assertCanLoadClass(TokioAsyncContext context, String className)
|
||||
throws ExecutionException, InterruptedException {
|
||||
Future<Class<Object>> loadAsync = context.loadClassAsync(className);
|
||||
// Block waiting for the future to resolve.
|
||||
Class<Object> loaded = loadAsync.get();
|
||||
assertEquals(className, loaded.getName());
|
||||
}
|
||||
|
||||
/** Assert that the class doesn't exist. */
|
||||
private static void assertClassNotFound(TokioAsyncContext context, String className)
|
||||
throws ExecutionException, InterruptedException {
|
||||
Future<Class<Object>> loadAsync = context.loadClassAsync(className);
|
||||
// Block waiting for the future to resolve.
|
||||
Throwable cause =
|
||||
assertThrows(
|
||||
"for " + className,
|
||||
ExecutionException.class,
|
||||
() -> loadAsync.get(10, TimeUnit.SECONDS))
|
||||
.getCause();
|
||||
assertTrue(
|
||||
"unexpected error: " + cause,
|
||||
cause instanceof ClassNotFoundException || cause instanceof NoClassDefFoundError);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void runNetworkClassLoadTestFunction() throws ExecutionException, InterruptedException {
|
||||
Network.checkClassesCanBeLoadedAsyncForTest();
|
||||
}
|
||||
}
|
@ -106,6 +106,8 @@ public final class Native {
|
||||
public static native byte[] Aes256GcmSiv_Encrypt(long aesGcmSivObj, byte[] ptext, byte[] nonce, byte[] associatedData) throws Exception;
|
||||
public static native long Aes256GcmSiv_New(byte[] key) throws Exception;
|
||||
|
||||
public static native Object AsyncLoadClass(Object tokioContext, String className);
|
||||
|
||||
public static native void AuthCredentialPresentation_CheckValidContents(byte[] presentationBytes) throws Exception;
|
||||
public static native byte[] AuthCredentialPresentation_GetPniCiphertext(byte[] presentationBytes);
|
||||
public static native long AuthCredentialPresentation_GetRedemptionTime(byte[] presentationBytes);
|
||||
|
@ -6,15 +6,14 @@
|
||||
#![allow(clippy::missing_safety_doc)]
|
||||
#![deny(clippy::unwrap_used)]
|
||||
|
||||
use jni::objects::{JByteArray, JClass, JLongArray, JObject};
|
||||
use jni::objects::{JByteArray, JClass, JLongArray, JObject, JString};
|
||||
#[cfg(not(target_os = "android"))]
|
||||
use jni::objects::{JMap, JValue};
|
||||
use jni::JNIEnv;
|
||||
|
||||
use libsignal_bridge::jni::*;
|
||||
use libsignal_bridge::jni_args;
|
||||
#[cfg(not(target_os = "android"))]
|
||||
use libsignal_bridge::jni_class_name;
|
||||
use libsignal_bridge::net::TokioAsyncContext;
|
||||
use libsignal_bridge::{jni_args, jni_class_name};
|
||||
use libsignal_protocol::*;
|
||||
|
||||
pub mod logging;
|
||||
@ -65,6 +64,39 @@ pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_preloadClasse
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_AsyncLoadClass<'local>(
|
||||
mut env: JNIEnv<'local>,
|
||||
_class: JClass,
|
||||
tokio_context: JObject<'local>,
|
||||
class_name: JString,
|
||||
) -> JObject<'local> {
|
||||
struct LoadClassFromName(String);
|
||||
|
||||
impl<'a> ResultTypeInfo<'a> for LoadClassFromName {
|
||||
type ResultType = JClass<'a>;
|
||||
|
||||
fn convert_into(self, env: &mut JNIEnv<'a>) -> Result<Self::ResultType, BridgeLayerError> {
|
||||
find_class(env, &self.0).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
run_ffi_safe(&mut env, |env| {
|
||||
let handle = call_method_checked(
|
||||
env,
|
||||
tokio_context,
|
||||
"unsafeNativeHandleWithoutGuard",
|
||||
jni_args!(() -> long),
|
||||
)?;
|
||||
let tokio_context = <&TokioAsyncContext>::convert_from(env, &handle)?;
|
||||
let class_name = env.get_string(&class_name)?.into();
|
||||
run_future_on_runtime(env, tokio_context, async {
|
||||
FutureResultReporter::new(Ok(LoadClassFromName(class_name)), ())
|
||||
})
|
||||
})
|
||||
.into()
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "android"))]
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_SealedSender_1MultiRecipientParseSentMessage<
|
||||
|
@ -102,11 +102,22 @@ pub fn preload_classes(env: &mut JNIEnv<'_>) -> Result<(), BridgeLayerError> {
|
||||
/// [`JNIEnv::find_class`].
|
||||
pub fn find_class<'output>(
|
||||
env: &mut JNIEnv<'output>,
|
||||
name: &'static str,
|
||||
) -> Result<JClass<'output>, jni::errors::Error> {
|
||||
name: &str,
|
||||
) -> Result<JClass<'output>, BridgeLayerError> {
|
||||
match get_preloaded_class(env, name)? {
|
||||
Some(c) => Ok(c),
|
||||
None => real_jni_find_class(env, name),
|
||||
None => real_jni_find_class(env, name).or_else(|e| {
|
||||
let exception = env.exception_occurred()?;
|
||||
Err(if !exception.is_null() {
|
||||
env.exception_clear()?;
|
||||
BridgeLayerError::CallbackException(
|
||||
"FindClass",
|
||||
ThrownException::new(env, exception)?,
|
||||
)
|
||||
} else {
|
||||
e.into()
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,7 +132,7 @@ pub fn find_class<'output>(
|
||||
/// be used by natively-spawned threads to access application-defined types.
|
||||
fn get_preloaded_class<'output>(
|
||||
env: &mut JNIEnv<'output>,
|
||||
name: &'static str,
|
||||
name: &str,
|
||||
) -> Result<Option<JClass<'output>>, jni::errors::Error> {
|
||||
let class = PRELOADED_CLASSES
|
||||
.get()
|
||||
@ -143,7 +154,7 @@ fn get_preloaded_class<'output>(
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
fn real_jni_find_class<'output>(
|
||||
env: &mut JNIEnv<'output>,
|
||||
name: &'static str,
|
||||
name: &str,
|
||||
) -> Result<JClass<'output>, jni::errors::Error> {
|
||||
env.find_class(name)
|
||||
}
|
||||
|
@ -617,11 +617,11 @@ where
|
||||
|
||||
let throwable = env
|
||||
.new_string(error.to_string())
|
||||
.map_err(Into::into)
|
||||
.and_then(|message| {
|
||||
let class = find_class(env, exception_type)?;
|
||||
Ok(new_object(env, class, jni_args!((message => java.lang.String) -> void))?.into())
|
||||
})
|
||||
.map_err(Into::into);
|
||||
});
|
||||
consume(env, throwable, &error)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user