0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-20 03:52:17 +02:00

java: add async class load method

Add a method to allow Java code to attempt to load a class on a Tokio worker 
thread like libsignal does internally. This will be used for testing both in 
libsignal and in dependents.

Fix a bug where exceptions raised during conversion from Rust result values to 
Java values weren't being correctly propagated to the Java Future that would 
report the result.
This commit is contained in:
Alex Konradi 2024-04-29 13:08:10 -04:00 committed by GitHub
parent f5c2c50047
commit 6edd0540fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 157 additions and 11 deletions

View File

@ -76,6 +76,33 @@ public class Network {
});
}
/**
* Try to load several libsignal classes asynchronously, using the same mechanism as native (Rust)
* code.
*
* <p>This should only be called in tests, and can be used to ensure at test time that libsignal
* async code won't fail to load exceptions.
*/
public static void checkClassesCanBeLoadedAsyncForTest() {
// This doesn't need to be comprehensive, just check a few classes.
final String[] classesToLoad = {
"org.signal.libsignal.net.CdsiLookupResponse$Entry",
"org.signal.libsignal.net.NetworkException",
"org.signal.libsignal.net.ChatServiceException",
"org.signal.libsignal.protocol.ServiceId",
};
TokioAsyncContext context = new TokioAsyncContext();
for (String className : classesToLoad) {
// No need to do anything with the result; if it doesn't throw, it succeeded.
try {
context.loadClassAsync(className).get();
} catch (ExecutionException | InterruptedException e) {
throw new RuntimeException(e);
}
}
}
TokioAsyncContext getAsyncContext() {
return this.tokioAsyncContext;
}

View File

@ -5,6 +5,7 @@
package org.signal.libsignal.net;
import org.signal.libsignal.internal.CompletableFuture;
import org.signal.libsignal.internal.Native;
import org.signal.libsignal.internal.NativeHandleGuard;
@ -13,6 +14,12 @@ class TokioAsyncContext extends NativeHandleGuard.SimpleOwner {
super(Native.TokioAsyncContext_new());
}
@SuppressWarnings("unchecked")
CompletableFuture<Class<Object>> loadClassAsync(String className) {
className = className.replace('.', '/');
return (CompletableFuture<Class<Object>>) Native.AsyncLoadClass(this, className);
}
@Override
protected void release(final long nativeHandle) {
Native.TokioAsyncContext_Destroy(nativeHandle);

View File

@ -0,0 +1,67 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
package org.signal.libsignal.net;
import static org.junit.Assert.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
public class TokioAsyncContextTest {
@Test
public void loadExceptionClasses() throws ExecutionException, InterruptedException {
TokioAsyncContext context = new TokioAsyncContext();
assertCanLoadClass(context, "org.signal.libsignal.net.CdsiProtocolException");
assertCanLoadClass(context, "org.signal.libsignal.net.NetworkException");
}
@Test
public void loadNonexistentClasses() throws ExecutionException, InterruptedException {
TokioAsyncContext context = new TokioAsyncContext();
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist1");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist2");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist3");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist4");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist5");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist6");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist7");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist8");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist9");
assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist10");
}
/** Assert that the class with the given name can be loaded on a Tokio worker thread. */
private static void assertCanLoadClass(TokioAsyncContext context, String className)
throws ExecutionException, InterruptedException {
Future<Class<Object>> loadAsync = context.loadClassAsync(className);
// Block waiting for the future to resolve.
Class<Object> loaded = loadAsync.get();
assertEquals(className, loaded.getName());
}
/** Assert that the class doesn't exist. */
private static void assertClassNotFound(TokioAsyncContext context, String className)
throws ExecutionException, InterruptedException {
Future<Class<Object>> loadAsync = context.loadClassAsync(className);
// Block waiting for the future to resolve.
Throwable cause =
assertThrows(
"for " + className,
ExecutionException.class,
() -> loadAsync.get(10, TimeUnit.SECONDS))
.getCause();
assertTrue(
"unexpected error: " + cause,
cause instanceof ClassNotFoundException || cause instanceof NoClassDefFoundError);
}
@Test
public void runNetworkClassLoadTestFunction() throws ExecutionException, InterruptedException {
Network.checkClassesCanBeLoadedAsyncForTest();
}
}

View File

