From 6edd0540fbef42eb82e8cc93ff217d1fc85467d4 Mon Sep 17 00:00:00 2001 From: Alex Konradi Date: Mon, 29 Apr 2024 13:08:10 -0400 Subject: [PATCH] java: add async class load method Add a method to allow Java code to attempt to load a class on a Tokio worker thread like libsignal does internally. This will be used for testing both in libsignal and in dependents. Fix a bug where exceptions raised during conversion from Rust result values to Java values weren't being correctly propagated to the Java Future that would report the result. --- .../org/signal/libsignal/net/Network.java | 27 ++++++++ .../libsignal/net/TokioAsyncContext.java | 7 ++ .../libsignal/net/TokioAsyncContextTest.java | 67 +++++++++++++++++++ .../org/signal/libsignal/internal/Native.java | 2 + rust/bridge/jni/src/lib.rs | 40 +++++++++-- rust/bridge/shared/src/jni/class_lookup.rs | 21 ++++-- rust/bridge/shared/src/jni/mod.rs | 4 +- 7 files changed, 157 insertions(+), 11 deletions(-) create mode 100644 java/client/src/test/java/org/signal/libsignal/net/TokioAsyncContextTest.java diff --git a/java/client/src/main/java/org/signal/libsignal/net/Network.java b/java/client/src/main/java/org/signal/libsignal/net/Network.java index 14690816..b1ac4974 100644 --- a/java/client/src/main/java/org/signal/libsignal/net/Network.java +++ b/java/client/src/main/java/org/signal/libsignal/net/Network.java @@ -76,6 +76,33 @@ public class Network { }); } + /** + * Try to load several libsignal classes asynchronously, using the same mechanism as native (Rust) + * code. + * + *

