0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-19 19:42:18 +02:00

Switch websocket-resources from ListenableFuture to CompletableFuture

This commit is contained in:
Moxie Marlinspike 2019-05-02 15:05:44 -07:00
parent 7e4b572699
commit 0c81556b90
6 changed files with 74 additions and 96 deletions

View File

@ -1,8 +1,5 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -10,7 +7,6 @@ import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import java.util.Optional;
@ -32,19 +28,12 @@ public class ProvisioningConnection implements DispatchChannel {
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/v1/message", null, body);
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
client.close(1001, "All you get.");
}
@Override
public void onFailure(Throwable throwable) {
client.close(1001, "That's all!");
}
});
client.sendRequest("PUT", "/v1/message", null, body)
.thenAccept(response -> client.close(1001, "All you get."))
.exceptionally(throwable -> {
client.close(1001, "That's all!");
return null;
});
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Protobuf Error: ", e);

View File

@ -3,9 +3,6 @@ package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
@ -28,8 +25,6 @@ import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.ws.rs.WebApplicationException;
import java.util.Collections;
import java.util.Iterator;
@ -123,35 +118,26 @@ public class WebSocketConnection implements DispatchChannel {
body = Optional.ofNullable(new EncryptedOutgoingMessage(message, device.getSignalingKey()).toByteArray());
}
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/api/v1/message", Collections.singletonList(header), body);
client.sendRequest("PUT", "/api/v1/message", Collections.singletonList(header), body)
.thenAccept(response -> {
boolean isReceipt = message.getType() == Envelope.Type.RECEIPT;
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(@Nullable WebSocketResponseMessage response) {
boolean isReceipt = message.getType() == Envelope.Type.RECEIPT;
if (isSuccessResponse(response) && !isReceipt) {
messageTime.update(System.currentTimeMillis() - message.getTimestamp());
}
if (isSuccessResponse(response) && !isReceipt) {
messageTime.update(System.currentTimeMillis() - message.getTimestamp());
}
if (isSuccessResponse(response)) {
if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached);
if (!isReceipt) sendDeliveryReceiptFor(message);
if (requery) processStoredMessages();
} else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) {
requeueMessage(message);
}
}
@Override
public void onFailure(@Nonnull Throwable throwable) {
if (!storedMessageInfo.isPresent()) requeueMessage(message);
}
private boolean isSuccessResponse(WebSocketResponseMessage response) {
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
}
});
if (isSuccessResponse(response)) {
if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached);
if (!isReceipt) sendDeliveryReceiptFor(message);
if (requery) processStoredMessages();
} else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) {
requeueMessage(message);
}
})
.exceptionally(throwable -> {
if (!storedMessageInfo.isPresent()) requeueMessage(message);
return null;
});
} catch (CryptoEncodingException e) {
logger.warn("Bad signaling key", e);
}
@ -179,6 +165,10 @@ public class WebSocketConnection implements DispatchChannel {
}
}
private boolean isSuccessResponse(WebSocketResponseMessage response) {
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
}
private void processStoredMessages() {
OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), device.getId());
Iterator<OutgoingMessageEntity> iterator = messages.getMessages().iterator();

View File

