0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-20 03:52:16 +02:00

Use reactive streams for WebSocket message queue

Initially, uses `ExperimentEnrollmentManager` to do a safe rollout.
This commit is contained in:
Chris Eager 2022-10-31 10:35:37 -05:00 committed by GitHub
parent 4252284405
commit c10fda8363
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 2359 additions and 1260 deletions

View File

@ -57,7 +57,7 @@
<jedis.version>2.9.0</jedis.version>
<kotlin.version>1.7.10</kotlin.version>
<kotlinx-serialization.version>1.4.0</kotlinx-serialization.version>
<lettuce.version>6.1.9.RELEASE</lettuce.version>
<lettuce.version>6.2.0.RELEASE</lettuce.version>
<libphonenumber.version>8.12.54</libphonenumber.version>
<logstash.logback.version>7.0.1</logstash.logback.version>
<micrometer.version>1.9.3</micrometer.version>
@ -151,6 +151,13 @@
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-bom</artifactId>
<version>2020.0.23</version> <!-- 3.4.x, see https://github.com/reactor/reactor#bom-versioning-scheme -->
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>com.eatthepath</groupId>
<artifactId>pushy</artifactId>

View File

@ -228,6 +228,10 @@
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-retry</artifactId>
</dependency>
<dependency>
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-reactor</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
@ -407,7 +411,6 @@
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
<version>3.3.22.RELEASE</version>
</dependency>
<dependency>
<groupId>io.vavr</groupId>
@ -420,6 +423,11 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-test</artifactId>
</dependency>
<dependency>
<groupId>org.signal</groupId>
<artifactId>embedded-redis</artifactId>

View File

