mirror of
https://github.com/signalapp/libsignal.git
synced 2024-09-20 03:52:17 +02:00
net: Handle responses to requests even if the socket has since closed
This commit is contained in:
parent
0134e3e15c
commit
7202905f5e
@ -103,8 +103,19 @@ struct PendingMessagesMap {
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
struct NoMoreRequests;
|
||||
|
||||
impl PendingMessagesMap {
|
||||
fn insert(&mut self, responder: oneshot::Sender<ResponseProto>) -> RequestId {
|
||||
const CANCELLED: u64 = u64::MAX;
|
||||
|
||||
fn insert(
|
||||
&mut self,
|
||||
responder: oneshot::Sender<ResponseProto>,
|
||||
) -> Result<RequestId, NoMoreRequests> {
|
||||
if self.next_id == Self::CANCELLED {
|
||||
return Err(NoMoreRequests);
|
||||
}
|
||||
|
||||
let id = RequestId::new(self.next_id);
|
||||
let prev = self.pending.insert(id, responder);
|
||||
assert!(
|
||||
@ -112,12 +123,17 @@ impl PendingMessagesMap {
|
||||
"IDs are picked uniquely and shouldn't wrap around in a reasonable amount of time"
|
||||
);
|
||||
self.next_id += 1;
|
||||
id
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn remove(&mut self, id: &RequestId) -> Option<oneshot::Sender<ResponseProto>> {
|
||||
self.pending.remove(id)
|
||||
}
|
||||
|
||||
fn cancel_all(&mut self) {
|
||||
self.next_id = Self::CANCELLED;
|
||||
self.pending.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive_where(Clone)]
|
||||
@ -272,6 +288,10 @@ async fn reader_task<S: AsyncDuplexStream + 'static>(
|
||||
}
|
||||
// before terminating the task, marking channel as inactive
|
||||
service_status.stop_service();
|
||||
|
||||
// Clear the pending messages map. These requests don't wait on the service status just in case
|
||||
// a response comes in late; dropping the response senders is how we cancel them.
|
||||
pending_messages.lock().await.cancel_all();
|
||||
}
|
||||
|
||||
#[derive_where(Clone)]
|
||||
@ -305,7 +325,10 @@ where
|
||||
// defining a scope here to release the lock ASAP
|
||||
let id = {
|
||||
let map = &mut self.pending_messages.lock().await;
|
||||
// It's possible that the service has been stopped between the check above and the
|
||||
// insert below. This accounts for that.
|
||||
map.insert(response_tx)
|
||||
.map_err(|_| WebSocketServiceError::ChannelClosed)?
|
||||
};
|
||||
|
||||
let msg = request_to_websocket_proto(msg, id)
|
||||
@ -313,19 +336,15 @@ where
|
||||
|
||||
self.ws_client_writer.send(msg.encode_to_vec()).await?;
|
||||
|
||||
let res = tokio::select! {
|
||||
result = response_rx => Ok(result.expect("sender is not dropped before receiver")),
|
||||
_ = tokio::time::sleep(timeout) => Err(ChatServiceError::Timeout),
|
||||
_ = self.service_status.stopped() => Err(WebSocketServiceError::ChannelClosed.into())
|
||||
tokio::select! {
|
||||
result = response_rx => result.map_err(|_| WebSocketServiceError::ChannelClosed.into()),
|
||||
_ = tokio::time::sleep(timeout) => {
|
||||
let map = &mut self.pending_messages.lock().await;
|
||||
map.remove(&id);
|
||||
Err(ChatServiceError::Timeout)
|
||||
},
|
||||
}
|
||||
.and_then(|response_proto| Ok(response_proto.try_into()?));
|
||||
|
||||
if res.is_err() {
|
||||
// in case of an error we need to clean up the listener from the `pending_messages` map
|
||||
let map = &mut self.pending_messages.lock().await;
|
||||
map.remove(&id);
|
||||
}
|
||||
res
|
||||
.and_then(|response_proto| Ok(response_proto.try_into()?))
|
||||
}
|
||||
|
||||
async fn connect(&self) -> Result<(), ChatServiceError> {
|
||||
@ -593,8 +612,38 @@ mod test {
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread", start_paused = true)]
|
||||
async fn ws_service_fails_request_if_stopped_before_reponse_received() {
|
||||
// creating a server that responds to requests with 200
|
||||
async fn ws_service_request_succeeds_even_if_server_closes_immediately_after() {
|
||||
// creating a server that accepts one request, responds with 200, and then closes
|
||||
let (ws_server, server_res_rx) = ws_warp_filter(move |websocket| async move {
|
||||
let (mut tx, mut rx) = websocket.split();
|
||||
let msg = rx
|
||||
.next()
|
||||
.await
|
||||
.expect("stream should not be closed")
|
||||
.expect("should be Ok");
|
||||
assert!(msg.is_binary(), "not binary: {msg:?}");
|
||||
let request = decode_and_validate(msg.as_bytes()).expect("chat message");
|
||||
let message_proto =
|
||||
response_for_request(&request, StatusCode::OK).expect("not an error");
|
||||
let send_result = tx
|
||||
.send(warp::ws::Message::binary(message_proto.encode_to_vec()))
|
||||
.await;
|
||||
assert_matches!(send_result, Ok(_));
|
||||
});
|
||||
|
||||
let ws_config = test_ws_config();
|
||||
let (ws_chat, _) = create_ws_chat_service(ws_config, ws_server).await;
|
||||
|
||||
let response = ws_chat
|
||||
.send(test_request(Method::GET, "/"), TIMEOUT_DURATION)
|
||||
.await;
|
||||
response.expect("response");
|
||||
validate_server_stopped_successfully(server_res_rx).await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread", start_paused = true)]
|
||||
async fn ws_service_fails_request_if_stopped_before_response_received() {
|
||||
// creating a server that accepts one request and then closes
|
||||
let (ws_server, server_res_rx) = ws_warp_filter(move |websocket| async move {
|
||||
let (mut tx, mut rx) = websocket.split();
|
||||
let _: warp::ws::Message = rx.next().await.expect("not closed").expect("not an error");
|
||||
|
Loading…
Reference in New Issue
Block a user