From aca995d7458a847e77cf47991371c06562a8ba62 Mon Sep 17 00:00:00 2001 From: Sergey Skrobotov Date: Tue, 26 Mar 2024 22:47:54 -0700 Subject: [PATCH] libsignal-net: additional API and debug info --- .../org/signal/libsignal/net/ChatService.java | 66 +++++++++++++- .../signal/libsignal/net/ChatServiceTest.java | 2 + .../org/signal/libsignal/internal/Native.java | 2 + node/Native.d.ts | 4 + node/ts/net.ts | 14 +++ node/ts/test/NetTest.ts | 2 + rust/bridge/node/bin/Native.d.ts.in | 2 + rust/bridge/shared/src/jni/convert.rs | 12 ++- rust/bridge/shared/src/net.rs | 10 +++ rust/bridge/shared/src/node/convert.rs | 6 ++ rust/bridge/shared/src/testing/net.rs | 6 +- rust/net/src/cdsi.rs | 15 ++-- rust/net/src/chat.rs | 86 ++++++++++++------ rust/net/src/chat/chat_reconnect.rs | 54 +++++++++--- rust/net/src/chat/error.rs | 20 ++++- rust/net/src/chat/ws.rs | 20 +++-- rust/net/src/env.rs | 7 +- rust/net/src/infra.rs | 87 ++++++++++++++++--- rust/net/src/infra/connection_manager.rs | 1 + rust/net/src/infra/dns.rs | 28 +++++- rust/net/src/infra/reconnect.rs | 85 +++++++++++------- rust/net/src/infra/ws.rs | 32 ++++--- 22 files changed, 438 insertions(+), 123 deletions(-) diff --git a/java/client/src/main/java/org/signal/libsignal/net/ChatService.java b/java/client/src/main/java/org/signal/libsignal/net/ChatService.java index 93c16574..4382ee5b 100644 --- a/java/client/src/main/java/org/signal/libsignal/net/ChatService.java +++ b/java/client/src/main/java/org/signal/libsignal/net/ChatService.java @@ -54,6 +54,51 @@ public class ChatService extends NativeHandleGuard.SimpleOwner { Native.ChatService_disconnect(asyncContextHandle, chatServiceHandle))); } + /** + * Initiates establishing of the underlying authenticated connection to the Chat Service. Once the + * service is connected, all the requests will be using the established connection. Also, if the + * connection is lost for any reason other than the call to {@link #disconnect()}, an automatic + * reconnect attempt will be made. + * + *

Note: it's not necessary to call this method before attempting the first request. If the + * service is not connected, {@code connectAuthenticated()} will be called before the first + * authenticated request. However, in the case of the authenticated connection, calling this + * method will result in starting to accept incoming requests from the Chat Service. + * + * @return a future with the result of the connection attempt (either a {@link DebugInfo} or an + * error). + */ + public CompletableFuture connectAuthenticated() { + return tokioAsyncContext.guardedMap( + asyncContextHandle -> + guardedMap( + chatServiceHandle -> + Native.ChatService_connect_auth(asyncContextHandle, chatServiceHandle) + .thenApply(o -> (DebugInfo) o))); + } + + /** + * Initiates establishing of the underlying unauthenticated connection to the Chat Service. Once + * the service is connected, all the requests will be using the established connection. Also, if + * the connection is lost for any reason other than the call to {@link #disconnect()}, an + * automatic reconnect attempt will be made. + * + *

Note: it's not necessary to call this method before attempting the first request. If the + * service is not connected, {@code connectUnauthenticated()} ()} will be called before the first + * unauthenticated request. + * + * @return a future with the result of the connection attempt (either a {@link DebugInfo} or an + * error). + */ + public CompletableFuture connectUnauthenticated() { + return tokioAsyncContext.guardedMap( + asyncContextHandle -> + guardedMap( + chatServiceHandle -> + Native.ChatService_connect_unauth(asyncContextHandle, chatServiceHandle) + .thenApply(o -> (DebugInfo) o))); + } + /** * Sends request to the Chat Service over an unauthenticated channel. * @@ -143,10 +188,25 @@ public class ChatService extends NativeHandleGuard.SimpleOwner { public record Response(int status, String message, Map headers, byte[] body) {} - public record DebugInfo(boolean connectionReused, int reconnectCount, IpType ipType) { + public record DebugInfo( + boolean connectionReused, + int reconnectCount, + IpType ipType, + int durationMs, + String connectionInfo) { @CalledFromNative - DebugInfo(boolean connectionReused, int reconnectCount, byte ipTypeCode) { - this(connectionReused, reconnectCount, IpType.values()[ipTypeCode]); + DebugInfo( + boolean connectionReused, + int reconnectCount, + byte ipTypeCode, + int durationMs, + String connectionInfo) { + this( + connectionReused, + reconnectCount, + IpType.values()[ipTypeCode], + durationMs, + connectionInfo); } } diff --git a/java/client/src/test/java/org/signal/libsignal/net/ChatServiceTest.java b/java/client/src/test/java/org/signal/libsignal/net/ChatServiceTest.java index 96cb0f24..9f7fd638 100644 --- a/java/client/src/test/java/org/signal/libsignal/net/ChatServiceTest.java +++ b/java/client/src/test/java/org/signal/libsignal/net/ChatServiceTest.java @@ -50,6 +50,8 @@ public class ChatServiceTest { assertTrue(debugInfo.connectionReused()); assertEquals(2, debugInfo.reconnectCount()); assertEquals(IpType.IPv4, debugInfo.ipType()); + assertEquals(200, debugInfo.durationMs()); + assertEquals("connection_info", debugInfo.connectionInfo()); } @Test diff --git a/java/shared/java/org/signal/libsignal/internal/Native.java b/java/shared/java/org/signal/libsignal/internal/Native.java index c06d486e..3069695c 100644 --- a/java/shared/java/org/signal/libsignal/internal/Native.java +++ b/java/shared/java/org/signal/libsignal/internal/Native.java @@ -161,6 +161,8 @@ public final class Native { public static native CompletableFuture CdsiLookup_new(long asyncRuntime, long connectionManager, String username, String password, long request, int timeoutMillis); public static native byte[] CdsiLookup_token(long lookup); + public static native CompletableFuture ChatService_connect_auth(long asyncRuntime, long chat); + public static native CompletableFuture ChatService_connect_unauth(long asyncRuntime, long chat); public static native CompletableFuture ChatService_disconnect(long asyncRuntime, long chat); public static native long ChatService_new(long connectionManager, String username, String password); public static native CompletableFuture ChatService_unauth_send(long asyncRuntime, long chat, long httpRequest, int timeoutMillis); diff --git a/node/Native.d.ts b/node/Native.d.ts index dcfabc74..380da02b 100644 --- a/node/Native.d.ts +++ b/node/Native.d.ts @@ -32,6 +32,8 @@ interface DebugInfo { connectionReused: boolean; reconnectCount: number; ipType: number; + durationMillis: number; + connectionInfo: string; } interface ResponseAndDebugInfo { @@ -166,6 +168,8 @@ export function Cds2ClientState_New(mrenclave: Buffer, attestationMsg: Buffer, c export function CdsiLookup_complete(asyncRuntime: Wrapper, lookup: Wrapper): Promise; export function CdsiLookup_new(asyncRuntime: Wrapper, connectionManager: Wrapper, username: string, password: string, request: Wrapper, timeoutMillis: number): Promise; export function CdsiLookup_token(lookup: Wrapper): Buffer; +export function ChatService_connect_auth(asyncRuntime: Wrapper, chat: Wrapper): Promise; +export function ChatService_connect_unauth(asyncRuntime: Wrapper, chat: Wrapper): Promise; export function ChatService_disconnect(asyncRuntime: Wrapper, chat: Wrapper): Promise; export function ChatService_new(connectionManager: Wrapper, username: string, password: string): Chat; export function ChatService_unauth_send(asyncRuntime: Wrapper, chat: Wrapper, httpRequest: Wrapper, timeoutMillis: number): Promise; diff --git a/node/ts/net.ts b/node/ts/net.ts index cd74b5d8..69696d3e 100644 --- a/node/ts/net.ts +++ b/node/ts/net.ts @@ -83,6 +83,20 @@ export class Net { await Native.ChatService_disconnect(this._asyncContext, this._chatService); } + async connectUnauthenticatedChatService(): Promise { + await Native.ChatService_connect_unauth( + this._asyncContext, + this._chatService + ); + } + + async connectAuthenticatedChatService(): Promise { + await Native.ChatService_connect_auth( + this._asyncContext, + this._chatService + ); + } + async unauthenticatedFetchAndDebug( chatRequest: ChatRequest ): Promise { diff --git a/node/ts/test/NetTest.ts b/node/ts/test/NetTest.ts index abaa34eb..ad62efc6 100644 --- a/node/ts/test/NetTest.ts +++ b/node/ts/test/NetTest.ts @@ -56,6 +56,8 @@ describe('chat service api', () => { connectionReused: true, reconnectCount: 2, ipType: 1, + durationMillis: 200, + connectionInfo: 'connection_info', }; expect(Native.TESTING_ChatServiceDebugInfoConvert()).deep.equals(expected); }); diff --git a/rust/bridge/node/bin/Native.d.ts.in b/rust/bridge/node/bin/Native.d.ts.in index 3839d4b3..5d3ea09e 100644 --- a/rust/bridge/node/bin/Native.d.ts.in +++ b/rust/bridge/node/bin/Native.d.ts.in @@ -32,6 +32,8 @@ interface DebugInfo { connectionReused: boolean; reconnectCount: number; ipType: number; + durationMillis: number; + connectionInfo: string; } interface ResponseAndDebugInfo { diff --git a/rust/bridge/shared/src/jni/convert.rs b/rust/bridge/shared/src/jni/convert.rs index ca732d9d..76b65f0c 100644 --- a/rust/bridge/shared/src/jni/convert.rs +++ b/rust/bridge/shared/src/jni/convert.rs @@ -1082,6 +1082,8 @@ impl<'a> ResultTypeInfo<'a> for libsignal_net::chat::DebugInfo { connection_reused, reconnect_count, ip_type, + duration, + connection_info, } = self; // reconnect count as i32 @@ -1090,6 +1092,12 @@ impl<'a> ResultTypeInfo<'a> for libsignal_net::chat::DebugInfo { // ip type as code let ip_type_byte = ip_type as i8; + // duration as millis + let duration_ms: i32 = duration.as_millis().try_into().expect("within i32 range"); + + // connection info string + let connection_info_string = env.new_string(connection_info)?; + let class = { const RESPONSE_CLASS: &str = jni_class_name!(org.signal.libsignal.net.ChatService::DebugInfo); @@ -1104,7 +1112,9 @@ impl<'a> ResultTypeInfo<'a> for libsignal_net::chat::DebugInfo { jni_args!(( connection_reused => boolean, reconnect_count_i32 => int, - ip_type_byte => byte + ip_type_byte => byte, + duration_ms => int, + connection_info_string => java.lang.String, ) -> void), )?) } diff --git a/rust/bridge/shared/src/net.rs b/rust/bridge/shared/src/net.rs index 84d08bb1..c8943fa1 100644 --- a/rust/bridge/shared/src/net.rs +++ b/rust/bridge/shared/src/net.rs @@ -322,6 +322,16 @@ async fn ChatService_disconnect(chat: &Chat) { chat.service.disconnect().await } +#[bridge_io(TokioAsyncContext, ffi = false)] +async fn ChatService_connect_unauth(chat: &Chat) -> Result { + chat.service.connect_unauthenticated().await +} + +#[bridge_io(TokioAsyncContext, ffi = false)] +async fn ChatService_connect_auth(chat: &Chat) -> Result { + chat.service.connect_authenticated().await +} + #[bridge_io(TokioAsyncContext, ffi = false)] async fn ChatService_unauth_send( chat: &Chat, diff --git a/rust/bridge/shared/src/node/convert.rs b/rust/bridge/shared/src/node/convert.rs index 0bcebab7..6318935f 100644 --- a/rust/bridge/shared/src/node/convert.rs +++ b/rust/bridge/shared/src/node/convert.rs @@ -955,16 +955,22 @@ impl<'a> ResultTypeInfo<'a> for libsignal_net::chat::DebugInfo { connection_reused, reconnect_count, ip_type, + duration, + connection_info, } = self; let obj = JsObject::new(cx); let connection_reused = cx.boolean(connection_reused); let reconnect_count = cx.number(reconnect_count); let ip_type = cx.number(ip_type as u8); + let duration = cx.number(duration.as_millis().try_into().unwrap_or(u32::MAX)); + let connection_info = cx.string(connection_info); obj.set(cx, "connectionReused", connection_reused)?; obj.set(cx, "reconnectCount", reconnect_count)?; obj.set(cx, "ipType", ip_type)?; + obj.set(cx, "durationMillis", duration)?; + obj.set(cx, "connectionInfo", connection_info)?; Ok(obj) } diff --git a/rust/bridge/shared/src/testing/net.rs b/rust/bridge/shared/src/testing/net.rs index 44c5dd50..15e0b07b 100644 --- a/rust/bridge/shared/src/testing/net.rs +++ b/rust/bridge/shared/src/testing/net.rs @@ -4,11 +4,13 @@ // use std::str::FromStr; +use std::time::Duration; use http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; use libsignal_bridge_macros::*; use libsignal_net::cdsi::{LookupError, LookupResponse, LookupResponseEntry, E164}; -use libsignal_net::chat::{ChatServiceError, DebugInfo, IpType, Response}; +use libsignal_net::chat::{ChatServiceError, DebugInfo, Response}; +use libsignal_net::infra::IpType; use libsignal_protocol::{Aci, Pni}; use nonzero_ext::nonzero; use uuid::Uuid; @@ -155,6 +157,8 @@ fn TESTING_ChatServiceDebugInfoConvert() -> Result connection_reused: true, reconnect_count: 2, ip_type: IpType::V4, + duration: Duration::from_millis(200), + connection_info: "connection_info".to_string(), }) } diff --git a/rust/net/src/cdsi.rs b/rust/net/src/cdsi.rs index b59b304a..2b58985a 100644 --- a/rust/net/src/cdsi.rs +++ b/rust/net/src/cdsi.rs @@ -484,7 +484,8 @@ mod test { use crate::infra::test::shared::InMemoryWarpConnector; use crate::infra::ws::testutil::{ - fake_websocket, run_attested_server, AttestedServerOutput, FAKE_ATTESTATION, + fake_websocket, mock_connection_info, run_attested_server, AttestedServerOutput, + FAKE_ATTESTATION, }; use crate::infra::ws::WebSocketClient; @@ -673,8 +674,7 @@ mod test { fake_server, )); - let ws_client = - WebSocketClient::new_fake(client, url::Host::Domain("localhost".to_string())); + let ws_client = WebSocketClient::new_fake(client, mock_connection_info()); let cdsi_connection = CdsiConnection( AttestedConnection::connect(ws_client, |fake_attestation| { assert_eq!(fake_attestation, FAKE_ATTESTATION); @@ -729,8 +729,7 @@ mod test { fake_server, )); - let ws_client = - WebSocketClient::new_fake(client, url::Host::Domain("localhost".to_string())); + let ws_client = WebSocketClient::new_fake(client, mock_connection_info()); let cdsi_connection = CdsiConnection( AttestedConnection::connect(ws_client, |fake_attestation| { assert_eq!(fake_attestation, FAKE_ATTESTATION); @@ -777,8 +776,7 @@ mod test { fake_server, )); - let ws_client = - WebSocketClient::new_fake(client, url::Host::Domain("localhost".to_string())); + let ws_client = WebSocketClient::new_fake(client, mock_connection_info()); let cdsi_connection = CdsiConnection( AttestedConnection::connect(ws_client, |fake_attestation| { assert_eq!(fake_attestation, FAKE_ATTESTATION); @@ -851,8 +849,7 @@ mod test { fake_server, )); - let ws_client = - WebSocketClient::new_fake(client, url::Host::Domain("localhost".to_string())); + let ws_client = WebSocketClient::new_fake(client, mock_connection_info()); let cdsi_connection = CdsiConnection( AttestedConnection::connect(ws_client, |fake_attestation| { assert_eq!(fake_attestation, FAKE_ATTESTATION); diff --git a/rust/net/src/chat.rs b/rust/net/src/chat.rs index 6ca0e2a3..4ce81038 100644 --- a/rust/net/src/chat.rs +++ b/rust/net/src/chat.rs @@ -8,13 +8,14 @@ use std::time::Duration; use ::http::uri::PathAndQuery; use ::http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; use async_trait::async_trait; -use url::Host; use crate::chat::ws::{ChatOverWebSocketServiceConnector, ServerRequest}; use crate::infra::connection_manager::MultiRouteConnectionManager; use crate::infra::reconnect::{ServiceConnectorWithDecorator, ServiceWithReconnect}; use crate::infra::ws::WebSocketClientConnector; -use crate::infra::{EndpointConnection, HttpRequestDecorator, TransportConnector}; +use crate::infra::{ + ConnectionInfo, EndpointConnection, HttpRequestDecorator, IpType, TransportConnector, +}; use crate::proto; use crate::utils::basic_authorization; @@ -38,6 +39,9 @@ pub trait ChatService { /// or HTTP) capable of sending [Request] objects. async fn send(&self, msg: Request, timeout: Duration) -> Result; + /// Establish a connection without sending a request. + async fn connect(&self) -> Result<(), ChatServiceError>; + /// If the service is currently holding an open connection, closes that connection. /// /// Depending on the implementing logic, the connection may be re-established later @@ -53,43 +57,28 @@ pub trait ChatServiceWithDebugInfo: ChatService { msg: Request, timeout: Duration, ) -> (Result, DebugInfo); + + /// Establish a connection without sending a request. + async fn connect_and_debug(&self) -> Result; } pub trait RemoteAddressInfo { /// Provides information about the remote address the service is connected to - /// - /// If IP information is available, implementation should prefer to return [Host::Ipv4] or [Host::Ipv6] - /// and only use [Host::Domain] as a fallback. - fn remote_address(&self) -> Host; + fn connection_info(&self) -> ConnectionInfo; } -#[derive(Copy, Clone, Debug)] -#[repr(u8)] -pub enum IpType { - Unknown = 0, - V4 = 1, - V6 = 2, -} - -impl From for IpType { - fn from(host: Host) -> Self { - match host { - Host::Domain(_) => IpType::Unknown, - Host::Ipv4(_) => IpType::V4, - Host::Ipv6(_) => IpType::V6, - } - } -} - -#[derive(Copy, Clone, Debug)] +#[derive(Debug)] pub struct DebugInfo { /// Indicates if the connection was active at the time of the call. pub connection_reused: bool, /// Number of times a connection had to be established since the service was created. pub reconnect_count: u32, - /// IP type of the connection that was used for the request. `0`, if information is not available - /// or if the connection failed. + /// IP type of the connection that was used for the request. pub ip_type: IpType, + /// Time it took to complete the request. + pub duration: Duration, + /// Connection information summary. + pub connection_info: String, } #[derive(Clone, Debug)] @@ -195,6 +184,14 @@ where self.unauth_service.send_and_debug(msg, timeout).await } + pub async fn connect_authenticated(&self) -> Result { + self.auth_service.connect_and_debug().await + } + + pub async fn connect_unauthenticated(&self) -> Result { + self.unauth_service.connect_and_debug().await + } + pub async fn disconnect(&self) { self.unauth_service.disconnect().await; self.auth_service.disconnect().await; @@ -246,6 +243,10 @@ where self.inner.send(msg, timeout).await } + async fn connect(&self) -> Result<(), ChatServiceError> { + self.inner.connect().await + } + async fn disconnect(&self) { self.inner.disconnect().await } @@ -263,6 +264,10 @@ where ) -> (Result, DebugInfo) { self.inner.send_and_debug(msg, timeout).await } + + async fn connect_and_debug(&self) -> Result { + self.inner.connect_and_debug().await + } } struct AuthorizedChatService { @@ -286,6 +291,10 @@ where self.inner.send(msg, timeout).await } + async fn connect(&self) -> Result<(), ChatServiceError> { + self.inner.connect().await + } + async fn disconnect(&self) { self.inner.disconnect().await } @@ -297,6 +306,10 @@ impl ChatService for Arc { self.as_ref().send(msg, timeout).await } + async fn connect(&self) -> Result<(), ChatServiceError> { + self.as_ref().connect().await + } + async fn disconnect(&self) { self.as_ref().disconnect().await } @@ -314,6 +327,10 @@ where ) -> (Result, DebugInfo) { self.inner.send_and_debug(msg, timeout).await } + + async fn connect_and_debug(&self) -> Result { + self.inner.connect_and_debug().await + } } #[async_trait] @@ -322,6 +339,10 @@ impl ChatService for Arc { self.as_ref().send(msg, timeout).await } + async fn connect(&self) -> Result<(), ChatServiceError> { + self.as_ref().connect().await + } + async fn disconnect(&self) { self.as_ref().disconnect().await } @@ -336,6 +357,10 @@ impl ChatServiceWithDebugInfo for Arc (Result, DebugInfo) { self.as_ref().send_and_debug(msg, timeout).await } + + async fn connect_and_debug(&self) -> Result { + self.as_ref().connect_and_debug().await + } } fn build_authorized_chat_service( @@ -443,10 +468,14 @@ pub(crate) mod test { ServiceState::Active(service, status) if !status.is_stopped() => { service.clone().send(msg, timeout).await } - _ => Err(ChatServiceError::NoServiceConnection), + _ => Err(ChatServiceError::AllConnectionRoutesFailed { attempts: 1 }), } } + async fn connect(&self) -> Result<(), ChatServiceError> { + Ok(()) + } + async fn disconnect(&self) { if let ServiceState::Active(_, status) = &*self.inner { status.stop_service() @@ -465,6 +494,7 @@ pub(crate) mod test { pub fn connection_manager() -> SingleRouteThrottlingConnectionManager { let connection_params = ConnectionParams::new( + "test", "test.signal.org", "test.signal.org", 443, diff --git a/rust/net/src/chat/chat_reconnect.rs b/rust/net/src/chat/chat_reconnect.rs index 93e4ad6d..c84249b5 100644 --- a/rust/net/src/chat/chat_reconnect.rs +++ b/rust/net/src/chat/chat_reconnect.rs @@ -28,11 +28,12 @@ where C::StartError: Send + Sync + Debug + LogSafeDisplay, { async fn send(&self, msg: Request, timeout: Duration) -> Result { - let service = self.service_clone().await; - match service { - Some(s) => s.send(msg, timeout).await, - None => Err(ChatServiceError::NoServiceConnection), - } + self.service_clone().await?.send(msg, timeout).await + } + + async fn connect(&self) -> Result<(), ChatServiceError> { + self.service_clone().await?; + Ok(()) } async fn disconnect(&self) { @@ -55,23 +56,50 @@ where msg: Request, timeout: Duration, ) -> (Result, DebugInfo) { - let deadline = Instant::now() + timeout; - let is_connected = self.is_connected(deadline).await; + let start = Instant::now(); + let initial_reconnect_count = self.reconnect_count(); + let deadline = start + timeout; let service = self.service_clone().await; - let (response, ip_type) = match service { - Some(s) => { + let (response, ip_type, connection_info) = match service { + Ok(s) => { let result = s.send(msg, deadline - Instant::now()).await; - (result, s.remote_address().into()) + ( + result, + IpType::from_host(&s.connection_info().address), + s.connection_info().description(), + ) } - None => (Err(ChatServiceError::NoServiceConnection), IpType::Unknown), + Err(e) => (Err(e.into()), IpType::Unknown, "".to_string()), }; + let duration = start.elapsed(); + let reconnect_count = self.reconnect_count(); ( response, DebugInfo { - reconnect_count: self.reconnect_count(), - connection_reused: is_connected, + connection_reused: reconnect_count == initial_reconnect_count, + reconnect_count, ip_type, + duration, + connection_info, }, ) } + + async fn connect_and_debug(&self) -> Result { + let start = Instant::now(); + let initial_reconnect_count = self.reconnect_count(); + let service = self.service_clone().await?; + let connection_info = service.connection_info(); + let ip_type = IpType::from_host(&connection_info.address); + let connection_info = connection_info.description(); + let duration = start.elapsed(); + let reconnect_count = self.reconnect_count(); + Ok(DebugInfo { + connection_reused: reconnect_count == initial_reconnect_count, + reconnect_count, + ip_type, + duration, + connection_info, + }) + } } diff --git a/rust/net/src/chat/error.rs b/rust/net/src/chat/error.rs index 44b1b717..86a07a13 100644 --- a/rust/net/src/chat/error.rs +++ b/rust/net/src/chat/error.rs @@ -6,6 +6,7 @@ use http::header::ToStrError; use crate::infra::errors::LogSafeDisplay; +use crate::infra::reconnect; use crate::infra::ws::WebSocketServiceError; #[derive(Debug, thiserror::Error, displaydoc::Display)] @@ -24,8 +25,10 @@ pub enum ChatServiceError { RequestHasInvalidHeader, /// Timeout Timeout, - /// Service is not connected - NoServiceConnection, + /// Timed out while establishing connection after {attempts} attempts + TimeoutEstablishingConnection { attempts: u16 }, + /// All connection routes failed or timed out, {attempts} attempts made + AllConnectionRoutesFailed { attempts: u16 }, } impl LogSafeDisplay for ChatServiceError {} @@ -35,3 +38,16 @@ impl From for ChatServiceError { ChatServiceError::RequestHasInvalidHeader } } + +impl From for ChatServiceError { + fn from(e: reconnect::ReconnectError) -> Self { + match e { + reconnect::ReconnectError::Timeout { attempts } => { + Self::TimeoutEstablishingConnection { attempts } + } + reconnect::ReconnectError::AllRoutesFailed { attempts } => { + Self::AllConnectionRoutesFailed { attempts } + } + } + } +} diff --git a/rust/net/src/chat/ws.rs b/rust/net/src/chat/ws.rs index dff945ab..bc8b3c08 100644 --- a/rust/net/src/chat/ws.rs +++ b/rust/net/src/chat/ws.rs @@ -14,7 +14,6 @@ use http::status::StatusCode; use prost::Message; use tokio::sync::{mpsc, oneshot, Mutex}; use tokio_tungstenite::WebSocketStream; -use url::Host; use crate::chat::{ ChatMessageType, ChatService, ChatServiceError, MessageProto, RemoteAddressInfo, Request, @@ -25,7 +24,7 @@ use crate::infra::ws::{ NextOrClose, TextOrBinary, WebSocketClient, WebSocketClientConnector, WebSocketClientReader, WebSocketClientWriter, WebSocketConnectError, WebSocketServiceError, }; -use crate::infra::{AsyncDuplexStream, ConnectionParams, TransportConnector}; +use crate::infra::{AsyncDuplexStream, ConnectionInfo, ConnectionParams, TransportConnector}; use crate::proto::chat_websocket::web_socket_message::Type; #[derive(Debug, Default, Eq, Hash, PartialEq, Clone, Copy)] @@ -122,7 +121,7 @@ impl ChatOverWebSocketServiceConnector { #[async_trait] impl ServiceConnector for ChatOverWebSocketServiceConnector { type Service = ChatOverWebSocket; - type Channel = (WebSocketStream, Host); + type Channel = (WebSocketStream, ConnectionInfo); type ConnectError = WebSocketConnectError; type StartError = ChatServiceError; @@ -143,7 +142,7 @@ impl ServiceConnector for ChatOverWebSocketServiceConnect let WebSocketClient { ws_client_writer, ws_client_reader, - remote_address, + connection_info, } = ws_client; let pending_messages: Arc> = Default::default(); tokio::spawn(reader_task( @@ -158,7 +157,7 @@ impl ServiceConnector for ChatOverWebSocketServiceConnect ws_client_writer, service_status: service_status.clone(), pending_messages, - remote_address, + connection_info, }, service_status, ) @@ -232,12 +231,12 @@ pub struct ChatOverWebSocket { ws_client_writer: WebSocketClientWriter, service_status: ServiceStatus, pending_messages: Arc>, - remote_address: Host, + connection_info: ConnectionInfo, } impl RemoteAddressInfo for ChatOverWebSocket { - fn remote_address(&self) -> Host { - self.remote_address.clone() + fn connection_info(&self) -> ConnectionInfo { + self.connection_info.clone() } } @@ -279,6 +278,11 @@ where res } + async fn connect(&self) -> Result<(), ChatServiceError> { + // ChatServiceOverWebsocket is created connected + Ok(()) + } + async fn disconnect(&self) { self.service_status.stop_service() } diff --git a/rust/net/src/env.rs b/rust/net/src/env.rs index f8bc3e34..28a67d16 100644 --- a/rust/net/src/env.rs +++ b/rust/net/src/env.rs @@ -131,6 +131,7 @@ pub const DOMAIN_CONFIG_SVR3_TPM2SNP_STAGING: DomainConfig = DomainConfig { }; const PROXY_CONFIG_F: ProxyConfig = ProxyConfig { + route_log_name: "proxy_f", hostname: "reflector-signal.global.ssl.fastly.net", sni_list: &[ "github.githubassets.com", @@ -140,6 +141,7 @@ const PROXY_CONFIG_F: ProxyConfig = ProxyConfig { }; const PROXY_CONFIG_G: ProxyConfig = ProxyConfig { + route_log_name: "proxy_g", hostname: "reflector-nrgwuv7kwq-uc.a.run.app", sni_list: &[ "www.google.com", @@ -163,12 +165,13 @@ impl DomainConfig { pub fn static_fallback(&self) -> (&'static str, LookupResult) { ( self.hostname, - LookupResult::new(self.ip_v4.into(), self.ip_v6.into()), + LookupResult::new_static(self.ip_v4.into(), self.ip_v6.into()), ) } pub fn connection_params(&self) -> ConnectionParams { ConnectionParams::new( + "direct", self.hostname, self.hostname, 443, @@ -189,6 +192,7 @@ impl DomainConfig { } pub struct ProxyConfig { + route_log_name: &'static str, hostname: &'static str, sni_list: &'static [&'static str], } @@ -203,6 +207,7 @@ impl ProxyConfig { sni_list.shuffle(&mut rng); sni_list.into_iter().map(move |sni| { ConnectionParams::new( + self.route_log_name, sni, self.hostname, 443, diff --git a/rust/net/src/infra.rs b/rust/net/src/infra.rs index 99a62e30..4ddb83ae 100644 --- a/rust/net/src/infra.rs +++ b/rust/net/src/infra.rs @@ -18,6 +18,7 @@ use futures_util::TryFutureExt; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio_boring::SslStream; +use url::Host; use crate::infra::certs::RootCertificates; use crate::infra::connection_manager::{ @@ -37,6 +38,24 @@ pub mod ws; const CONNECTION_ATTEMPT_DELAY: Duration = Duration::from_millis(200); +#[derive(Copy, Clone, Debug)] +#[repr(u8)] +pub enum IpType { + Unknown = 0, + V4 = 1, + V6 = 2, +} + +impl IpType { + pub(crate) fn from_host(host: &Host) -> Self { + match host { + Host::Domain(_) => IpType::Unknown, + Host::Ipv4(_) => IpType::V4, + Host::Ipv6(_) => IpType::V6, + } + } +} + /// A collection of commonly used decorators for HTTP requests. #[derive(Clone, Debug)] pub enum HttpRequestDecorator { @@ -71,6 +90,7 @@ impl From for HttpRequestDecoratorSeq { /// only be applied to the initial connection upgrade request). #[derive(Clone, Debug)] pub struct ConnectionParams { + pub route_type: &'static str, pub sni: Arc, pub host: Arc, pub port: u16, @@ -80,6 +100,7 @@ pub struct ConnectionParams { impl ConnectionParams { pub fn new( + route_type: &'static str, sni: &str, host: &str, port: u16, @@ -87,6 +108,7 @@ impl ConnectionParams { certs: RootCertificates, ) -> Self { Self { + route_type, sni: Arc::from(sni), host: Arc::from(host), port, @@ -107,6 +129,32 @@ impl ConnectionParams { } } +#[derive(Debug, Clone)] +pub struct ConnectionInfo { + /// Type of the connection, e.g. direct or via proxy + pub route_type: &'static str, + + /// The source of the DNS data, e.g. lookup or static fallback + pub dns_source: &'static str, + + /// Address that was used to establish the connection + /// + /// If IP information is available, it's recommended to use [Host::Ipv4] or [Host::Ipv6] + /// and only use [Host::Domain] as a fallback. + pub address: Host, +} + +impl ConnectionInfo { + pub fn description(&self) -> String { + format!( + "route={};dns_source={};ip_type={:?}", + self.route_type, + self.dns_source, + IpType::from_host(&self.address) + ) + } +} + impl HttpRequestDecoratorSeq { pub fn decorate_request( &self, @@ -142,7 +190,7 @@ impl HttpRequestDecorator { } } -pub struct StreamAndHost(T, url::Host); +pub struct StreamAndInfo(T, ConnectionInfo); pub trait AsyncDuplexStream: AsyncRead + AsyncWrite + Unpin + Send + Sync {} @@ -156,7 +204,7 @@ pub trait TransportConnector: Clone + Send + Sync { &self, connection_params: &ConnectionParams, alpn: &[u8], - ) -> Result, TransportConnectError>; + ) -> Result, TransportConnectError>; } #[derive(Clone)] @@ -172,9 +220,10 @@ impl TransportConnector for TcpSslTransportConnector { &self, connection_params: &ConnectionParams, alpn: &[u8], - ) -> Result, TransportConnectError> { - let StreamAndHost(tcp_stream, remote_address) = connect_tcp( + ) -> Result, TransportConnectError> { + let StreamAndInfo(tcp_stream, remote_address) = connect_tcp( &self.dns_resolver, + connection_params.route_type, &connection_params.sni, connection_params.port, ) @@ -188,7 +237,7 @@ impl TransportConnector for TcpSslTransportConnector { .await .map_err(|_| TransportConnectError::SslFailedHandshake)?; - Ok(StreamAndHost(ssl_stream, remote_address)) + Ok(StreamAndInfo(ssl_stream, remote_address)) } } @@ -251,9 +300,10 @@ pub fn make_ws_config( pub(crate) async fn connect_tcp( dns_resolver: &DnsResolver, + route_type: &'static str, host: &str, port: u16, -) -> Result, TransportConnectError> { +) -> Result, TransportConnectError> { let dns_lookup = dns_resolver .lookup_ip(host) .await @@ -263,6 +313,8 @@ pub(crate) async fn connect_tcp( return Err(TransportConnectError::DnsError); } + let dns_source = dns_lookup.source(); + // The idea is to go through the list of candidate IP addresses // and to attempt a connection to each of them, giving each one a `CONNECTION_ATTEMPT_DELAY` headstart // before moving on to the next candidate. @@ -282,7 +334,16 @@ pub(crate) async fn connect_tcp( log::debug!("failed to connect to IP [{}] with an error: {:?}", ip, e) }) .await - .map(|r| StreamAndHost(r, ip_addr_to_host(ip))) + .map(|r| { + StreamAndInfo( + r, + ConnectionInfo { + route_type, + dns_source, + address: ip_addr_to_host(ip), + }, + ) + }) } }); @@ -322,7 +383,7 @@ pub(crate) mod test { use crate::infra::reconnect::{ ServiceConnector, ServiceInitializer, ServiceState, ServiceStatus, }; - use crate::infra::{ConnectionParams, StreamAndHost, TransportConnector}; + use crate::infra::{ConnectionInfo, ConnectionParams, StreamAndInfo, TransportConnector}; #[derive(Debug, Display)] pub(crate) enum TestError { @@ -374,7 +435,7 @@ pub(crate) mod test { &self, connection_params: &ConnectionParams, _alpn: &[u8], - ) -> Result, TransportConnectError> { + ) -> Result, TransportConnectError> { let (client, server) = tokio::io::duplex(1024); let routes = self.filter.clone(); tokio::spawn(async { @@ -382,9 +443,13 @@ pub(crate) mod test { futures_util::stream::iter(vec![Ok::(server)]); warp::serve(routes).run_incoming(one_element_iter).await; }); - Ok(StreamAndHost( + Ok(StreamAndInfo( client, - url::Host::Domain(connection_params.host.to_string()), + ConnectionInfo { + route_type: "test", + dns_source: "test", + address: url::Host::Domain(connection_params.host.to_string()), + }, )) } } diff --git a/rust/net/src/infra/connection_manager.rs b/rust/net/src/infra/connection_manager.rs index fda613d1..4b84e789 100644 --- a/rust/net/src/infra/connection_manager.rs +++ b/rust/net/src/infra/connection_manager.rs @@ -494,6 +494,7 @@ mod test { fn example_connection_params(host: &str) -> ConnectionParams { ConnectionParams::new( + "test", host, host, 443, diff --git a/rust/net/src/infra/dns.rs b/rust/net/src/infra/dns.rs index 9385a572..69f06447 100644 --- a/rust/net/src/infra/dns.rs +++ b/rust/net/src/infra/dns.rs @@ -23,6 +23,7 @@ pub enum Error { #[derive(Debug, Default, Clone)] pub struct LookupResult { + from_lookup: bool, ipv4: Vec, ipv6: Vec, } @@ -45,8 +46,27 @@ impl IntoIterator for LookupResult { } impl LookupResult { - pub fn new(ipv4: Vec, ipv6: Vec) -> Self { - Self { ipv4, ipv6 } + pub fn from_lookup(ipv4: Vec, ipv6: Vec) -> Self { + Self { + from_lookup: true, + ipv4, + ipv6, + } + } + + pub fn new_static(ipv4: Vec, ipv6: Vec) -> Self { + Self { + from_lookup: false, + ipv4, + ipv6, + } + } + + pub(crate) fn source(&self) -> &'static str { + match self.from_lookup { + true => "lookup", + false => "static", + } } pub(crate) fn is_empty(&self) -> bool { @@ -96,7 +116,7 @@ impl DnsResolver { SocketAddr::V4(v4) => Either::Left(*v4.ip()), SocketAddr::V6(v6) => Either::Right(*v6.ip()), }); - match LookupResult::new(ipv4s, ipv6s) { + match LookupResult::from_lookup(ipv4s, ipv6s) { lookup_result if !lookup_result.is_empty() => Ok(lookup_result), _ => Err(Error::LookupFailed), } @@ -161,7 +181,7 @@ mod test { } fn validate_expected_order(ipv4s: Vec, ipv6s: Vec, expected: Vec) { - let lookup_result = LookupResult::new(ipv4s, ipv6s); + let lookup_result = LookupResult::new_static(ipv4s, ipv6s); let actual: Vec = lookup_result.into_iter().collect(); assert_eq!(expected, actual); } diff --git a/rust/net/src/infra/reconnect.rs b/rust/net/src/infra/reconnect.rs index 28d033d2..03b1fbe7 100644 --- a/rust/net/src/infra/reconnect.rs +++ b/rust/net/src/infra/reconnect.rs @@ -10,6 +10,7 @@ use std::time::Duration; use async_trait::async_trait; use derive_where::derive_where; +use displaydoc::Display; use tokio::sync::Mutex; use tokio::time::{timeout_at, Instant}; use tokio_util::sync::CancellationToken; @@ -231,6 +232,14 @@ pub(crate) struct ServiceWithReconnect { data: Arc>, } +#[derive(Debug, Display)] +pub(crate) enum ReconnectError { + /// Operation timed out + Timeout { attempts: u16 }, + /// All attempted routes failed to connect + AllRoutesFailed { attempts: u16 }, +} + impl ServiceWithReconnect where M: ConnectionManager + 'static, @@ -257,17 +266,6 @@ where } } - pub(crate) async fn is_connected(&self, deadline: Instant) -> bool { - let guard = match timeout_at(deadline, self.data.state.lock()).await { - Ok(guard) => guard, - Err(_) => { - log::info!("Timed out waiting for the state lock"); - return false; - } - }; - matches!(&*guard, ServiceState::Active(_, status) if !status.is_stopped()) - } - pub(crate) fn reconnect_count(&self) -> u32 { self.data.reconnect_count.load(Ordering::Relaxed) } @@ -278,13 +276,14 @@ where } } - pub(crate) async fn service_clone(&self) -> Option { + pub(crate) async fn service_clone(&self) -> Result { + let mut attempts: u16 = 0; let deadline = Instant::now() + self.data.connection_timeout; let mut guard = match timeout_at(deadline, self.data.state.lock()).await { Ok(guard) => guard, Err(_) => { log::info!("Timed out waiting for the state lock"); - return None; + return Err(ReconnectError::Timeout { attempts }); } }; loop { @@ -294,7 +293,7 @@ where // if the state is `Active` and service has not been stopped, // clone the service and return it log::debug!("reusing active service instance"); - return Some(service.clone()); + return Ok(service.clone()); } if let Some(error) = service_status.get_error() { log::debug!("Service stopped due to an error: {:?}", error); @@ -305,7 +304,7 @@ where // checking if the `next_attempt_time` is still in the future if next_attempt_time > &deadline { log::debug!("All possible routes are in cooldown state"); - return None; + return Err(ReconnectError::AllRoutesFailed { attempts }); } // it's safe to sleep without a `timeout` // because we just checked that we'll wake before the deadline @@ -315,16 +314,17 @@ where // keep trying until we hit our own timeout deadline log::info!("Connection attempt timed out"); if Instant::now() >= deadline { - return None; + return Err(ReconnectError::Timeout { attempts }); } } ServiceState::Error(e) => { - // short circuiting mechanism is responsibility of the `ConnectionManager`, + // short-circuiting mechanism is responsibility of the `ConnectionManager`, // so here we're just going to keep trying until we get into // one of the non-retryable states, `Cooldown` or time out. log::info!("Connection attempt resulted in an error: {}", e); } }; + attempts += 1; *guard = match timeout_at(deadline, self.data.service_initializer.connect()).await { Ok(result) => { self.data.reconnect_count.fetch_add(1, Ordering::Relaxed); @@ -354,7 +354,7 @@ mod test { SingleRouteThrottlingConnectionManager, MAX_COOLDOWN_INTERVAL, }; use crate::infra::reconnect::{ - ServiceConnector, ServiceState, ServiceStatus, ServiceWithReconnect, + ReconnectError, ServiceConnector, ServiceState, ServiceStatus, ServiceWithReconnect, }; use crate::infra::test::shared::{ TestError, LONG_CONNECTION_TIME, NORMAL_CONNECTION_TIME, TIMEOUT_DURATION, @@ -442,6 +442,7 @@ mod test { fn example_connection_params() -> ConnectionParams { ConnectionParams::new( + "test", "chat.signal.org", "chat.signal.org", 443, @@ -510,8 +511,25 @@ mod test { let service_with_reconnect = ServiceWithReconnect::new(connector.clone(), manager, TIMEOUT_DURATION); let service = service_with_reconnect.service_clone().await; - assert!(service.is_none()); - assert!(connector.attempts_made() > 1); + + // Here we have 3 attempts made by the reconnect service: + // - first attempt went to the connector and resulted in expected error + // - after the first attempt, the configured cooldown is 0, so the second attempt + // also went to the connector and resulted in expected error + // - after two consecutive unsuccessful attempts, the configured cooldown is 1 second, + // so the third attempt was made by the reconnect service but didn't reach the connector + // and immediately resulted in a Cooldown result + // - 1 second is longder than our test TIMEOUT_DURATION, so no more attempts were made + // Based on that, connector only saw 2 attempts, but ServiceWithReconnect had time + // to perform 3 attempts. + // Note that if the values in `COOLDOWN_INTERVALS` constant change, the number of attempts + // may also change + assert_eq!(connector.attempts_made(), 2); + assert_matches!( + service, + Err(ReconnectError::AllRoutesFailed { attempts: 3 }) + ); + assert_matches!( *service_with_reconnect.data.state.lock().await, ServiceState::Cooldown(_) @@ -519,10 +537,10 @@ mod test { let now_or_never_service_option = service_with_reconnect.service_clone().now_or_never(); // the future should be completed immediately - // but the result of the future should be `None` because we're in cooldown + // but the result of the future should be `Err()` because we're in cooldown assert!(now_or_never_service_option .expect("completed future") - .is_none()); + .is_err()); } #[tokio::test] @@ -538,7 +556,7 @@ mod test { service.expect("service is present").close_channel(); let service = service_with_reconnect.service_clone().await; assert_eq!(connector.attempts_made(), 2); - assert!(service.is_some()); + assert_matches!(service, Ok(_)); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -558,8 +576,8 @@ mod test { let handle2 = tokio::spawn(async move { aaa2.service_clone().await }); let (s1, s2) = tokio::join!(handle1, handle2); - assert!(s1.expect("future completed successfully").is_some()); - assert!(s2.expect("future completed successfully").is_some()); + assert!(s1.expect("future completed successfully").is_ok()); + assert!(s2.expect("future completed successfully").is_ok()); assert_eq!(connector.attempts_made(), 1); } @@ -581,7 +599,7 @@ mod test { let res = service_with_reconnect.service_clone().await; // now the time should've auto-advanced from `start` by the `connection_timeout` value - assert!(res.is_none()); + assert!(res.is_err()); assert_eq!(Instant::now(), start + service_with_reconnect_timeout); } @@ -604,7 +622,7 @@ mod test { let res = service_with_reconnect.service_clone().await; // now the time should've auto-advanced from `start` by the `connection_timeout` value - assert!(res.is_none()); + assert_matches!(res, Err(ReconnectError::Timeout { attempts: 1 })); assert_eq!(Instant::now(), start + service_with_reconnect_timeout); } @@ -621,7 +639,12 @@ mod test { time::advance(TIME_ADVANCE_VALUE).await; connector.set_service_healthy(false); let service = service_with_reconnect.service_clone().await; - assert!(service.is_none()); + + // number of attempts is the same as in the `immediately_fail_if_in_cooldown()` test + assert_matches!( + service, + Err(ReconnectError::AllRoutesFailed { attempts: 3 }) + ); // At this point, `service_with_reconnect` tried multiple times to connect // and hit the cooldown. Let's advance time to make sure next attempt will be made. @@ -629,7 +652,7 @@ mod test { connector.set_service_healthy(true); let service = service_with_reconnect.service_clone().await; - assert!(service.is_some()); + assert_matches!(service, Ok(_)); } #[tokio::test(flavor = "current_thread", start_paused = true)] @@ -645,7 +668,7 @@ mod test { time::advance(TIME_ADVANCE_VALUE).await; connector.set_time_to_connect(LONG_CONNECTION_TIME); let service = service_with_reconnect.service_clone().await; - assert!(service.is_none()); + assert_matches!(service, Err(ReconnectError::Timeout { attempts: 1 })); // At this point, `service_with_reconnect` tried multiple times to connect // and hit the cooldown. Let's advance time to make sure next attempt will be made. @@ -653,6 +676,6 @@ mod test { connector.set_time_to_connect(NORMAL_CONNECTION_TIME); let service = service_with_reconnect.service_clone().await; - assert!(service.is_some()); + assert_matches!(service, Ok(_)); } } diff --git a/rust/net/src/infra/ws.rs b/rust/net/src/infra/ws.rs index cc114ca7..a40bb64c 100644 --- a/rust/net/src/infra/ws.rs +++ b/rust/net/src/infra/ws.rs @@ -26,7 +26,9 @@ use tungstenite::{http, Message}; use crate::infra::errors::LogSafeDisplay; use crate::infra::reconnect::{ServiceConnector, ServiceStatus}; use crate::infra::ws::error::{HttpFormatError, ProtocolError, SpaceError}; -use crate::infra::{AsyncDuplexStream, ConnectionParams, StreamAndHost, TransportConnector}; +use crate::infra::{ + AsyncDuplexStream, ConnectionInfo, ConnectionParams, StreamAndInfo, TransportConnector, +}; use crate::utils::timeout; pub mod error; @@ -121,7 +123,7 @@ where WebSocketServiceError: Into, { type Service = WebSocketClient; - type Channel = (WebSocketStream, url::Host); + type Channel = (WebSocketStream, ConnectionInfo); type ConnectError = WebSocketConnectError; type StartError = E; @@ -159,7 +161,7 @@ where fn start_ws_service( channel: WebSocketStream, - remote_address: url::Host, + connection_info: ConnectionInfo, keep_alive_interval: Duration, max_idle_time: Duration, ) -> (WebSocketClient, ServiceStatus) { @@ -182,7 +184,7 @@ fn start_ws_service( WebSocketClient { ws_client_writer, ws_client_reader, - remote_address, + connection_info, }, service_status, ) @@ -312,8 +314,8 @@ async fn connect_websocket( endpoint: PathAndQuery, ws_config: tungstenite::protocol::WebSocketConfig, transport_connector: &T, -) -> Result<(WebSocketStream, url::Host), WebSocketConnectError> { - let StreamAndHost(ssl_stream, remote_address) = transport_connector +) -> Result<(WebSocketStream, ConnectionInfo), WebSocketConnectError> { + let StreamAndInfo(ssl_stream, remote_address) = transport_connector .connect(connection_params, WS_ALPN) .await?; @@ -385,7 +387,7 @@ impl From for Message { pub struct WebSocketClient { pub(crate) ws_client_writer: WebSocketClientWriter, pub(crate) ws_client_reader: WebSocketClientReader, - pub(crate) remote_address: url::Host, + pub(crate) connection_info: ConnectionInfo, } impl WebSocketClient @@ -393,11 +395,11 @@ where WebSocketServiceError: Into, { #[cfg(test)] - pub(crate) fn new_fake(channel: WebSocketStream, remote_address: url::Host) -> Self { + pub(crate) fn new_fake(channel: WebSocketStream, connection_info: ConnectionInfo) -> Self { const VERY_LARGE_TIMEOUT: Duration = Duration::from_secs(u32::MAX as u64); let (client, _service_status) = start_ws_service( channel, - remote_address, + connection_info, VERY_LARGE_TIMEOUT, VERY_LARGE_TIMEOUT, ); @@ -462,7 +464,7 @@ pub struct AttestedConnection { impl AttestedConnection { pub(crate) fn remote_address(&self) -> &url::Host { - &self.websocket.remote_address + &self.websocket.connection_info.address } } @@ -626,12 +628,20 @@ pub(crate) mod testutil { (server_stream, client_stream) } + pub(crate) fn mock_connection_info() -> ConnectionInfo { + ConnectionInfo { + route_type: "test", + dns_source: "test", + address: url::Host::Domain("localhost".to_string()), + } + } + pub(crate) fn websocket_test_client( channel: WebSocketStream, ) -> WebSocketClient { start_ws_service( channel, - url::Host::Domain("localhost".to_string()), + mock_connection_info(), WS_KEEP_ALIVE_INTERVAL, WS_MAX_IDLE_TIME, )