@ -116,7 +116,6 @@ import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.limits.DynamicRateLimiters;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeOptionManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
@ -330,6 +329,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getAppConfig().getConfigurationName(),
DynamicConfiguration.class);
BlockingQueue<Runnable> messageDeletionQueue = new ArrayBlockingQueue<>(10_000);
Metrics.gaugeCollectionSize(name(getClass(), "messageDeletionQueueSize"), Collections.emptyList(),
messageDeletionQueue);
ExecutorService messageDeletionAsyncExecutor = environment.lifecycle()
.executorService(name(getClass(), "messageDeletionAsyncExecutor-%d")).maxThreads(16)
.workQueue(messageDeletionQueue).build();
Accounts accounts = new Accounts(dynamicConfigurationManager,
dynamoDbClient,
dynamoDbAsyncClient,
@ -345,9 +351,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient, config.getDynamoDbTables().getKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient,
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration());
config.getDynamoDbTables().getMessages().getExpiration(),
messageDeletionAsyncExecutor);
RemoteConfigs remoteConfigs = new RemoteConfigs(dynamoDbClient,
config.getDynamoDbTables().getRemoteConfig().getTableName());
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbClient,
@ -452,9 +459,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DirectoryQueue directoryQueue = new DirectoryQueue(config.getDirectoryConfiguration().getSqsConfiguration());
StoredVerificationCodeManager pendingAccountsManager = new StoredVerificationCodeManager(pendingAccounts);
StoredVerificationCodeManager pendingDevicesManager = new StoredVerificationCodeManager(pendingDevices);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, messagesCluster, keyspaceNotificationDispatchExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster, dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, messagesCluster, Clock.systemUTC(),
keyspaceNotificationDispatchExecutor, messageDeletionAsyncExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster, dynamicConfigurationManager);
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster, config.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager);
UsernameGenerator usernameGenerator = new UsernameGenerator(config.getUsername());
@ -503,8 +511,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, pushChallengeDynamoDb);
RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
recaptchaClient, dynamicRateLimiters);
RateLimitChallengeOptionManager rateLimitChallengeOptionManager =
new RateLimitChallengeOptionManager(dynamicRateLimiters, dynamicConfigurationManager);
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
@ -628,8 +634,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager,
clientPresenceManager, websocketScheduledExecutor));
webSocketEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
clientPresenceManager, websocketScheduledExecutor, experimentEnrollmentManager));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new ContentLengthFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
@ -30,6 +30,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
@ -538,25 +539,29 @@ public class MessageController {
@Timed
@DELETE
@Path("/uuid/{uuid}")
public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) {
messagesManager.delete(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
uuid,
null).ifPresent(deletedMessage -> {
public CompletableFuture<Void> removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) {
return messagesManager.delete(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
uuid,
null)
.thenAccept(maybeDeletedMessage -> {
maybeDeletedMessage.ifPresent(deletedMessage -> {
WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(), auth.getAuthenticatedDevice());
WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(),
auth.getAuthenticatedDevice());
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
try {
receiptSender.sendReceipt(
UUID.fromString(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(),
UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
}
}
});
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
try {
receiptSender.sendReceipt(
UUID.fromString(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(),
UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
}
}
});
});
}
@Timed

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.redis;
import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisNoScriptException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
@ -15,9 +16,12 @@ import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.codec.binary.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
public class ClusterLuaScript {
@ -73,11 +77,31 @@ public class ClusterLuaScript {
execute(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)));
}
public CompletableFuture<Object> executeAsync(final List<String> keys, final List<String> args) {
return redisCluster.withCluster(connection ->
executeAsync(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)));
}
public Flux<Object> executeReactive(final List<String> keys, final List<String> args) {
return redisCluster.withCluster(connection ->
executeReactive(connection, keys.toArray(STRING_ARRAY), args.toArray(STRING_ARRAY)));
}
public Object executeBinary(final List<byte[]> keys, final List<byte[]> args) {
return redisCluster.withBinaryCluster(connection ->
execute(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)));
}
public CompletableFuture<Object> executeBinaryAsync(final List<byte[]> keys, final List<byte[]> args) {
return redisCluster.withBinaryCluster(connection ->
executeAsync(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)));
}
public Flux<Object> executeBinaryReactive(final List<byte[]> keys, final List<byte[]> args) {
return redisCluster.withBinaryCluster(connection ->
executeReactive(connection, keys.toArray(BYTE_ARRAY_ARRAY), args.toArray(BYTE_ARRAY_ARRAY)));
}
private <T> Object execute(final StatefulRedisClusterConnection<T, T> connection, final T[] keys, final T[] args) {
try {
try {
@ -90,4 +114,32 @@ public class ClusterLuaScript {
throw e;
}
}
private <T> CompletableFuture<Object> executeAsync(final StatefulRedisClusterConnection<T, T> connection,
final T[] keys, final T[] args) {
return connection.async().evalsha(sha, scriptOutputType, keys, args)
.exceptionallyCompose(throwable -> {
if (throwable instanceof RedisNoScriptException) {
return connection.async().eval(script, scriptOutputType, keys, args);
}
log.warn("Failed to execute script", throwable);
throw new RedisException(throwable);
}).toCompletableFuture();
}
private <T> Flux<Object> executeReactive(final StatefulRedisClusterConnection<T, T> connection,
final T[] keys, final T[] args) {
return connection.reactive().evalsha(sha, scriptOutputType, keys, args)
.onErrorResume(e -> {
if (e instanceof RedisNoScriptException) {
return connection.reactive().eval(script, scriptOutputType, keys, args);
}
log.warn("Failed to execute script", e);
return Mono.error(e);
});
}
}

View File

@ -8,6 +8,8 @@ package org.whispersystems.textsecuregcm.redis;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.github.resilience4j.reactor.circuitbreaker.operator.CircuitBreakerOperator;
import io.github.resilience4j.reactor.retry.RetryOperator;
import io.github.resilience4j.retry.Retry;
import io.lettuce.core.ClientOptions.DisconnectedBehavior;
import io.lettuce.core.RedisCommandTimeoutException;
@ -24,11 +26,13 @@ import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil;
import org.whispersystems.textsecuregcm.util.Constants;
import reactor.core.publisher.Flux;
/**
* A fault-tolerant access manager for a Redis cluster. A fault-tolerant Redis cluster provides managed,
@ -81,64 +85,79 @@ public class FaultTolerantRedisCluster {
}
void shutdown() {
stringConnection.close();
binaryConnection.close();
stringConnection.close();
binaryConnection.close();
for (final StatefulRedisClusterPubSubConnection<?, ?> pubSubConnection : pubSubConnections) {
pubSubConnection.close();
}
for (final StatefulRedisClusterPubSubConnection<?, ?> pubSubConnection : pubSubConnections) {
pubSubConnection.close();
}
clusterClient.shutdown();
clusterClient.shutdown();
}
public String getName() {
return name;
}
public String getName() {
return name;
}
public void useCluster(final Consumer<StatefulRedisClusterConnection<String, String>> consumer) {
useConnection(stringConnection, consumer);
}
public void useCluster(final Consumer<StatefulRedisClusterConnection<String, String>> consumer) {
useConnection(stringConnection, consumer);
}
public <T> T withCluster(final Function<StatefulRedisClusterConnection<String, String>, T> function) {
return withConnection(stringConnection, function);
}
public <T> T withCluster(final Function<StatefulRedisClusterConnection<String, String>, T> function) {
return withConnection(stringConnection, function);
}
public void useBinaryCluster(final Consumer<StatefulRedisClusterConnection<byte[], byte[]>> consumer) {
useConnection(binaryConnection, consumer);
}
public void useBinaryCluster(final Consumer<StatefulRedisClusterConnection<byte[], byte[]>> consumer) {
useConnection(binaryConnection, consumer);
}
public <T> T withBinaryCluster(final Function<StatefulRedisClusterConnection<byte[], byte[]>, T> function) {
return withConnection(binaryConnection, function);
}
public <T> T withBinaryCluster(final Function<StatefulRedisClusterConnection<byte[], byte[]>, T> function) {
return withConnection(binaryConnection, function);
}
private <K, V> void useConnection(final StatefulRedisClusterConnection<K, V> connection, final Consumer<StatefulRedisClusterConnection<K, V>> consumer) {
try {
circuitBreaker.executeCheckedRunnable(() -> retry.executeRunnable(() -> consumer.accept(connection)));
} catch (final Throwable t) {
if (t instanceof RedisException) {
throw (RedisException) t;
} else {
throw new RedisException(t);
}
}
}
public <T> Publisher<T> withBinaryClusterReactive(
final Function<StatefulRedisClusterConnection<byte[], byte[]>, Publisher<T>> function) {
return withConnectionReactive(binaryConnection, function);
}
private <T, K, V> T withConnection(final StatefulRedisClusterConnection<K, V> connection, final Function<StatefulRedisClusterConnection<K, V>, T> function) {
try {
return circuitBreaker.executeCheckedSupplier(() -> retry.executeCallable(() -> function.apply(connection)));
} catch (final Throwable t) {
if (t instanceof RedisException) {
throw (RedisException) t;
} else {
throw new RedisException(t);
}
}
private <K, V> void useConnection(final StatefulRedisClusterConnection<K, V> connection,
final Consumer<StatefulRedisClusterConnection<K, V>> consumer) {
try {
circuitBreaker.executeCheckedRunnable(() -> retry.executeRunnable(() -> consumer.accept(connection)));
} catch (final Throwable t) {
if (t instanceof RedisException) {
throw (RedisException) t;
} else {
throw new RedisException(t);
}
}
}
public FaultTolerantPubSubConnection<String, String> createPubSubConnection() {
final StatefulRedisClusterPubSubConnection<String, String> pubSubConnection = clusterClient.connectPubSub();
pubSubConnections.add(pubSubConnection);
return new FaultTolerantPubSubConnection<>(name, pubSubConnection, circuitBreaker, retry);
private <T, K, V> T withConnection(final StatefulRedisClusterConnection<K, V> connection,
final Function<StatefulRedisClusterConnection<K, V>, T> function) {
try {
return circuitBreaker.executeCheckedSupplier(() -> retry.executeCallable(() -> function.apply(connection)));
} catch (final Throwable t) {
if (t instanceof RedisException) {
throw (RedisException) t;
} else {
throw new RedisException(t);
}
}
}
private <T, K, V> Publisher<T> withConnectionReactive(final StatefulRedisClusterConnection<K, V> connection,
final Function<StatefulRedisClusterConnection<K, V>, Publisher<T>> function) {
return Flux.from(function.apply(connection))
.transformDeferred(RetryOperator.of(retry))
.transformDeferred(CircuitBreakerOperator.of(circuitBreaker));
}
public FaultTolerantPubSubConnection<String, String> createPubSubConnection() {
final StatefulRedisClusterPubSubConnection<String, String> pubSubConnection = clusterClient.connectPubSub();
pubSubConnections.add(pubSubConnection);
return new FaultTolerantPubSubConnection<>(name, pubSubConnection, circuitBreaker, retry);
}
}

View File

@ -26,7 +26,7 @@ import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
public class AbstractDynamoDbStore {
public abstract class AbstractDynamoDbStore {
private final DynamoDbClient dynamoDbClient;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -22,6 +22,7 @@ import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
@ -34,23 +35,32 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
public class MessagesCache extends RedisClusterPubSubAdapter<String, String> implements Managed {
private final FaultTolerantRedisCluster readDeleteCluster;
private final FaultTolerantPubSubConnection<String, String> pubSubConnection;
private final Clock clock;
private final ExecutorService notificationExecutorService;
private final ExecutorService messageDeletionExecutorService;
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByGuidScript;
@ -79,22 +89,23 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::";
private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::";
private static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
@VisibleForTesting
static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
private static final String REMOVE_TIMER_NAME = name(MessagesCache.class, "remove");
private static final String REMOVE_METHOD_TAG = "method";
private static final String REMOVE_METHOD_UUID = "uuid";
private static final int PAGE_SIZE = 100;
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
public MessagesCache(final FaultTolerantRedisCluster insertCluster, final FaultTolerantRedisCluster readDeleteCluster,
final ExecutorService notificationExecutorService) throws IOException {
final Clock clock, final ExecutorService notificationExecutorService,
final ExecutorService messageDeletionExecutorService) throws IOException {
this.readDeleteCluster = readDeleteCluster;
this.pubSubConnection = readDeleteCluster.createPubSubConnection();
this.clock = clock;
this.notificationExecutorService = notificationExecutorService;
this.messageDeletionExecutorService = messageDeletionExecutorService;
this.insertScript = ClusterLuaScript.fromResource(insertCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByGuidScript = ClusterLuaScript.fromResource(readDeleteCluster, "lua/remove_item_by_guid.lua",
@ -147,33 +158,39 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
guid.toString().getBytes(StandardCharsets.UTF_8))));
}
public Optional<MessageProtos.Envelope> remove(final UUID destinationUuid, final long destinationDevice,
public CompletableFuture<Optional<MessageProtos.Envelope>> remove(final UUID destinationUuid,
final long destinationDevice,
final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst();
return remove(destinationUuid, destinationDevice, List.of(messageGuid))
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.get(0)));
}
@SuppressWarnings("unchecked")
public List<MessageProtos.Envelope> remove(final UUID destinationUuid, final long destinationDevice,
public CompletableFuture<List<MessageProtos.Envelope>> remove(final UUID destinationUuid,
final long destinationDevice,
final List<UUID> messageGuids) {
final List<byte[]> serialized = (List<byte[]>) Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG,
REMOVE_METHOD_UUID).record(() ->
removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.collect(Collectors.toList())));
.collect(Collectors.toList()))
.thenApplyAsync(result -> {
List<byte[]> serialized = (List<byte[]>) result;
final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size());
final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size());
for (final byte[] bytes : serialized) {
try {
removedMessages.add(MessageProtos.Envelope.parseFrom(bytes));
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
for (final byte[] bytes : serialized) {
try {
removedMessages.add(MessageProtos.Envelope.parseFrom(bytes));
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
return removedMessages;
return removedMessages;
}, messageDeletionExecutorService);
}
public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) {
@ -181,50 +198,110 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0);
}
@SuppressWarnings("unchecked")
public List<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDevice, final int limit) {
return getMessagesTimer.record(() -> {
final List<byte[]> queueItems = (List<byte[]>) getItemsScript.executeBinary(
List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDevice) {
final long earliestAllowableEphemeralTimestamp =
System.currentTimeMillis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
final long earliestAllowableEphemeralTimestamp =
clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
final List<MessageProtos.Envelope> messageEntities;
final List<UUID> staleEphemeralMessageGuids = new ArrayList<>();
final Flux<MessageProtos.Envelope> allMessages = getAllMessages(destinationUuid, destinationDevice)
.publish()
// We expect exactly two subscribers to this base flux:
// 1. the websocket that delivers messages to clients
// 2. an internal process to discard stale ephemeral messages
// The discard subscriber will subscribe immediately, but we dont want to do any work if the
// websocket never subscribes.
.autoConnect(2);
if (queueItems.size() % 2 == 0) {
messageEntities = new ArrayList<>(queueItems.size() / 2);
final Flux<MessageProtos.Envelope> messagesToPublish = allMessages
.filter(Predicate.not(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp)));
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
if (message.getEphemeral() && message.getTimestamp() < earliestAllowableEphemeralTimestamp) {
staleEphemeralMessageGuids.add(UUID.fromString(message.getServerGuid()));
continue;
}
final Flux<MessageProtos.Envelope> staleEphemeralMessages = allMessages
.filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp));
messageEntities.add(message);
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
discardStaleEphemeralMessages(destinationUuid, destinationDevice, staleEphemeralMessages);
return messagesToPublish;
}
private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message,
long earliestAllowableTimestamp) {
return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp;
}
private void discardStaleEphemeralMessages(final UUID destinationUuid, final long destinationDevice,
Flux<MessageProtos.Envelope> staleEphemeralMessages) {
staleEphemeralMessages
.map(e -> UUID.fromString(e.getServerGuid()))
.buffer(PAGE_SIZE)
.subscribeOn(Schedulers.boundedElastic())
.subscribe(staleEphemeralMessageGuids ->
remove(destinationUuid, destinationDevice, staleEphemeralMessageGuids)
.thenAccept(removedMessages -> staleEphemeralMessagesCounter.increment(removedMessages.size())),
e -> logger.warn("Could not remove stale ephemeral messages from cache", e));
}
@VisibleForTesting
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final long destinationDevice) {
// fetch messages by page
return getNextMessagePage(destinationUuid, destinationDevice, -1)
.expand(queueItemsAndLastMessageId -> {
// expand() is breadth-first, so each page will be published in order
if (queueItemsAndLastMessageId.first().isEmpty()) {
return Mono.empty();
}
}
} else {
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
messageEntities = Collections.emptyList();
}
try {
remove(destinationUuid, destinationDevice, staleEphemeralMessageGuids);
staleEphemeralMessagesCounter.increment(staleEphemeralMessageGuids.size());
} catch (final Throwable e) {
logger.warn("Could not remove stale ephemeral messages from cache", e);
}
return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second());
})
.limitRate(1)
// we want to ensure we dont accidentally block the Lettuce/netty i/o executors
.publishOn(Schedulers.boundedElastic())
.map(Pair::first)
.flatMapIterable(queueItems -> {
final List<MessageProtos.Envelope> envelopes = new ArrayList<>(queueItems.size() / 2);
return messageEntities;
});
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
envelopes.add(message);
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
return envelopes;
});
}
private Flux<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final long destinationDevice,
long messageId) {
return getItemsScript.executeBinaryReactive(
List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8),
String.valueOf(messageId).getBytes(StandardCharsets.UTF_8)))
.map(result -> {
logger.trace("Processing page: {}", messageId);
@SuppressWarnings("unchecked")
List<byte[]> queueItems = (List<byte[]>) result;
if (queueItems.isEmpty()) {
return new Pair<>(Collections.emptyList(), null);
}
if (queueItems.size() % 2 != 0) {
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
return new Pair<>(Collections.emptyList(), null);
}
final long lastMessageId = Long.parseLong(
new String(queueItems.get(queueItems.size() - 1), StandardCharsets.UTF_8));
return new Pair<>(queueItems, lastMessageId);
});
}
@VisibleForTesting

View File

@ -1,5 +1,5 @@
/*
* Copyright 2021 Signal Messenger, LLC
* Copyright 2021-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -17,19 +17,24 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.DeleteRequest;
import software.amazon.awssdk.services.dynamodb.model.PutRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
@ -48,22 +53,25 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
private static final String KEY_ENVELOPE_BYTES = "EB";
private final Timer storeTimer = timer(name(getClass(), "store"));
private final Timer loadTimer = timer(name(getClass(), "load"));
private final Timer deleteByGuid = timer(name(getClass(), "delete", "guid"));
private final Timer deleteByKey = timer(name(getClass(), "delete", "key"));
private final Timer deleteByAccount = timer(name(getClass(), "delete", "account"));
private final Timer deleteByDevice = timer(name(getClass(), "delete", "device"));
private final DynamoDbAsyncClient dbAsyncClient;
private final String tableName;
private final Duration timeToLive;
private final ExecutorService messageDeletionExecutor;
private static final Logger logger = LoggerFactory.getLogger(MessagesDynamoDb.class);
public MessagesDynamoDb(DynamoDbClient dynamoDb, String tableName, Duration timeToLive) {
public MessagesDynamoDb(DynamoDbClient dynamoDb, DynamoDbAsyncClient dynamoDbAsyncClient, String tableName,
Duration timeToLive, ExecutorService messageDeletionExecutor) {
super(dynamoDb);
this.dbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
this.timeToLive = timeToLive;
this.messageDeletionExecutor = messageDeletionExecutor;
}
public void store(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) {
@ -95,105 +103,105 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems));
}
public List<MessageProtos.Envelope> load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) {
return loadTimer.record(() -> {
final int numberOfMessagesToFetch = Math.min(requestedNumberOfMessagesToFetch, RESULT_SET_CHUNK_SIZE);
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.consistentRead(true)
.keyConditionExpression("#part = :part AND begins_with ( #sort , :sortprefix )")
.expressionAttributeNames(Map.of(
"#part", KEY_PARTITION,
"#sort", KEY_SORT))
.expressionAttributeValues(Map.of(
":part", partitionKey,
":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId)))
.limit(numberOfMessagesToFetch)
.build();
List<MessageProtos.Envelope> messageEntities = new ArrayList<>(numberOfMessagesToFetch);
for (Map<String, AttributeValue> message : db().queryPaginator(queryRequest).items()) {
try {
messageEntities.add(convertItemToEnvelope(message));
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse envelope", e);
}
public Publisher<MessageProtos.Envelope> load(final UUID destinationAccountUuid, final long destinationDeviceId,
final Integer limit) {
if (messageEntities.size() == numberOfMessagesToFetch) {
// queryPaginator() uses limit() as the page size, not as an absolute limit
// but a page might be smaller than limit, because a page is capped at 1 MB
break;
}
}
return messageEntities;
});
}
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final QueryRequest.Builder queryRequestBuilder = QueryRequest.builder()
.tableName(tableName)
.consistentRead(true)
.keyConditionExpression("#part = :part AND begins_with ( #sort , :sortprefix )")
.expressionAttributeNames(Map.of(
"#part", KEY_PARTITION,
"#sort", KEY_SORT))
.expressionAttributeValues(Map.of(
":part", partitionKey,
":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId)));
public Optional<MessageProtos.Envelope> deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid,
final UUID messageUuid) {
return deleteByGuid.record(() -> {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.indexName(LOCAL_INDEX_MESSAGE_UUID_NAME)
.projectionExpression(KEY_SORT)
.consistentRead(true)
.keyConditionExpression("#part = :part AND #uuid = :uuid")
.expressionAttributeNames(Map.of(
"#part", KEY_PARTITION,
"#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT))
.expressionAttributeValues(Map.of(
":part", partitionKey,
":uuid", convertLocalIndexMessageUuidSortKey(messageUuid)))
.build();
return deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(partitionKey, queryRequest);
});
}
public Optional<MessageProtos.Envelope> deleteMessage(final UUID destinationAccountUuid,
final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) {
return deleteByKey.record(() -> {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid);
DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, sortKey))
.returnValues(ReturnValue.ALL_OLD);
final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build());
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
try {
return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse envelope", e);
return Optional.empty();
}
}
return Optional.empty();
});
}
@Nonnull
private Optional<MessageProtos.Envelope> deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) {
Optional<MessageProtos.Envelope> result = Optional.empty();
for (Map<String, AttributeValue> item : db().queryPaginator(queryRequest).items()) {
final byte[] rangeKeyValue = item.get(KEY_SORT).b().asByteArray();
DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, AttributeValues.fromByteArray(rangeKeyValue)));
if (result.isEmpty()) {
deleteItemRequest.returnValues(ReturnValue.ALL_OLD);
}
final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build());
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
try {
result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse envelope", e);
}
}
if (limit != null) {
// some callers dont take advantage of reactive streams, so we want to support limiting the fetch size. Otherwise,
// we could fetch up to 1 MB (likely >1,000 messages) and discard 90% of them
queryRequestBuilder.limit(Math.min(RESULT_SET_CHUNK_SIZE, limit));
}
return result;
final QueryRequest queryRequest = queryRequestBuilder.build();
return dbAsyncClient.queryPaginator(queryRequest).items()
.map(message -> {
try {
return convertItemToEnvelope(message);
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse envelope", e);
return null;
}
})
.filter(Predicate.not(Objects::isNull));
}
public CompletableFuture<Optional<MessageProtos.Envelope>> deleteMessageByDestinationAndGuid(
final UUID destinationAccountUuid, final UUID messageUuid) {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.indexName(LOCAL_INDEX_MESSAGE_UUID_NAME)
.projectionExpression(KEY_SORT)
.consistentRead(true)
.keyConditionExpression("#part = :part AND #uuid = :uuid")
.expressionAttributeNames(Map.of(
"#part", KEY_PARTITION,
"#uuid", LOCAL_INDEX_MESSAGE_UUID_KEY_SORT))
.expressionAttributeValues(Map.of(
":part", partitionKey,
":uuid", convertLocalIndexMessageUuidSortKey(messageUuid)))
.build();
// because we are filtering on message UUID, this query should return at most one item,
// but its simpler to handle the full stream and return the last item
return Flux.from(dbAsyncClient.queryPaginator(queryRequest).items())
.flatMap(item -> Mono.fromCompletionStage(dbAsyncClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT,
AttributeValues.fromByteArray(item.get(KEY_SORT).b().asByteArray())))
.returnValues(ReturnValue.ALL_OLD)
.build())))
.mapNotNull(deleteItemResponse -> {
try {
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
return convertItemToEnvelope(deleteItemResponse.attributes());
}
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse envelope", e);
}
return null;
})
.last()
.toFuture()
.thenApply(Optional::ofNullable);
}
public CompletableFuture<Optional<MessageProtos.Envelope>> deleteMessage(final UUID destinationAccountUuid,
final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid);
DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_PARTITION, partitionKey, KEY_SORT, sortKey))
.returnValues(ReturnValue.ALL_OLD);
return dbAsyncClient.deleteItem(deleteItemRequest.build())
.thenApplyAsync(deleteItemResponse -> {
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
try {
return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse envelope", e);
}
}
return Optional.empty();
}, messageDeletionExecutor);
}
public void deleteAllMessagesForAccount(final UUID destinationAccountUuid) {
@ -248,7 +256,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
KEY_PARTITION, partitionKey,
KEY_SORT, item.get(KEY_SORT))).build())
.build())
.collect(Collectors.toList());
.toList();
executeTableWriteItemsUntilComplete(Map.of(tableName, deletes));
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
@ -9,19 +9,30 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux;
public class MessagesManager {
private static final int RESULT_SET_CHUNK_SIZE = 100;
private static final Logger logger = LoggerFactory.getLogger(MessagesManager.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter cacheHitByGuidMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByGuid"));
private static final Meter cacheMissByGuidMeter = metricRegistry.meter(
@ -55,18 +66,32 @@ public class MessagesManager {
return messagesCache.hasMessages(destinationUuid, destinationDevice);
}
public Pair<List<Envelope>, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final boolean cachedMessagesOnly) {
List<Envelope> messageList = new ArrayList<>();
public Pair<List<Envelope>, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice,
boolean cachedMessagesOnly) {
if (!cachedMessagesOnly) {
messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE));
}
final List<Envelope> envelopes = Flux.from(
getMessagesForDevice(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE, cachedMessagesOnly))
.take(RESULT_SET_CHUNK_SIZE, true)
.collectList()
.blockOptional().orElse(Collections.emptyList());
if (messageList.size() < RESULT_SET_CHUNK_SIZE) {
messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size()));
}
return new Pair<>(envelopes, envelopes.size() >= RESULT_SET_CHUNK_SIZE);
}
return new Pair<>(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE);
public Publisher<Envelope> getMessagesForDeviceReactive(UUID destinationUuid, long destinationDevice,
final boolean cachedMessagesOnly) {
return getMessagesForDevice(destinationUuid, destinationDevice, null, cachedMessagesOnly);
}
private Publisher<Envelope> getMessagesForDevice(UUID destinationUuid, long destinationDevice,
@Nullable Integer limit, final boolean cachedMessagesOnly) {
final Publisher<Envelope> dynamoPublisher =
cachedMessagesOnly ? Flux.empty() : messagesDynamoDb.load(destinationUuid, destinationDevice, limit);
final Publisher<Envelope> cachePublisher = messagesCache.get(destinationUuid, destinationDevice);
return Flux.concat(dynamoPublisher, cachePublisher);
}
public void clear(UUID destinationUuid) {
@ -79,21 +104,25 @@ public class MessagesManager {
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId);
}
public Optional<Envelope> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) {
Optional<Envelope> removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid);
public CompletableFuture<Optional<Envelope>> delete(UUID destinationUuid, long destinationDeviceId, UUID guid,
@Nullable Long serverTimestamp) {
return messagesCache.remove(destinationUuid, destinationDeviceId, guid)
.thenCompose(removed -> {
if (removed.isEmpty()) {
if (serverTimestamp == null) {
removed = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, guid);
} else {
removed = messagesDynamoDb.deleteMessage(destinationUuid, destinationDeviceId, guid, serverTimestamp);
}
cacheMissByGuidMeter.mark();
} else {
cacheHitByGuidMeter.mark();
}
if (removed.isPresent()) {
cacheHitByGuidMeter.mark();
return CompletableFuture.completedFuture(removed);
}
return removed;
cacheMissByGuidMeter.mark();
if (serverTimestamp == null) {
return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, guid);
} else {
return messagesDynamoDb.deleteMessage(destinationUuid, destinationDeviceId, guid, serverTimestamp);
}
});
}
/**
@ -112,10 +141,15 @@ public class MessagesManager {
final List<UUID> messageGuids = messages.stream().map(message -> UUID.fromString(message.getServerGuid()))
.collect(Collectors.toList());
int messagesRemovedFromCache = messagesCache.remove(destinationUuid, destinationDeviceId, messageGuids).size();
persistMessageMeter.mark(nonEphemeralMessages.size());
int messagesRemovedFromCache = 0;
try {
messagesRemovedFromCache = messagesCache.remove(destinationUuid, destinationDeviceId, messageGuids)
.get(30, TimeUnit.SECONDS).size();
persistMessageMeter.mark(nonEphemeralMessages.size());
} catch (InterruptedException | ExecutionException | TimeoutException e) {
logger.warn("Failed to remove messages from cache", e);
}
return messagesRemovedFromCache;
}
@ -129,4 +163,5 @@ public class MessagesManager {
public void removeMessageAvailabilityListener(final MessageAvailabilityListener listener) {
messagesCache.removeMessageAvailabilityListener(listener);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -11,14 +11,19 @@ import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -32,32 +37,48 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration" ));
private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(name(WebSocketConnection.class, "unauthenticated_connection_duration"));
private static final Counter openWebsocketCounter = metricRegistry.counter(name(WebSocketConnection.class, "open_websockets"));
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer durationTimer = metricRegistry.timer(
name(WebSocketConnection.class, "connected_duration"));
private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(
name(WebSocketConnection.class, "unauthenticated_connection_duration"));
private static final Counter openWebsocketCounter = metricRegistry.counter(
name(WebSocketConnection.class, "open_websockets"));
private static final String OPEN_WEBSOCKET_COUNTER_NAME = MetricsUtil.name(WebSocketConnection.class,
"openWebsockets");
private static final long RENEW_PRESENCE_INTERVAL_MINUTES = 5;
private static final String REACTIVE_MESSAGE_QUEUE_EXPERIMENT_NAME = "reactive_message_queue_v1";
private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class);
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager;
private final ClientPresenceManager clientPresenceManager;
private final ScheduledExecutorService scheduledExecutorService;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AtomicInteger openReactiveWebSockets = new AtomicInteger(0);
private final AtomicInteger openStandardWebSockets = new AtomicInteger(0);
public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager,
PushNotificationManager pushNotificationManager,
ClientPresenceManager clientPresenceManager,
ScheduledExecutorService scheduledExecutorService)
{
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
ScheduledExecutorService scheduledExecutorService,
ExperimentEnrollmentManager experimentEnrollmentManager) {
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
this.clientPresenceManager = clientPresenceManager;
this.scheduledExecutorService = scheduledExecutorService;
this.experimentEnrollmentManager = experimentEnrollmentManager;
Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, Tags.of("reactive", String.valueOf(true)), openReactiveWebSockets);
Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, Tags.of("reactive", String.valueOf(false)), openStandardWebSockets);
}
@Override
@ -66,43 +87,56 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class);
final Device device = auth.getAuthenticatedDevice();
final Timer.Context timer = durationTimer.time();
final boolean enrolledInReactiveMessageQueue = experimentEnrollmentManager.isEnrolled(
auth.getAccount().getUuid(),
REACTIVE_MESSAGE_QUEUE_EXPERIMENT_NAME);
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager, auth, device,
context.getClient(),
scheduledExecutorService);
scheduledExecutorService,
enrolledInReactiveMessageQueue);
openWebsocketCounter.inc();
if (enrolledInReactiveMessageQueue) {
openReactiveWebSockets.incrementAndGet();
} else {
openStandardWebSockets.incrementAndGet();
}
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), device, context.getClient().getUserAgent());
final AtomicReference<ScheduledFuture<?>> renewPresenceFutureReference = new AtomicReference<>();
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
openWebsocketCounter.dec();
timer.stop();
final ScheduledFuture<?> renewPresenceFuture = renewPresenceFutureReference.get();
if (renewPresenceFuture != null) {
renewPresenceFuture.cancel(false);
}
connection.stop();
RedisOperation.unchecked(
() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId()));
RedisOperation.unchecked(() -> {
messagesManager.removeMessageAvailabilityListener(connection);
if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) {
try {
pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId(), true);
} catch (NotPushRegisteredException ignored) {
}
}
});
context.addListener((closingContext, statusCode, reason) -> {
openWebsocketCounter.dec();
if (enrolledInReactiveMessageQueue) {
openReactiveWebSockets.decrementAndGet();
} else {
openStandardWebSockets.decrementAndGet();
}
timer.stop();
final ScheduledFuture<?> renewPresenceFuture = renewPresenceFutureReference.get();
if (renewPresenceFuture != null) {
renewPresenceFuture.cancel(false);
}
connection.stop();
RedisOperation.unchecked(
() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId()));
RedisOperation.unchecked(() -> {
messagesManager.removeMessageAvailabilityListener(connection);
if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) {
try {
pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId(), true);
} catch (NotPushRegisteredException ignored) {
}
}
});
});
try {

View File

@ -1,12 +1,11 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.Meter;
@ -34,11 +33,13 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import javax.ws.rs.WebApplicationException;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener;
@ -49,13 +50,14 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
@ -70,8 +72,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
name(WebSocketConnection.class, "messagesPersisted"));
private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent"));
private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures"));
private static final Meter discardedMessagesMeter = metricRegistry.meter(
name(WebSocketConnection.class, "discardedMessages"));
private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class,
"initialQueueLength");
@ -85,11 +85,12 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
"messageAvailableAfterClientClosed");
private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message";
private static final String REACTIVE_TAG = "reactive";
private static final long SLOW_DRAIN_THRESHOLD = 10_000;
@VisibleForTesting
static final int MAX_DESKTOP_MESSAGE_SIZE = 1024 * 1024;
static final int MESSAGE_PUBLISHER_LIMIT_RATE = 100;
@VisibleForTesting
static final int MAX_CONSECUTIVE_RETRIES = 5;
@ -111,18 +112,19 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private final ScheduledExecutorService scheduledExecutorService;
private final boolean isDesktopClient;
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(
StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
private final LongAdder sentMessageCounter = new LongAdder();
private final AtomicLong queueDrainStartTime = new AtomicLong();
private final AtomicInteger consecutiveRetries = new AtomicInteger();
private final AtomicReference<ScheduledFuture<?>> retryFuture = new AtomicReference<>();
private final AtomicInteger consecutiveRetries = new AtomicInteger();
private final AtomicReference<ScheduledFuture<?>> retryFuture = new AtomicReference<>();
private final AtomicReference<Disposable> messageSubscription = new AtomicReference<>();
private final Random random = new Random();
private final boolean useReactive;
private Scheduler reactiveScheduler;
private enum StoredMessageState {
EMPTY,
@ -135,7 +137,28 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
AuthenticatedAccount auth,
Device device,
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService) {
ScheduledExecutorService scheduledExecutorService,
boolean useReactive) {
this(receiptSender,
messagesManager,
auth,
device,
client,
scheduledExecutorService,
useReactive,
Schedulers.boundedElastic());
}
@VisibleForTesting
WebSocketConnection(ReceiptSender receiptSender,
MessagesManager messagesManager,
AuthenticatedAccount auth,
Device device,
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService,
boolean useReactive,
Scheduler reactiveScheduler) {
this(receiptSender,
messagesManager,
@ -143,7 +166,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
device,
client,
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
scheduledExecutorService);
scheduledExecutorService,
useReactive,
reactiveScheduler);
}
@VisibleForTesting
@ -153,7 +178,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
Device device,
WebSocketClient client,
int sendFuturesTimeoutMillis,
ScheduledExecutorService scheduledExecutorService) {
ScheduledExecutorService scheduledExecutorService,
boolean useReactive,
Scheduler reactiveScheduler) {
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
@ -162,16 +189,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
this.client = client;
this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis;
this.scheduledExecutorService = scheduledExecutorService;
Optional<ClientPlatform> maybePlatform;
try {
maybePlatform = Optional.of(UserAgentUtil.parseUserAgentString(client.getUserAgent()).getPlatform());
} catch (final UnrecognizedUserAgentException e) {
maybePlatform = Optional.empty();
}
this.isDesktopClient = maybePlatform.map(platform -> platform == ClientPlatform.DESKTOP).orElse(false);
this.useReactive = useReactive;
this.reactiveScheduler = reactiveScheduler;
}
public void start() {
@ -186,10 +205,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
future.cancel(false);
}
final Disposable subscription = messageSubscription.get();
if (subscription != null) {
subscription.dispose();
}
client.close(1000, "OK");
}
private CompletableFuture<WebSocketResponseMessage> sendMessage(final Envelope message, final Optional<StoredMessageInfo> storedMessageInfo) {
private CompletableFuture<?> sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) {
// clear ephemeral field from the envelope
final Optional<byte[]> body = Optional.ofNullable(message.toBuilder().clearEphemeral().build().toByteArray());
@ -199,33 +223,43 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
MessageMetrics.measureAccountEnvelopeUuidMismatches(auth.getAccount(), message);
// X-Signal-Key: false must be sent until Android stops assuming it missing means true
return client.sendRequest("PUT", "/api/v1/message", List.of("X-Signal-Key: false", TimestampHeaderUtil.getTimestampHeader()), body).whenComplete((response, throwable) -> {
if (throwable == null) {
if (isSuccessResponse(response)) {
if (storedMessageInfo.isPresent()) {
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), storedMessageInfo.get().getGuid(), storedMessageInfo.get().getServerTimestamp());
}
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
recordMessageDeliveryDuration(message.getTimestamp(), device);
sendDeliveryReceiptFor(message);
}
} else {
final List<Tag> tags = new ArrayList<>(
List.of(Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())),
UserAgentTagUtil.getPlatformTag(client.getUserAgent())));
// TODO Remove this once we've identified the cause of message rejections from desktop clients
if (StringUtils.isNotBlank(response.getMessage())) {
tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage()));
return client.sendRequest("PUT", "/api/v1/message",
List.of("X-Signal-Key: false", TimestampHeaderUtil.getTimestampHeader()), body)
.whenComplete((ignored, throwable) -> {
if (throwable != null) {
sendFailuresMeter.mark();
}
}).thenCompose(response -> {
final CompletableFuture<?> result;
if (isSuccessResponse(response)) {
Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment();
}
} else {
sendFailuresMeter.mark();
}
});
result = messagesManager.delete(auth.getAccount().getUuid(), device.getId(),
storedMessageInfo.guid(), storedMessageInfo.serverTimestamp());
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
recordMessageDeliveryDuration(message.getTimestamp(), device);
sendDeliveryReceiptFor(message);
}
} else {
final List<Tag> tags = new ArrayList<>(
List.of(
Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())),
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of(REACTIVE_TAG, String.valueOf(useReactive))
));
// TODO Remove this once we've identified the cause of message rejections from desktop clients
if (StringUtils.isNotBlank(response.getMessage())) {
tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage()));
}
Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment();
result = CompletableFuture.completedFuture(null);
}
return result;
});
}
public static void recordMessageDeliveryDuration(long timestamp, Device messageDestinationDevice) {
@ -260,65 +294,96 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
@VisibleForTesting
void processStoredMessages() {
if (processStoredMessagesSemaphore.tryAcquire()) {
final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY);
final CompletableFuture<Void> queueClearedFuture = new CompletableFuture<>();
sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueClearedFuture);
queueClearedFuture.whenComplete((v, cause) -> {
if (cause == null) {
consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum());
Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS);
if (drainDuration > SLOW_DRAIN_THRESHOLD) {
Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment();
}
client.sendRequest("PUT", "/api/v1/queue/empty",
Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty());
}
} else {
storedMessageState.compareAndSet(StoredMessageState.EMPTY, state);
}
processStoredMessagesSemaphore.release();
if (cause == null) {
if (storedMessageState.get() != StoredMessageState.EMPTY) {
processStoredMessages();
}
} else {
if (client.isOpen()) {
if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) {
logger.warn("Max consecutive retries exceeded", cause);
client.close(1011, "Failed to retrieve messages");
} else {
logger.debug("Failed to clear queue", cause);
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment();
final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS);
retryFuture
.set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS));
}
} else {
logger.debug("Client disconnected before queue cleared");
}
}
});
if (useReactive) {
processStoredMessages_reactive();
} else {
processStoredMessage_paged();
}
}
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) {
private void processStoredMessage_paged() {
assert !useReactive;
if (processStoredMessagesSemaphore.tryAcquire()) {
final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY);
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueCleared);
setQueueClearedHandler(state, queueCleared);
}
}
private void setQueueClearedHandler(final StoredMessageState state, final CompletableFuture<Void> queueCleared) {
queueCleared.whenComplete((v, cause) -> {
if (cause == null) {
consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final List<Tag> tags = List.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of(REACTIVE_TAG, String.valueOf(useReactive))
);
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum());
Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS);
if (drainDuration > SLOW_DRAIN_THRESHOLD) {
Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment();
}
client.sendRequest("PUT", "/api/v1/queue/empty",
Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty());
}
} else {
storedMessageState.compareAndSet(StoredMessageState.EMPTY, state);
}
processStoredMessagesSemaphore.release();
if (cause == null) {
if (storedMessageState.get() != StoredMessageState.EMPTY) {
processStoredMessages();
}
} else {
if (client.isOpen()) {
if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) {
logger.warn("Max consecutive retries exceeded", cause);
client.close(1011, "Failed to retrieve messages");
} else {
logger.debug("Failed to clear queue", cause);
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment();
final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS);
retryFuture
.set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS));
}
} else {
logger.debug("Client disconnected before queue cleared");
}
}
});
}
private void processStoredMessages_reactive() {
assert useReactive;
if (processStoredMessagesSemaphore.tryAcquire()) {
final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY);
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
sendMessagesReactive(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueCleared);
setQueueClearedHandler(state, queueCleared);
}
}
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueCleared) {
try {
final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(), device.getId(), cachedMessagesOnly);
@ -330,25 +395,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
for (int i = 0; i < messages.size(); i++) {
final Envelope envelope = messages.get(i);
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
final boolean discard;
if (isDesktopClient && envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE) {
discard = true;
} else if (envelope.getStory() && !client.shouldDeliverStories()) {
discard = true;
} else {
discard = false;
}
if (discard) {
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp());
discardedMessagesMeter.mark();
sendFutures[i] = CompletableFuture.completedFuture(null);
} else {
sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(messageGuid, envelope.getServerTimestamp())));
}
sendFutures[i] = sendMessage(envelope);
}
// Set a large, non-zero timeout, to prevent any failure to acknowledge receipt from blocking indefinitely
@ -357,16 +404,45 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
.whenComplete((v, cause) -> {
if (cause == null) {
if (hasMore) {
sendNextMessagePage(cachedMessagesOnly, queueClearedFuture);
sendNextMessagePage(cachedMessagesOnly, queueCleared);
} else {
queueClearedFuture.complete(null);
queueCleared.complete(null);
}
} else {
queueClearedFuture.completeExceptionally(cause);
queueCleared.completeExceptionally(cause);
}
});
} catch (final Exception e) {
queueClearedFuture.completeExceptionally(e);
queueCleared.completeExceptionally(e);
}
}
private void sendMessagesReactive(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueCleared) {
final Publisher<Envelope> messages =
messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), device.getId(), cachedMessagesOnly);
final Disposable subscription = Flux.from(messages)
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
.flatMapSequential(envelope ->
Mono.fromFuture(sendMessage(envelope).orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS)))
.doOnError(queueCleared::completeExceptionally)
.doOnComplete(() -> queueCleared.complete(null))
.subscribeOn(reactiveScheduler)
.subscribe();
messageSubscription.set(subscription);
}
private CompletableFuture<?> sendMessage(Envelope envelope) {
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getStory() && !client.shouldDeliverStories()) {
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp());
return CompletableFuture.completedFuture(null);
} else {
return sendMessage(envelope, new StoredMessageInfo(messageGuid, envelope.getServerTimestamp()));
}
}
@ -381,6 +457,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
messageAvailableMeter.mark();
storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE);
processStoredMessages();
return true;
@ -396,6 +473,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
messagesPersistedMeter.mark();
storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
processStoredMessages();
return true;
@ -405,7 +483,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public void handleDisplacement(final boolean connectedElsewhere) {
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)));
Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
Tag.of(REACTIVE_TAG, String.valueOf(useReactive)));
Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
@ -429,21 +508,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
}
}
private static class StoredMessageInfo {
private final UUID guid;
private final long serverTimestamp;
private record StoredMessageInfo(UUID guid, long serverTimestamp) {
public StoredMessageInfo(UUID guid, long serverTimestamp) {
this.guid = guid;
this.serverTimestamp = serverTimestamp;
}
public UUID getGuid() {
return guid;
}
public long getServerTimestamp() {
return serverTimestamp;
}
}
}

View File

@ -25,7 +25,6 @@ import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@ -45,9 +44,9 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
@ -97,6 +96,8 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
ExecutorService keyspaceNotificationDispatchExecutor = environment.lifecycle()
.executorService(name(getClass(), "keyspaceNotification-%d")).maxThreads(4).build();
ExecutorService messageDeletionExecutor = environment.lifecycle()
.executorService(name(getClass(), "messageDeletion-%d")).maxThreads(4).build();
ExecutorService backupServiceExecutor = environment.lifecycle()
.executorService(name(getClass(), "backupService-%d")).maxThreads(8).minThreads(1).build();
ExecutorService storageServiceExecutor = environment.lifecycle()
@ -156,15 +157,14 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
configuration.getDynamoDbTables().getReservedUsernames().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient,
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration());
configuration.getDynamoDbTables().getMessages().getExpiration(),
messageDeletionExecutor);
FaultTolerantRedisCluster messageInsertCacheCluster = new FaultTolerantRedisCluster("message_insert_cluster",
configuration.getMessageCacheConfiguration().getRedisClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster messageReadDeleteCluster = new FaultTolerantRedisCluster("message_read_delete_cluster",
configuration.getMessageCacheConfiguration().getRedisClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster metricsCluster = new FaultTolerantRedisCluster("metrics_cluster",
configuration.getMetricsClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster clientPresenceCluster = new FaultTolerantRedisCluster("client_presence_cluster",
configuration.getClientPresenceClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster rateLimitersCluster = new FaultTolerantRedisCluster("rate_limiters",
@ -176,8 +176,7 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
Executors.newSingleThreadScheduledExecutor(), keyspaceNotificationDispatchExecutor);
MessagesCache messagesCache = new MessagesCache(messageInsertCacheCluster, messageReadDeleteCluster,
keyspaceNotificationDispatchExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster, dynamicConfigurationManager);
Clock.systemUTC(), keyspaceNotificationDispatchExecutor, messageDeletionExecutor);
DirectoryQueue directoryQueue = new DirectoryQueue(
configuration.getDirectoryConfiguration().getSqsConfiguration());
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);

View File

@ -27,7 +27,6 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@ -48,9 +47,9 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
@ -99,6 +98,8 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
ExecutorService keyspaceNotificationDispatchExecutor = environment.lifecycle()
.executorService(name(getClass(), "keyspaceNotification-%d")).maxThreads(4).build();
ExecutorService messageDeletionExecutor = environment.lifecycle()
.executorService(name(getClass(), "messageDeletion-%d")).maxThreads(4).build();
ExecutorService backupServiceExecutor = environment.lifecycle()
.executorService(name(getClass(), "backupService-%d")).maxThreads(8).minThreads(1).build();
ExecutorService storageServiceExecutor = environment.lifecycle()
@ -158,15 +159,14 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
configuration.getDynamoDbTables().getReservedUsernames().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient,
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration());
configuration.getDynamoDbTables().getMessages().getExpiration(),
messageDeletionExecutor);
FaultTolerantRedisCluster messageInsertCacheCluster = new FaultTolerantRedisCluster("message_insert_cluster",
configuration.getMessageCacheConfiguration().getRedisClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster messageReadDeleteCluster = new FaultTolerantRedisCluster("message_read_delete_cluster",
configuration.getMessageCacheConfiguration().getRedisClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster metricsCluster = new FaultTolerantRedisCluster("metrics_cluster",
configuration.getMetricsClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster clientPresenceCluster = new FaultTolerantRedisCluster("client_presence_cluster",
configuration.getClientPresenceClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster rateLimitersCluster = new FaultTolerantRedisCluster("rate_limiters",
@ -178,8 +178,7 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
Executors.newSingleThreadScheduledExecutor(), keyspaceNotificationDispatchExecutor);
MessagesCache messagesCache = new MessagesCache(messageInsertCacheCluster, messageReadDeleteCluster,
keyspaceNotificationDispatchExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster, dynamicConfigurationManager);
Clock.systemUTC(), keyspaceNotificationDispatchExecutor, messageDeletionExecutor);
DirectoryQueue directoryQueue = new DirectoryQueue(
configuration.getDirectoryConfiguration().getSqsConfiguration());
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);

View File

@ -26,7 +26,6 @@ import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@ -46,9 +45,9 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.ProhibitedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
@ -102,6 +101,8 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
ExecutorService keyspaceNotificationDispatchExecutor = environment.lifecycle()
.executorService(name(getClass(), "keyspaceNotification-%d")).maxThreads(4).build();
ExecutorService messageDeletionExecutor = environment.lifecycle()
.executorService(name(getClass(), "messageDeletion-%d")).maxThreads(4).build();
ExecutorService backupServiceExecutor = environment.lifecycle()
.executorService(name(getClass(), "backupService-%d")).maxThreads(8).minThreads(1).build();
ExecutorService storageServiceExecutor = environment.lifecycle()
@ -161,15 +162,14 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
configuration.getDynamoDbTables().getReservedUsernames().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient,
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration());
configuration.getDynamoDbTables().getMessages().getExpiration(),
messageDeletionExecutor);
FaultTolerantRedisCluster messageInsertCacheCluster = new FaultTolerantRedisCluster("message_insert_cluster",
configuration.getMessageCacheConfiguration().getRedisClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster messageReadDeleteCluster = new FaultTolerantRedisCluster("message_read_delete_cluster",
configuration.getMessageCacheConfiguration().getRedisClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster metricsCluster = new FaultTolerantRedisCluster("metrics_cluster",
configuration.getMetricsClusterConfiguration(), redisClusterClientResources);
FaultTolerantRedisCluster clientPresenceCluster = new FaultTolerantRedisCluster("client_presence",
configuration.getClientPresenceClusterConfiguration(), redisClusterClientResources);
SecureBackupClient secureBackupClient = new SecureBackupClient(backupCredentialsGenerator, backupServiceExecutor,
@ -179,8 +179,7 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
Executors.newSingleThreadScheduledExecutor(), keyspaceNotificationDispatchExecutor);
MessagesCache messagesCache = new MessagesCache(messageInsertCacheCluster, messageReadDeleteCluster,
keyspaceNotificationDispatchExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster, dynamicConfigurationManager);
Clock.systemUTC(), keyspaceNotificationDispatchExecutor, messageDeletionExecutor);
DirectoryQueue directoryQueue = new DirectoryQueue(
configuration.getDirectoryConfiguration().getSqsConfiguration());
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);

View File

@ -1,6 +1,7 @@
local queueKey = KEYS[1]
local queueLockKey = KEYS[2]
local limit = ARGV[1]
local afterMessageId = ARGV[2]
local locked = redis.call("GET", queueLockKey)
@ -8,12 +9,17 @@ if locked then
return {}
end
-- The range is inclusive
local min = 0
local max = limit - 1
if afterMessageId == "null" then
-- An index range is inclusive
local min = 0
local max = limit - 1
if max < 0 then
return {}
if max < 0 then
return {}
end
return redis.call("ZRANGE", queueKey, min, max, "WITHSCORES")
else
-- note: this is deprecated in Redis 6.2, and should be migrated to zrange after the cluster is updated
return redis.call("ZRANGEBYSCORE", queueKey, "("..afterMessageId, "+inf", "WITHSCORES", "LIMIT", 0, limit)
end
return redis.call("ZRANGE", queueKey, min, max, "WITHSCORES")

View File

@ -46,6 +46,7 @@ import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
@ -533,17 +534,25 @@ class MessageControllerTest {
UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(generateEnvelope(
uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)));
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0)));
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))));
UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null)).thenReturn(Optional.empty());
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
UUID uuid4 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid4, null))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("Oh No")));
Response response = resources.getJerseyTest()
.target(String.format("/v1/messages/uuid/%s", uuid1))
@ -573,6 +582,15 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verifyNoMoreInteractions(receiptSender);
response = resources.getJerseyTest()
.target(String.format("/v1/messages/uuid/%s", uuid4))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.delete();
assertThat("Bad Response Code", response.getStatus(), is(equalTo(500)));
verifyNoMoreInteractions(receiptSender);
}
@Test
@ -700,7 +718,7 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "FIXME")
.header("User-Agent", "Test-UA")
.put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture(payloadFilename), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));

View File

@ -13,18 +13,29 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.FlushMode;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisNoScriptException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.RedisCommand;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import reactor.core.publisher.Flux;
public class ClusterLuaScriptTest {
class ClusterLuaScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ -32,7 +43,7 @@ public class ClusterLuaScriptTest {
@Test
void testExecute() {
final RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder().stringCommands(commands).build();
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
@ -51,7 +62,7 @@ public class ClusterLuaScriptTest {
@Test
void testExecuteScriptNotLoaded() {
final RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder().stringCommands(commands).build();
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
@ -71,8 +82,10 @@ public class ClusterLuaScriptTest {
void testExecuteBinaryScriptNotLoaded() {
final RedisAdvancedClusterCommands<String, String> stringCommands = mock(RedisAdvancedClusterCommands.class);
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster =
RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder()
.stringCommands(stringCommands)
.binaryCommands(binaryCommands)
.build();
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
@ -85,17 +98,85 @@ public class ClusterLuaScriptTest {
luaScript.executeBinary(keys, values);
verify(binaryCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]),
values.toArray(new byte[0][]));
}
@Test
public void testExecuteRealCluster() {
void testExecuteBinaryAsyncScriptNotLoaded() throws Exception {
final RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands =
mock(RedisAdvancedClusterAsyncCommands.class);
final FaultTolerantRedisCluster mockCluster =
RedisClusterHelper.builder().binaryAsyncCommands(binaryAsyncCommands).build();
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<byte[]> keys = List.of("key".getBytes(StandardCharsets.UTF_8));
final List<byte[]> values = List.of("value".getBytes(StandardCharsets.UTF_8));
final AsyncCommand<?, ?, ?> evalShaFailure = new AsyncCommand<>(mock(RedisCommand.class));
evalShaFailure.completeExceptionally(new RedisNoScriptException("OH NO"));
final AsyncCommand<?, ?, ?> evalSuccess = new AsyncCommand<>(mock(RedisCommand.class));
evalSuccess.complete();
when(binaryAsyncCommands.evalsha(any(), any(), any(), any())).thenReturn((RedisFuture<Object>) evalShaFailure);
when(binaryAsyncCommands.eval(anyString(), any(), any(), any())).thenReturn((RedisFuture<Object>) evalSuccess);
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.executeBinaryAsync(keys, values).get(5, TimeUnit.SECONDS);
verify(binaryAsyncCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]),
values.toArray(new byte[0][]));
verify(binaryAsyncCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]),
values.toArray(new byte[0][]));
}
@Test
void testExecuteBinaryReactiveScriptNotLoaded() {
final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands =
mock(RedisAdvancedClusterReactiveCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder()
.binaryReactiveCommands(binaryReactiveCommands).build();
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<byte[]> keys = List.of("key".getBytes(StandardCharsets.UTF_8));
final List<byte[]> values = List.of("value".getBytes(StandardCharsets.UTF_8));
when(binaryReactiveCommands.evalsha(any(), any(), any(), any()))
.thenReturn(Flux.error(new RedisNoScriptException("OH NO")));
when(binaryReactiveCommands.eval(anyString(), any(), any(), any())).thenReturn(Flux.just("ok"));
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.executeBinaryReactive(keys, values).blockLast(Duration.ofSeconds(5));
verify(binaryReactiveCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]),
values.toArray(new byte[0][]));
verify(binaryReactiveCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]),
values.toArray(new byte[0][]));
}
@ParameterizedTest
@EnumSource(ExecuteMode.class)
void testExecuteRealCluster(final ExecuteMode mode) throws Exception {
REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(c -> c.sync().scriptFlush(FlushMode.SYNC));
REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(c -> c.sync().configResetstat());
final ClusterLuaScript script = new ClusterLuaScript(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
"return 2;",
ScriptOutputType.INTEGER);
for (int i = 0; i < 7; i++) {
assertEquals(2L, script.execute(Collections.emptyList(), Collections.emptyList()));
final long actual = switch (mode) {
case SYNC -> (long) script.execute(Collections.emptyList(), Collections.emptyList());
case ASYNC ->
(long) script.executeAsync(Collections.emptyList(), Collections.emptyList()).get(5, TimeUnit.SECONDS);
case REACTIVE -> (long) script.executeReactive(Collections.emptyList(), Collections.emptyList())
.blockLast(Duration.ofSeconds(5));
};
assertEquals(2L, actual);
}
final int evalCount = REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> {
@ -120,4 +201,11 @@ public class ClusterLuaScriptTest {
assertEquals(1, evalCount);
}
private enum ExecuteMode {
SYNC,
ASYNC,
REACTIVE
}
}

View File

@ -155,7 +155,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.buildMockRedisCluster(commands),
RedisClusterHelper.builder().stringCommands(commands).build(),
deletedAccountsManager,
mock(DirectoryQueue.class),
mock(Keys.class),

View File

@ -147,7 +147,7 @@ class AccountsManagerTest {
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.buildMockRedisCluster(commands),
RedisClusterHelper.builder().stringCommands(commands).build(),
deletedAccountsManager,
directoryQueue,
keys,

View File

@ -78,7 +78,14 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
}
@Override
public void afterEach(ExtensionContext context) throws Exception {
public void afterEach(ExtensionContext context) {
stopServer();
}
/**
* For use in integration tests that want to test resiliency/error handling
*/
public void stopServer() {
try {
server.stop();
} catch (Exception e) {

View File

@ -15,6 +15,7 @@ import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.cluster.SlotHash;
import java.nio.ByteBuffer;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
@ -32,7 +33,6 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@ -47,6 +47,7 @@ class MessagePersisterIntegrationTest {
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private ExecutorService notificationExecutorService;
private ExecutorService messageDeletionExecutorService;
private MessagesCache messagesCache;
private MessagesManager messagesManager;
private MessagePersister messagePersister;
@ -66,13 +67,16 @@ class MessagePersisterIntegrationTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messageDeletionExecutorService = Executors.newSingleThreadExecutor();
final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(),
MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14));
dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14),
messageDeletionExecutorService);
final AccountsManager accountsManager = mock(AccountsManager.class);
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), notificationExecutorService,
messageDeletionExecutorService);
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class));
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY);
@ -94,6 +98,9 @@ class MessagePersisterIntegrationTest {
void tearDown() throws Exception {
notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(15, TimeUnit.SECONDS);
messageDeletionExecutorService.shutdown();
messageDeletionExecutorService.awaitTermination(15, TimeUnit.SECONDS);
}
@Test

View File

@ -22,6 +22,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.lettuce.core.cluster.SlotHash;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
@ -46,7 +47,7 @@ class MessagePersisterTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private ExecutorService notificationExecutorService;
private ExecutorService sharedExecutorService;
private MessagesCache messagesCache;
private MessagesDynamoDb messagesDynamoDb;
private MessagePersister messagePersister;
@ -74,9 +75,9 @@ class MessagePersisterTest {
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
notificationExecutorService = Executors.newSingleThreadExecutor();
sharedExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), sharedExecutorService, sharedExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY);
@ -88,7 +89,7 @@ class MessagePersisterTest {
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
for (final MessageProtos.Envelope message : messages) {
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid()));
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())).get();
}
return null;
@ -97,8 +98,8 @@ class MessagePersisterTest {
@AfterEach
void tearDown() throws Exception {
notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
sharedExecutorService.shutdown();
sharedExecutorService.awaitTermination(1, TimeUnit.SECONDS);
}
@Test

View File

@ -9,14 +9,26 @@ import static org.assertj.core.api.Assertions.assertThat;
import com.google.protobuf.ByteString;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.reactivestreams.Publisher;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.tests.util.MessageHelper;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
class MessagesDynamoDbTest {
@ -59,6 +71,7 @@ class MessagesDynamoDbTest {
MESSAGE3 = builder.build();
}
private ExecutorService messageDeletionExecutorService;
private MessagesDynamoDb messagesDynamoDb;
@ -67,8 +80,18 @@ class MessagesDynamoDbTest {
@BeforeEach
void setup() {
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
Duration.ofDays(14));
messageDeletionExecutorService = Executors.newSingleThreadExecutor();
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14),
messageDeletionExecutorService);
}
@AfterEach
void teardown() throws Exception {
messageDeletionExecutorService.shutdown();
messageDeletionExecutorService.awaitTermination(5, TimeUnit.SECONDS);
StepVerifier.resetDefaultTimeout();
}
@Test
@ -77,7 +100,7 @@ class MessagesDynamoDbTest {
final int destinationDeviceId = random.nextInt(255) + 1;
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
final List<MessageProtos.Envelope> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
final List<MessageProtos.Envelope> messagesStored = load(destinationUuid, destinationDeviceId,
MessagesDynamoDb.RESULT_SET_CHUNK_SIZE);
assertThat(messagesStored).isNotNull().hasSize(3);
final MessageProtos.Envelope firstMessage =
@ -88,6 +111,73 @@ class MessagesDynamoDbTest {
assertThat(messagesStored).element(2).isEqualTo(MESSAGE2);
}
@ParameterizedTest
@ValueSource(ints = {10, 100, 100, 1_000, 3_000})
void testLoadManyAfterInsert(final int messageCount) {
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final List<MessageProtos.Envelope> messages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
final Publisher<?> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, null);
final long firstRequest = Math.min(10, messageCount);
StepVerifier.setDefaultTimeout(Duration.ofSeconds(15));
StepVerifier.Step<?> step = StepVerifier.create(fetchedMessages, 0)
.expectSubscription()
.thenRequest(firstRequest)
.expectNextCount(firstRequest);
if (messageCount > firstRequest) {
step = step.thenRequest(messageCount)
.expectNextCount(messageCount - firstRequest);
}
step.thenCancel()
.verify();
}
@Test
void testLimitedLoad() {
final int messageCount = 200;
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final List<MessageProtos.Envelope> messages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
final int messageLoadLimit = 100;
final int halfOfMessageLoadLimit = messageLoadLimit / 2;
final Publisher<?> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, messageLoadLimit);
StepVerifier.setDefaultTimeout(Duration.ofSeconds(10));
final AtomicInteger messagesRemaining = new AtomicInteger(messageLoadLimit);
StepVerifier.create(fetchedMessages, 0)
.expectSubscription()
.thenRequest(halfOfMessageLoadLimit)
.expectNextCount(halfOfMessageLoadLimit)
// the first 100 should be fetched and buffered, but further requests should fail
.then(() -> dynamoDbExtension.stopServer())
.thenRequest(halfOfMessageLoadLimit)
.expectNextCount(halfOfMessageLoadLimit)
// weve consumed all the buffered messages, so a single request will fail
.thenRequest(1)
.expectError()
.verify();
}
@Test
void testDeleteForDestination() {
final UUID destinationUuid = UUID.randomUUID();
@ -96,18 +186,18 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@ -119,71 +209,79 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@Test
void testDeleteMessageByDestinationAndGuid() {
void testDeleteMessageByDestinationAndGuid() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid,
UUID.fromString(MESSAGE2.getServerGuid()));
UUID.fromString(MESSAGE2.getServerGuid())).get(5, TimeUnit.SECONDS);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
@Test
void testDeleteSingleMessage() {
void testDeleteSingleMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessage(secondDestinationUuid, 1,
UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp());
UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
private List<MessageProtos.Envelope> load(final UUID destinationUuid, final long destinationDeviceId,
final int count) {
return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count))
.take(count, true)
.collectList()
.block();
}
}

