0
0
mirror of https://github.com/signalapp/libsignal.git synced 2024-09-20 03:52:17 +02:00

net: Handle responses to requests even if the socket has since closed

This commit is contained in:
Jordan Rose 2024-05-31 17:03:37 -07:00
parent 0134e3e15c
commit 7202905f5e

View File

@ -103,8 +103,19 @@ struct PendingMessagesMap {
next_id: u64,
}
struct NoMoreRequests;
impl PendingMessagesMap {
fn insert(&mut self, responder: oneshot::Sender<ResponseProto>) -> RequestId {
const CANCELLED: u64 = u64::MAX;
fn insert(
&mut self,
responder: oneshot::Sender<ResponseProto>,
) -> Result<RequestId, NoMoreRequests> {
if self.next_id == Self::CANCELLED {
return Err(NoMoreRequests);
}
let id = RequestId::new(self.next_id);
let prev = self.pending.insert(id, responder);
assert!(
@ -112,12 +123,17 @@ impl PendingMessagesMap {
"IDs are picked uniquely and shouldn't wrap around in a reasonable amount of time"
);
self.next_id += 1;
id
Ok(id)
}
fn remove(&mut self, id: &RequestId) -> Option<oneshot::Sender<ResponseProto>> {
self.pending.remove(id)
}
fn cancel_all(&mut self) {
self.next_id = Self::CANCELLED;
self.pending.clear();
}
}
#[derive_where(Clone)]
@ -272,6 +288,10 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
}
// before terminating the task, marking channel as inactive
service_status.stop_service();
// Clear the pending messages map. These requests don't wait on the service status just in case
// a response comes in late; dropping the response senders is how we cancel them.
pending_messages.lock().await.cancel_all();
}
#[derive_where(Clone)]
@ -305,7 +325,10 @@ where
// defining a scope here to release the lock ASAP
let id = {
let map = &mut self.pending_messages.lock().await;
// It's possible that the service has been stopped between the check above and the
// insert below. This accounts for that.
map.insert(response_tx)
.map_err(|_| WebSocketServiceError::ChannelClosed)?
};
let msg = request_to_websocket_proto(msg, id)
@ -313,19 +336,15 @@ where
self.ws_client_writer.send(msg.encode_to_vec()).await?;
let res = tokio::select! {
result = response_rx => Ok(result.expect("sender is not dropped before receiver")),
_ = tokio::time::sleep(timeout) => Err(ChatServiceError::Timeout),
_ = self.service_status.stopped() => Err(WebSocketServiceError::ChannelClosed.into())
tokio::select! {
result = response_rx => result.map_err(|_| WebSocketServiceError::ChannelClosed.into()),
_ = tokio::time::sleep(timeout) => {
let map = &mut self.pending_messages.lock().await;
map.remove(&id);
Err(ChatServiceError::Timeout)
},
}
.and_then(|response_proto| Ok(response_proto.try_into()?));
if res.is_err() {
// in case of an error we need to clean up the listener from the `pending_messages` map
let map = &mut self.pending_messages.lock().await;
map.remove(&id);
}
res
.and_then(|response_proto| Ok(response_proto.try_into()?))
}
async fn connect(&self) -> Result<(), ChatServiceError> {
@ -593,8 +612,38 @@ mod test {
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn ws_service_fails_request_if_stopped_before_reponse_received() {
// creating a server that responds to requests with 200
async fn ws_service_request_succeeds_even_if_server_closes_immediately_after() {
// creating a server that accepts one request, responds with 200, and then closes
let (ws_server, server_res_rx) = ws_warp_filter(move |websocket| async move {
let (mut tx, mut rx) = websocket.split();
let msg = rx
.next()
.await
.expect("stream should not be closed")
.expect("should be Ok");
assert!(msg.is_binary(), "not binary: {msg:?}");
let request = decode_and_validate(msg.as_bytes()).expect("chat message");
let message_proto =
response_for_request(&request, StatusCode::OK).expect("not an error");
let send_result = tx
.send(warp::ws::Message::binary(message_proto.encode_to_vec()))
.await;
assert_matches!(send_result, Ok(_));
});
let ws_config = test_ws_config();
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
let response = ws_chat
.send(test_request(Method::GET, "/"), TIMEOUT_DURATION)
.await;
response.expect("response");
validate_server_stopped_successfully(server_res_rx).await;
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn ws_service_fails_request_if_stopped_before_response_received() {
// creating a server that accepts one request and then closes
let (ws_server, server_res_rx) = ws_warp_filter(move |websocket| async move {
let (mut tx, mut rx) = websocket.split();
let _: warp::ws::Message = rx.next().await.expect("not closed").expect("not an error");