This should only be called in tests, and can be used to ensure at test time that libsignal + * async code won't fail to load exceptions. + */ + public static void checkClassesCanBeLoadedAsyncForTest() { + // This doesn't need to be comprehensive, just check a few classes. + final String[] classesToLoad = { + "org.signal.libsignal.net.CdsiLookupResponse$Entry", + "org.signal.libsignal.net.NetworkException", + "org.signal.libsignal.net.ChatServiceException", + "org.signal.libsignal.protocol.ServiceId", + }; + TokioAsyncContext context = new TokioAsyncContext(); + + for (String className : classesToLoad) { + // No need to do anything with the result; if it doesn't throw, it succeeded. + try { + context.loadClassAsync(className).get(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } + } + } + TokioAsyncContext getAsyncContext() { return this.tokioAsyncContext; } diff --git a/java/client/src/main/java/org/signal/libsignal/net/TokioAsyncContext.java b/java/client/src/main/java/org/signal/libsignal/net/TokioAsyncContext.java index e646b3e2..28202b63 100644 --- a/java/client/src/main/java/org/signal/libsignal/net/TokioAsyncContext.java +++ b/java/client/src/main/java/org/signal/libsignal/net/TokioAsyncContext.java @@ -5,6 +5,7 @@ package org.signal.libsignal.net; +import org.signal.libsignal.internal.CompletableFuture; import org.signal.libsignal.internal.Native; import org.signal.libsignal.internal.NativeHandleGuard; @@ -13,6 +14,12 @@ class TokioAsyncContext extends NativeHandleGuard.SimpleOwner { super(Native.TokioAsyncContext_new()); } + @SuppressWarnings("unchecked") + CompletableFuture> loadClassAsync(String className) { + className = className.replace('.', '/'); + return (CompletableFuture>) Native.AsyncLoadClass(this, className); + } + @Override protected void release(final long nativeHandle) { Native.TokioAsyncContext_Destroy(nativeHandle); diff --git a/java/client/src/test/java/org/signal/libsignal/net/TokioAsyncContextTest.java b/java/client/src/test/java/org/signal/libsignal/net/TokioAsyncContextTest.java new file mode 100644 index 00000000..b4c95248 --- /dev/null +++ b/java/client/src/test/java/org/signal/libsignal/net/TokioAsyncContextTest.java @@ -0,0 +1,67 @@ +// +// Copyright 2024 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +package org.signal.libsignal.net; + +import static org.junit.Assert.*; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.junit.Test; + +public class TokioAsyncContextTest { + @Test + public void loadExceptionClasses() throws ExecutionException, InterruptedException { + TokioAsyncContext context = new TokioAsyncContext(); + assertCanLoadClass(context, "org.signal.libsignal.net.CdsiProtocolException"); + assertCanLoadClass(context, "org.signal.libsignal.net.NetworkException"); + } + + @Test + public void loadNonexistentClasses() throws ExecutionException, InterruptedException { + TokioAsyncContext context = new TokioAsyncContext(); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist1"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist2"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist3"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist4"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist5"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist6"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist7"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist8"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist9"); + assertClassNotFound(context, "org.signal.libsignal.ClassThatDoesNotExist10"); + } + + /** Assert that the class with the given name can be loaded on a Tokio worker thread. */ + private static void assertCanLoadClass(TokioAsyncContext context, String className) + throws ExecutionException, InterruptedException { + Future> loadAsync = context.loadClassAsync(className); + // Block waiting for the future to resolve. + Class loaded = loadAsync.get(); + assertEquals(className, loaded.getName()); + } + + /** Assert that the class doesn't exist. */ + private static void assertClassNotFound(TokioAsyncContext context, String className) + throws ExecutionException, InterruptedException { + Future> loadAsync = context.loadClassAsync(className); + // Block waiting for the future to resolve. + Throwable cause = + assertThrows( + "for " + className, + ExecutionException.class, + () -> loadAsync.get(10, TimeUnit.SECONDS)) + .getCause(); + assertTrue( + "unexpected error: " + cause, + cause instanceof ClassNotFoundException || cause instanceof NoClassDefFoundError); + } + + @Test + public void runNetworkClassLoadTestFunction() throws ExecutionException, InterruptedException { + Network.checkClassesCanBeLoadedAsyncForTest(); + } +} diff --git a/java/shared/java/org/signal/libsignal/internal/Native.java b/java/shared/java/org/signal/libsignal/internal/Native.java index c63eea50..4ed0238a 100644 --- a/java/shared/java/org/signal/libsignal/internal/Native.java +++ b/java/shared/java/org/signal/libsignal/internal/Native.java @@ -106,6 +106,8 @@ public final class Native { public static native byte[] Aes256GcmSiv_Encrypt(long aesGcmSivObj, byte[] ptext, byte[] nonce, byte[] associatedData) throws Exception; public static native long Aes256GcmSiv_New(byte[] key) throws Exception; + public static native Object AsyncLoadClass(Object tokioContext, String className); + public static native void AuthCredentialPresentation_CheckValidContents(byte[] presentationBytes) throws Exception; public static native byte[] AuthCredentialPresentation_GetPniCiphertext(byte[] presentationBytes); public static native long AuthCredentialPresentation_GetRedemptionTime(byte[] presentationBytes); diff --git a/rust/bridge/jni/src/lib.rs b/rust/bridge/jni/src/lib.rs index 5898c0fd..34a0498e 100644 --- a/rust/bridge/jni/src/lib.rs +++ b/rust/bridge/jni/src/lib.rs @@ -6,15 +6,14 @@ #![allow(clippy::missing_safety_doc)] #![deny(clippy::unwrap_used)] -use jni::objects::{JByteArray, JClass, JLongArray, JObject}; +use jni::objects::{JByteArray, JClass, JLongArray, JObject, JString}; #[cfg(not(target_os = "android"))] use jni::objects::{JMap, JValue}; use jni::JNIEnv; use libsignal_bridge::jni::*; -use libsignal_bridge::jni_args; -#[cfg(not(target_os = "android"))] -use libsignal_bridge::jni_class_name; +use libsignal_bridge::net::TokioAsyncContext; +use libsignal_bridge::{jni_args, jni_class_name}; use libsignal_protocol::*; pub mod logging; @@ -65,6 +64,39 @@ pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_preloadClasse }) } +#[no_mangle] +pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_AsyncLoadClass<'local>( + mut env: JNIEnv<'local>, + _class: JClass, + tokio_context: JObject<'local>, + class_name: JString, +) -> JObject<'local> { + struct LoadClassFromName(String); + + impl<'a> ResultTypeInfo<'a> for LoadClassFromName { + type ResultType = JClass<'a>; + + fn convert_into(self, env: &mut JNIEnv<'a>) -> Result { + find_class(env, &self.0).map_err(Into::into) + } + } + + run_ffi_safe(&mut env, |env| { + let handle = call_method_checked( + env, + tokio_context, + "unsafeNativeHandleWithoutGuard", + jni_args!(() -> long), + )?; + let tokio_context = <&TokioAsyncContext>::convert_from(env, &handle)?; + let class_name = env.get_string(&class_name)?.into(); + run_future_on_runtime(env, tokio_context, async { + FutureResultReporter::new(Ok(LoadClassFromName(class_name)), ()) + }) + }) + .into() +} + #[cfg(not(target_os = "android"))] #[no_mangle] pub unsafe extern "C" fn Java_org_signal_libsignal_internal_Native_SealedSender_1MultiRecipientParseSentMessage< diff --git a/rust/bridge/shared/src/jni/class_lookup.rs b/rust/bridge/shared/src/jni/class_lookup.rs index e0e00d87..dadbed4d 100644 --- a/rust/bridge/shared/src/jni/class_lookup.rs +++ b/rust/bridge/shared/src/jni/class_lookup.rs @@ -102,11 +102,22 @@ pub fn preload_classes(env: &mut JNIEnv<'_>) -> Result<(), BridgeLayerError> { /// [`JNIEnv::find_class`]. pub fn find_class<'output>( env: &mut JNIEnv<'output>, - name: &'static str, -) -> Result, jni::errors::Error> { + name: &str, +) -> Result, BridgeLayerError> { match get_preloaded_class(env, name)? { Some(c) => Ok(c), - None => real_jni_find_class(env, name), + None => real_jni_find_class(env, name).or_else(|e| { + let exception = env.exception_occurred()?; + Err(if !exception.is_null() { + env.exception_clear()?; + BridgeLayerError::CallbackException( + "FindClass", + ThrownException::new(env, exception)?, + ) + } else { + e.into() + }) + }), } } @@ -121,7 +132,7 @@ pub fn find_class<'output>( /// be used by natively-spawned threads to access application-defined types. fn get_preloaded_class<'output>( env: &mut JNIEnv<'output>, - name: &'static str, + name: &str, ) -> Result>, jni::errors::Error> { let class = PRELOADED_CLASSES .get() @@ -143,7 +154,7 @@ fn get_preloaded_class<'output>( #[allow(clippy::disallowed_methods)] fn real_jni_find_class<'output>( env: &mut JNIEnv<'output>, - name: &'static str, + name: &str, ) -> Result, jni::errors::Error> { env.find_class(name) } diff --git a/rust/bridge/shared/src/jni/mod.rs b/rust/bridge/shared/src/jni/mod.rs index 8d14f867..d7af336e 100644 --- a/rust/bridge/shared/src/jni/mod.rs +++ b/rust/bridge/shared/src/jni/mod.rs @@ -617,11 +617,11 @@ where let throwable = env .new_string(error.to_string()) + .map_err(Into::into) .and_then(|message| { let class = find_class(env, exception_type)?; Ok(new_object(env, class, jni_args!((message => java.lang.String) -> void))?.into()) - }) - .map_err(Into::into); + }); consume(env, throwable, &error) }