From 55ac7166e046f8a14d82308b76ea4607fd7abba5 Mon Sep 17 00:00:00 2001 From: Sergey Skrobotov Date: Wed, 7 Aug 2024 16:38:45 -0700 Subject: [PATCH] net: dropping auto-reconnect logic --- node/Native.d.ts | 3 +- node/ts/net.ts | 55 +- node/ts/test/NetTest.ts | 12 +- rust/bridge/shared/src/net/chat.rs | 27 +- rust/bridge/shared/types/src/ffi/chat.rs | 4 +- rust/bridge/shared/types/src/net/chat.rs | 62 ++- rust/bridge/shared/types/src/node/chat.rs | 4 +- rust/net/examples/chat_smoke_test.rs | 6 +- rust/net/src/chat.rs | 11 +- rust/net/src/chat/chat_reconnect.rs | 4 +- rust/net/src/chat/error.rs | 1 - rust/net/src/chat/server_requests.rs | 9 +- rust/net/src/chat/ws.rs | 100 ++-- rust/net/src/infra.rs | 11 +- rust/net/src/infra/reconnect.rs | 493 ++---------------- rust/net/src/infra/ws.rs | 41 +- .../Sources/LibSignalClient/ChatService.swift | 4 +- swift/Sources/SignalFfi/signal_ffi.h | 2 +- 18 files changed, 267 insertions(+), 582 deletions(-) diff --git a/node/Native.d.ts b/node/Native.d.ts index 64a11558..ff311caa 100644 --- a/node/Native.d.ts +++ b/node/Native.d.ts @@ -180,7 +180,8 @@ export function Cds2ClientState_New(mrenclave: Buffer, attestationMsg: Buffer, c export function CdsiLookup_complete(asyncRuntime: Wrapper, lookup: Wrapper): Promise; export function CdsiLookup_new(asyncRuntime: Wrapper, connectionManager: Wrapper, username: string, password: string, request: Wrapper): Promise; export function CdsiLookup_token(lookup: Wrapper): Buffer; -export function ChatServer_SetListener(runtime: Wrapper, chat: Wrapper, makeListener: MakeChatListener | null): void; +export function ChatService_SetListenerAuth(runtime: Wrapper, chat: Wrapper, makeListener: MakeChatListener | null): void; +export function ChatService_SetListenerUnauth(runtime: Wrapper, chat: Wrapper, makeListener: MakeChatListener | null): void; export function ChatService_auth_send(asyncRuntime: Wrapper, chat: Wrapper, httpRequest: Wrapper, timeoutMillis: number): Promise; export function ChatService_auth_send_and_debug(asyncRuntime: Wrapper, chat: Wrapper, httpRequest: Wrapper, timeoutMillis: number): Promise; export function ChatService_connect_auth(asyncRuntime: Wrapper, chat: Wrapper): Promise; diff --git a/node/ts/net.ts b/node/ts/net.ts index f2c4d820..e23bf685 100644 --- a/node/ts/net.ts +++ b/node/ts/net.ts @@ -122,7 +122,17 @@ export class ChatServerMessageAck { } } -export interface ChatServiceListener { +export interface ConnectionEventsListener { + /** + * Called when the client gets disconnected from the server. + * + * This includes both deliberate disconnects as well as unexpected socket closures that will be + * automatically retried. + */ + onConnectionInterrupted(): void; +} + +export interface ChatServiceListener extends ConnectionEventsListener { /** * Called when the server delivers an incoming message to the client. * @@ -144,17 +154,6 @@ export interface ChatServiceListener { * were in the queue *when the connection was established* have been delivered. */ onQueueEmpty(): void; - - /** - * Called when the client gets disconnected from the server. - * - * This includes both deliberate disconnects as well as unexpected socket closures that will be - * automatically retried. - * - * Will not be called if no other requests have been invoked for this connection attempt. That is, - * you should never see this as the first callback, nor two of these callbacks in a row. - */ - onConnectionInterrupted(): void; } /** @@ -256,7 +255,7 @@ export class AuthenticatedChatService implements ChatService { listener.onConnectionInterrupted(); }, }; - Native.ChatServer_SetListener( + Native.ChatService_SetListenerAuth( asyncContext, this.chatService, nativeChatListener @@ -315,11 +314,32 @@ export class UnauthenticatedChatService implements ChatService { constructor( private readonly asyncContext: TokioAsyncContext, - connectionManager: ConnectionManager + connectionManager: ConnectionManager, + listener: ConnectionEventsListener ) { this.chatService = newNativeHandle( Native.ChatService_new(connectionManager, '', '', false) ); + const nativeChatListener = { + _incoming_message( + _envelope: Buffer, + _timestamp: number, + _ack: ServerMessageAck + ): void { + throw new Error('Event not supported on unauthenticated connection'); + }, + _queue_empty(): void { + throw new Error('Event not supported on unauthenticated connection'); + }, + _connection_interrupted(): void { + listener.onConnectionInterrupted(); + }, + }; + Native.ChatService_SetListenerUnauth( + asyncContext, + this.chatService, + nativeChatListener + ); } disconnect(): Promise { @@ -426,10 +446,13 @@ export class Net { /** * Creates a new instance of {@link UnauthenticatedChatService}. */ - public newUnauthenticatedChatService(): UnauthenticatedChatService { + public newUnauthenticatedChatService( + listener: ConnectionEventsListener + ): UnauthenticatedChatService { return new UnauthenticatedChatService( this.asyncContext, - this.connectionManager + this.connectionManager, + listener ); } diff --git a/node/ts/test/NetTest.ts b/node/ts/test/NetTest.ts index a50caa13..b5684d7c 100644 --- a/node/ts/test/NetTest.ts +++ b/node/ts/test/NetTest.ts @@ -176,9 +176,13 @@ describe('chat service api', () => { it('can connect unauthenticated', async () => { const net = new Net(Environment.Staging, userAgent); - const chatService = net.newUnauthenticatedChatService(); + const listener = { + onConnectionInterrupted: sinon.stub(), + }; + const chatService = net.newUnauthenticatedChatService(listener); await chatService.connect(); await chatService.disconnect(); + expect(listener.onConnectionInterrupted).to.have.been.calledOnce; }).timeout(10000); it('can connect through a proxy server', async () => { @@ -190,9 +194,13 @@ describe('chat service api', () => { const [host = PROXY_SERVER, port = '443'] = PROXY_SERVER.split(':', 2); net.setProxy(host, parseInt(port, 10)); - const chatService = net.newUnauthenticatedChatService(); + const listener = { + onConnectionInterrupted: sinon.stub(), + }; + const chatService = net.newUnauthenticatedChatService(listener); await chatService.connect(); await chatService.disconnect(); + expect(listener.onConnectionInterrupted).to.have.been.calledOnce; }).timeout(10000); }); diff --git a/rust/bridge/shared/src/net/chat.rs b/rust/bridge/shared/src/net/chat.rs index 739b4dda..88dadfc9 100644 --- a/rust/bridge/shared/src/net/chat.rs +++ b/rust/bridge/shared/src/net/chat.rs @@ -16,6 +16,7 @@ use libsignal_net::auth::Auth; use libsignal_net::chat::{ self, ChatServiceError, DebugInfo as ChatServiceDebugInfo, Request, Response as ChatResponse, }; +use libsignal_net::infra::ws::WebSocketServiceError; use crate::support::*; use crate::*; @@ -172,19 +173,35 @@ async fn ChatService_auth_send_and_debug( } #[bridge_fn(jni = false)] -fn ChatServer_SetListener( +fn ChatService_SetListenerAuth( runtime: &TokioAsyncContext, chat: &Chat, make_listener: Option<&dyn MakeChatListener>, ) { let Some(maker) = make_listener else { - chat.clear_listener(); + chat.clear_listener_auth(); return; }; let listener = maker.make_listener(); - chat.set_listener(listener, runtime) + chat.set_listener_auth(listener, runtime) +} + +#[bridge_fn(jni = false, ffi = false)] +fn ChatService_SetListenerUnauth( + runtime: &TokioAsyncContext, + chat: &Chat, + make_listener: Option<&dyn MakeChatListener>, +) { + let Some(maker) = make_listener else { + chat.clear_listener_unauth(); + return; + }; + + let listener = maker.make_listener(); + + chat.set_listener_unauth(listener, runtime) } #[cfg(feature = "testing-fns")] @@ -201,7 +218,9 @@ fn TESTING_ChatService_InjectRawServerRequest(chat: &Chat, bytes: &[u8]) { #[bridge_fn] fn TESTING_ChatService_InjectConnectionInterrupted(chat: &Chat) { chat.synthetic_request_tx - .blocking_send(chat::ws::ServerEvent::Stopped) + .blocking_send(chat::ws::ServerEvent::Stopped(ChatServiceError::WebSocket( + WebSocketServiceError::ChannelClosed, + ))) .expect("not closed"); } diff --git a/rust/bridge/shared/types/src/ffi/chat.rs b/rust/bridge/shared/types/src/ffi/chat.rs index b9e9ca1d..15afe59e 100644 --- a/rust/bridge/shared/types/src/ffi/chat.rs +++ b/rust/bridge/shared/types/src/ffi/chat.rs @@ -7,6 +7,7 @@ use super::*; use crate::net::chat::{ChatListener, MakeChatListener, ServerMessageAck}; +use libsignal_net::chat::ChatServiceError; use std::ffi::{c_uchar, c_void}; type ReceivedIncomingMessage = extern "C" fn( @@ -76,7 +77,8 @@ impl ChatListener for ChatListenerStruct { (self.0.received_queue_empty)(self.0.ctx) } - fn connection_interrupted(&mut self) { + // TODO: pass `_disconnect_cause` to `connection_interrupted` + fn connection_interrupted(&mut self, _disconnect_cause: ChatServiceError) { (self.0.connection_interrupted)(self.0.ctx) } } diff --git a/rust/bridge/shared/types/src/net/chat.rs b/rust/bridge/shared/types/src/net/chat.rs index e40444a5..c217acd4 100644 --- a/rust/bridge/shared/types/src/net/chat.rs +++ b/rust/bridge/shared/types/src/net/chat.rs @@ -15,7 +15,9 @@ use http::status::InvalidStatusCode; use http::uri::{InvalidUri, PathAndQuery}; use http::{HeaderMap, HeaderName, HeaderValue}; use libsignal_net::auth::Auth; -use libsignal_net::chat::{self, DebugInfo as ChatServiceDebugInfo, Response as ChatResponse}; +use libsignal_net::chat::{ + self, ChatServiceError, DebugInfo as ChatServiceDebugInfo, Response as ChatResponse, +}; use libsignal_protocol::Timestamp; use tokio::sync::{mpsc, oneshot}; @@ -57,7 +59,8 @@ pub struct Chat { Arc, Arc, >, - listener: std::sync::Mutex, + listener_auth: std::sync::Mutex, + listener_unauth: std::sync::Mutex, pub synthetic_request_tx: mpsc::Sender>, } @@ -66,9 +69,14 @@ impl RefUnwindSafe for Chat {} impl Chat { pub fn new(connection_manager: &ConnectionManager, auth: Auth, receive_stories: bool) -> Self { - let (incoming_tx, incoming_rx) = mpsc::channel(1); - let incoming_stream = chat::server_requests::stream_incoming_messages(incoming_rx); - let synthetic_request_tx = incoming_tx.clone(); + let (incoming_auth_tx, incoming_auth_rx) = mpsc::channel(1); + let incoming_stream_auth = + chat::server_requests::stream_incoming_messages(incoming_auth_rx); + let synthetic_request_tx = incoming_auth_tx.clone(); + + let (incoming_unauth_tx, incoming_unauth_rx) = mpsc::channel(1); + let incoming_stream_unauth = + chat::server_requests::stream_incoming_messages(incoming_unauth_rx); Chat { service: chat::chat_service( @@ -78,22 +86,44 @@ impl Chat { .lock() .expect("not poisoned") .clone(), - incoming_tx, + incoming_auth_tx, + incoming_unauth_tx, auth, receive_stories, ) .into_dyn(), - listener: std::sync::Mutex::new(ChatListenerState::Inactive(Box::pin(incoming_stream))), + listener_auth: std::sync::Mutex::new(ChatListenerState::Inactive(Box::pin( + incoming_stream_auth, + ))), + listener_unauth: std::sync::Mutex::new(ChatListenerState::Inactive(Box::pin( + incoming_stream_unauth, + ))), synthetic_request_tx, } } - pub fn set_listener(&self, listener: Box, runtime: &TokioAsyncContext) { + pub fn set_listener_auth(&self, listener: Box, runtime: &TokioAsyncContext) { + Chat::set_listener(listener, &self.listener_auth, runtime); + } + + pub fn set_listener_unauth( + &self, + listener: Box, + runtime: &TokioAsyncContext, + ) { + Chat::set_listener(listener, &self.listener_unauth, runtime); + } + + fn set_listener( + listener: Box, + listener_state: &std::sync::Mutex, + runtime: &TokioAsyncContext, + ) { use futures_util::future::Either; let (cancel_tx, cancel_rx) = oneshot::channel::<()>(); - let mut guard = self.listener.lock().expect("unpoisoned"); + let mut guard = listener_state.lock().expect("unpoisoned"); let request_stream_future = match std::mem::replace(&mut *guard, ChatListenerState::CurrentlyBeingMutated) { ChatListenerState::Inactive(request_stream) => { @@ -125,8 +155,12 @@ impl Chat { drop(guard); } - pub fn clear_listener(&self) { - self.listener.lock().expect("unpoisoned").cancel(); + pub fn clear_listener_auth(&self) { + self.listener_auth.lock().expect("unpoisoned").cancel(); + } + + pub fn clear_listener_unauth(&self) { + self.listener_unauth.lock().expect("unpoisoned").cancel(); } } @@ -204,7 +238,7 @@ pub trait ChatListener: Send { ack: ServerMessageAck, ); fn received_queue_empty(&mut self); - fn connection_interrupted(&mut self); + fn connection_interrupted(&mut self, disconnect_cause: ChatServiceError); } impl dyn ChatListener { @@ -223,7 +257,9 @@ impl dyn ChatListener { ServerMessageAck::new(send_ack), ), chat::server_requests::ServerMessage::QueueEmpty => self.received_queue_empty(), - chat::server_requests::ServerMessage::Stopped => self.connection_interrupted(), + chat::server_requests::ServerMessage::Stopped(error) => { + self.connection_interrupted(error) + } } } diff --git a/rust/bridge/shared/types/src/node/chat.rs b/rust/bridge/shared/types/src/node/chat.rs index 622b4b7d..24c0dd00 100644 --- a/rust/bridge/shared/types/src/node/chat.rs +++ b/rust/bridge/shared/types/src/node/chat.rs @@ -5,6 +5,7 @@ use crate::net::chat::{ChatListener, MakeChatListener, ServerMessageAck}; use crate::node::ResultTypeInfo; +use libsignal_net::chat::ChatServiceError; use libsignal_protocol::Timestamp; use neon::context::FunctionContext; use neon::event::Channel; @@ -53,7 +54,8 @@ impl ChatListener for NodeChatListener { }); } - fn connection_interrupted(&mut self) { + // TODO: pass `_disconnect_cause` to `_connection_interrupted` + fn connection_interrupted(&mut self, _disconnect_cause: ChatServiceError) { let callback_object_shared = self.callback_object.clone(); self.js_channel.send(move |mut cx| { let callback = callback_object_shared.to_inner(&mut cx); diff --git a/rust/net/examples/chat_smoke_test.rs b/rust/net/examples/chat_smoke_test.rs index 2cded1cd..da31ca75 100644 --- a/rust/net/examples/chat_smoke_test.rs +++ b/rust/net/examples/chat_smoke_test.rs @@ -115,11 +115,13 @@ async fn test_connection( &network_change_event, ); - let (incoming_tx, _incoming_rx) = mpsc::channel(1); + let (incoming_auth_tx, _incoming_rx) = mpsc::channel(1); + let (incoming_unauth_tx, _incoming_rx) = mpsc::channel(1); let chat = chat_service( &connection, transport_connector, - incoming_tx, + incoming_auth_tx, + incoming_unauth_tx, Auth { username: "".to_owned(), password: "".to_owned(), diff --git a/rust/net/src/chat.rs b/rust/net/src/chat.rs index 85db4a6f..8c984896 100644 --- a/rust/net/src/chat.rs +++ b/rust/net/src/chat.rs @@ -447,18 +447,19 @@ fn build_anonymous_chat_service( pub fn chat_service( endpoint: &EndpointConnection, transport_connector: T, - incoming_tx: tokio::sync::mpsc::Sender>, + incoming_auth_tx: tokio::sync::mpsc::Sender>, + incoming_unauth_tx: tokio::sync::mpsc::Sender>, auth: Auth, receive_stories: bool, ) -> Chat { // Cannot reuse the same connector, since they lock on `incoming_tx` internally. let unauth_ws_connector = ChatOverWebSocketServiceConnector::new( WebSocketClientConnector::new(transport_connector.clone(), endpoint.config.clone()), - incoming_tx.clone(), + incoming_unauth_tx, ); let auth_ws_connector = ChatOverWebSocketServiceConnector::new( WebSocketClientConnector::new(transport_connector, endpoint.config.clone()), - incoming_tx, + incoming_auth_tx, ); { let auth_service = build_authorized_chat_service( @@ -513,7 +514,7 @@ pub(crate) mod test { timeout: Duration, ) -> Result { match &*self.inner { - ServiceState::Active(service, status) if !status.is_stopped() => { + ServiceState::Active(service, status) if !status.is_cancelled() => { service.clone().send(msg, timeout).await } _ => Err(ChatServiceError::AllConnectionRoutesFailed { attempts: 1 }), @@ -526,7 +527,7 @@ pub(crate) mod test { async fn disconnect(&self) { if let ServiceState::Active(_, status) = &*self.inner { - status.stop_service() + status.cancel() } } } diff --git a/rust/net/src/chat/chat_reconnect.rs b/rust/net/src/chat/chat_reconnect.rs index e6147e2b..866e6f5a 100644 --- a/rust/net/src/chat/chat_reconnect.rs +++ b/rust/net/src/chat/chat_reconnect.rs @@ -33,7 +33,7 @@ where } async fn connect(&self) -> Result<(), ChatServiceError> { - Ok(self.connect_from_inactive().await?) + Ok(self.connect().await?) } async fn disconnect(&self) { @@ -103,7 +103,7 @@ where async fn connect_and_debug(&self) -> Result { let start = Instant::now(); - self.connect_from_inactive().await?; + self.connect().await?; let connection_info = self.connection_info().await?; let ip_type = IpType::from_host(&connection_info.address); diff --git a/rust/net/src/chat/error.rs b/rust/net/src/chat/error.rs index ab55e375..1946be3c 100644 --- a/rust/net/src/chat/error.rs +++ b/rust/net/src/chat/error.rs @@ -100,7 +100,6 @@ impl> From e.into(), - reconnect::ReconnectError::Inactive => Self::ServiceInactive, } } } diff --git a/rust/net/src/chat/server_requests.rs b/rust/net/src/chat/server_requests.rs index 646e8406..805681a0 100644 --- a/rust/net/src/chat/server_requests.rs +++ b/rust/net/src/chat/server_requests.rs @@ -28,7 +28,7 @@ pub enum ServerMessage { send_ack: ResponseEnvelopeSender, }, /// Not actually a message, but an event processed as part of the message stream. - Stopped, + Stopped(ChatServiceError), } impl std::fmt::Debug for ServerMessage { @@ -46,7 +46,10 @@ impl std::fmt::Debug for ServerMessage { .field("envelope", &format_args!("{} bytes", envelope.len())) .field("server_delivery_timestamp", server_delivery_timestamp) .finish(), - Self::Stopped => write!(f, "Stopped"), + Self::Stopped(error) => f + .debug_struct("ConnectionInterrupted") + .field("reason", error) + .finish(), } } } @@ -55,7 +58,7 @@ pub fn stream_incoming_messages( receiver: mpsc::Receiver>, ) -> impl Stream { ReceiverStream::new(receiver).filter_map(|request| match request { - ServerEvent::Stopped => Some(ServerMessage::Stopped), + ServerEvent::Stopped(error) => Some(ServerMessage::Stopped(error)), ServerEvent::Request { request_proto, response_sender, diff --git a/rust/net/src/chat/ws.rs b/rust/net/src/chat/ws.rs index 23e37108..b93d8bb0 100644 --- a/rust/net/src/chat/ws.rs +++ b/rust/net/src/chat/ws.rs @@ -15,12 +15,13 @@ use prost::Message; use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::time::Instant; use tokio_tungstenite::WebSocketStream; +use tokio_util::sync::CancellationToken; use crate::chat::{ ChatMessageType, ChatService, ChatServiceError, MessageProto, RemoteAddressInfo, Request, RequestProto, Response, ResponseProto, }; -use crate::infra::reconnect::{ServiceConnector, ServiceStatus}; +use crate::infra::reconnect::ServiceConnector; use crate::infra::ws::{ NextOrClose, TextOrBinary, WebSocketClient, WebSocketClientConnector, WebSocketClientReader, WebSocketClientWriter, WebSocketConnectError, WebSocketServiceError, @@ -67,7 +68,7 @@ pub enum ServerEvent { request_proto: RequestProto, response_sender: ResponseSender, }, - Stopped, + Stopped(ChatServiceError), } impl ServerEvent { @@ -173,10 +174,7 @@ impl ServiceConnector for ChatOverWebSocketServiceConnect .await } - fn start_service( - &self, - channel: Self::Channel, - ) -> (Self::Service, ServiceStatus) { + fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) { let (ws_client, service_status) = self.ws_client_connector.start_service(channel); let WebSocketClient { ws_client_writer, @@ -194,7 +192,7 @@ impl ServiceConnector for ChatOverWebSocketServiceConnect ( ChatOverWebSocket { ws_client_writer, - service_status: service_status.clone(), + service_cancellation: service_status.clone(), pending_messages, connection_info, }, @@ -208,7 +206,7 @@ async fn reader_task( ws_client_writer: WebSocketClientWriter, incoming_tx: Arc>>>, pending_messages: Arc>, - service_status: ServiceStatus, + service_cancellation: CancellationToken, ) { const LONG_REQUEST_PROCESSING_THRESHOLD: Duration = Duration::from_millis(500); @@ -219,23 +217,22 @@ async fn reader_task( let mut previous_request_paths_for_logging = VecDeque::with_capacity(incoming_tx.max_capacity()); - let mut has_ever_sent_server_event = false; - loop { + let error = loop { let data = match ws_client_reader.next().await { Ok(NextOrClose::Next(TextOrBinary::Binary(data))) => data, Ok(NextOrClose::Next(TextOrBinary::Text(_))) => { log::info!("Text frame received on chat websocket"); - service_status.stop_service_with_error(ChatServiceError::UnexpectedFrameReceived); - break; + service_cancellation.cancel(); + break ChatServiceError::UnexpectedFrameReceived; } Ok(NextOrClose::Close(_)) => { - service_status.stop_service_with_error(WebSocketServiceError::ChannelClosed.into()); - break; + service_cancellation.cancel(); + break WebSocketServiceError::ChannelClosed.into(); } Err(e) => { - service_status.stop_service_with_error(e); - break; + service_cancellation.cancel(); + break e; } }; @@ -246,8 +243,8 @@ async fn reader_task( let server_request = match ServerEvent::new(req, ws_client_writer.clone()) { Ok(server_request) => server_request, Err(e) => { - service_status.stop_service_with_error(e); - break; + service_cancellation.cancel(); + break e; } }; @@ -256,9 +253,8 @@ async fn reader_task( let request_send_elapsed = request_send_start.elapsed(); if delivery_result.is_err() { - service_status.stop_service_with_error( - ChatServiceError::FailedToPassMessageToIncomingChannel, - ); + service_cancellation.cancel(); + break ChatServiceError::FailedToPassMessageToIncomingChannel; } if previous_request_paths_for_logging.len() == incoming_tx.max_capacity() { @@ -282,7 +278,6 @@ async fn reader_task( } } previous_request_paths_for_logging.push_back(request_path); - has_ever_sent_server_event = true; } Ok(ChatMessage::Response(id, res)) => { let map = &mut pending_messages.lock().await; @@ -293,22 +288,16 @@ async fn reader_task( } } Err(e) => { - service_status.stop_service_with_error(e); + service_cancellation.cancel(); + break e; } } - } + }; - if has_ever_sent_server_event { - // We only need a Stopped event to separate the events from different connections. If there - // haven't been any events from this connection, we don't bother. This also avoids sending - // events for the unauth socket, even though it still has to read responses. - // (But don't worry about delivery failure here; if no one's listening, the event is - // superfluous anyway.) - _ = incoming_tx.send(ServerEvent::Stopped).await; - } + _ = incoming_tx.send(ServerEvent::Stopped(error)).await; // before terminating the task, marking channel as inactive - service_status.stop_service(); + service_cancellation.cancel(); // Clear the pending messages map. These requests don't wait on the service status just in case // a response comes in late; dropping the response senders is how we cancel them. @@ -319,7 +308,7 @@ async fn reader_task( #[derive(Debug)] pub struct ChatOverWebSocket { ws_client_writer: WebSocketClientWriter, - service_status: ServiceStatus, + service_cancellation: CancellationToken, pending_messages: Arc>, connection_info: ConnectionInfo, } @@ -337,7 +326,7 @@ where { async fn send(&self, msg: Request, timeout: Duration) -> Result { // checking if channel has been closed - if self.service_status.is_stopped() { + if self.service_cancellation.is_cancelled() { return Err(WebSocketServiceError::ChannelClosed.into()); } @@ -374,7 +363,7 @@ where } async fn disconnect(&self) { - self.service_status.stop_service() + self.service_cancellation.cancel() } } @@ -488,12 +477,12 @@ mod test { let ws_config = test_ws_config(); let time_to_wait = ws_config.max_idle_time * 2; let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await; - assert!(!ws_chat.service_status().unwrap().is_stopped()); + assert!(!ws_chat.service_status().unwrap().is_cancelled()); // sleeping for a period of time long enough to stop the service // in case of missing PONG responses tokio::time::sleep(time_to_wait).await; - assert!(!ws_chat.service_status().unwrap().is_stopped()); + assert!(!ws_chat.service_status().unwrap().is_cancelled()); } #[tokio::test(flavor = "current_thread", start_paused = true)] @@ -507,12 +496,12 @@ mod test { }); let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await; - assert!(!ws_chat.service_status().unwrap().is_stopped()); + assert!(!ws_chat.service_status().unwrap().is_cancelled()); // sleeping for a period of time long enough for the service to stop, // which is what should happen since the PONG messages are not sent back tokio::time::sleep(duration).await; - assert!(ws_chat.service_status().unwrap().is_stopped()); + assert!(ws_chat.service_status().unwrap().is_cancelled()); } #[tokio::test(flavor = "current_thread", start_paused = true)] @@ -533,12 +522,12 @@ mod test { }); let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await; - assert!(!ws_chat.service_status().unwrap().is_stopped()); + assert!(!ws_chat.service_status().unwrap().is_cancelled()); // sleeping for a period of time long enough to stop the service // in case of missing PONG responses tokio::time::sleep(time_to_wait).await; - assert!(ws_chat.service_status().unwrap().is_stopped()); + assert!(ws_chat.service_status().unwrap().is_cancelled()); // making sure server logic completed in the expected way validate_server_stopped_successfully(server_res_rx).await; } @@ -562,12 +551,12 @@ mod test { let ws_config = test_ws_config(); let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await; - assert!(!ws_chat.service_status().unwrap().is_stopped()); + assert!(!ws_chat.service_status().unwrap().is_cancelled()); // sleeping for a period of time long enough to stop the service // in case of missing PONG responses tokio::time::sleep(time_to_wait).await; - assert!(ws_chat.service_status().unwrap().is_stopped()); + assert!(ws_chat.service_status().unwrap().is_cancelled()); // making sure server logic completed in the expected way validate_server_stopped_successfully(server_res_rx).await; } @@ -595,13 +584,26 @@ mod test { }); let ws_config = test_ws_config(); - let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await; + let (ws_chat, mut incoming_rx) = create_ws_chat_service(ws_config, ws_server).await; let response = ws_chat .send(test_request(Method::GET, "/"), TIMEOUT_DURATION) .await; response.expect("response"); validate_server_running(server_res_rx).await; + + // now we're disconnecting manually in which case we expect a `Stopped` event + // with `ChatServiceError::WebSocket(WebSocketServiceError::ChannelClosed)` error + let service_status = ws_chat.service_status().expect("some status"); + service_status.cancel(); + service_status.cancelled().await; + let event = incoming_rx.recv().await; + assert_matches!( + event, + Some(ServerEvent::Stopped(ChatServiceError::WebSocket( + WebSocketServiceError::ChannelClosed + ))) + ); } #[tokio::test(flavor = "current_thread", start_paused = true)] @@ -737,12 +739,12 @@ mod test { assert_matches!( incoming_rx.recv().await.expect("server request"), - ServerEvent::Stopped + ServerEvent::Stopped(_) ); } #[tokio::test(flavor = "current_thread", start_paused = true)] - async fn ws_service_skips_stop_event_without_requests() { + async fn ws_service_receives_stopped_event_if_server_disconnects() { // creating a server that accepts a request, responds, then closes the connection. let (ws_server, server_res_rx) = ws_warp_filter(move |websocket| async move { let (mut tx, mut rx) = websocket.split(); @@ -774,7 +776,9 @@ mod test { assert_matches!( incoming_rx.try_recv(), - Err(mpsc::error::TryRecvError::Disconnected) + Ok(ServerEvent::Stopped(ChatServiceError::WebSocket( + WebSocketServiceError::Protocol(_) + ))) ); } diff --git a/rust/net/src/infra.rs b/rust/net/src/infra.rs index 88b33cbc..af703191 100644 --- a/rust/net/src/infra.rs +++ b/rust/net/src/infra.rs @@ -340,13 +340,12 @@ pub(crate) mod test { use derive_where::derive_where; use displaydoc::Display; use tokio::io::DuplexStream; + use tokio_util::sync::CancellationToken; use warp::{Filter, Reply}; use crate::infra::connection_manager::{ConnectionManager, ErrorClass, ErrorClassifier}; use crate::infra::errors::{LogSafeDisplay, TransportConnectError}; - use crate::infra::reconnect::{ - ServiceConnector, ServiceInitializer, ServiceState, ServiceStatus, - }; + use crate::infra::reconnect::{ServiceConnector, ServiceInitializer, ServiceState}; use crate::infra::{ Alpn, ConnectionInfo, ConnectionParams, DnsSource, RouteType, StreamAndInfo, TransportConnector, @@ -462,7 +461,7 @@ pub(crate) mod test { #[derive_where(Clone)] pub(crate) struct NoReconnectService { - pub(crate) inner: Arc>, + pub(crate) inner: Arc>, } impl NoReconnectService @@ -484,9 +483,9 @@ pub(crate) mod test { } } - pub(crate) fn service_status(&self) -> Option<&ServiceStatus> { + pub(crate) fn service_status(&self) -> Option<&CancellationToken> { match &*self.inner { - ServiceState::Active(_, status) => Some(status), + ServiceState::Active(_, service_cancellation) => Some(service_cancellation), _ => None, } } diff --git a/rust/net/src/infra/reconnect.rs b/rust/net/src/infra/reconnect.rs index a7d95345..91bda4e2 100644 --- a/rust/net/src/infra/reconnect.rs +++ b/rust/net/src/infra/reconnect.rs @@ -5,13 +5,11 @@ use std::fmt::Debug; use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use derive_where::derive_where; use displaydoc::Display; -use static_assertions::const_assert; use tokio::sync::Mutex; use tokio::time::{timeout_at, Instant}; use tokio_util::sync::CancellationToken; @@ -22,7 +20,6 @@ use crate::infra::connection_manager::{ }; use crate::infra::errors::LogSafeDisplay; use crate::infra::{ConnectionInfo, ConnectionParams, HttpRequestDecorator}; -use crate::timeouts::CONNECTION_ROUTE_COOLDOWN_INTERVALS; // A duration where, if this is all that's left on the timeout, we're more likely to fail than not. // Useful for debouncing repeated connection attempts. @@ -31,13 +28,13 @@ const MINIMUM_CONNECTION_TIME: Duration = Duration::from_millis(500); /// For a service that needs to go through some initialization procedure /// before it's ready for use, this enum describes its possible states. #[derive(Debug)] -pub(crate) enum ServiceState { +pub(crate) enum ServiceState { /// Service was not explicitly activated. Inactive, /// Contains an instance of the service which is initialized and ready to use. /// Also, since we're not actively listening for the event of service going inactive, - /// the `ServiceStatus` could be used to see if the service is actually running. - Active(T, ServiceStatus), + /// the `CancellationToken` could be used to see if the service is actually running. + Active(T, CancellationToken), /// The service is inactive and no initialization attempts are to be made /// until the `Instant` held by this object. Cooldown(Instant), @@ -62,10 +59,7 @@ pub(crate) trait ServiceConnector: Clone { connection_params: &ConnectionParams, ) -> Result; - fn start_service( - &self, - channel: Self::Channel, - ) -> (Self::Service, ServiceStatus); + fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken); } #[async_trait] @@ -85,10 +79,7 @@ where (*self).connect_channel(connection_params).await } - fn start_service( - &self, - channel: Self::Channel, - ) -> (Self::Service, ServiceStatus) { + fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) { (*self).start_service(channel) } } @@ -125,72 +116,11 @@ where self.inner.connect_channel(&decorated).await } - fn start_service( - &self, - channel: Self::Channel, - ) -> (Self::Service, ServiceStatus) { + fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) { self.inner.start_service(channel) } } -#[derive(Debug)] -#[derive_where(Clone)] -pub(crate) struct ServiceStatus { - maybe_error: Arc>>, - service_cancellation: CancellationToken, -} - -impl Default for ServiceStatus { - fn default() -> Self { - Self { - maybe_error: Arc::new(OnceLock::new()), - service_cancellation: CancellationToken::new(), - } - } -} - -impl ServiceStatus { - pub(crate) fn stop_service(&self) { - self.maybe_error.get_or_init(|| None); - self.service_cancellation.cancel(); - } - - pub(crate) fn stop_service_with_error(&self, error: E) { - self.maybe_error.get_or_init(|| Some(error)); - self.service_cancellation.cancel(); - } - - pub(crate) fn is_stopped(&self) -> bool { - self.service_cancellation.is_cancelled() - } - - pub(crate) async fn stopped(&self) { - self.service_cancellation.cancelled().await - } - - /// Returns an error if `stop_service_with_error` was called previously. - /// - /// Note that returning `None` could mean the service is still running, or that the service was - /// stopped deliberately without an error, or that the service stopped because of an error but - /// that error was handled elsewhere. - pub(crate) fn get_error(&self) -> Option<&E> { - match self.maybe_error.get() { - None => { - // service not stopped - None - } - Some(None) => { - // service stopped without error - None - } - Some(Some(e)) => { - // service stopped with error - Some(e) - } - } - } -} - pub(crate) struct ServiceInitializer { service_connector: C, connection_manager: M, @@ -211,7 +141,7 @@ where } } - pub(crate) async fn connect(&self) -> ServiceState { + pub(crate) async fn connect(&self) -> ServiceState { log::debug!("attempting a connection"); let connection_attempt_result = self .connection_manager @@ -252,7 +182,7 @@ where pub(crate) struct ServiceWithReconnectData { reconnect_count: AtomicU32, - state: Mutex>, + state: Mutex>, service_initializer: ServiceInitializer, connection_timeout: Duration, } @@ -270,8 +200,6 @@ pub(crate) enum ReconnectError { AllRoutesFailed { attempts: u16 }, /// Rejected by server: {0} RejectedByServer(E), - /// Service is in the inactive state - Inactive, } impl ErrorClassifier for ReconnectError { @@ -280,7 +208,7 @@ impl ErrorClassifier for ReconnectError { ReconnectError::Timeout { .. } | ReconnectError::AllRoutesFailed { .. } => { ErrorClass::Intermittent } - ReconnectError::RejectedByServer(_) | ReconnectError::Inactive => ErrorClass::Fatal, + ReconnectError::RejectedByServer(_) => ErrorClass::Fatal, } } } @@ -300,7 +228,7 @@ where async fn map_service(&self, mapper: fn(&C::Service) -> T) -> Result { let guard = self.data.state.lock().await; match &*guard { - ServiceState::Active(service, status) if !status.is_stopped() => Ok(mapper(service)), + ServiceState::Active(service, status) if !status.is_cancelled() => Ok(mapper(service)), ServiceState::Inactive => Err(StateError::Inactive), ServiceState::Cooldown(_) | ServiceState::ConnectionTimedOut @@ -348,20 +276,11 @@ where self.data.reconnect_count.load(Ordering::Relaxed) } - pub(crate) async fn reconnect_if_active(&self) -> Result<(), ReconnectError> { - self.connect(false).await + pub(crate) async fn connect(&self) -> Result<(), ReconnectError> { + self.do_connect().await } - pub(crate) async fn connect_from_inactive( - &self, - ) -> Result<(), ReconnectError> { - self.connect(true).await - } - - async fn connect( - &self, - is_explicit_connect: bool, - ) -> Result<(), ReconnectError> { + async fn do_connect(&self) -> Result<(), ReconnectError> { let mut attempts: u16 = 0; let start_of_connection_process = Instant::now(); let deadline = start_of_connection_process + self.data.connection_timeout; @@ -379,13 +298,10 @@ where loop { match &*guard { ServiceState::Inactive => { - if !is_explicit_connect { - return Err(ReconnectError::Inactive); - } - // otherwise, proceeding to connect + // proceeding to connect } ServiceState::Active(_, service_status) => { - if !service_status.is_stopped() { + if !service_status.is_cancelled() { // if the state is `Active` and service has not been stopped, // clone the service and return it log::debug!("reusing active service instance"); @@ -437,13 +353,6 @@ where continue; } ErrorClass::Fatal => { - if !is_explicit_connect { - // Only explicit connection requests have a place to report this - // error, so for now, non-explicit attempts treat this as a generic - // failure. - return Err(ReconnectError::AllRoutesFailed { attempts }); - } - let state = std::mem::replace(&mut *guard, ServiceState::Inactive); let ServiceState::Error(e) = state else { unreachable!("we checked this above, matching on &*guard"); @@ -464,21 +373,16 @@ where } attempts += 1; - *guard = match timeout_at(deadline, self.data.service_initializer.connect()).await { - Ok(ServiceState::Active(service, service_state)) => { - self.schedule_reconnect(service_state.clone()); - ServiceState::Active(service, service_state) - } - Ok(result) => result, - Err(_) => ServiceState::ConnectionTimedOut, - } + *guard = timeout_at(deadline, self.data.service_initializer.connect()) + .await + .unwrap_or(ServiceState::ConnectionTimedOut); } } pub(crate) async fn disconnect(&self) { let mut guard = self.data.state.lock().await; if let ServiceState::Active(_, service_status) = &*guard { - service_status.stop_service(); + service_status.cancel(); } *guard = ServiceState::Inactive; log::info!("service disconnected"); @@ -487,67 +391,6 @@ where pub(crate) async fn service(&self) -> Result { self.map_service(|service| service.clone()).await } - - fn schedule_reconnect(&self, service_status: ServiceStatus) { - let service_with_reconnect = self.clone(); - tokio::spawn(async move { - let _ = service_status.service_cancellation.cancelled().await; - if let Some(error) = service_status.get_error() { - log::debug!("Service stopped due to an error: {:?}", error); - log::info!("Service stopped due to an error: {}", error); - } else { - log::info!("Service stopped"); - } - // This is a background thread so there is no overall timeout on reconnect. - // Each attempt is limited by the `data.connection_timeout` duration - // but unless we're in one of the non-proceeding states, we'll be trying to - // connect. - let mut sleep_until = Instant::now(); - loop { - if sleep_until > Instant::now() { - tokio::time::sleep_until(sleep_until).await; - } - log::debug!("attempting reconnect"); - match service_with_reconnect.reconnect_if_active().await { - Ok(_) => { - log::info!("reconnect attempt succeeded"); - service_with_reconnect - .data - .reconnect_count - .fetch_add(1, Ordering::Relaxed); - return; - } - Err(ReconnectError::Inactive) => { - return; - } - Err(error) => { - log::warn!("reconnect attempt failed: {}", error); - let guard = service_with_reconnect.data.state.lock().await; - match &*guard { - ServiceState::Cooldown(next_attempt_time) => { - sleep_until = *next_attempt_time; - } - ServiceState::ConnectionTimedOut | ServiceState::Error(_) => { - // Keep trying, but throttle a little in case of early exits - // TODO: In practice, this should only happen when there's a Fatal - // error, in which case retrying won't help and we should really - // shut down the reconnect loop. But part of doing that would be - // making sure the error is reported. Even when that's implemented, - // though, this is a good backstop against bugs. - const_assert!(CONNECTION_ROUTE_COOLDOWN_INTERVALS[0].is_zero()); - sleep_until += CONNECTION_ROUTE_COOLDOWN_INTERVALS[1]; - } - ServiceState::Inactive | ServiceState::Active(_, _) => { - // most likely, `disconnect()` was called and we - // switched to the `ServiceState::Inactive` state - return; - } - } - } - } - } - }); - } } #[cfg(test)] @@ -557,7 +400,7 @@ mod test { use std::sync::{Arc, Mutex}; use std::time::Duration; - use crate::timeouts::{CONNECTION_ROUTE_COOLDOWN_INTERVALS, CONNECTION_ROUTE_MAX_COOLDOWN}; + use crate::timeouts::CONNECTION_ROUTE_MAX_COOLDOWN; use assert_matches::assert_matches; use async_trait::async_trait; use futures_util::FutureExt; @@ -576,19 +419,7 @@ mod test { use crate::utils::{sleep_and_catch_up, ObservableEvent}; #[derive(Clone, Debug)] - struct TestService { - service_status: ServiceStatus, - } - - impl TestService { - fn new(service_status: ServiceStatus) -> Self { - Self { service_status } - } - - fn close_channel(&self) { - self.service_status.stop_service(); - } - } + struct TestService; #[derive(Clone)] struct TestServiceConnector { @@ -642,13 +473,10 @@ mod test { } } - fn start_service( - &self, - _channel: Self::Channel, - ) -> (Self::Service, ServiceStatus) { - let service_status_arc = ServiceStatus::default(); - let service = TestService::new(service_status_arc.clone()); - (service, service_status_arc) + fn start_service(&self, _channel: Self::Channel) -> (Self::Service, CancellationToken) { + let service_cancellation = CancellationToken::new(); + let service = TestService; + (service, service_cancellation) } } @@ -678,45 +506,11 @@ mod test { #[tokio::test] async fn service_started_with_request() { let (connector, service_with_reconnect) = connector_and_service(); - service_with_reconnect - .connect_from_inactive() - .await - .expect(""); + service_with_reconnect.connect().await.expect(""); let _service = service_with_reconnect.service().await; assert_eq!(connector.attempts_made(), 1); } - #[tokio::test(start_paused = true)] - async fn service_tries_to_reconnect_if_connection_lost() { - let (connector, service_with_reconnect) = connector_and_service(); - service_with_reconnect - .connect_from_inactive() - .await - .expect("connected"); - - let service = service_with_reconnect.service().await.expect("available"); - - // `close_channel()` call emulates lost connection and reconnection will be triggered - // unless service_with_reconnect is in the `Inactive` state - service.close_channel(); - - // giving time to reconnect - sleep_and_catch_up(NORMAL_CONNECTION_TIME).await; - - let service = service_with_reconnect.service().await.expect("available"); - - // we're doing it again, but this time we'll instruct service connector to fail, - // and as a result, service won't be available - service.close_channel(); - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - time::advance(TIME_ADVANCE_VALUE).await; - - assert_matches!( - service_with_reconnect.service().await, - Err(StateError::ServiceUnavailable) - ); - } - #[tokio::test] async fn service_is_inactive_before_connected() { let (_, service_with_reconnect) = connector_and_service(); @@ -729,10 +523,7 @@ mod test { #[tokio::test(start_paused = true)] async fn service_doesnt_reconnect_if_disconnected() { let (_, service_with_reconnect) = connector_and_service(); - service_with_reconnect - .connect_from_inactive() - .await - .expect("connected"); + service_with_reconnect.connect().await.expect("connected"); // making sure service is available let _ = service_with_reconnect.service().await.expect("available"); @@ -754,7 +545,7 @@ mod test { let (connector, service_with_reconnect) = connector_and_service(); connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - let connection_result = service_with_reconnect.connect_from_inactive().await; + let connection_result = service_with_reconnect.connect().await; // Here we have 3 attempts made by the reconnect service: // - first attempt went to the connector and resulted in expected error @@ -792,10 +583,10 @@ mod test { let (connector, service_with_reconnect) = connector_and_service(); let aaa1 = service_with_reconnect.clone(); - let handle1 = tokio::spawn(async move { aaa1.connect_from_inactive().await }); + let handle1 = tokio::spawn(async move { aaa1.connect().await }); let aaa2 = service_with_reconnect.clone(); - let handle2 = tokio::spawn(async move { aaa2.connect_from_inactive().await }); + let handle2 = tokio::spawn(async move { aaa2.connect().await }); let (s1, s2) = tokio::join!(handle1, handle2); assert!(s1.expect("future completed successfully").is_ok()); @@ -819,7 +610,7 @@ mod test { ); let service_with_reconnect = ServiceWithReconnect::new(connector.clone(), manager, service_with_reconnect_timeout); - let res = service_with_reconnect.connect_from_inactive().await; + let res = service_with_reconnect.connect().await; // now the time should've auto-advanced from `start` by the `connection_timeout` value assert!(res.is_err()); @@ -843,7 +634,7 @@ mod test { ); let service_with_reconnect = ServiceWithReconnect::new(connector.clone(), manager, service_with_reconnect_timeout); - let res = service_with_reconnect.connect_from_inactive().await; + let res = service_with_reconnect.connect().await; // now the time should've auto-advanced from `start` by the `connection_timeout` value assert_matches!(res, Err(ReconnectError::Timeout { attempts: 1 })); @@ -856,7 +647,7 @@ mod test { time::advance(TIME_ADVANCE_VALUE).await; connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - let connection_result = service_with_reconnect.connect_from_inactive().await; + let connection_result = service_with_reconnect.connect().await; // number of attempts is the same as in the `immediately_fail_if_in_cooldown()` test assert_matches!( @@ -869,7 +660,7 @@ mod test { time::advance(CONNECTION_ROUTE_MAX_COOLDOWN).await; connector.set_connection_error(None); - let connection_result = service_with_reconnect.connect_from_inactive().await; + let connection_result = service_with_reconnect.connect().await; assert_matches!(connection_result, Ok(_)); } @@ -878,7 +669,7 @@ mod test { let (connector, service_with_reconnect) = connector_and_service(); time::advance(TIME_ADVANCE_VALUE).await; connector.set_time_to_connect(LONG_CONNECTION_TIME); - let connection_result = service_with_reconnect.connect_from_inactive().await; + let connection_result = service_with_reconnect.connect().await; assert_matches!( connection_result, Err(ReconnectError::Timeout { attempts: 1 }) @@ -889,125 +680,17 @@ mod test { time::advance(CONNECTION_ROUTE_MAX_COOLDOWN).await; connector.set_time_to_connect(NORMAL_CONNECTION_TIME); - let connection_result = service_with_reconnect.connect_from_inactive().await; + let connection_result = service_with_reconnect.connect().await; assert_matches!(connection_result, Ok(_)); } - #[tokio::test(start_paused = true)] - async fn service_keep_reconnecting_attempts_if_first_fails() { - let (connector, service_with_reconnect) = connector_and_service(); - service_with_reconnect - .connect_from_inactive() - .await - .expect("connected"); - let service = service_with_reconnect.service().await.expect("service"); - - // at this point, one successfull connection attempt - assert_eq!(connector.attempts.load(Ordering::Relaxed), 1); - - // internet connection lost - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - service.close_channel(); - - sleep_and_catch_up(NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 2); - - sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[0] + NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 3); - assert_matches!(service_with_reconnect.service().await, Err(_)); - - sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[1] + NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 4); - assert_matches!(service_with_reconnect.service().await, Err(_)); - - // now internet connection is back - // letting next cooldown interval pass and checking again - connector.set_connection_error(None); - - sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[2] + NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 5); - assert_matches!(service_with_reconnect.service().await, Ok(_)); - } - - #[tokio::test(start_paused = true)] - async fn service_stops_reconnect_attempts_if_disconnected_after_some_time() { - let (connector, service_with_reconnect) = connector_and_service(); - service_with_reconnect - .connect_from_inactive() - .await - .expect("connected"); - let service = service_with_reconnect.service().await.expect("service"); - - // at this point, one successfull connection attempt - assert_eq!(connector.attempts.load(Ordering::Relaxed), 1); - - // internet connection lost - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - service.close_channel(); - - sleep_and_catch_up(NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 2); - - sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[0] + NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 3); - assert_matches!(service_with_reconnect.service().await, Err(_)); - - sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[1] + NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 4); - assert_matches!(service_with_reconnect.service().await, Err(_)); - - // now we decide to disconnect, and we need to make sure we're not making - // any more attempts - service_with_reconnect.disconnect().await; - for interval in CONNECTION_ROUTE_COOLDOWN_INTERVALS.into_iter().skip(2) { - sleep_and_catch_up(interval + NORMAL_CONNECTION_TIME).await; - assert_eq!(connector.attempts.load(Ordering::Relaxed), 4); - assert_matches!( - service_with_reconnect.service().await, - Err(StateError::Inactive) - ); - } - } - - #[tokio::test(start_paused = true)] - async fn reconnect_count_behaves_correctly() { - let (_, service_with_reconnect) = connector_and_service(); - service_with_reconnect - .connect_from_inactive() - .await - .expect("connected"); - let service = service_with_reconnect.service().await.expect("service"); - - // manual connection should not count as "reconnect" - assert_eq!(0, service_with_reconnect.reconnect_count()); - - // emulating unexpected disconnect - service.close_channel(); - // giving time to reconnect - sleep_and_catch_up(NORMAL_CONNECTION_TIME).await; - - // reconnect count should increase by 1 - assert_eq!(1, service_with_reconnect.reconnect_count()); - - // now, manually disconnecting and connecting again - service_with_reconnect.disconnect().await; - service_with_reconnect - .connect_from_inactive() - .await - .expect("connected"); - - // reconnect count should not change - assert_eq!(1, service_with_reconnect.reconnect_count()); - } - #[tokio::test(start_paused = true)] async fn service_times_out_early_on_guard_contention() { let (connector, service_with_reconnect) = connector_and_service(); let guard = service_with_reconnect.data.state.lock().await; let service_for_task = service_with_reconnect.clone(); - let connection_task = - tokio::spawn(async move { service_for_task.connect_from_inactive().await }); + let connection_task = tokio::spawn(async move { service_for_task.connect().await }); sleep_and_catch_up(TIMEOUT_DURATION - MINIMUM_CONNECTION_TIME).await; drop(guard); @@ -1020,106 +703,6 @@ mod test { assert_eq!(connector.attempts_made(), 0); } - #[tokio::test(start_paused = true)] - async fn intermittent_errors_do_not_get_propagated() { - let (connector, service_with_reconnect) = connector_and_service(); - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - - // A "reconnect" won't even make it to the connection error yet. - let inactive_error = service_with_reconnect - .reconnect_if_active() - .await - .expect_err("not active yet"); - assert_matches!(inactive_error, ReconnectError::Inactive); - - // A proper connect will. - let fatal_error = service_with_reconnect - .connect_from_inactive() - .await - .expect_err("should have returned the connection error"); - assert_matches!(fatal_error, ReconnectError::AllRoutesFailed { .. }); - assert_matches!( - *service_with_reconnect.data.state.lock().await, - ServiceState::Cooldown(_) - ); - - // Okay, let's connect properly... - time::advance(CONNECTION_ROUTE_MAX_COOLDOWN).await; - connector.set_connection_error(None); - service_with_reconnect - .connect_from_inactive() - .await - .expect("success"); - let service = service_with_reconnect.service().await.expect("service"); - - // ...then disconnect... - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent))); - service.close_channel(); - - // ...then try to auto-reconnect. - let reconnect_error = service_with_reconnect - .reconnect_if_active() - .await - .expect_err("not active yet"); - assert_matches!(reconnect_error, ReconnectError::AllRoutesFailed { .. }); - assert_matches!( - *service_with_reconnect.data.state.lock().await, - ServiceState::Cooldown(_) - ); - } - - #[tokio::test(start_paused = true)] - async fn fatal_error_gets_propagated_on_explicit_connect_only() { - let (connector, service_with_reconnect) = connector_and_service(); - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Fatal))); - - // A "reconnect" won't even make it to the connection error yet. - let inactive_error = service_with_reconnect - .reconnect_if_active() - .await - .expect_err("not active yet"); - assert_matches!(inactive_error, ReconnectError::Inactive); - - // A proper connect will. - let fatal_error = service_with_reconnect - .connect_from_inactive() - .await - .expect_err("should have returned the connection error"); - assert_matches!( - fatal_error, - ReconnectError::RejectedByServer(ClassifiableTestError(ErrorClass::Fatal)) - ); - - // And it will leave us inactive. - assert_matches!( - *service_with_reconnect.data.state.lock().await, - ServiceState::Inactive - ); - - // Okay, let's connect properly... - connector.set_connection_error(None); - service_with_reconnect - .connect_from_inactive() - .await - .expect("success"); - let service = service_with_reconnect.service().await.expect("service"); - - // ...then disconnect... - connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Fatal))); - service.close_channel(); - - // ...then try to auto-reconnect. - let reconnect_error = service_with_reconnect - .reconnect_if_active() - .await - .expect_err("not active yet"); - assert_matches!(reconnect_error, ReconnectError::AllRoutesFailed { .. }); - assert_matches!( - *service_with_reconnect.data.state.lock().await, - ServiceState::Error(ClassifiableTestError(ErrorClass::Fatal)) - ); - } - fn connector_and_service() -> ( TestServiceConnector, ServiceWithReconnect, diff --git a/rust/net/src/infra/ws.rs b/rust/net/src/infra/ws.rs index 64be0253..e61f227c 100644 --- a/rust/net/src/infra/ws.rs +++ b/rust/net/src/infra/ws.rs @@ -19,12 +19,13 @@ use http::uri::PathAndQuery; use tokio::sync::Mutex; use tokio::time::Instant; use tokio_tungstenite::WebSocketStream; +use tokio_util::sync::CancellationToken; use tungstenite::handshake::client::generate_key; use tungstenite::protocol::CloseFrame; use tungstenite::{http, Message}; use crate::infra::errors::LogSafeDisplay; -use crate::infra::reconnect::{ServiceConnector, ServiceStatus}; +use crate::infra::reconnect::ServiceConnector; use crate::infra::ws::error::{HttpFormatError, ProtocolError, SpaceError}; use crate::infra::{ Alpn, AsyncDuplexStream, ConnectionInfo, ConnectionParams, StreamAndInfo, TransportConnector, @@ -145,10 +146,7 @@ where .map_err(Into::into) } - fn start_service( - &self, - channel: Self::Channel, - ) -> (Self::Service, ServiceStatus) { + fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) { start_ws_service( channel.0, channel.1, @@ -163,19 +161,20 @@ fn start_ws_service( connection_info: ConnectionInfo, keep_alive_interval: Duration, max_idle_time: Duration, -) -> (WebSocketClient, ServiceStatus) { - let service_status = ServiceStatus::default(); +) -> (WebSocketClient, CancellationToken) { + let service_cancellation = CancellationToken::new(); let (ws_sink, ws_stream) = channel.split(); let ws_client_writer = WebSocketClientWriter { ws_sink: Arc::new(Mutex::new(ws_sink)), - service_status: service_status.clone(), + service_cancellation: service_cancellation.clone(), + error_type: Default::default(), }; let ws_client_reader = WebSocketClientReader { ws_stream, keep_alive_interval, max_idle_time, ws_writer: ws_client_writer.clone(), - service_status: service_status.clone(), + service_cancellation: service_cancellation.clone(), last_frame_received: Instant::now(), last_keepalive_sent: Instant::now(), }; @@ -185,7 +184,7 @@ fn start_ws_service( ws_client_reader, connection_info, }, - service_status, + service_cancellation, ) } @@ -193,7 +192,8 @@ fn start_ws_service( #[derive(Debug)] pub(crate) struct WebSocketClientWriter { ws_sink: Arc, Message>>>, - service_status: ServiceStatus, + service_cancellation: CancellationToken, + error_type: PhantomData, } impl WebSocketClientWriter @@ -201,7 +201,7 @@ where WebSocketServiceError: Into, { pub async fn send(&self, message: impl Into) -> Result<(), E> { - run_and_update_status(&self.service_status, || { + run_and_update_status(&self.service_cancellation, || { async { let mut guard = self.ws_sink.lock().await; guard.send(message.into()).await?; @@ -218,7 +218,7 @@ where pub(crate) struct WebSocketClientReader { ws_stream: SplitStream>, ws_writer: WebSocketClientWriter, - service_status: ServiceStatus, + service_cancellation: CancellationToken, keep_alive_interval: Duration, max_idle_time: Duration, last_frame_received: Instant, @@ -236,7 +236,7 @@ where IdleTimeout, StopService, } - run_and_update_status(&self.service_status, || async { + run_and_update_status(&self.service_cancellation, || async { loop { // first, waiting for the next lifecycle action let next_ping_time = self.last_keepalive_sent + self.keep_alive_interval; @@ -245,7 +245,7 @@ where maybe_message = self.ws_stream.next() => Event::Message(maybe_message), _ = tokio::time::sleep_until(next_ping_time) => Event::SendKeepAlive, _ = tokio::time::sleep_until(idle_timeout_time) => Event::IdleTimeout, - _ = self.service_status.stopped() => Event::StopService, + _ = self.service_cancellation.cancelled() => Event::StopService, } { Event::SendKeepAlive => { self.ws_writer.send(Message::Ping(vec![])).await?; @@ -281,7 +281,7 @@ where Message::Binary(b) => return Ok(NextOrClose::Next(b.into())), Message::Ping(_) | Message::Pong(_) => continue, Message::Close(close_frame) => { - self.service_status.stop_service(); + self.service_cancellation.cancel(); return Ok(NextOrClose::Close(close_frame)); } Message::Frame(_) => unreachable!("only for sending"), @@ -292,18 +292,21 @@ where } } -async fn run_and_update_status(service_status: &ServiceStatus, f: F) -> Result +async fn run_and_update_status( + service_status: &CancellationToken, + f: F, +) -> Result where WebSocketServiceError: Into, F: FnOnce() -> Ft, Ft: Future>, { - if service_status.is_stopped() { + if service_status.is_cancelled() { return Err(WebSocketServiceError::ChannelClosed.into()); } let result = f().await; if result.is_err() { - service_status.stop_service(); + service_status.cancel(); } result.map_err(Into::into) } diff --git a/swift/Sources/LibSignalClient/ChatService.swift b/swift/Sources/LibSignalClient/ChatService.swift index bfad0d84..e6730988 100644 --- a/swift/Sources/LibSignalClient/ChatService.swift +++ b/swift/Sources/LibSignalClient/ChatService.swift @@ -81,9 +81,9 @@ public class AuthenticatedChatService: NativeHandleOwner, ChatService { withNativeHandle { chatService in if let listener { var listenerStruct = ChatListenerBridge(chatService: self, chatListener: listener).makeListenerStruct() - failOnError(signal_chat_server_set_listener(tokioAsyncContext, chatService, &listenerStruct)) + failOnError(signal_chat_service_set_listener_auth(tokioAsyncContext, chatService, &listenerStruct)) } else { - failOnError(signal_chat_server_set_listener(tokioAsyncContext, chatService, nil)) + failOnError(signal_chat_service_set_listener_auth(tokioAsyncContext, chatService, nil)) } } } diff --git a/swift/Sources/SignalFfi/signal_ffi.h b/swift/Sources/SignalFfi/signal_ffi.h index 948e78e9..4bc2077d 100644 --- a/swift/Sources/SignalFfi/signal_ffi.h +++ b/swift/Sources/SignalFfi/signal_ffi.h @@ -1583,7 +1583,7 @@ SignalFfiError *signal_chat_service_auth_send(SignalCPromiseFfiChatResponse *pro SignalFfiError *signal_chat_service_auth_send_and_debug(SignalCPromiseFfiResponseAndDebugInfo *promise, const SignalTokioAsyncContext *async_runtime, const SignalChat *chat, const SignalHttpRequest *http_request, uint32_t timeout_millis); -SignalFfiError *signal_chat_server_set_listener(const SignalTokioAsyncContext *runtime, const SignalChat *chat, const SignalFfiMakeChatListenerStruct *make_listener); +SignalFfiError *signal_chat_service_set_listener_auth(const SignalTokioAsyncContext *runtime, const SignalChat *chat, const SignalFfiMakeChatListenerStruct *make_listener); SignalFfiError *signal_testing_chat_service_inject_raw_server_request(const SignalChat *chat, SignalBorrowedBuffer bytes);