From 1c617284f390dfd5943d77bf5e41fb240195f1cd Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Fri, 6 Sep 2024 17:22:41 -0500 Subject: [PATCH] Add MRM views experiment to `MessagesCache.getMessagesToPersist()` --- .../storage/MessagePersister.java | 19 ++----- .../textsecuregcm/storage/MessagesCache.java | 50 ++++++++++++------ .../MessagePersisterServiceCommand.java | 5 +- .../MessagePersisterIntegrationTest.java | 4 +- .../storage/MessagePersisterTest.java | 10 +--- .../storage/MessagesCacheTest.java | 52 ++++++++++++++++++- 6 files changed, 90 insertions(+), 50 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 23045769..ec86f95d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -15,22 +15,14 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.time.Duration; import java.time.Instant; -import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.util.Util; -import reactor.core.publisher.Flux; -import reactor.util.function.Tuple2; -import reactor.util.function.Tuples; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; public class MessagePersister implements Managed { @@ -38,8 +30,6 @@ public class MessagePersister implements Managed { private final MessagesCache messagesCache; private final MessagesManager messagesManager; private final AccountsManager accountsManager; - private final ClientPresenceManager clientPresenceManager; - private final KeysManager keysManager; private final Duration persistDelay; @@ -72,17 +62,14 @@ public class MessagePersister implements Managed { private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class); public MessagePersister(final MessagesCache messagesCache, final MessagesManager messagesManager, - final AccountsManager accountsManager, final ClientPresenceManager clientPresenceManager, - final KeysManager keysManager, - final DynamicConfigurationManager dynamicConfigurationManager, - final Duration persistDelay, + final AccountsManager accountsManager, + final DynamicConfigurationManager dynamicConfigurationManager, final Duration persistDelay, final int dedicatedProcessWorkerThreadCount ) { + this.messagesCache = messagesCache; this.messagesManager = messagesManager; this.accountsManager = accountsManager; - this.clientPresenceManager = clientPresenceManager; - this.keysManager = keysManager; this.persistDelay = persistDelay; this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount]; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 0797e3f1..d791bbea 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -11,7 +11,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; -import io.lettuce.core.ScoredValue; import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; @@ -493,8 +492,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp .subscribe(); } - private Mono, Long>> getNextMessagePage(final UUID destinationUuid, - final byte destinationDevice, + private Mono, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, long messageId) { return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId) @@ -520,22 +518,40 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp @VisibleForTesting List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, final int limit) { - return getMessagesTimer.record(() -> { - final List> scoredMessages = redisCluster.withBinaryCluster( - connection -> connection.sync() - .zrangeWithScores(getMessageQueueKey(accountUuid, destinationDevice), 0, limit)); - final List envelopes = new ArrayList<>(scoredMessages.size()); - for (final ScoredValue scoredMessage : scoredMessages) { - try { - envelopes.add(MessageProtos.Envelope.parseFrom(scoredMessage.getValue())); - } catch (InvalidProtocolBufferException e) { - logger.warn("Failed to parse envelope", e); - } - } + final Timer.Sample sample = Timer.start(); - return envelopes; - }); + final List messages = redisCluster.withBinaryCluster(connection -> + connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit)); + + return Flux.fromIterable(messages) + .mapNotNull(message -> { + try { + return MessageProtos.Envelope.parseFrom(message); + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + return null; + } + }) + .concatMap(message -> { + final Mono messageMono; + if (message.hasSharedMrmKey()) { + final Mono experimentMono = maybeRunMrmViewExperiment(message, accountUuid, destinationDevice); + + // mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content + // To avoid races, wait for the experiment to run, but ignore any errors + messageMono = experimentMono + .onErrorComplete() + .then(Mono.just(message.toBuilder().clearSharedMrmKey().build())); + } else { + messageMono = Mono.just(message); + } + + return messageMono; + }) + .collectList() + .doOnTerminate(() -> sample.stop(getMessagesTimer)) + .block(Duration.ofSeconds(5)); } public CompletableFuture clear(final UUID destinationUuid) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java index 1a89e444..25a10b82 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MessagePersisterServiceCommand.java @@ -61,10 +61,7 @@ public class MessagePersisterServiceCommand extends ServerCommand { final UUID destinationUuid = invocation.getArgument(0); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 54061bc6..700b16a8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -206,8 +206,7 @@ class MessagesCacheTest { .collect(Collectors.toList())).get(5, TimeUnit.SECONDS); assertEquals(messagesToRemove.stream().map(RemovedMessage::fromEnvelope).toList(), removedMessages); - assertEquals(messagesToPreserve, - messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(messagesToPreserve, get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } @Test @@ -620,6 +619,55 @@ class MessagesCacheTest { }, "Shared MRM data should be deleted asynchronously"); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) throws Exception { + final UUID destinationUuid = UUID.randomUUID(); + final byte deviceId = 1; + + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(destinationUuid, true); + + messagesCache.insert(messageGuid, destinationUuid, deviceId, message); + + final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage( + new AciServiceIdentifier(destinationUuid), deviceId); + + final byte[] sharedMrmDataKey; + if (sharedMrmKeyPresent) { + sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); + } else { + sharedMrmDataKey = new byte[]{1}; + } + + final UUID mrmMessageGuid = UUID.randomUUID(); + final MessageProtos.Envelope mrmMessage = generateRandomMessage(mrmMessageGuid, true) + .toBuilder() + // clear some things added by the helper + .clearServerGuid() + // mrm views phase 1: messages have content + .setContent( + ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid))))) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) + .build(); + messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage); + + final List messages = get(destinationUuid, deviceId, 100); + + assertEquals(2, messages.size()); + + assertEquals(message.toBuilder() + .setServerGuid(messageGuid.toString()) + .build(), + messages.getFirst()); + + assertEquals(mrmMessage.toBuilder(). + clearSharedMrmKey(). + setServerGuid(mrmMessageGuid.toString()) + .build(), + messages.getLast()); + } + private List get(final UUID destinationUuid, final byte destinationDeviceId, final int messageCount) { return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))