@ -1,6 +1,5 @@
package org.whispersystems.textsecuregcm.tests.websocket;
import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.ByteString;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Test;
@ -39,6 +38,7 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import io.dropwizard.auth.basic.BasicCredentials;
import static org.junit.Assert.*;
@ -134,14 +134,14 @@ public class WebSocketConnectionTest {
when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId()))
.thenReturn(outgoingMessagesList);
final List<SettableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<SettableFuture<WebSocketResponseMessage>>() {
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public SettableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
SettableFuture<WebSocketResponseMessage> future = SettableFuture.create();
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
@ -158,10 +158,10 @@ public class WebSocketConnectionTest {
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(1).set(response);
futures.get(1).complete(response);
futures.get(0).setException(new IOException());
futures.get(2).setException(new IOException());
futures.get(0).completeExceptionally(new IOException());
futures.get(2).completeExceptionally(new IOException());
verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L), eq(2L), eq(false));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L));
@ -217,14 +217,14 @@ public class WebSocketConnectionTest {
when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId()))
.thenReturn(pendingMessagesList);
final List<SettableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<SettableFuture<WebSocketResponseMessage>>() {
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public SettableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
SettableFuture<WebSocketResponseMessage> future = SettableFuture.create();
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
@ -251,8 +251,8 @@ public class WebSocketConnectionTest {
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(1).set(response);
futures.get(0).setException(new IOException());
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()));
verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class));
@ -322,14 +322,14 @@ public class WebSocketConnectionTest {
when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId()))
.thenReturn(pendingMessagesList);
final List<SettableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<SettableFuture<WebSocketResponseMessage>>() {
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public SettableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
SettableFuture<WebSocketResponseMessage> future = SettableFuture.create();
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
@ -347,8 +347,8 @@ public class WebSocketConnectionTest {
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(1).set(response);
futures.get(0).setException(new IOException());
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()));
verifyNoMoreInteractions(websocketSender);

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -16,8 +16,6 @@
*/
package org.whispersystems.websocket;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketException;
@ -34,20 +32,21 @@ import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketClient {
private static final Logger logger = LoggerFactory.getLogger(WebSocketClient.class);
private final Session session;
private final RemoteEndpoint remoteEndpoint;
private final WebSocketMessageFactory messageFactory;
private final Map<Long, SettableFuture<WebSocketResponseMessage>> pendingRequestMapper;
private final Session session;
private final RemoteEndpoint remoteEndpoint;
private final WebSocketMessageFactory messageFactory;
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper;
public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint,
WebSocketMessageFactory messageFactory,
Map<Long, SettableFuture<WebSocketResponseMessage>> pendingRequestMapper)
Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper)
{
this.session = session;
this.remoteEndpoint = remoteEndpoint;
@ -55,12 +54,12 @@ public class WebSocketClient {
this.pendingRequestMapper = pendingRequestMapper;
}
public ListenableFuture<WebSocketResponseMessage> sendRequest(String verb, String path,
List<String> headers,
Optional<byte[]> body)
public CompletableFuture<WebSocketResponseMessage> sendRequest(String verb, String path,
List<String> headers,
Optional<byte[]> body)
{
final long requestId = generateRequestId();
final SettableFuture<WebSocketResponseMessage> future = SettableFuture.create();
final long requestId = generateRequestId();
final CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
pendingRequestMapper.put(requestId, future);
@ -72,7 +71,7 @@ public class WebSocketClient {
public void writeFailed(Throwable x) {
logger.debug("Write failed", x);
pendingRequestMapper.remove(requestId);
future.setException(x);
future.completeExceptionally(x);
}
@Override
@ -81,7 +80,7 @@ public class WebSocketClient {
} catch (WebSocketException e) {
logger.debug("Write", e);
pendingRequestMapper.remove(requestId);
future.setException(e);
future.completeExceptionally(e);
}
return future;

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -17,7 +17,6 @@
package org.whispersystems.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.SettableFuture;
import org.eclipse.jetty.server.RequestLog;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
@ -48,6 +47,7 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
@ -56,7 +56,7 @@ public class WebSocketResourceProvider implements WebSocketListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProvider.class);
private final Map<Long, SettableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final Object authenticated;
private final WebSocketMessageFactory messageFactory;
@ -131,10 +131,10 @@ public class WebSocketResourceProvider implements WebSocketListener {
context.notifyClosed(statusCode, reason);
for (long requestId : requestMap.keySet()) {
SettableFuture outstandingRequest = requestMap.remove(requestId);
CompletableFuture outstandingRequest = requestMap.remove(requestId);
if (outstandingRequest != null) {
outstandingRequest.setException(new IOException("Connection closed!"));
outstandingRequest.completeExceptionally(new IOException("Connection closed!"));
}
}
}
@ -160,10 +160,10 @@ public class WebSocketResourceProvider implements WebSocketListener {
}
private void handleResponse(WebSocketResponseMessage responseMessage) {
SettableFuture<WebSocketResponseMessage> future = requestMap.remove(responseMessage.getRequestId());
CompletableFuture<WebSocketResponseMessage> future = requestMap.remove(responseMessage.getRequestId());
if (future != null) {
future.set(responseMessage);
future.complete(responseMessage);
}
}
@ -197,7 +197,7 @@ public class WebSocketResourceProvider implements WebSocketListener {
error.getStatus(),
"Error response",
headers,
Optional.<byte[]>empty());
Optional.empty());
remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(response.toByteArray()));
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify