0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-19 19:42:19 +02:00

net: dropping auto-reconnect logic

This commit is contained in:
Sergey Skrobotov 2024-08-07 16:38:45 -07:00 committed by GitHub
parent aa3f6532b2
commit 55ac7166e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 267 additions and 582 deletions

3
node/Native.d.ts vendored
View File

@ -180,7 +180,8 @@ export function Cds2ClientState_New(mrenclave: Buffer, attestationMsg: Buffer, c
export function CdsiLookup_complete(asyncRuntime: Wrapper<TokioAsyncContext>, lookup: Wrapper<CdsiLookup>): Promise<LookupResponse>;
export function CdsiLookup_new(asyncRuntime: Wrapper<TokioAsyncContext>, connectionManager: Wrapper<ConnectionManager>, username: string, password: string, request: Wrapper<LookupRequest>): Promise<CdsiLookup>;
export function CdsiLookup_token(lookup: Wrapper<CdsiLookup>): Buffer;
export function ChatServer_SetListener(runtime: Wrapper<TokioAsyncContext>, chat: Wrapper<Chat>, makeListener: MakeChatListener | null): void;
export function ChatService_SetListenerAuth(runtime: Wrapper<TokioAsyncContext>, chat: Wrapper<Chat>, makeListener: MakeChatListener | null): void;
export function ChatService_SetListenerUnauth(runtime: Wrapper<TokioAsyncContext>, chat: Wrapper<Chat>, makeListener: MakeChatListener | null): void;
export function ChatService_auth_send(asyncRuntime: Wrapper<TokioAsyncContext>, chat: Wrapper<Chat>, httpRequest: Wrapper<HttpRequest>, timeoutMillis: number): Promise<ChatResponse>;
export function ChatService_auth_send_and_debug(asyncRuntime: Wrapper<TokioAsyncContext>, chat: Wrapper<Chat>, httpRequest: Wrapper<HttpRequest>, timeoutMillis: number): Promise<ResponseAndDebugInfo>;
export function ChatService_connect_auth(asyncRuntime: Wrapper<TokioAsyncContext>, chat: Wrapper<Chat>): Promise<ChatServiceDebugInfo>;

View File

@ -122,7 +122,17 @@ export class ChatServerMessageAck {
}
}
export interface ChatServiceListener {
export interface ConnectionEventsListener {
/**
* Called when the client gets disconnected from the server.
*
* This includes both deliberate disconnects as well as unexpected socket closures that will be
* automatically retried.
*/
onConnectionInterrupted(): void;
}
export interface ChatServiceListener extends ConnectionEventsListener {
/**
* Called when the server delivers an incoming message to the client.
*
@ -144,17 +154,6 @@ export interface ChatServiceListener {
* were in the queue *when the connection was established* have been delivered.
*/
onQueueEmpty(): void;
/**
* Called when the client gets disconnected from the server.
*
* This includes both deliberate disconnects as well as unexpected socket closures that will be
* automatically retried.
*
* Will not be called if no other requests have been invoked for this connection attempt. That is,
* you should never see this as the first callback, nor two of these callbacks in a row.
*/
onConnectionInterrupted(): void;
}
/**
@ -256,7 +255,7 @@ export class AuthenticatedChatService implements ChatService {
listener.onConnectionInterrupted();
},
};
Native.ChatServer_SetListener(
Native.ChatService_SetListenerAuth(
asyncContext,
this.chatService,
nativeChatListener
@ -315,11 +314,32 @@ export class UnauthenticatedChatService implements ChatService {
constructor(
private readonly asyncContext: TokioAsyncContext,
connectionManager: ConnectionManager
connectionManager: ConnectionManager,
listener: ConnectionEventsListener
) {
this.chatService = newNativeHandle(
Native.ChatService_new(connectionManager, '', '', false)
);
const nativeChatListener = {
_incoming_message(
_envelope: Buffer,
_timestamp: number,
_ack: ServerMessageAck
): void {
throw new Error('Event not supported on unauthenticated connection');
},
_queue_empty(): void {
throw new Error('Event not supported on unauthenticated connection');
},
_connection_interrupted(): void {
listener.onConnectionInterrupted();
},
};
Native.ChatService_SetListenerUnauth(
asyncContext,
this.chatService,
nativeChatListener
);
}
disconnect(): Promise<void> {
@ -426,10 +446,13 @@ export class Net {
/**
* Creates a new instance of {@link UnauthenticatedChatService}.
*/
public newUnauthenticatedChatService(): UnauthenticatedChatService {
public newUnauthenticatedChatService(
listener: ConnectionEventsListener
): UnauthenticatedChatService {
return new UnauthenticatedChatService(
this.asyncContext,
this.connectionManager
this.connectionManager,
listener
);
}

View File

@ -176,9 +176,13 @@ describe('chat service api', () => {
it('can connect unauthenticated', async () => {
const net = new Net(Environment.Staging, userAgent);
const chatService = net.newUnauthenticatedChatService();
const listener = {
onConnectionInterrupted: sinon.stub(),
};
const chatService = net.newUnauthenticatedChatService(listener);
await chatService.connect();
await chatService.disconnect();
expect(listener.onConnectionInterrupted).to.have.been.calledOnce;
}).timeout(10000);
it('can connect through a proxy server', async () => {
@ -190,9 +194,13 @@ describe('chat service api', () => {
const [host = PROXY_SERVER, port = '443'] = PROXY_SERVER.split(':', 2);
net.setProxy(host, parseInt(port, 10));
const chatService = net.newUnauthenticatedChatService();
const listener = {
onConnectionInterrupted: sinon.stub(),
};
const chatService = net.newUnauthenticatedChatService(listener);
await chatService.connect();
await chatService.disconnect();
expect(listener.onConnectionInterrupted).to.have.been.calledOnce;
}).timeout(10000);
});

View File

@ -16,6 +16,7 @@ use libsignal_net::auth::Auth;
use libsignal_net::chat::{
self, ChatServiceError, DebugInfo as ChatServiceDebugInfo, Request, Response as ChatResponse,
};
use libsignal_net::infra::ws::WebSocketServiceError;
use crate::support::*;
use crate::*;
@ -172,19 +173,35 @@ async fn ChatService_auth_send_and_debug(
}
#[bridge_fn(jni = false)]
fn ChatServer_SetListener(
fn ChatService_SetListenerAuth(
runtime: &TokioAsyncContext,
chat: &Chat,
make_listener: Option<&dyn MakeChatListener>,
) {
let Some(maker) = make_listener else {
chat.clear_listener();
chat.clear_listener_auth();
return;
};
let listener = maker.make_listener();
chat.set_listener(listener, runtime)
chat.set_listener_auth(listener, runtime)
}
#[bridge_fn(jni = false, ffi = false)]
fn ChatService_SetListenerUnauth(
runtime: &TokioAsyncContext,
chat: &Chat,
make_listener: Option<&dyn MakeChatListener>,
) {
let Some(maker) = make_listener else {
chat.clear_listener_unauth();
return;
};
let listener = maker.make_listener();
chat.set_listener_unauth(listener, runtime)
}
#[cfg(feature = "testing-fns")]
@ -201,7 +218,9 @@ fn TESTING_ChatService_InjectRawServerRequest(chat: &Chat, bytes: &[u8]) {
#[bridge_fn]
fn TESTING_ChatService_InjectConnectionInterrupted(chat: &Chat) {
chat.synthetic_request_tx
.blocking_send(chat::ws::ServerEvent::Stopped)
.blocking_send(chat::ws::ServerEvent::Stopped(ChatServiceError::WebSocket(
WebSocketServiceError::ChannelClosed,
)))
.expect("not closed");
}

View File

@ -7,6 +7,7 @@ use super::*;
use crate::net::chat::{ChatListener, MakeChatListener, ServerMessageAck};
use libsignal_net::chat::ChatServiceError;
use std::ffi::{c_uchar, c_void};
type ReceivedIncomingMessage = extern "C" fn(
@ -76,7 +77,8 @@ impl ChatListener for ChatListenerStruct {
(self.0.received_queue_empty)(self.0.ctx)
}
fn connection_interrupted(&mut self) {
// TODO: pass `_disconnect_cause` to `connection_interrupted`
fn connection_interrupted(&mut self, _disconnect_cause: ChatServiceError) {
(self.0.connection_interrupted)(self.0.ctx)
}
}

View File

@ -15,7 +15,9 @@ use http::status::InvalidStatusCode;
use http::uri::{InvalidUri, PathAndQuery};
use http::{HeaderMap, HeaderName, HeaderValue};
use libsignal_net::auth::Auth;
use libsignal_net::chat::{self, DebugInfo as ChatServiceDebugInfo, Response as ChatResponse};
use libsignal_net::chat::{
self, ChatServiceError, DebugInfo as ChatServiceDebugInfo, Response as ChatResponse,
};
use libsignal_protocol::Timestamp;
use tokio::sync::{mpsc, oneshot};
@ -57,7 +59,8 @@ pub struct Chat {
Arc<dyn chat::ChatServiceWithDebugInfo + Send + Sync>,
Arc<dyn chat::ChatServiceWithDebugInfo + Send + Sync>,
>,
listener: std::sync::Mutex<ChatListenerState>,
listener_auth: std::sync::Mutex<ChatListenerState>,
listener_unauth: std::sync::Mutex<ChatListenerState>,
pub synthetic_request_tx:
mpsc::Sender<chat::ws::ServerEvent<libsignal_net::infra::tcp_ssl::TcpSslConnectorStream>>,
}
@ -66,9 +69,14 @@ impl RefUnwindSafe for Chat {}
impl Chat {
pub fn new(connection_manager: &ConnectionManager, auth: Auth, receive_stories: bool) -> Self {
let (incoming_tx, incoming_rx) = mpsc::channel(1);
let incoming_stream = chat::server_requests::stream_incoming_messages(incoming_rx);
let synthetic_request_tx = incoming_tx.clone();
let (incoming_auth_tx, incoming_auth_rx) = mpsc::channel(1);
let incoming_stream_auth =
chat::server_requests::stream_incoming_messages(incoming_auth_rx);
let synthetic_request_tx = incoming_auth_tx.clone();
let (incoming_unauth_tx, incoming_unauth_rx) = mpsc::channel(1);
let incoming_stream_unauth =
chat::server_requests::stream_incoming_messages(incoming_unauth_rx);
Chat {
service: chat::chat_service(
@ -78,22 +86,44 @@ impl Chat {
.lock()
.expect("not poisoned")
.clone(),
incoming_tx,
incoming_auth_tx,
incoming_unauth_tx,
auth,
receive_stories,
)
.into_dyn(),
listener: std::sync::Mutex::new(ChatListenerState::Inactive(Box::pin(incoming_stream))),
listener_auth: std::sync::Mutex::new(ChatListenerState::Inactive(Box::pin(
incoming_stream_auth,
))),
listener_unauth: std::sync::Mutex::new(ChatListenerState::Inactive(Box::pin(
incoming_stream_unauth,
))),
synthetic_request_tx,
}
}
pub fn set_listener(&self, listener: Box<dyn ChatListener>, runtime: &TokioAsyncContext) {
pub fn set_listener_auth(&self, listener: Box<dyn ChatListener>, runtime: &TokioAsyncContext) {
Chat::set_listener(listener, &self.listener_auth, runtime);
}
pub fn set_listener_unauth(
&self,
listener: Box<dyn ChatListener>,
runtime: &TokioAsyncContext,
) {
Chat::set_listener(listener, &self.listener_unauth, runtime);
}
fn set_listener(
listener: Box<dyn ChatListener>,
listener_state: &std::sync::Mutex<ChatListenerState>,
runtime: &TokioAsyncContext,
) {
use futures_util::future::Either;
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
let mut guard = self.listener.lock().expect("unpoisoned");
let mut guard = listener_state.lock().expect("unpoisoned");
let request_stream_future =
match std::mem::replace(&mut *guard, ChatListenerState::CurrentlyBeingMutated) {
ChatListenerState::Inactive(request_stream) => {
@ -125,8 +155,12 @@ impl Chat {
drop(guard);
}
pub fn clear_listener(&self) {
self.listener.lock().expect("unpoisoned").cancel();
pub fn clear_listener_auth(&self) {
self.listener_auth.lock().expect("unpoisoned").cancel();
}
pub fn clear_listener_unauth(&self) {
self.listener_unauth.lock().expect("unpoisoned").cancel();
}
}
@ -204,7 +238,7 @@ pub trait ChatListener: Send {
ack: ServerMessageAck,
);
fn received_queue_empty(&mut self);
fn connection_interrupted(&mut self);
fn connection_interrupted(&mut self, disconnect_cause: ChatServiceError);
}
impl dyn ChatListener {
@ -223,7 +257,9 @@ impl dyn ChatListener {
ServerMessageAck::new(send_ack),
),
chat::server_requests::ServerMessage::QueueEmpty => self.received_queue_empty(),
chat::server_requests::ServerMessage::Stopped => self.connection_interrupted(),
chat::server_requests::ServerMessage::Stopped(error) => {
self.connection_interrupted(error)
}
}
}

View File

@ -5,6 +5,7 @@
use crate::net::chat::{ChatListener, MakeChatListener, ServerMessageAck};
use crate::node::ResultTypeInfo;
use libsignal_net::chat::ChatServiceError;
use libsignal_protocol::Timestamp;
use neon::context::FunctionContext;
use neon::event::Channel;
@ -53,7 +54,8 @@ impl ChatListener for NodeChatListener {
});
}
fn connection_interrupted(&mut self) {
// TODO: pass `_disconnect_cause` to `_connection_interrupted`
fn connection_interrupted(&mut self, _disconnect_cause: ChatServiceError) {
let callback_object_shared = self.callback_object.clone();
self.js_channel.send(move |mut cx| {
let callback = callback_object_shared.to_inner(&mut cx);

View File

@ -115,11 +115,13 @@ async fn test_connection(
&network_change_event,
);
let (incoming_tx, _incoming_rx) = mpsc::channel(1);
let (incoming_auth_tx, _incoming_rx) = mpsc::channel(1);
let (incoming_unauth_tx, _incoming_rx) = mpsc::channel(1);
let chat = chat_service(
&connection,
transport_connector,
incoming_tx,
incoming_auth_tx,
incoming_unauth_tx,
Auth {
username: "".to_owned(),
password: "".to_owned(),

View File

@ -447,18 +447,19 @@ fn build_anonymous_chat_service(
pub fn chat_service<T: TransportConnector + 'static>(
endpoint: &EndpointConnection<MultiRouteConnectionManager>,
transport_connector: T,
incoming_tx: tokio::sync::mpsc::Sender<ServerEvent<T::Stream>>,
incoming_auth_tx: tokio::sync::mpsc::Sender<ServerEvent<T::Stream>>,
incoming_unauth_tx: tokio::sync::mpsc::Sender<ServerEvent<T::Stream>>,
auth: Auth,
receive_stories: bool,
) -> Chat<impl ChatServiceWithDebugInfo, impl ChatServiceWithDebugInfo> {
// Cannot reuse the same connector, since they lock on `incoming_tx` internally.
let unauth_ws_connector = ChatOverWebSocketServiceConnector::new(
WebSocketClientConnector::new(transport_connector.clone(), endpoint.config.clone()),
incoming_tx.clone(),
incoming_unauth_tx,
);
let auth_ws_connector = ChatOverWebSocketServiceConnector::new(
WebSocketClientConnector::new(transport_connector, endpoint.config.clone()),
incoming_tx,
incoming_auth_tx,
);
{
let auth_service = build_authorized_chat_service(
@ -513,7 +514,7 @@ pub(crate) mod test {
timeout: Duration,
) -> Result<Response, ChatServiceError> {
match &*self.inner {
ServiceState::Active(service, status) if !status.is_stopped() => {
ServiceState::Active(service, status) if !status.is_cancelled() => {
service.clone().send(msg, timeout).await
}
_ => Err(ChatServiceError::AllConnectionRoutesFailed { attempts: 1 }),
@ -526,7 +527,7 @@ pub(crate) mod test {
async fn disconnect(&self) {
if let ServiceState::Active(_, status) = &*self.inner {
status.stop_service()
status.cancel()
}
}
}

View File

@ -33,7 +33,7 @@ where
}
async fn connect(&self) -> Result<(), ChatServiceError> {
Ok(self.connect_from_inactive().await?)
Ok(self.connect().await?)
}
async fn disconnect(&self) {
@ -103,7 +103,7 @@ where
async fn connect_and_debug(&self) -> Result<DebugInfo, ChatServiceError> {
let start = Instant::now();
self.connect_from_inactive().await?;
self.connect().await?;
let connection_info = self.connection_info().await?;
let ip_type = IpType::from_host(&connection_info.address);

View File

@ -100,7 +100,6 @@ impl<E: LogSafeDisplay + Into<ChatServiceError>> From<reconnect::ReconnectError<
Self::AllConnectionRoutesFailed { attempts }
}
reconnect::ReconnectError::RejectedByServer(e) => e.into(),
reconnect::ReconnectError::Inactive => Self::ServiceInactive,
}
}
}

View File

@ -28,7 +28,7 @@ pub enum ServerMessage {
send_ack: ResponseEnvelopeSender,
},
/// Not actually a message, but an event processed as part of the message stream.
Stopped,
Stopped(ChatServiceError),
}
impl std::fmt::Debug for ServerMessage {
@ -46,7 +46,10 @@ impl std::fmt::Debug for ServerMessage {
.field("envelope", &format_args!("{} bytes", envelope.len()))
.field("server_delivery_timestamp", server_delivery_timestamp)
.finish(),
Self::Stopped => write!(f, "Stopped"),
Self::Stopped(error) => f
.debug_struct("ConnectionInterrupted")
.field("reason", error)
.finish(),
}
}
}
@ -55,7 +58,7 @@ pub fn stream_incoming_messages(
receiver: mpsc::Receiver<ServerEvent<impl AsyncDuplexStream + 'static>>,
) -> impl Stream<Item = ServerMessage> {
ReceiverStream::new(receiver).filter_map(|request| match request {
ServerEvent::Stopped => Some(ServerMessage::Stopped),
ServerEvent::Stopped(error) => Some(ServerMessage::Stopped(error)),
ServerEvent::Request {
request_proto,
response_sender,

View File

@ -15,12 +15,13 @@ use prost::Message;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time::Instant;
use tokio_tungstenite::WebSocketStream;
use tokio_util::sync::CancellationToken;
use crate::chat::{
ChatMessageType, ChatService, ChatServiceError, MessageProto, RemoteAddressInfo, Request,
RequestProto, Response, ResponseProto,
};
use crate::infra::reconnect::{ServiceConnector, ServiceStatus};
use crate::infra::reconnect::ServiceConnector;
use crate::infra::ws::{
NextOrClose, TextOrBinary, WebSocketClient, WebSocketClientConnector, WebSocketClientReader,
WebSocketClientWriter, WebSocketConnectError, WebSocketServiceError,
@ -67,7 +68,7 @@ pub enum ServerEvent<S> {
request_proto: RequestProto,
response_sender: ResponseSender<S>,
},
Stopped,
Stopped(ChatServiceError),
}
impl<S: AsyncDuplexStream> ServerEvent<S> {
@ -173,10 +174,7 @@ impl<T: TransportConnector> ServiceConnector for ChatOverWebSocketServiceConnect
.await
}
fn start_service(
&self,
channel: Self::Channel,
) -> (Self::Service, ServiceStatus<Self::StartError>) {
fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) {
let (ws_client, service_status) = self.ws_client_connector.start_service(channel);
let WebSocketClient {
ws_client_writer,
@ -194,7 +192,7 @@ impl<T: TransportConnector> ServiceConnector for ChatOverWebSocketServiceConnect
(
ChatOverWebSocket {
ws_client_writer,
service_status: service_status.clone(),
service_cancellation: service_status.clone(),
pending_messages,
connection_info,
},
@ -208,7 +206,7 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
ws_client_writer: WebSocketClientWriter<S, ChatServiceError>,
incoming_tx: Arc<Mutex<mpsc::Sender<ServerEvent<S>>>>,
pending_messages: Arc<Mutex<PendingMessagesMap>>,
service_status: ServiceStatus<ChatServiceError>,
service_cancellation: CancellationToken,
) {
const LONG_REQUEST_PROCESSING_THRESHOLD: Duration = Duration::from_millis(500);
@ -219,23 +217,22 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
let mut previous_request_paths_for_logging =
VecDeque::with_capacity(incoming_tx.max_capacity());
let mut has_ever_sent_server_event = false;
loop {
let error = loop {
let data = match ws_client_reader.next().await {
Ok(NextOrClose::Next(TextOrBinary::Binary(data))) => data,
Ok(NextOrClose::Next(TextOrBinary::Text(_))) => {
log::info!("Text frame received on chat websocket");
service_status.stop_service_with_error(ChatServiceError::UnexpectedFrameReceived);
break;
service_cancellation.cancel();
break ChatServiceError::UnexpectedFrameReceived;
}
Ok(NextOrClose::Close(_)) => {
service_status.stop_service_with_error(WebSocketServiceError::ChannelClosed.into());
break;
service_cancellation.cancel();
break WebSocketServiceError::ChannelClosed.into();
}
Err(e) => {
service_status.stop_service_with_error(e);
break;
service_cancellation.cancel();
break e;
}
};
@ -246,8 +243,8 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
let server_request = match ServerEvent::new(req, ws_client_writer.clone()) {
Ok(server_request) => server_request,
Err(e) => {
service_status.stop_service_with_error(e);
break;
service_cancellation.cancel();
break e;
}
};
@ -256,9 +253,8 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
let request_send_elapsed = request_send_start.elapsed();
if delivery_result.is_err() {
service_status.stop_service_with_error(
ChatServiceError::FailedToPassMessageToIncomingChannel,
);
service_cancellation.cancel();
break ChatServiceError::FailedToPassMessageToIncomingChannel;
}
if previous_request_paths_for_logging.len() == incoming_tx.max_capacity() {
@ -282,7 +278,6 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
}
}
previous_request_paths_for_logging.push_back(request_path);
has_ever_sent_server_event = true;
}
Ok(ChatMessage::Response(id, res)) => {
let map = &mut pending_messages.lock().await;
@ -293,22 +288,16 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
}
}
Err(e) => {
service_status.stop_service_with_error(e);
service_cancellation.cancel();
break e;
}
}
}
};
if has_ever_sent_server_event {
// We only need a Stopped event to separate the events from different connections. If there
// haven't been any events from this connection, we don't bother. This also avoids sending
// events for the unauth socket, even though it still has to read responses.
// (But don't worry about delivery failure here; if no one's listening, the event is
// superfluous anyway.)
_ = incoming_tx.send(ServerEvent::Stopped).await;
}
_ = incoming_tx.send(ServerEvent::Stopped(error)).await;
// before terminating the task, marking channel as inactive
service_status.stop_service();
service_cancellation.cancel();
// Clear the pending messages map. These requests don't wait on the service status just in case
// a response comes in late; dropping the response senders is how we cancel them.
@ -319,7 +308,7 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
#[derive(Debug)]
pub struct ChatOverWebSocket<S> {
ws_client_writer: WebSocketClientWriter<S, ChatServiceError>,
service_status: ServiceStatus<ChatServiceError>,
service_cancellation: CancellationToken,
pending_messages: Arc<Mutex<PendingMessagesMap>>,
connection_info: ConnectionInfo,
}
@ -337,7 +326,7 @@ where
{
async fn send(&self, msg: Request, timeout: Duration) -> Result<Response, ChatServiceError> {
// checking if channel has been closed
if self.service_status.is_stopped() {
if self.service_cancellation.is_cancelled() {
return Err(WebSocketServiceError::ChannelClosed.into());
}
@ -374,7 +363,7 @@ where
}
async fn disconnect(&self) {
self.service_status.stop_service()
self.service_cancellation.cancel()
}
}
@ -488,12 +477,12 @@ mod test {
let ws_config = test_ws_config();
let time_to_wait = ws_config.max_idle_time * 2;
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
assert!(!ws_chat.service_status().unwrap().is_stopped());
assert!(!ws_chat.service_status().unwrap().is_cancelled());
// sleeping for a period of time long enough to stop the service
// in case of missing PONG responses
tokio::time::sleep(time_to_wait).await;
assert!(!ws_chat.service_status().unwrap().is_stopped());
assert!(!ws_chat.service_status().unwrap().is_cancelled());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
@ -507,12 +496,12 @@ mod test {
});
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
assert!(!ws_chat.service_status().unwrap().is_stopped());
assert!(!ws_chat.service_status().unwrap().is_cancelled());
// sleeping for a period of time long enough for the service to stop,
// which is what should happen since the PONG messages are not sent back
tokio::time::sleep(duration).await;
assert!(ws_chat.service_status().unwrap().is_stopped());
assert!(ws_chat.service_status().unwrap().is_cancelled());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
@ -533,12 +522,12 @@ mod test {
});
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
assert!(!ws_chat.service_status().unwrap().is_stopped());
assert!(!ws_chat.service_status().unwrap().is_cancelled());
// sleeping for a period of time long enough to stop the service
// in case of missing PONG responses
tokio::time::sleep(time_to_wait).await;
assert!(ws_chat.service_status().unwrap().is_stopped());
assert!(ws_chat.service_status().unwrap().is_cancelled());
// making sure server logic completed in the expected way
validate_server_stopped_successfully(server_res_rx).await;
}
@ -562,12 +551,12 @@ mod test {
let ws_config = test_ws_config();
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
assert!(!ws_chat.service_status().unwrap().is_stopped());
assert!(!ws_chat.service_status().unwrap().is_cancelled());
// sleeping for a period of time long enough to stop the service
// in case of missing PONG responses
tokio::time::sleep(time_to_wait).await;
assert!(ws_chat.service_status().unwrap().is_stopped());
assert!(ws_chat.service_status().unwrap().is_cancelled());
// making sure server logic completed in the expected way
validate_server_stopped_successfully(server_res_rx).await;
}
@ -595,13 +584,26 @@ mod test {
});
let ws_config = test_ws_config();
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
let (ws_chat, mut incoming_rx) = create_ws_chat_service(ws_config, ws_server).await;
let response = ws_chat
.send(test_request(Method::GET, "/"), TIMEOUT_DURATION)
.await;
response.expect("response");
validate_server_running(server_res_rx).await;
// now we're disconnecting manually in which case we expect a `Stopped` event
// with `ChatServiceError::WebSocket(WebSocketServiceError::ChannelClosed)` error
let service_status = ws_chat.service_status().expect("some status");
service_status.cancel();
service_status.cancelled().await;
let event = incoming_rx.recv().await;
assert_matches!(
event,
Some(ServerEvent::Stopped(ChatServiceError::WebSocket(
WebSocketServiceError::ChannelClosed
)))
);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
@ -737,12 +739,12 @@ mod test {
assert_matches!(
incoming_rx.recv().await.expect("server request"),
ServerEvent::Stopped
ServerEvent::Stopped(_)
);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn ws_service_skips_stop_event_without_requests() {
async fn ws_service_receives_stopped_event_if_server_disconnects() {
// creating a server that accepts a request, responds, then closes the connection.
let (ws_server, server_res_rx) = ws_warp_filter(move |websocket| async move {
let (mut tx, mut rx) = websocket.split();
@ -774,7 +776,9 @@ mod test {
assert_matches!(
incoming_rx.try_recv(),
Err(mpsc::error::TryRecvError::Disconnected)
Ok(ServerEvent::Stopped(ChatServiceError::WebSocket(
WebSocketServiceError::Protocol(_)
)))
);
}

View File

@ -340,13 +340,12 @@ pub(crate) mod test {
use derive_where::derive_where;
use displaydoc::Display;
use tokio::io::DuplexStream;
use tokio_util::sync::CancellationToken;
use warp::{Filter, Reply};
use crate::infra::connection_manager::{ConnectionManager, ErrorClass, ErrorClassifier};
use crate::infra::errors::{LogSafeDisplay, TransportConnectError};
use crate::infra::reconnect::{
ServiceConnector, ServiceInitializer, ServiceState, ServiceStatus,
};
use crate::infra::reconnect::{ServiceConnector, ServiceInitializer, ServiceState};
use crate::infra::{
Alpn, ConnectionInfo, ConnectionParams, DnsSource, RouteType, StreamAndInfo,
TransportConnector,
@ -462,7 +461,7 @@ pub(crate) mod test {
#[derive_where(Clone)]
pub(crate) struct NoReconnectService<C: ServiceConnector> {
pub(crate) inner: Arc<ServiceState<C::Service, C::ConnectError, C::StartError>>,
pub(crate) inner: Arc<ServiceState<C::Service, C::ConnectError>>,
}
impl<C> NoReconnectService<C>
@ -484,9 +483,9 @@ pub(crate) mod test {
}
}
pub(crate) fn service_status(&self) -> Option<&ServiceStatus<C::StartError>> {
pub(crate) fn service_status(&self) -> Option<&CancellationToken> {
match &*self.inner {
ServiceState::Active(_, status) => Some(status),
ServiceState::Active(_, service_cancellation) => Some(service_cancellation),
_ => None,
}
}

View File

@ -5,13 +5,11 @@
use std::fmt::Debug;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use derive_where::derive_where;
use displaydoc::Display;
use static_assertions::const_assert;
use tokio::sync::Mutex;
use tokio::time::{timeout_at, Instant};
use tokio_util::sync::CancellationToken;
@ -22,7 +20,6 @@ use crate::infra::connection_manager::{
};
use crate::infra::errors::LogSafeDisplay;
use crate::infra::{ConnectionInfo, ConnectionParams, HttpRequestDecorator};
use crate::timeouts::CONNECTION_ROUTE_COOLDOWN_INTERVALS;
// A duration where, if this is all that's left on the timeout, we're more likely to fail than not.
// Useful for debouncing repeated connection attempts.
@ -31,13 +28,13 @@ const MINIMUM_CONNECTION_TIME: Duration = Duration::from_millis(500);
/// For a service that needs to go through some initialization procedure
/// before it's ready for use, this enum describes its possible states.
#[derive(Debug)]
pub(crate) enum ServiceState<T, CE, SE> {
pub(crate) enum ServiceState<T, CE> {
/// Service was not explicitly activated.
Inactive,
/// Contains an instance of the service which is initialized and ready to use.
/// Also, since we're not actively listening for the event of service going inactive,
/// the `ServiceStatus` could be used to see if the service is actually running.
Active(T, ServiceStatus<SE>),
/// the `CancellationToken` could be used to see if the service is actually running.
Active(T, CancellationToken),
/// The service is inactive and no initialization attempts are to be made
/// until the `Instant` held by this object.
Cooldown(Instant),
@ -62,10 +59,7 @@ pub(crate) trait ServiceConnector: Clone {
connection_params: &ConnectionParams,
) -> Result<Self::Channel, Self::ConnectError>;
fn start_service(
&self,
channel: Self::Channel,
) -> (Self::Service, ServiceStatus<Self::StartError>);
fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken);
}
#[async_trait]
@ -85,10 +79,7 @@ where
(*self).connect_channel(connection_params).await
}
fn start_service(
&self,
channel: Self::Channel,
) -> (Self::Service, ServiceStatus<Self::StartError>) {
fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) {
(*self).start_service(channel)
}
}
@ -125,72 +116,11 @@ where
self.inner.connect_channel(&decorated).await
}
fn start_service(
&self,
channel: Self::Channel,
) -> (Self::Service, ServiceStatus<Self::StartError>) {
fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) {
self.inner.start_service(channel)
}
}
#[derive(Debug)]
#[derive_where(Clone)]
pub(crate) struct ServiceStatus<E> {
maybe_error: Arc<OnceLock<Option<E>>>,
service_cancellation: CancellationToken,
}
impl<E> Default for ServiceStatus<E> {
fn default() -> Self {
Self {
maybe_error: Arc::new(OnceLock::new()),
service_cancellation: CancellationToken::new(),
}
}
}
impl<E> ServiceStatus<E> {
pub(crate) fn stop_service(&self) {
self.maybe_error.get_or_init(|| None);
self.service_cancellation.cancel();
}
pub(crate) fn stop_service_with_error(&self, error: E) {
self.maybe_error.get_or_init(|| Some(error));
self.service_cancellation.cancel();
}
pub(crate) fn is_stopped(&self) -> bool {
self.service_cancellation.is_cancelled()
}
pub(crate) async fn stopped(&self) {
self.service_cancellation.cancelled().await
}
/// Returns an error if `stop_service_with_error` was called previously.
///
/// Note that returning `None` could mean the service is still running, or that the service was
/// stopped deliberately without an error, or that the service stopped because of an error but
/// that error was handled elsewhere.
pub(crate) fn get_error(&self) -> Option<&E> {
match self.maybe_error.get() {
None => {
// service not stopped
None
}
Some(None) => {
// service stopped without error
None
}
Some(Some(e)) => {
// service stopped with error
Some(e)
}
}
}
}
pub(crate) struct ServiceInitializer<C, M> {
service_connector: C,
connection_manager: M,
@ -211,7 +141,7 @@ where
}
}
pub(crate) async fn connect(&self) -> ServiceState<C::Service, C::ConnectError, C::StartError> {
pub(crate) async fn connect(&self) -> ServiceState<C::Service, C::ConnectError> {
log::debug!("attempting a connection");
let connection_attempt_result = self
.connection_manager
@ -252,7 +182,7 @@ where
pub(crate) struct ServiceWithReconnectData<C: ServiceConnector, M> {
reconnect_count: AtomicU32,
state: Mutex<ServiceState<C::Service, C::ConnectError, C::StartError>>,
state: Mutex<ServiceState<C::Service, C::ConnectError>>,
service_initializer: ServiceInitializer<C, M>,
connection_timeout: Duration,
}
@ -270,8 +200,6 @@ pub(crate) enum ReconnectError<E: LogSafeDisplay> {
AllRoutesFailed { attempts: u16 },
/// Rejected by server: {0}
RejectedByServer(E),
/// Service is in the inactive state
Inactive,
}
impl<E: LogSafeDisplay> ErrorClassifier for ReconnectError<E> {
@ -280,7 +208,7 @@ impl<E: LogSafeDisplay> ErrorClassifier for ReconnectError<E> {
ReconnectError::Timeout { .. } | ReconnectError::AllRoutesFailed { .. } => {
ErrorClass::Intermittent
}
ReconnectError::RejectedByServer(_) | ReconnectError::Inactive => ErrorClass::Fatal,
ReconnectError::RejectedByServer(_) => ErrorClass::Fatal,
}
}
}
@ -300,7 +228,7 @@ where
async fn map_service<T>(&self, mapper: fn(&C::Service) -> T) -> Result<T, StateError> {
let guard = self.data.state.lock().await;
match &*guard {
ServiceState::Active(service, status) if !status.is_stopped() => Ok(mapper(service)),
ServiceState::Active(service, status) if !status.is_cancelled() => Ok(mapper(service)),
ServiceState::Inactive => Err(StateError::Inactive),
ServiceState::Cooldown(_)
| ServiceState::ConnectionTimedOut
@ -348,20 +276,11 @@ where
self.data.reconnect_count.load(Ordering::Relaxed)
}
pub(crate) async fn reconnect_if_active(&self) -> Result<(), ReconnectError<C::ConnectError>> {
self.connect(false).await
pub(crate) async fn connect(&self) -> Result<(), ReconnectError<C::ConnectError>> {
self.do_connect().await
}
pub(crate) async fn connect_from_inactive(
&self,
) -> Result<(), ReconnectError<C::ConnectError>> {
self.connect(true).await
}
async fn connect(
&self,
is_explicit_connect: bool,
) -> Result<(), ReconnectError<C::ConnectError>> {
async fn do_connect(&self) -> Result<(), ReconnectError<C::ConnectError>> {
let mut attempts: u16 = 0;
let start_of_connection_process = Instant::now();
let deadline = start_of_connection_process + self.data.connection_timeout;
@ -379,13 +298,10 @@ where
loop {
match &*guard {
ServiceState::Inactive => {
if !is_explicit_connect {
return Err(ReconnectError::Inactive);
}
// otherwise, proceeding to connect
// proceeding to connect
}
ServiceState::Active(_, service_status) => {
if !service_status.is_stopped() {
if !service_status.is_cancelled() {
// if the state is `Active` and service has not been stopped,
// clone the service and return it
log::debug!("reusing active service instance");
@ -437,13 +353,6 @@ where
continue;
}
ErrorClass::Fatal => {
if !is_explicit_connect {
// Only explicit connection requests have a place to report this
// error, so for now, non-explicit attempts treat this as a generic
// failure.
return Err(ReconnectError::AllRoutesFailed { attempts });
}
let state = std::mem::replace(&mut *guard, ServiceState::Inactive);
let ServiceState::Error(e) = state else {
unreachable!("we checked this above, matching on &*guard");
@ -464,21 +373,16 @@ where
}
attempts += 1;
*guard = match timeout_at(deadline, self.data.service_initializer.connect()).await {
Ok(ServiceState::Active(service, service_state)) => {
self.schedule_reconnect(service_state.clone());
ServiceState::Active(service, service_state)
}
Ok(result) => result,
Err(_) => ServiceState::ConnectionTimedOut,
}
*guard = timeout_at(deadline, self.data.service_initializer.connect())
.await
.unwrap_or(ServiceState::ConnectionTimedOut);
}
}
pub(crate) async fn disconnect(&self) {
let mut guard = self.data.state.lock().await;
if let ServiceState::Active(_, service_status) = &*guard {
service_status.stop_service();
service_status.cancel();
}
*guard = ServiceState::Inactive;
log::info!("service disconnected");
@ -487,67 +391,6 @@ where
pub(crate) async fn service(&self) -> Result<C::Service, StateError> {
self.map_service(|service| service.clone()).await
}
fn schedule_reconnect(&self, service_status: ServiceStatus<C::StartError>) {
let service_with_reconnect = self.clone();
tokio::spawn(async move {
let _ = service_status.service_cancellation.cancelled().await;
if let Some(error) = service_status.get_error() {
log::debug!("Service stopped due to an error: {:?}", error);
log::info!("Service stopped due to an error: {}", error);
} else {
log::info!("Service stopped");
}
// This is a background thread so there is no overall timeout on reconnect.
// Each attempt is limited by the `data.connection_timeout` duration
// but unless we're in one of the non-proceeding states, we'll be trying to
// connect.
let mut sleep_until = Instant::now();
loop {
if sleep_until > Instant::now() {
tokio::time::sleep_until(sleep_until).await;
}
log::debug!("attempting reconnect");
match service_with_reconnect.reconnect_if_active().await {
Ok(_) => {
log::info!("reconnect attempt succeeded");
service_with_reconnect
.data
.reconnect_count
.fetch_add(1, Ordering::Relaxed);
return;
}
Err(ReconnectError::Inactive) => {
return;
}
Err(error) => {
log::warn!("reconnect attempt failed: {}", error);
let guard = service_with_reconnect.data.state.lock().await;
match &*guard {
ServiceState::Cooldown(next_attempt_time) => {
sleep_until = *next_attempt_time;
}
ServiceState::ConnectionTimedOut | ServiceState::Error(_) => {
// Keep trying, but throttle a little in case of early exits
// TODO: In practice, this should only happen when there's a Fatal
// error, in which case retrying won't help and we should really
// shut down the reconnect loop. But part of doing that would be
// making sure the error is reported. Even when that's implemented,
// though, this is a good backstop against bugs.
const_assert!(CONNECTION_ROUTE_COOLDOWN_INTERVALS[0].is_zero());
sleep_until += CONNECTION_ROUTE_COOLDOWN_INTERVALS[1];
}
ServiceState::Inactive | ServiceState::Active(_, _) => {
// most likely, `disconnect()` was called and we
// switched to the `ServiceState::Inactive` state
return;
}
}
}
}
}
});
}
}
#[cfg(test)]
@ -557,7 +400,7 @@ mod test {
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::timeouts::{CONNECTION_ROUTE_COOLDOWN_INTERVALS, CONNECTION_ROUTE_MAX_COOLDOWN};
use crate::timeouts::CONNECTION_ROUTE_MAX_COOLDOWN;
use assert_matches::assert_matches;
use async_trait::async_trait;
use futures_util::FutureExt;
@ -576,19 +419,7 @@ mod test {
use crate::utils::{sleep_and_catch_up, ObservableEvent};
#[derive(Clone, Debug)]
struct TestService {
service_status: ServiceStatus<TestError>,
}
impl TestService {
fn new(service_status: ServiceStatus<TestError>) -> Self {
Self { service_status }
}
fn close_channel(&self) {
self.service_status.stop_service();
}
}
struct TestService;
#[derive(Clone)]
struct TestServiceConnector {
@ -642,13 +473,10 @@ mod test {
}
}
fn start_service(
&self,
_channel: Self::Channel,
) -> (Self::Service, ServiceStatus<Self::StartError>) {
let service_status_arc = ServiceStatus::default();
let service = TestService::new(service_status_arc.clone());
(service, service_status_arc)
fn start_service(&self, _channel: Self::Channel) -> (Self::Service, CancellationToken) {
let service_cancellation = CancellationToken::new();
let service = TestService;
(service, service_cancellation)
}
}
@ -678,45 +506,11 @@ mod test {
#[tokio::test]
async fn service_started_with_request() {
let (connector, service_with_reconnect) = connector_and_service();
service_with_reconnect
.connect_from_inactive()
.await
.expect("");
service_with_reconnect.connect().await.expect("");
let _service = service_with_reconnect.service().await;
assert_eq!(connector.attempts_made(), 1);
}
#[tokio::test(start_paused = true)]
async fn service_tries_to_reconnect_if_connection_lost() {
let (connector, service_with_reconnect) = connector_and_service();
service_with_reconnect
.connect_from_inactive()
.await
.expect("connected");
let service = service_with_reconnect.service().await.expect("available");
// `close_channel()` call emulates lost connection and reconnection will be triggered
// unless service_with_reconnect is in the `Inactive` state
service.close_channel();
// giving time to reconnect
sleep_and_catch_up(NORMAL_CONNECTION_TIME).await;
let service = service_with_reconnect.service().await.expect("available");
// we're doing it again, but this time we'll instruct service connector to fail,
// and as a result, service won't be available
service.close_channel();
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
time::advance(TIME_ADVANCE_VALUE).await;
assert_matches!(
service_with_reconnect.service().await,
Err(StateError::ServiceUnavailable)
);
}
#[tokio::test]
async fn service_is_inactive_before_connected() {
let (_, service_with_reconnect) = connector_and_service();
@ -729,10 +523,7 @@ mod test {
#[tokio::test(start_paused = true)]
async fn service_doesnt_reconnect_if_disconnected() {
let (_, service_with_reconnect) = connector_and_service();
service_with_reconnect
.connect_from_inactive()
.await
.expect("connected");
service_with_reconnect.connect().await.expect("connected");
// making sure service is available
let _ = service_with_reconnect.service().await.expect("available");
@ -754,7 +545,7 @@ mod test {
let (connector, service_with_reconnect) = connector_and_service();
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
let connection_result = service_with_reconnect.connect_from_inactive().await;
let connection_result = service_with_reconnect.connect().await;
// Here we have 3 attempts made by the reconnect service:
// - first attempt went to the connector and resulted in expected error
@ -792,10 +583,10 @@ mod test {
let (connector, service_with_reconnect) = connector_and_service();
let aaa1 = service_with_reconnect.clone();
let handle1 = tokio::spawn(async move { aaa1.connect_from_inactive().await });
let handle1 = tokio::spawn(async move { aaa1.connect().await });
let aaa2 = service_with_reconnect.clone();
let handle2 = tokio::spawn(async move { aaa2.connect_from_inactive().await });
let handle2 = tokio::spawn(async move { aaa2.connect().await });
let (s1, s2) = tokio::join!(handle1, handle2);
assert!(s1.expect("future completed successfully").is_ok());
@ -819,7 +610,7 @@ mod test {
);
let service_with_reconnect =
ServiceWithReconnect::new(connector.clone(), manager, service_with_reconnect_timeout);
let res = service_with_reconnect.connect_from_inactive().await;
let res = service_with_reconnect.connect().await;
// now the time should've auto-advanced from `start` by the `connection_timeout` value
assert!(res.is_err());
@ -843,7 +634,7 @@ mod test {
);
let service_with_reconnect =
ServiceWithReconnect::new(connector.clone(), manager, service_with_reconnect_timeout);
let res = service_with_reconnect.connect_from_inactive().await;
let res = service_with_reconnect.connect().await;
// now the time should've auto-advanced from `start` by the `connection_timeout` value
assert_matches!(res, Err(ReconnectError::Timeout { attempts: 1 }));
@ -856,7 +647,7 @@ mod test {
time::advance(TIME_ADVANCE_VALUE).await;
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
let connection_result = service_with_reconnect.connect_from_inactive().await;
let connection_result = service_with_reconnect.connect().await;
// number of attempts is the same as in the `immediately_fail_if_in_cooldown()` test
assert_matches!(
@ -869,7 +660,7 @@ mod test {
time::advance(CONNECTION_ROUTE_MAX_COOLDOWN).await;
connector.set_connection_error(None);
let connection_result = service_with_reconnect.connect_from_inactive().await;
let connection_result = service_with_reconnect.connect().await;
assert_matches!(connection_result, Ok(_));
}
@ -878,7 +669,7 @@ mod test {
let (connector, service_with_reconnect) = connector_and_service();
time::advance(TIME_ADVANCE_VALUE).await;
connector.set_time_to_connect(LONG_CONNECTION_TIME);
let connection_result = service_with_reconnect.connect_from_inactive().await;
let connection_result = service_with_reconnect.connect().await;
assert_matches!(
connection_result,
Err(ReconnectError::Timeout { attempts: 1 })
@ -889,125 +680,17 @@ mod test {
time::advance(CONNECTION_ROUTE_MAX_COOLDOWN).await;
connector.set_time_to_connect(NORMAL_CONNECTION_TIME);
let connection_result = service_with_reconnect.connect_from_inactive().await;
let connection_result = service_with_reconnect.connect().await;
assert_matches!(connection_result, Ok(_));
}
#[tokio::test(start_paused = true)]
async fn service_keep_reconnecting_attempts_if_first_fails() {
let (connector, service_with_reconnect) = connector_and_service();
service_with_reconnect
.connect_from_inactive()
.await
.expect("connected");
let service = service_with_reconnect.service().await.expect("service");
// at this point, one successfull connection attempt
assert_eq!(connector.attempts.load(Ordering::Relaxed), 1);
// internet connection lost
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
service.close_channel();
sleep_and_catch_up(NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 2);
sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[0] + NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 3);
assert_matches!(service_with_reconnect.service().await, Err(_));
sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[1] + NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 4);
assert_matches!(service_with_reconnect.service().await, Err(_));
// now internet connection is back
// letting next cooldown interval pass and checking again
connector.set_connection_error(None);
sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[2] + NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 5);
assert_matches!(service_with_reconnect.service().await, Ok(_));
}
#[tokio::test(start_paused = true)]
async fn service_stops_reconnect_attempts_if_disconnected_after_some_time() {
let (connector, service_with_reconnect) = connector_and_service();
service_with_reconnect
.connect_from_inactive()
.await
.expect("connected");
let service = service_with_reconnect.service().await.expect("service");
// at this point, one successfull connection attempt
assert_eq!(connector.attempts.load(Ordering::Relaxed), 1);
// internet connection lost
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
service.close_channel();
sleep_and_catch_up(NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 2);
sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[0] + NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 3);
assert_matches!(service_with_reconnect.service().await, Err(_));
sleep_and_catch_up(CONNECTION_ROUTE_COOLDOWN_INTERVALS[1] + NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 4);
assert_matches!(service_with_reconnect.service().await, Err(_));
// now we decide to disconnect, and we need to make sure we're not making
// any more attempts
service_with_reconnect.disconnect().await;
for interval in CONNECTION_ROUTE_COOLDOWN_INTERVALS.into_iter().skip(2) {
sleep_and_catch_up(interval + NORMAL_CONNECTION_TIME).await;
assert_eq!(connector.attempts.load(Ordering::Relaxed), 4);
assert_matches!(
service_with_reconnect.service().await,
Err(StateError::Inactive)
);
}
}
#[tokio::test(start_paused = true)]
async fn reconnect_count_behaves_correctly() {
let (_, service_with_reconnect) = connector_and_service();
service_with_reconnect
.connect_from_inactive()
.await
.expect("connected");
let service = service_with_reconnect.service().await.expect("service");
// manual connection should not count as "reconnect"
assert_eq!(0, service_with_reconnect.reconnect_count());
// emulating unexpected disconnect
service.close_channel();
// giving time to reconnect
sleep_and_catch_up(NORMAL_CONNECTION_TIME).await;
// reconnect count should increase by 1
assert_eq!(1, service_with_reconnect.reconnect_count());
// now, manually disconnecting and connecting again
service_with_reconnect.disconnect().await;
service_with_reconnect
.connect_from_inactive()
.await
.expect("connected");
// reconnect count should not change
assert_eq!(1, service_with_reconnect.reconnect_count());
}
#[tokio::test(start_paused = true)]
async fn service_times_out_early_on_guard_contention() {
let (connector, service_with_reconnect) = connector_and_service();
let guard = service_with_reconnect.data.state.lock().await;
let service_for_task = service_with_reconnect.clone();
let connection_task =
tokio::spawn(async move { service_for_task.connect_from_inactive().await });
let connection_task = tokio::spawn(async move { service_for_task.connect().await });
sleep_and_catch_up(TIMEOUT_DURATION - MINIMUM_CONNECTION_TIME).await;
drop(guard);
@ -1020,106 +703,6 @@ mod test {
assert_eq!(connector.attempts_made(), 0);
}
#[tokio::test(start_paused = true)]
async fn intermittent_errors_do_not_get_propagated() {
let (connector, service_with_reconnect) = connector_and_service();
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
// A "reconnect" won't even make it to the connection error yet.
let inactive_error = service_with_reconnect
.reconnect_if_active()
.await
.expect_err("not active yet");
assert_matches!(inactive_error, ReconnectError::Inactive);
// A proper connect will.
let fatal_error = service_with_reconnect
.connect_from_inactive()
.await
.expect_err("should have returned the connection error");
assert_matches!(fatal_error, ReconnectError::AllRoutesFailed { .. });
assert_matches!(
*service_with_reconnect.data.state.lock().await,
ServiceState::Cooldown(_)
);
// Okay, let's connect properly...
time::advance(CONNECTION_ROUTE_MAX_COOLDOWN).await;
connector.set_connection_error(None);
service_with_reconnect
.connect_from_inactive()
.await
.expect("success");
let service = service_with_reconnect.service().await.expect("service");
// ...then disconnect...
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Intermittent)));
service.close_channel();
// ...then try to auto-reconnect.
let reconnect_error = service_with_reconnect
.reconnect_if_active()
.await
.expect_err("not active yet");
assert_matches!(reconnect_error, ReconnectError::AllRoutesFailed { .. });
assert_matches!(
*service_with_reconnect.data.state.lock().await,
ServiceState::Cooldown(_)
);
}
#[tokio::test(start_paused = true)]
async fn fatal_error_gets_propagated_on_explicit_connect_only() {
let (connector, service_with_reconnect) = connector_and_service();
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Fatal)));
// A "reconnect" won't even make it to the connection error yet.
let inactive_error = service_with_reconnect
.reconnect_if_active()
.await
.expect_err("not active yet");
assert_matches!(inactive_error, ReconnectError::Inactive);
// A proper connect will.
let fatal_error = service_with_reconnect
.connect_from_inactive()
.await
.expect_err("should have returned the connection error");
assert_matches!(
fatal_error,
ReconnectError::RejectedByServer(ClassifiableTestError(ErrorClass::Fatal))
);
// And it will leave us inactive.
assert_matches!(
*service_with_reconnect.data.state.lock().await,
ServiceState::Inactive
);
// Okay, let's connect properly...
connector.set_connection_error(None);
service_with_reconnect
.connect_from_inactive()
.await
.expect("success");
let service = service_with_reconnect.service().await.expect("service");
// ...then disconnect...
connector.set_connection_error(Some(ClassifiableTestError(ErrorClass::Fatal)));
service.close_channel();
// ...then try to auto-reconnect.
let reconnect_error = service_with_reconnect
.reconnect_if_active()
.await
.expect_err("not active yet");
assert_matches!(reconnect_error, ReconnectError::AllRoutesFailed { .. });
assert_matches!(
*service_with_reconnect.data.state.lock().await,
ServiceState::Error(ClassifiableTestError(ErrorClass::Fatal))
);
}
fn connector_and_service() -> (
TestServiceConnector,
ServiceWithReconnect<TestServiceConnector, SingleRouteThrottlingConnectionManager>,

View File

@ -19,12 +19,13 @@ use http::uri::PathAndQuery;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tokio_tungstenite::WebSocketStream;
use tokio_util::sync::CancellationToken;
use tungstenite::handshake::client::generate_key;
use tungstenite::protocol::CloseFrame;
use tungstenite::{http, Message};
use crate::infra::errors::LogSafeDisplay;
use crate::infra::reconnect::{ServiceConnector, ServiceStatus};
use crate::infra::reconnect::ServiceConnector;
use crate::infra::ws::error::{HttpFormatError, ProtocolError, SpaceError};
use crate::infra::{
Alpn, AsyncDuplexStream, ConnectionInfo, ConnectionParams, StreamAndInfo, TransportConnector,
@ -145,10 +146,7 @@ where
.map_err(Into::into)
}
fn start_service(
&self,
channel: Self::Channel,
) -> (Self::Service, ServiceStatus<Self::StartError>) {
fn start_service(&self, channel: Self::Channel) -> (Self::Service, CancellationToken) {
start_ws_service(
channel.0,
channel.1,
@ -163,19 +161,20 @@ fn start_ws_service<S: AsyncDuplexStream, E>(
connection_info: ConnectionInfo,
keep_alive_interval: Duration,
max_idle_time: Duration,
) -> (WebSocketClient<S, E>, ServiceStatus<E>) {
let service_status = ServiceStatus::default();
) -> (WebSocketClient<S, E>, CancellationToken) {
let service_cancellation = CancellationToken::new();
let (ws_sink, ws_stream) = channel.split();
let ws_client_writer = WebSocketClientWriter {
ws_sink: Arc::new(Mutex::new(ws_sink)),
service_status: service_status.clone(),
service_cancellation: service_cancellation.clone(),
error_type: Default::default(),
};
let ws_client_reader = WebSocketClientReader {
ws_stream,
keep_alive_interval,
max_idle_time,
ws_writer: ws_client_writer.clone(),
service_status: service_status.clone(),
service_cancellation: service_cancellation.clone(),
last_frame_received: Instant::now(),
last_keepalive_sent: Instant::now(),
};
@ -185,7 +184,7 @@ fn start_ws_service<S: AsyncDuplexStream, E>(
ws_client_reader,
connection_info,
},
service_status,
service_cancellation,
)
}
@ -193,7 +192,8 @@ fn start_ws_service<S: AsyncDuplexStream, E>(
#[derive(Debug)]
pub(crate) struct WebSocketClientWriter<S, E> {
ws_sink: Arc<Mutex<SplitSink<WebSocketStream<S>, Message>>>,
service_status: ServiceStatus<E>,
service_cancellation: CancellationToken,
error_type: PhantomData<E>,
}
impl<S: AsyncDuplexStream, E> WebSocketClientWriter<S, E>
@ -201,7 +201,7 @@ where
WebSocketServiceError: Into<E>,
{
pub async fn send(&self, message: impl Into<Message>) -> Result<(), E> {
run_and_update_status(&self.service_status, || {
run_and_update_status(&self.service_cancellation, || {
async {
let mut guard = self.ws_sink.lock().await;
guard.send(message.into()).await?;
@ -218,7 +218,7 @@ where
pub(crate) struct WebSocketClientReader<S, E> {
ws_stream: SplitStream<WebSocketStream<S>>,
ws_writer: WebSocketClientWriter<S, E>,
service_status: ServiceStatus<E>,
service_cancellation: CancellationToken,
keep_alive_interval: Duration,
max_idle_time: Duration,
last_frame_received: Instant,
@ -236,7 +236,7 @@ where
IdleTimeout,
StopService,
}
run_and_update_status(&self.service_status, || async {
run_and_update_status(&self.service_cancellation, || async {
loop {
// first, waiting for the next lifecycle action
let next_ping_time = self.last_keepalive_sent + self.keep_alive_interval;
@ -245,7 +245,7 @@ where
maybe_message = self.ws_stream.next() => Event::Message(maybe_message),
_ = tokio::time::sleep_until(next_ping_time) => Event::SendKeepAlive,
_ = tokio::time::sleep_until(idle_timeout_time) => Event::IdleTimeout,
_ = self.service_status.stopped() => Event::StopService,
_ = self.service_cancellation.cancelled() => Event::StopService,
} {
Event::SendKeepAlive => {
self.ws_writer.send(Message::Ping(vec![])).await?;
@ -281,7 +281,7 @@ where
Message::Binary(b) => return Ok(NextOrClose::Next(b.into())),
Message::Ping(_) | Message::Pong(_) => continue,
Message::Close(close_frame) => {
self.service_status.stop_service();
self.service_cancellation.cancel();
return Ok(NextOrClose::Close(close_frame));
}
Message::Frame(_) => unreachable!("only for sending"),
@ -292,18 +292,21 @@ where
}
}
async fn run_and_update_status<T, F, Ft, E>(service_status: &ServiceStatus<E>, f: F) -> Result<T, E>
async fn run_and_update_status<T, F, Ft, E>(
service_status: &CancellationToken,
f: F,
) -> Result<T, E>
where
WebSocketServiceError: Into<E>,
F: FnOnce() -> Ft,
Ft: Future<Output = Result<T, E>>,
{
if service_status.is_stopped() {
if service_status.is_cancelled() {
return Err(WebSocketServiceError::ChannelClosed.into());
}
let result = f().await;
if result.is_err() {
service_status.stop_service();
service_status.cancel();
}
result.map_err(Into::into)
}

View File

@ -81,9 +81,9 @@ public class AuthenticatedChatService: NativeHandleOwner, ChatService {
withNativeHandle { chatService in
if let listener {
var listenerStruct = ChatListenerBridge(chatService: self, chatListener: listener).makeListenerStruct()
failOnError(signal_chat_server_set_listener(tokioAsyncContext, chatService, &listenerStruct))
failOnError(signal_chat_service_set_listener_auth(tokioAsyncContext, chatService, &listenerStruct))
} else {
failOnError(signal_chat_server_set_listener(tokioAsyncContext, chatService, nil))
failOnError(signal_chat_service_set_listener_auth(tokioAsyncContext, chatService, nil))
}
}
}

View File

@ -1583,7 +1583,7 @@ SignalFfiError *signal_chat_service_auth_send(SignalCPromiseFfiChatResponse *pro
SignalFfiError *signal_chat_service_auth_send_and_debug(SignalCPromiseFfiResponseAndDebugInfo *promise, const SignalTokioAsyncContext *async_runtime, const SignalChat *chat, const SignalHttpRequest *http_request, uint32_t timeout_millis);
SignalFfiError *signal_chat_server_set_listener(const SignalTokioAsyncContext *runtime, const SignalChat *chat, const SignalFfiMakeChatListenerStruct *make_listener);
SignalFfiError *signal_chat_service_set_listener_auth(const SignalTokioAsyncContext *runtime, const SignalChat *chat, const SignalFfiMakeChatListenerStruct *make_listener);
SignalFfiError *signal_testing_chat_service_inject_raw_server_request(const SignalChat *chat, SignalBorrowedBuffer bytes);