View File

@ -14,13 +14,11 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
class MessagesManagerTest {
private final MessagesDynamoDb messagesDynamoDb = mock(MessagesDynamoDb.class);
private final MessagesCache messagesCache = mock(MessagesCache.class);
private final PushLatencyManager pushLatencyManager = mock(PushLatencyManager.class);
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,

View File

@ -41,7 +41,7 @@ public class ProfilesManagerTest {
void setUp() {
//noinspection unchecked
commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.builder().stringCommands(commands).build();
profiles = mock(Profiles.class);

View File

@ -0,0 +1,28 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
public class MessageHelper {
public static MessageProtos.Envelope createMessage(UUID senderUuid, final int senderDeviceId, UUID destinationUuid,
long timestamp, String content) {
return MessageProtos.Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setServerTimestamp(0)
.setSourceUuid(senderUuid.toString())
.setSourceDevice(senderDeviceId)
.setDestinationUuid(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
}
}

View File

@ -5,70 +5,118 @@
package org.whispersystems.textsecuregcm.tests.util;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import java.util.function.Consumer;
import java.util.function.Function;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.util.function.Consumer;
import java.util.function.Function;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class RedisClusterHelper {
@SuppressWarnings("unchecked")
public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands<String, String> stringCommands) {
return buildMockRedisCluster(stringCommands, mock(RedisAdvancedClusterCommands.class));
public static RedisClusterHelper.Builder builder() {
return new Builder();
}
@SuppressWarnings("unchecked")
private static FaultTolerantRedisCluster buildMockRedisCluster(
final RedisAdvancedClusterCommands<String, String> stringCommands,
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands,
final RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands,
final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class);
final StatefulRedisClusterConnection<String, String> stringConnection = mock(StatefulRedisClusterConnection.class);
final StatefulRedisClusterConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisClusterConnection.class);
when(stringConnection.sync()).thenReturn(stringCommands);
when(binaryConnection.sync()).thenReturn(binaryCommands);
when(binaryConnection.async()).thenReturn(binaryAsyncCommands);
when(binaryConnection.reactive()).thenReturn(binaryReactiveCommands);
when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringConnection);
return null;
}).when(cluster).useCluster(any(Consumer.class));
when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringConnection);
return null;
}).when(cluster).useCluster(any(Consumer.class));
when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryConnection);
return null;
}).when(cluster).useBinaryCluster(any(Consumer.class));
when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryConnection);
return null;
}).when(cluster).useBinaryCluster(any(Consumer.class));
return cluster;
}
@SuppressWarnings("unchecked")
public static class Builder {
private RedisAdvancedClusterCommands<String, String> stringCommands = mock(RedisAdvancedClusterCommands.class);
private RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = mock(RedisAdvancedClusterCommands.class);
private RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands = mock(
RedisAdvancedClusterAsyncCommands.class);
private RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands = mock(
RedisAdvancedClusterReactiveCommands.class);
private Builder() {
}
@SuppressWarnings("unchecked")
public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands<String, String> stringCommands, final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands) {
final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class);
final StatefulRedisClusterConnection<String, String> stringConnection = mock(StatefulRedisClusterConnection.class);
final StatefulRedisClusterConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisClusterConnection.class);
when(stringConnection.sync()).thenReturn(stringCommands);
when(binaryConnection.sync()).thenReturn(binaryCommands);
when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringConnection);
return null;
}).when(cluster).useCluster(any(Consumer.class));
when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringConnection);
return null;
}).when(cluster).useCluster(any(Consumer.class));
when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryConnection);
return null;
}).when(cluster).useBinaryCluster(any(Consumer.class));
when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryConnection);
return null;
}).when(cluster).useBinaryCluster(any(Consumer.class));
return cluster;
public Builder stringCommands(final RedisAdvancedClusterCommands<String, String> stringCommands) {
this.stringCommands = stringCommands;
return this;
}
public Builder binaryCommands(final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands) {
this.binaryCommands = binaryCommands;
return this;
}
public Builder binaryAsyncCommands(final RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands) {
this.binaryAsyncCommands = binaryAsyncCommands;
return this;
}
public Builder binaryReactiveCommands(
final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
this.binaryReactiveCommands = binaryReactiveCommands;
return this;
}
public FaultTolerantRedisCluster build() {
return RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands, binaryAsyncCommands,
binaryReactiveCommands);
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -22,6 +22,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
@ -36,8 +37,10 @@ import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -56,6 +59,7 @@ import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import reactor.core.scheduler.Schedulers;
class WebSocketConnectionIntegrationTest {
@ -65,16 +69,13 @@ class WebSocketConnectionIntegrationTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final int SEND_FUTURES_TIMEOUT_MILLIS = 100;
private ExecutorService executorService;
private ExecutorService sharedExecutorService;
private MessagesDynamoDb messagesDynamoDb;
private MessagesCache messagesCache;
private ReportMessageManager reportMessageManager;
private Account account;
private Device device;
private WebSocketClient webSocketClient;
private WebSocketConnection webSocketConnection;
private ScheduledExecutorService retrySchedulingExecutor;
private long serialTimestamp = System.currentTimeMillis();
@ -82,11 +83,12 @@ class WebSocketConnectionIntegrationTest {
@BeforeEach
void setUp() throws Exception {
executorService = Executors.newSingleThreadExecutor();
sharedExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
REDIS_CLUSTER_EXTENSION.getRedisCluster(), executorService);
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
Duration.ofDays(7));
REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), sharedExecutorService, sharedExecutorService);
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(7),
sharedExecutorService);
reportMessageManager = mock(ReportMessageManager.class);
account = mock(Account.class);
device = mock(Device.class);
@ -96,30 +98,36 @@ class WebSocketConnectionIntegrationTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
SEND_FUTURES_TIMEOUT_MILLIS,
retrySchedulingExecutor);
}
@AfterEach
void tearDown() throws Exception {
executorService.shutdown();
executorService.awaitTermination(2, TimeUnit.SECONDS);
sharedExecutorService.shutdown();
sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS);
retrySchedulingExecutor.shutdown();
retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS);
}
@Test
void testProcessStoredMessages() {
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@ParameterizedTest
@CsvSource({
"207, 173, true",
"207, 173, false",
"323, 0, true",
"323, 0, false",
"0, 221, true",
"0, 221, false",
})
void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount,
final boolean useReactive) {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor,
useReactive);
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
@ -150,8 +158,8 @@ class WebSocketConnectionIntegrationTest {
final AtomicBoolean queueCleared = new AtomicBoolean(false);
when(successResponse.getStatus()).thenReturn(200);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
CompletableFuture.completedFuture(successResponse));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any()))
.thenReturn(CompletableFuture.completedFuture(successResponse));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer(
(Answer<CompletableFuture<WebSocketResponseMessage>>) invocation -> {
@ -194,8 +202,18 @@ class WebSocketConnectionIntegrationTest {
});
}
@Test
void testProcessStoredMessagesClientClosed() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testProcessStoredMessagesClientClosed(final boolean useReactive) {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor,
useReactive);
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@ -250,8 +268,20 @@ class WebSocketConnectionIntegrationTest {
});
}
@Test
void testProcessStoredMessagesSendFutureTimeout() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testProcessStoredMessagesSendFutureTimeout(final boolean useReactive) {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly
retrySchedulingExecutor,
useReactive,
Schedulers.boundedElastic());
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@ -346,4 +376,5 @@ class WebSocketConnectionIntegrationTest {
.setDestinationUuid(UUID.randomUUID().toString())
.build();
}
}

View File

@ -1,11 +1,14 @@
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<root level="warn">
<appender-ref ref="STDOUT" />
</root>
<root level="warn">
<appender-ref ref="STDOUT"/>
</root>
<!-- uncomment and combine with .log() in StepVerifier for more insight into reactor operations -->
<!-- <logger name="reactor" level="debug"/> -->
</configuration>