@ -106,6 +106,8 @@ public final class Native {
public static native byte[] Aes256GcmSiv_Encrypt(long aesGcmSivObj, byte[] ptext, byte[] nonce, byte[] associatedData) throws Exception;
public static native long Aes256GcmSiv_New(byte[] key) throws Exception;
public static native Object AsyncLoadClass(Object tokioContext, String className);
public static native void AuthCredentialPresentation_CheckValidContents(byte[] presentationBytes) throws Exception;
public static native byte[] AuthCredentialPresentation_GetPniCiphertext(byte[] presentationBytes);
public static native long AuthCredentialPresentation_GetRedemptionTime(byte[] presentationBytes);

View File

@ -6,15 +6,14 @@
#![allow(clippy::missing_safety_doc)]
#![deny(clippy::unwrap_used)]
use jni::objects::{JByteArray, JClass, JLongArray, JObject};
use jni::objects::{JByteArray, JClass, JLongArray, JObject, JString};
#[cfg(not(target_os = "android"))]
use jni::objects::{JMap, JValue};
use jni::JNIEnv;
use libsignal_bridge::jni::*;
use libsignal_bridge::jni_args;
#[cfg(not(target_os = "android"))]
use libsignal_bridge::jni_class_name;
use libsignal_bridge::net::TokioAsyncContext;
use libsignal_bridge::{jni_args, jni_class_name};
use libsignal_protocol::*;
pub mod logging;
@ -65,6 +64,39 @@ pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_preloadClasse
})
}
#[no_mangle]
pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_AsyncLoadClass<'local>(
mut env: JNIEnv<'local>,
_class: JClass,
tokio_context: JObject<'local>,
class_name: JString,
) -> JObject<'local> {
struct LoadClassFromName(String);
impl<'a> ResultTypeInfo<'a> for LoadClassFromName {
type ResultType = JClass<'a>;
fn convert_into(self, env: &mut JNIEnv<'a>) -> Result<Self::ResultType, BridgeLayerError> {
find_class(env, &self.0).map_err(Into::into)
}
}
run_ffi_safe(&mut env, |env| {
let handle = call_method_checked(
env,
tokio_context,
"unsafeNativeHandleWithoutGuard",
jni_args!(() -> long),
)?;
let tokio_context = <&TokioAsyncContext>::convert_from(env, &handle)?;
let class_name = env.get_string(&class_name)?.into();
run_future_on_runtime(env, tokio_context, async {
FutureResultReporter::new(Ok(LoadClassFromName(class_name)), ())
})
})
.into()
}
#[cfg(not(target_os = "android"))]
#[no_mangle]
pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_SealedSender_1MultiRecipientParseSentMessage<

View File

@ -102,11 +102,22 @@ pub fn preload_classes(env: &mut JNIEnv<'_>) -> Result<(), BridgeLayerError> {
/// [`JNIEnv::find_class`].
pub fn find_class<'output>(
env: &mut JNIEnv<'output>,
name: &'static str,
) -> Result<JClass<'output>, jni::errors::Error> {
name: &str,
) -> Result<JClass<'output>, BridgeLayerError> {
match get_preloaded_class(env, name)? {
Some(c) => Ok(c),
None => real_jni_find_class(env, name),
None => real_jni_find_class(env, name).or_else(|e| {
let exception = env.exception_occurred()?;
Err(if !exception.is_null() {
env.exception_clear()?;
BridgeLayerError::CallbackException(
"FindClass",
ThrownException::new(env, exception)?,
)
} else {
e.into()
})
}),
}
}
@ -121,7 +132,7 @@ pub fn find_class<'output>(
/// be used by natively-spawned threads to access application-defined types.
fn get_preloaded_class<'output>(
env: &mut JNIEnv<'output>,
name: &'static str,
name: &str,
) -> Result<Option<JClass<'output>>, jni::errors::Error> {
let class = PRELOADED_CLASSES
.get()
@ -143,7 +154,7 @@ fn get_preloaded_class<'output>(
#[allow(clippy::disallowed_methods)]
fn real_jni_find_class<'output>(
env: &mut JNIEnv<'output>,
name: &'static str,
name: &str,
) -> Result<JClass<'output>, jni::errors::Error> {
env.find_class(name)
}

View File

@ -617,11 +617,11 @@ where
let throwable = env
.new_string(error.to_string())
.map_err(Into::into)
.and_then(|message| {
let class = find_class(env, exception_type)?;
Ok(new_object(env, class, jni_args!((message => java.lang.String) -> void))?.into())
})
.map_err(Into::into);
});
consume(env, throwable, &error)
}