0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-19 19:42:18 +02:00

Multi-recipient message views

This adds support for storing multi-recipient message payloads and recipient views in Redis, and only fanning out on delivery or persistence. Phase 1: confirm storage and retrieval correctness.
This commit is contained in:
Chris Eager 2024-09-04 13:58:20 -05:00 committed by GitHub
parent d78c8370b6
commit 11601fd091
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 1544 additions and 328 deletions

View File

@ -632,7 +632,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
keyspaceNotificationDispatchExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock);
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager);
ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases,
recurringJobExecutor,
config.getClientReleaseConfiguration().refreshInterval(),

View File

@ -27,5 +27,4 @@ public class MessageCacheConfiguration {
public int getPersistDelayMinutes() {
return persistDelayMinutes;
}
}

View File

@ -5,21 +5,9 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import java.util.List;
import javax.validation.constraints.NotNull;
public record DynamicMessagesConfiguration(@NotNull List<DynamoKeyScheme> dynamoKeySchemes) {
public enum DynamoKeyScheme {
TRADITIONAL,
LAZY_DELETION;
}
public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean mrmViewExperimentEnabled) {
public DynamicMessagesConfiguration() {
this(List.of(DynamoKeyScheme.TRADITIONAL));
}
public DynamoKeyScheme writeKeyScheme() {
return dynamoKeySchemes().getLast();
this(false, false);
}
}

View File

@ -24,7 +24,6 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.security.MessageDigest;
import java.time.Clock;
import java.time.Duration;
@ -73,8 +72,8 @@ import javax.ws.rs.core.Response.Status;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
@ -261,7 +260,7 @@ public class MessageController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
@Operation(
@Operation(
summary = "Send a message",
description = """
Deliver a message to a single recipient. May be authenticated or unauthenticated; if unauthenticated,
@ -309,9 +308,10 @@ public class MessageController {
if (groupSendToken != null) {
if (!source.isEmpty() || !accessKey.isEmpty()) {
throw new BadRequestException("Group send endorsement tokens should not be combined with other authentication");
throw new BadRequestException(
"Group send endorsement tokens should not be combined with other authentication");
} else if (isStory) {
throw new BadRequestException("Group send endorsement tokens should not be sent for story messages");
throw new BadRequestException("Group send endorsement tokens should not be sent for story messages");
}
}
@ -346,8 +346,7 @@ public class MessageController {
}
final Optional<byte[]> spamReportToken = switch (senderType) {
case SENDER_TYPE_IDENTIFIED ->
reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination);
case SENDER_TYPE_IDENTIFIED -> reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination);
default -> Optional.empty();
};
@ -470,7 +469,7 @@ public class MessageController {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
@ -621,27 +620,28 @@ public class MessageController {
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> {
final Account account = recipient.account();
final Account account = recipient.account();
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet());
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.deviceIdToRegistrationId().entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.deviceIdToRegistrationId().entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
@ -667,6 +667,11 @@ public class MessageController {
}
try {
@Nullable final byte[] sharedMrmKey =
dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().storeSharedMrmData()
? messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage)
: null;
CompletableFuture.allOf(
recipients.values().stream()
.flatMap(recipientData -> {
@ -692,8 +697,7 @@ public class MessageController {
sentMessageCounter.increment();
sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp,
online,
isStory, isUrgent, payload);
online, isStory, isUrgent, payload, sharedMrmKey);
},
multiRecipientMessageExecutor));
})
@ -739,8 +743,8 @@ public class MessageController {
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
.map(account ->
account.getUnidentifiedAccessKey()
.filter(b -> b.length == keyLength)
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.filter(b -> b.length == keyLength)
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.reduce(new byte[keyLength],
(a, b) -> {
final byte[] xor = new byte[keyLength];
@ -828,23 +832,28 @@ public class MessageController {
auth.getAuthenticatedDevice(),
uuid,
null)
.thenAccept(maybeDeletedMessage -> {
maybeDeletedMessage.ifPresent(deletedMessage -> {
.thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> {
WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getServerTimestamp(),
auth.getAuthenticatedDevice());
WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(),
auth.getAuthenticatedDevice());
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().isPresent()
&& removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) {
try {
receiptSender.sendReceipt(
ServiceIdentifier.valueOf(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(),
AciServiceIdentifier.valueOf(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(),
aciServiceIdentifier, removedMessage.clientTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
}
} else {
// If source service ID is present and the envelope type is not a server delivery receipt, then
// the source service ID *should always* be an ACI -- PNIs are receive-only, so they can only be the
// "source" via server delivery receipts
logger.warn("Source service ID unexpectedly a PNI service ID");
}
});
})
}
}))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@ -943,19 +952,25 @@ public class MessageController {
boolean online,
boolean story,
boolean urgent,
byte[] payload) {
byte[] payload,
@Nullable byte[] sharedMrmKey) {
final Envelope.Builder messageBuilder = Envelope.newBuilder();
final long serverTimestamp = System.currentTimeMillis();
messageBuilder
.setType(Type.UNIDENTIFIED_SENDER)
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFrom(payload))
.setStory(story)
.setUrgent(urgent)
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString());
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString());
if (sharedMrmKey != null) {
messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
}
// mrm views phase 1: always set content
messageBuilder.setContent(ByteString.copyFrom(payload));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
}

View File

@ -31,15 +31,15 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
envelopeBuilder.setType(envelopeType)
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis())
.setDestinationUuid(destinationIdentifier.toServiceIdentifierString())
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setStory(story)
.setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) {
envelopeBuilder
.setSourceUuid(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(sourceDeviceId.intValue());
}

View File

@ -40,15 +40,15 @@ public record OutgoingMessageEntity(UUID guid,
public MessageProtos.Envelope toEnvelope() {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type()))
.setTimestamp(timestamp())
.setClientTimestamp(timestamp())
.setServerTimestamp(serverTimestamp())
.setDestinationUuid(destinationUuid().toServiceIdentifierString())
.setDestinationServiceId(destinationUuid().toServiceIdentifierString())
.setServerGuid(guid().toString())
.setStory(story)
.setUrgent(urgent);
if (sourceUuid() != null) {
builder.setSourceUuid(sourceUuid().toServiceIdentifierString());
builder.setSourceServiceId(sourceUuid().toServiceIdentifierString());
builder.setSourceDevice(sourceDevice());
}
@ -72,10 +72,10 @@ public record OutgoingMessageEntity(UUID guid,
return new OutgoingMessageEntity(
UUID.fromString(envelope.getServerGuid()),
envelope.getType().getNumber(),
envelope.getTimestamp(),
envelope.hasSourceUuid() ? ServiceIdentifier.valueOf(envelope.getSourceUuid()) : null,
envelope.getClientTimestamp(),
envelope.hasSourceServiceId() ? ServiceIdentifier.valueOf(envelope.getSourceServiceId()) : null,
envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? ServiceIdentifier.valueOf(envelope.getDestinationUuid()) : null,
envelope.hasDestinationServiceId() ? ServiceIdentifier.valueOf(envelope.getDestinationServiceId()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.getContent().toByteArray(),
envelope.getServerTimestamp(),

View File

@ -50,11 +50,11 @@ public final class MessageMetrics {
public void measureAccountEnvelopeUuidMismatches(final Account account,
final MessageProtos.Envelope envelope) {
if (envelope.hasDestinationUuid()) {
if (envelope.hasDestinationServiceId()) {
try {
measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationUuid()));
measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationServiceId()));
} catch (final IllegalArgumentException ignored) {
logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid());
logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationServiceId());
}
}
}

View File

@ -92,7 +92,7 @@ public class MessageSender {
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceUuid()))
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment();
}
}

View File

@ -45,10 +45,10 @@ public class ReceiptSender {
destinationAccount -> {
final Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis())
.setSourceUuid(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice((int) sourceDeviceId)
.setDestinationUuid(destinationIdentifier.toServiceIdentifierString())
.setTimestamp(messageId)
.setSourceServiceId(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice(sourceDeviceId)
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setClientTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
.setUrgent(false);

View File

@ -138,12 +138,13 @@ public class ChangeNumberManager {
final long serverTimestamp = System.currentTimeMillis();
final Envelope envelope = Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type()))
.setTimestamp(serverTimestamp)
.setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setDestinationServiceId(
new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setContent(ByteString.copyFrom(contents.get()))
.setSourceUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice((int) Device.PRIMARY_ID)
.setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.build();

View File

@ -8,10 +8,10 @@ package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ScoredValue;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
@ -20,6 +20,7 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
@ -38,14 +39,17 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Pair;
@ -57,6 +61,62 @@ import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
/**
* Manages short-term storage of messages in Redis. Messages are frequently delivered to their destination and deleted
* shortly after they reach the server, and this cache acts as a low-latency holding area for new messages, reducing
* load on higher-latency, longer-term storage systems. Redis in particular provides keyspace notifications, which act
* as a form of pub-sub notifications to alert listeners when new messages arrive.
* <p>
* The following structures are used:
* <dl>
* <dt>{@code queueKey}</code></dt>
* <dd>A sorted set of messages in a devices queue. A messages score is its queue-local message ID. See
* <a href="https://redis.io/docs/latest/develop/use/patterns/twitter-clone/#the-sorted-set-data-type">Redis.io: The
* Sorted Set data type</a> for background on scores and this data structure.</dd>
* <dt>{@code queueMetadataKey}</dt>
* <dd>A hash containing message guids and their queue-local message ID. It also contains a {@code counter} key, which is
* incremented to supply the next message ID. This is used to remove a message by GUID from {@code queueKey} by its
* local messageId.</dd>
* <dt>{@code sharedMrmKey}</dt>
* <dd>A hash containing a single multi-recipient message pending delivery. It contains:
* <ul>
* <li>{@code data} - the serialized SealedSenderMultiRecipientMessage data</li>
* <li>fields with each recipient device's view into the payload ({@link SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)}</li>
* </ul>
* Note: this is shared among all of the message's recipients, and it may be located in any Redis shard. As each recipients
* message is delivered, its corresponding view is idempotently removed. When {@code data} is the only remaining
* field, the hash will be deleted.
* </dd>
* <dt>{@code queueLockKey}</dt>
* <dd>Used to indicate that a queue is being modified by the {@link MessagePersister} and that {@code get_items} should
* return an empty list.</dd>
* <dt>{@code queueTotalIndexKey}</dt>
* <dd>A sorted set of all queues in a shard. A queues score is the timestamp of its oldest message, which is used by
* the {@link MessagePersister} to prioritize queues to persist.</dd>
* </dl>
* <p>
* At a high level, the process is:
* <ol>
* <li>Insert: the queue metadata is queried for the next incremented message ID. The message data is inserted into
* the queue at that ID, and the message GUID is inserted in the queue metadata.</li>
* <li>Get: a batch of messages are retrieved from the queue, potentially with an after-message-ID offset.</li>
* <li>Remove: a set of messages are remove by GUID. For each GUID, the message ID is retrieved from the queue metadata,
* and then that single-value range is removed from the queue.</li>
* </ol>
* For multi-recipient messages (sometimes abbreviated MRM), there are similar operations on the common data during
* insert, get, and remove. MRM inserts must occur before individual queue inserts, while removal is considered
* best-effort, and uses key expiration as back-stop garbage collection.
* <p>
* For atomicity, many operations are implemented as Lua scripts that are executed on the Redis server using
* {@code EVAL}/{@code EVALSHA}.
*
* @see MessagesCacheInsertScript
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
* @see MessagesCacheGetItemsScript
* @see MessagesCacheRemoveByGuidScript
* @see MessagesCacheRemoveRecipientViewFromMrmDataScript
* @see MessagesCacheRemoveQueueScript
*/
public class MessagesCache extends RedisClusterPubSubAdapter<String, String> implements Managed {
private final FaultTolerantRedisCluster redisCluster;
@ -69,17 +129,22 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
// messageDeletionExecutorService wrapped into a reactor Scheduler
private final Scheduler messageDeletionScheduler;
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript;
private final ClusterLuaScript removeQueueScript;
private final ClusterLuaScript getQueuesToPersistScript;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final MessagesCacheInsertScript insertScript;
private final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript;
private final MessagesCacheRemoveByGuidScript removeByGuidScript;
private final MessagesCacheGetItemsScript getItemsScript;
private final MessagesCacheRemoveQueueScript removeQueueScript;
private final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript;
private final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript;
private final ReentrantLock messageListenersLock = new ReentrantLock();
private final Map<String, MessageAvailabilityListener> messageListenersByQueueName = new HashMap<>();
private final Map<MessageAvailabilityListener, String> queueNamesByMessageListener = new IdentityHashMap<>();
private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert"));
private final Timer insertSharedMrmPayloadTimer = Metrics.timer(name(MessagesCache.class, "insertSharedMrmPayload"));
private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get"));
private final Timer getQueuesToPersistTimer = Metrics.timer(name(MessagesCache.class, "getQueuesToPersist"));
private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid"));
@ -95,6 +160,9 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
name(MessagesCache.class, "messageAvailabilityListenerRemovedAfterAdd"));
private final Counter prunedStaleSubscriptionCounter = Metrics.counter(
name(MessagesCache.class, "prunedStaleSubscription"));
private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved"));
private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter(
name(MessagesCache.class, "sharedMrmKeyRemoved"));
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
@ -102,16 +170,49 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::";
private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::";
private static final String MRM_VIEWS_EXPERIMENT_NAME = "mrmViews";
@VisibleForTesting
static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
private static final String GET_FLUX_NAME = MetricsUtil.name(MessagesCache.class, "get");
private static final int PAGE_SIZE = 100;
private static final int REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY = 8;
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
public MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService,
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock)
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager)
throws IOException {
this(
redisCluster,
notificationExecutorService,
messageDeliveryScheduler,
messageDeletionExecutorService,
clock,
dynamicConfigurationManager,
new MessagesCacheInsertScript(redisCluster),
new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(redisCluster),
new MessagesCacheGetItemsScript(redisCluster),
new MessagesCacheRemoveByGuidScript(redisCluster),
new MessagesCacheRemoveQueueScript(redisCluster),
new MessagesCacheGetQueuesToPersistScript(redisCluster),
new MessagesCacheRemoveRecipientViewFromMrmDataScript(redisCluster)
);
}
@VisibleForTesting
MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService,
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final MessagesCacheInsertScript insertScript,
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript,
final MessagesCacheGetItemsScript getItemsScript, final MessagesCacheRemoveByGuidScript removeByGuidScript,
final MessagesCacheRemoveQueueScript removeQueueScript,
final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript,
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript)
throws IOException {
this.redisCluster = redisCluster;
@ -123,14 +224,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
this.messageDeletionExecutorService = messageDeletionExecutorService;
this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion");
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua",
ScriptOutputType.MULTI);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua",
ScriptOutputType.STATUS);
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua",
ScriptOutputType.MULTI);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.insertScript = insertScript;
this.insertMrmScript = insertMrmScript;
this.removeByGuidScript = removeByGuidScript;
this.getItemsScript = getItemsScript;
this.removeQueueScript = removeQueueScript;
this.getQueuesToPersistScript = getQueuesToPersistScript;
this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript;
}
@Override
@ -164,51 +266,51 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice,
final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
return (long) insertTimer.record(() ->
insertScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getServerTimestamp()).getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8))));
return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid));
}
public CompletableFuture<Optional<MessageProtos.Envelope>> remove(final UUID destinationUuid,
public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid,
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid);
insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage));
return sharedMrmKey;
}
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice,
final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid))
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.get(0)));
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst()));
}
@SuppressWarnings("unchecked")
public CompletableFuture<List<MessageProtos.Envelope>> remove(final UUID destinationUuid,
final byte destinationDevice,
final List<UUID> messageGuids) {
public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice, final List<UUID> messageGuids) {
final Timer.Sample sample = Timer.start();
return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.collect(Collectors.toList()))
.thenApplyAsync(result -> {
List<byte[]> serialized = (List<byte[]>) result;
return removeByGuidScript.execute(destinationUuid, destinationDevice, messageGuids)
.thenApplyAsync(serialized -> {
final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size());
final List<RemovedMessage> removedMessages = new ArrayList<>(serialized.size());
final List<byte[]> sharedMrmKeysToUpdate = new ArrayList<>();
for (final byte[] bytes : serialized) {
try {
removedMessages.add(MessageProtos.Envelope.parseFrom(bytes));
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes);
removedMessages.add(RemovedMessage.fromEnvelope(envelope));
if (envelope.hasSharedMrmKey()) {
sharedMrmKeysToUpdate.add(envelope.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
removeRecipientViewFromMrmData(sharedMrmKeysToUpdate, destinationUuid, destinationDevice);
return removedMessages;
}, messageDeletionExecutorService)
.whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
}, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
}
public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) {
@ -251,7 +353,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message,
long earliestAllowableTimestamp) {
return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp;
return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp;
}
private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice,
@ -283,37 +385,101 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
// we want to ensure we dont accidentally block the Lettuce/netty i/o executors
.publishOn(messageDeliveryScheduler)
.map(Pair::first)
.flatMapIterable(queueItems -> {
final List<MessageProtos.Envelope> envelopes = new ArrayList<>(queueItems.size() / 2);
.concatMap(queueItems -> {
final List<Mono<MessageProtos.Envelope>> envelopes = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
envelopes.add(message);
final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) {
maybeRunMrmViewExperiment(message, destinationUuid, destinationDevice);
// mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content
messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build());
} else {
messageMono = Mono.just(message);
}
envelopes.add(messageMono);
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
return envelopes;
return Flux.mergeSequential(envelopes);
});
}
private Flux<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice,
/**
* Runs the fetch and compare logic for the MRM view experiment, if it is enabled.
*
* @see DynamicMessagesConfiguration#mrmViewExperimentEnabled()
*/
private void maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessage, final UUID destinationUuid,
final byte destinationDevice) {
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration()
.mrmViewExperimentEnabled()) {
final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);
final byte[] key = mrmMessage.getSharedMrmKey().toByteArray();
final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey(
new AciServiceIdentifier(destinationUuid), destinationDevice);
final Mono<MessageProtos.Envelope> mrmMessageMono = Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList()
.publishOn(messageDeliveryScheduler)
.handle((mrmDataAndView, sink) -> {
try {
assert mrmDataAndView.size() == 2;
final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient(
mrmDataAndView.getFirst().getValue(),
mrmDataAndView.getLast().getValue());
sink.next(mrmMessage.toBuilder()
.clearSharedMrmKey()
.setContent(ByteString.copyFrom(content))
.build());
mrmContentRetrievedCounter.increment();
} catch (Exception e) {
sink.error(e);
}
})));
experiment.compareFutureResult(mrmMessage.toBuilder().clearSharedMrmKey().build(),
mrmMessageMono.toFuture());
}
}
/**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/
private void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid,
final byte deviceId) {
Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(
keys -> removeRecipientViewFromMrmDataScript.execute(keys, new AciServiceIdentifier(accountUuid), deviceId),
REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY)
.subscribe(sharedMrmDataKeyRemovedCounter::increment, e -> logger.warn("Error removing recipient view", e));
}
private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid,
final byte destinationDevice,
long messageId) {
return getItemsScript.executeBinaryReactive(
List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8),
String.valueOf(messageId).getBytes(StandardCharsets.UTF_8)))
.map(result -> {
return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId)
.map(queueItems -> {
logger.trace("Processing page: {}", messageId);
@SuppressWarnings("unchecked")
List<byte[]> queueItems = (List<byte[]>) result;
if (queueItems.isEmpty()) {
return new Pair<>(Collections.emptyList(), null);
}
@ -324,7 +490,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
final long lastMessageId = Long.parseLong(
new String(queueItems.get(queueItems.size() - 1), StandardCharsets.UTF_8));
new String(queueItems.getLast(), StandardCharsets.UTF_8));
return new Pair<>(queueItems, lastMessageId);
});
@ -362,10 +528,35 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
public CompletableFuture<Void> clear(final UUID destinationUuid, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return removeQueueScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, deviceId),
getMessageQueueMetadataKey(destinationUuid, deviceId),
getQueueIndexKey(destinationUuid, deviceId)),
Collections.emptyList())
return removeQueueScript.execute(destinationUuid, deviceId, Collections.emptyList())
.publishOn(messageDeletionScheduler)
.expand(messagesToProcess -> {
if (messagesToProcess.isEmpty()) {
return Mono.empty();
}
final List<byte[]> mrmKeys = new ArrayList<>(messagesToProcess.size());
final List<String> processedMessages = new ArrayList<>(messagesToProcess.size());
for (byte[] serialized : messagesToProcess) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(serialized);
processedMessages.add(message.getServerGuid());
if (message.hasSharedMrmKey()) {
mrmKeys.add(message.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
removeRecipientViewFromMrmData(mrmKeys, destinationUuid, deviceId);
return removeQueueScript.execute(destinationUuid, deviceId, processedMessages);
})
.then()
.toFuture()
.thenRun(() -> sample.stop(clearQueueTimer));
}
@ -375,11 +566,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) {
//noinspection unchecked
return getQueuesToPersistTimer.record(() -> (List<String>) getQueuesToPersistScript.execute(
List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)),
List.of(String.valueOf(maxTime.toEpochMilli()),
String.valueOf(limit))));
return getQueuesToPersistTimer.record(() -> getQueuesToPersistScript.execute(slot, maxTime, limit));
}
void addQueueToPersist(final UUID accountUuid, final byte deviceId) {
@ -538,29 +725,36 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return channel.substring(startOfHashTag + 1, endOfHashTag);
}
@VisibleForTesting
static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) {
static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) {
static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) {
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
}
private static byte[] getQueueIndexKey(final int slot) {
static byte[] getQueueIndexKey(final int slot) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) {
static byte[] getSharedMrmKey(final UUID mrmGuid) {
return ("mrm::{" + mrmGuid.toString() + "}").getBytes(StandardCharsets.UTF_8);
}
static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getUnlinkInProgressKey(final UUID accountUuid) {
return ("user_account_unlinking::{" + accountUuid + "}").getBytes(StandardCharsets.UTF_8);
static byte[] getSharedMrmViewKey(final AciServiceIdentifier serviceIdentifier, final byte deviceId) {
final ByteBuffer keyBb = ByteBuffer.allocate(18);
keyBb.put(serviceIdentifier.toFixedWidthByteArray());
keyBb.put(deviceId);
assert !keyBb.hasRemaining();
return keyBb.array();
}
static UUID getAccountUuidFromQueueName(final String queueName) {

View File

@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Retrieves a list of messages and their corresponding queue-local IDs for the device. To support streaming processing,
* the last queue-local message ID from a previous call may be used as the {@code afterMessageId}.
*/
class MessagesCacheGetItemsScript {
private final ClusterLuaScript getItemsScript;
MessagesCacheGetItemsScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.OBJECT);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
int limit, long afterMessageId) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getPersistInProgressKey(destinationUuid, destinationDevice) // queueLockKey
);
final List<byte[]> args = List.of(
String.valueOf(limit).getBytes(StandardCharsets.UTF_8), // limit
String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8) // afterMessageId
);
//noinspection unchecked
return getItemsScript.executeBinaryReactive(keys, args)
.map(result -> (List<byte[]>) result)
.next();
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.List;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Returns a list of queues that may be persisted. They will be sorted from oldest to more recent, limited by the
* {@code maxTime} argument.
*
* @see MessagePersister
*/
class MessagesCacheGetQueuesToPersistScript {
private final ClusterLuaScript getQueuesToPersistScript;
MessagesCacheGetQueuesToPersistScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua",
ScriptOutputType.MULTI);
}
List<String> execute(final int slot, final Instant maxTime, final int limit) {
final List<String> keys = List.of(
new String(MessagesCache.getQueueIndexKey(slot), StandardCharsets.UTF_8) // queueTotalIndexKey
);
final List<String> args = List.of(
String.valueOf(maxTime.toEpochMilli()), // maxTime
String.valueOf(limit) // limit
);
//noinspection unchecked
return (List<String>) getQueuesToPersistScript.execute(keys, args);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Inserts an envelope into the message queue for a destination device.
*/
class MessagesCacheInsertScript {
private final ClusterLuaScript insertScript;
MessagesCacheInsertScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
}
long execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
assert envelope.hasServerGuid();
assert envelope.hasServerTimestamp();
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = new ArrayList<>(Arrays.asList(
envelope.toByteArray(), // message
String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime
envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid
));
return (long) insertScript.executeBinary(keys, args);
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as
* fields in the hash.
*
* @see SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)
*/
class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
private final ClusterLuaScript script;
MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster)
throws IOException {
this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua",
ScriptOutputType.INTEGER);
}
void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) {
final List<byte[]> keys = List.of(
sharedMrmKey // sharedMrmKey
);
// Pre-allocate capacity for the most fields we expect -- 6 devices per recipient, plus the data field.
final List<byte[]> args = new ArrayList<>(message.getRecipients().size() * 6 + 1);
args.add(message.serialized());
message.getRecipients().forEach((serviceId, recipient) -> {
for (byte device : recipient.getDevices()) {
final byte[] key = new byte[18];
System.arraycopy(serviceId.toServiceIdFixedWidthBinary(), 0, key, 0, 17);
key[17] = device;
args.add(key);
args.add(message.serializedRecipientView(recipient));
}
});
script.executeBinary(keys, args);
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Removes a list of message GUIDs from the queue of a destination device.
*/
class MessagesCacheRemoveByGuidScript {
private final ClusterLuaScript removeByGuidScript;
MessagesCacheRemoveByGuidScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua",
ScriptOutputType.OBJECT);
}
CompletableFuture<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
final List<UUID> messageGuids) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.toList();
//noinspection unchecked
return removeByGuidScript.executeBinaryAsync(keys, args)
.thenApply(result -> (List<byte[]>) result);
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Removes a device's queue from the cache. For a non-empty queue, this script must be executed multiple times.
* <ol>
* <li>The first call will return a list of messages to check for {@code sharedMrmKeys}. If a {@code sharedMrmKey} is present, {@link MessagesCacheRemoveRecipientViewFromMrmDataScript} must be called.</li>
* <li>Once theses messages have been processed, this script should be called again, confirming that the messages have been processed.</li>
* <li>This should be repeated until the script returns an empty list, as the script only returns a page ({@value PAGE_SIZE}) of messages at a time.</li>
* </ol>
*/
class MessagesCacheRemoveQueueScript {
private static final int PAGE_SIZE = 100;
private final ClusterLuaScript removeQueueScript;
MessagesCacheRemoveQueueScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua",
ScriptOutputType.MULTI);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
final List<String> processedMessageGuids) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = new ArrayList<>();
args.addFirst(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8)); // limit
args.addAll(processedMessageGuids.stream().map(guid -> guid.getBytes(StandardCharsets.UTF_8))
.toList()); // processedMessageGuids
//noinspection unchecked
return removeQueueScript.executeBinaryReactive(keys, args)
.map(result -> (List<byte[]>) result)
.next();
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Removes the given destination device from the given {@code sharedMrmKeys}. If there are no devices remaining in the
* hash as a result, the shared payload is deleted.
* <p>
* NOTE: Callers are responsible for ensuring that all keys are in the same slot.
*/
class MessagesCacheRemoveRecipientViewFromMrmDataScript {
private final ClusterLuaScript removeRecipientViewFromMrmDataScript;
MessagesCacheRemoveRecipientViewFromMrmDataScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeRecipientViewFromMrmDataScript = ClusterLuaScript.fromResource(redisCluster,
"lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER);
}
Mono<Long> execute(final Collection<byte[]> keysCollection, final AciServiceIdentifier serviceIdentifier,
final byte deviceId) {
final List<byte[]> keys = keysCollection instanceof List<byte[]>
? (List<byte[]>) keysCollection
: new ArrayList<>(keysCollection);
return removeRecipientViewFromMrmDataScript.executeBinaryReactive(keys,
List.of(MessagesCache.getSharedMrmViewKey(serviceIdentifier, deviceId)))
.map(o -> (long) o)
.next();
}
}

View File

@ -19,8 +19,10 @@ import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.Pair;
@ -62,8 +64,8 @@ public class MessagesManager {
messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message);
if (message.hasSourceUuid() && !destinationUuid.toString().equals(message.getSourceUuid())) {
reportMessageManager.store(message.getSourceUuid(), messageGuid);
if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) {
reportMessageManager.store(message.getSourceServiceId(), messageGuid);
}
}
@ -137,7 +139,7 @@ public class MessagesManager {
return messagesCache.clear(destinationUuid, deviceId);
}
public CompletableFuture<Optional<Envelope>> delete(UUID destinationUuid, Device destinationDevice, UUID guid,
public CompletableFuture<Optional<RemovedMessage>> delete(UUID destinationUuid, Device destinationDevice, UUID guid,
@Nullable Long serverTimestamp) {
return messagesCache.remove(destinationUuid, destinationDevice.getId(), guid)
.thenComposeAsync(removed -> {
@ -146,12 +148,16 @@ public class MessagesManager {
return CompletableFuture.completedFuture(removed);
}
final CompletableFuture<Optional<MessageProtos.Envelope>> maybeDeletedEnvelope;
if (serverTimestamp == null) {
return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, destinationDevice, guid);
maybeDeletedEnvelope = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid,
destinationDevice, guid);
} else {
return messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid, serverTimestamp);
maybeDeletedEnvelope = messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid,
serverTimestamp);
}
return maybeDeletedEnvelope.thenApply(maybeEnvelope -> maybeEnvelope.map(RemovedMessage::fromEnvelope));
}, messageDeletionExecutor);
}
@ -194,4 +200,14 @@ public class MessagesManager {
messagesCache.removeMessageAvailabilityListener(listener);
}
/**
* Inserts the shared multi-recipient message payload to storage.
*
* @return a key where the shared data is stored
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
*/
public byte[] insertSharedMultiRecipientMessagePayload(
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage);
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.Optional;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public record RemovedMessage(Optional<ServiceIdentifier> sourceServiceId, ServiceIdentifier destinationServiceId,
@VisibleForTesting UUID serverGuid, long serverTimestamp, long clientTimestamp,
MessageProtos.Envelope.Type envelopeType) {
public static RemovedMessage fromEnvelope(MessageProtos.Envelope envelope) {
return new RemovedMessage(
envelope.hasSourceServiceId()
? Optional.of(ServiceIdentifier.valueOf(envelope.getSourceServiceId()))
: Optional.empty(),
ServiceIdentifier.valueOf(envelope.getDestinationServiceId()),
UUID.fromString(envelope.getServerGuid()),
envelope.getServerTimestamp(),
envelope.getClientTimestamp(),
envelope.getType()
);
}
}

View File

@ -294,16 +294,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
}
private void sendDeliveryReceiptFor(Envelope message) {
if (!message.hasSourceUuid()) {
if (!message.hasSourceServiceId()) {
return;
}
try {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationUuid()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceUuid()),
message.getTimestamp());
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getClientTimestamp());
} catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceUuid());
logger.error("Could not parse UUID: {}", message.getSourceServiceId());
} catch (Exception e) {
logger.warn("Failed to send receipt", e);
}

View File

@ -205,7 +205,7 @@ record CommandDependencies(
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
recurringJobExecutor, keyspaceNotificationDispatchExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient,
configuration.getDynamoDbTables().getReportMessage().getTableName(),

View File

@ -11,30 +11,31 @@ option java_outer_classname = "MessageProtos";
message Envelope {
enum Type {
UNKNOWN = 0;
CIPHERTEXT = 1;
KEY_EXCHANGE = 2;
PREKEY_BUNDLE = 3;
UNKNOWN = 0;
CIPHERTEXT = 1;
KEY_EXCHANGE = 2;
PREKEY_BUNDLE = 3;
SERVER_DELIVERY_RECEIPT = 5;
UNIDENTIFIED_SENDER = 6;
reserved 7;
PLAINTEXT_CONTENT = 8; // for decryption error receipts
}
optional Type type = 1;
optional string source_uuid = 11;
optional Type type = 1;
optional string source_service_id = 11;
optional uint32 source_device = 7;
optional uint64 timestamp = 5;
optional bytes content = 8; // Contains an encrypted Content
optional uint64 client_timestamp = 5;
optional bytes content = 8; // Contains an encrypted Content
optional string server_guid = 9;
optional uint64 server_timestamp = 10;
optional bool ephemeral = 12; // indicates that the message should not be persisted if the recipient is offline
optional string destination_uuid = 13;
optional string destination_service_id = 13;
optional bool urgent = 14 [default=true];
optional string updated_pni = 15;
optional bool story = 16; // indicates that the content is a story.
optional bytes report_spam_token = 17; // token sent when reporting spam
// next: 18
optional bytes shared_mrm_key = 18; // indicates content should be fetched from multi-recipient message datastore
// next: 19
}
message ProvisioningUuid {
@ -42,25 +43,25 @@ message ProvisioningUuid {
}
message ServerCertificate {
message Certificate {
optional uint32 id = 1;
optional bytes key = 2;
}
message Certificate {
optional uint32 id = 1;
optional bytes key = 2;
}
optional bytes certificate = 1;
optional bytes signature = 2;
optional bytes certificate = 1;
optional bytes signature = 2;
}
message SenderCertificate {
message Certificate {
optional string sender = 1;
optional string sender_uuid = 6;
optional uint32 sender_device = 2;
optional fixed64 expires = 3;
optional bytes identity_key = 4;
optional ServerCertificate signer = 5;
}
message Certificate {
optional string sender = 1;
optional string sender_uuid = 6;
optional uint32 sender_device = 2;
optional fixed64 expires = 3;
optional bytes identity_key = 4;
optional ServerCertificate signer = 5;
}
optional bytes certificate = 1;
optional bytes signature = 2;
optional bytes certificate = 1;
optional bytes signature = 2;
}

View File

@ -46,7 +46,7 @@ local getNextInterval = function(interval)
end
local results = redis.call("ZRANGEBYSCORE", pendingNotificationQueue, 0, maxTime, "LIMIT", 0, limit)
local results = redis.call("ZRANGE", pendingNotificationQueue, 0, maxTime, "BYSCORE", "LIMIT", 0, limit)
local collated = {}
if results and next(results) then

View File

@ -1,7 +1,10 @@
local queueKey = KEYS[1]
local queueLockKey = KEYS[2]
local limit = ARGV[1]
local afterMessageId = ARGV[2]
-- gets messages from a device's queue, up to a given limit
-- returns a list of all envelopes and their queue-local IDs
local queueKey = KEYS[1] -- sorted set of all Envelopes for a device, scored by queue-local ID
local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persistent and must not be read
local limit = ARGV[1] -- [number] the maximum number of messages to return
local afterMessageId = ARGV[2] -- [number] a queue-local ID to exclusively start after, to support pagination. Use -1 to start at the beginning
local locked = redis.call("GET", queueLockKey)
@ -9,17 +12,8 @@ if locked then
return {}
end
if afterMessageId == "null" then
-- An index range is inclusive
local min = 0
local max = limit - 1
if max < 0 then
return {}
end
return redis.call("ZRANGE", queueKey, min, max, "WITHSCORES")
else
-- note: this is deprecated in Redis 6.2, and should be migrated to zrange after the cluster is updated
return redis.call("ZRANGEBYSCORE", queueKey, "("..afterMessageId, "+inf", "WITHSCORES", "LIMIT", 0, limit)
if afterMessageId == "null" or afterMessageId == nil then
return redis.error_reply("ERR afterMessageId is required")
end
return redis.call("ZRANGE", queueKey, "("..afterMessageId, "+inf", "BYSCORE", "LIMIT", 0, limit, "WITHSCORES")

View File

@ -1,8 +1,10 @@
local queueTotalIndexKey = KEYS[1]
local maxTime = ARGV[1]
local limit = ARGV[2]
-- returns a list of queues that meet persistence criteria
local results = redis.call("ZRANGEBYSCORE", queueTotalIndexKey, 0, maxTime, "LIMIT", 0, limit)
local queueTotalIndexKey = KEYS[1] -- sorted set of all queues in the shard, by timestamp of oldest message
local maxTime = ARGV[1] -- [number] the most recent queue timestamp that may be fetched
local limit = ARGV[2] -- [number] the maximum number of queues to fetch
local results = redis.call("ZRANGE", queueTotalIndexKey, 0, maxTime, "BYSCORE", "LIMIT", 0, limit)
if results and next(results) then
redis.call("ZREM", queueTotalIndexKey, unpack(results))

View File

@ -1,9 +1,12 @@
local queueKey = KEYS[1]
local queueMetadataKey = KEYS[2]
local queueTotalIndexKey = KEYS[3]
local message = ARGV[1]
local currentTime = ARGV[2]
local guid = ARGV[3]
-- inserts a message into a device's queue, and updates relevant associated data
-- returns a number, the queue-local message ID (useful for testing)
local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local message = ARGV[1] -- [bytes] the Envelope to insert
local currentTime = ARGV[2] -- [number] the message timestamp, to sort the queue in the queueTotalIndex
local guid = ARGV[3] -- [string] the message GUID
if redis.call("HEXISTS", queueMetadataKey, guid) == 1 then
return tonumber(redis.call("HGET", queueMetadataKey, guid))
@ -14,9 +17,8 @@ local messageId = redis.call("HINCRBY", queueMetadataKey, "counter", 1)
redis.call("ZADD", queueKey, "NX", messageId, message)
redis.call("HSET", queueMetadataKey, guid, messageId)
redis.call("EXPIRE", queueKey, 7776000) -- 90 days
redis.call("EXPIRE", queueMetadataKey, 7776000) -- 90 days
redis.call("EXPIRE", queueKey, 2678400) -- 31 days
redis.call("EXPIRE", queueMetadataKey, 2678400) -- 31 days
redis.call("ZADD", queueTotalIndexKey, "NX", currentTime, queueKey)
return messageId

View File

@ -0,0 +1,13 @@
-- inserts shared multi-recipient message data
local sharedMrmKey = KEYS[1] -- [string] the key containing the shared MRM data
local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data
-- the remainder of ARGV is list of recipient keys and view data
redis.call("HSET", sharedMrmKey, "data", mrmData);
redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days
-- unpack() fails with "too many results" at very large table sizes, so we loop
for i = 2, #ARGV, 2 do
redis.call("HSET", sharedMrmKey, ARGV[i], ARGV[i + 1])
end

View File

@ -1,20 +1,26 @@
local queueKey = KEYS[1]
local queueMetadataKey = KEYS[2]
local queueTotalIndexKey = KEYS[3]
-- removes a list of messages by ID from the cluster, returning the deleted messages
-- returns a list of removed envelopes
-- Note: content may be absent for MRM messages, and for these messages, the caller must update the sharedMrmKey
-- to remove the recipient's reference
local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local messageGuids = ARGV -- [list[string]] message GUIDs
local removedMessages = {}
for _, guid in ipairs(ARGV) do
for _, guid in ipairs(messageGuids) do
local messageId = redis.call("HGET", queueMetadataKey, guid)
if messageId then
local envelope = redis.call("ZRANGEBYSCORE", queueKey, messageId, messageId, "LIMIT", 0, 1)
local envelope = redis.call("ZRANGE", queueKey, messageId, messageId, "BYSCORE", "LIMIT", 0, 1)
redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId)
redis.call("HDEL", queueMetadataKey, guid)
if envelope and next(envelope) then
removedMessages[#removedMessages + 1] = envelope[1]
table.insert(removedMessages, envelope[1])
end
end
end

View File

@ -1,7 +1,29 @@
local queueKey = KEYS[1]
local queueMetadataKey = KEYS[2]
local queueTotalIndexKey = KEYS[3]
-- incrementally removes a given device's queue and associated data
-- returns: a page of messages and scores.
-- The messages must be checked for mrmKeys to update. After updating MRM keys, this script must be called again
-- with processedMessageGuids. If the returned table is empty, then
-- the queue has been fully deleted.
redis.call("DEL", queueKey)
redis.call("DEL", queueMetadataKey)
redis.call("ZREM", queueTotalIndexKey, queueKey)
local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local limit = ARGV[1] -- the maximum number of messages to return
local processedMessageGuids = { unpack(ARGV, 2) }
for _, guid in ipairs(processedMessageGuids) do
local messageId = redis.call("HGET", queueMetadataKey, guid)
if messageId then
redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId)
redis.call("HDEL", queueMetadataKey, guid)
end
end
local messages = redis.call("ZRANGE", queueKey, 0, limit-1)
if #messages == 0 then
redis.call("DEL", queueKey)
redis.call("DEL", queueMetadataKey)
redis.call("ZREM", queueTotalIndexKey, queueKey)
end
return messages

View File

@ -0,0 +1,17 @@
-- Removes the given recipient view from the shared MRM data. If the only field remaining after the removal is the
-- `data` field, then the key will be deleted
local sharedMrmKeys = KEYS -- KEYS: list of all keys in a single slot to update
local recipientViewToRemove = ARGV[1] -- the recipient view to remove from the hash
local keysDeleted = 0
for _, sharedMrmKey in ipairs(sharedMrmKeys) do
redis.call("HDEL", sharedMrmKey, recipientViewToRemove)
if redis.call("HLEN", sharedMrmKey) == 1 then
redis.call("DEL", sharedMrmKey)
keysDeleted = keysDeleted + 1
end
end
return keysDeleted

View File

@ -87,6 +87,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
@ -121,6 +122,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.RemovedMessage;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -251,6 +253,7 @@ class MessageControllerTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
@ -311,7 +314,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid());
assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice());
assertTrue(captor.getValue().getUrgent());
}
@ -353,7 +356,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid());
assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice());
assertFalse(captor.getValue().getUrgent());
}
@ -375,7 +378,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid());
assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice());
}
}
@ -410,7 +413,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertFalse(captor.getValue().hasSourceUuid());
assertFalse(captor.getValue().hasSourceServiceId());
assertFalse(captor.getValue().hasSourceDevice());
}
}
@ -444,7 +447,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) {
verify(messageSender).sendMessage(
any(Account.class), any(Device.class), argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
any(Account.class), any(Device.class), argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(false));
} else {
verifyNoMoreInteractions(messageSender);
@ -732,23 +735,27 @@ class MessageControllerTest {
@Test
void testDeleteMessages() {
long timestamp = System.currentTimeMillis();
long clientTimestamp = System.currentTimeMillis();
UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID();
final long serverTimestamp = 0;
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid1, null))
.thenReturn(
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(
new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)),
new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid1, serverTimestamp, clientTimestamp,
Envelope.Type.CIPHERTEXT))));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid2, null))
.thenReturn(
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0))));
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(
new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)),
new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid2, serverTimestamp, clientTimestamp,
Envelope.Type.SERVER_DELIVERY_RECEIPT))));
UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid3, null))
@ -766,7 +773,7 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1),
eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp));
eq(new AciServiceIdentifier(sourceUuid)), eq(clientTimestamp));
}
try (final Response response = resources.getJerseyTest()
@ -1068,9 +1075,16 @@ class MessageControllerTest {
}
private record Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
Byte[] deviceId,
Integer[] registrationId,
byte[] perRecipientKeyMaterial) {
Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
this(uuid, new Byte[]{deviceId}, new Integer[]{registrationId}, perRecipientKeyMaterial);
}
}
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
@ -1081,8 +1095,13 @@ class MessageControllerTest {
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
}
bb.put(r.deviceId()); // device id (1 byte)
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
assert (r.deviceId.length == r.registrationId.length);
for (int i = 0; i < r.deviceId.length; i++) {
final int hasMore = i == r.deviceId.length - 1 ? 0 : 0x8000;
bb.put(r.deviceId()[i]); // device id (1 byte)
bb.putShort((short) (r.registrationId()[i] | hasMore)); // registration id (2 bytes)
}
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
@ -1157,7 +1176,7 @@ class MessageControllerTest {
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.put(entity)) {
assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200)));
@ -1206,7 +1225,7 @@ class MessageControllerTest {
.queryParam("story", isStory)
.queryParam("urgent", urgent)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader)
.put(entity)) {
@ -1216,7 +1235,7 @@ class MessageControllerTest {
.sendMessage(
any(),
any(),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
@ -1384,7 +1403,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@ -1395,7 +1414,7 @@ class MessageControllerTest {
.sendMessage(
any(),
any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(true));
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
@ -1423,7 +1442,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@ -1454,7 +1473,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@ -1620,7 +1639,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
@ -1663,7 +1682,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
@ -1702,7 +1721,7 @@ class MessageControllerTest {
.queryParam("story", true)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class)))
@ -1730,14 +1749,14 @@ class MessageControllerTest {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString())
.setDestinationServiceId(destinationUuid.toString())
.setStory(story)
.setServerGuid(guid.toString());
if (sourceUuid != null) {
builder.setSourceUuid(sourceUuid.toString());
builder.setSourceServiceId(sourceUuid.toString());
builder.setSourceDevice(sourceDevice);
}

View File

@ -104,7 +104,7 @@ class MessageMetricsTest {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
if (destinationIdentifier != null) {
builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString());
builder.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString());
}
return builder.build();

View File

@ -151,7 +151,7 @@ class MessageSenderTest {
private MessageProtos.Envelope generateRandomMessage() {
return MessageProtos.Envelope.newBuilder()
.setTimestamp(System.currentTimeMillis())
.setClientTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis())
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)

View File

@ -54,8 +54,8 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

View File

@ -160,8 +160,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@ -208,8 +208,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@ -254,8 +254,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@ -296,8 +296,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@ -340,8 +340,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}

View File

@ -81,7 +81,7 @@ class MessagePersisterIntegrationTest {
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService,
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
@ -185,12 +185,12 @@ class MessagePersisterIntegrationTest {
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long serverTimestamp) {
return MessageProtos.Envelope.newBuilder()
.setTimestamp(serverTimestamp * 2) // client timestamp may not be accurate
.setClientTimestamp(serverTimestamp * 2) // client timestamp may not be accurate
.setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setDestinationServiceId(UUID.randomUUID().toString())
.build();
}
}

View File

@ -40,6 +40,7 @@ import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
@ -48,12 +49,11 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class MessagePersisterTest {
@RegisterExtension
@ -104,7 +104,7 @@ class MessagePersisterTest {
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager,
keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1);
@ -356,7 +356,8 @@ class MessagePersisterTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)

View File

@ -0,0 +1,74 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.lettuce.core.RedisCommandExecutionException;
import io.lettuce.core.ScriptOutputType;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheGetItemsScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheGetItemsScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final String serverGuid = UUID.randomUUID().toString();
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(serverGuid)
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1)
.block(Duration.ofSeconds(1));
assertNotNull(messageAndScores);
assertEquals(2, messageAndScores.size());
final MessageProtos.Envelope resultEnvelope = MessageProtos.Envelope.parseFrom(
messageAndScores.getFirst());
assertEquals(serverGuid, resultEnvelope.getServerGuid());
}
@Test
void testCacheGetItemsInvalidParameter() throws Exception {
final ClusterLuaScript getItemsScript = ClusterLuaScript.fromResource(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
"lua/get_items.lua", ScriptOutputType.OBJECT);
final byte[] fakeKey = new byte[]{1};
final Exception e = assertThrows(RedisCommandExecutionException.class,
() -> getItemsScript.executeBinaryReactive(List.of(fakeKey, fakeKey),
List.of("1".getBytes(StandardCharsets.UTF_8)))
.next()
.block(Duration.ofSeconds(1)));
assertEquals("ERR afterMessageId is required", e.getMessage());
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Instant;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheInsertScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheInsertScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1));
final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
assertEquals(2, insertScript.execute(destinationUuid, deviceId, envelope2));
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1),
"Repeated with same guid should have same message ID");
}
}

View File

@ -0,0 +1,74 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ParameterizedTest
@MethodSource
void testInsert(final int count, final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey));
assertEquals(totalDevices + 1, hashFieldCount);
}
public static List<Arguments> testInsert() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Arguments> testCases = new ArrayList<>();
testCases.add(Arguments.of(1, singleAccount));
for (int j = 1000; j <= 30000; j += 1000) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
.mapToObj(v -> (byte) v)
.toList());
return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices);
})
.collect(Collectors.toMap(Pair::first, Pair::second));
testCases.add(Arguments.of(j, manyAccounts));
}
return testCases;
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheRemoveByGuidScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheRemoveByGuid() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID serverGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(serverGuid.toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> removedMessages = removeByGuidScript.execute(destinationUuid, deviceId,
List.of(serverGuid)).get(1, TimeUnit.SECONDS);
assertEquals(1, removedMessages.size());
final MessageProtos.Envelope resultMessage = MessageProtos.Envelope.parseFrom(
removedMessages.getFirst());
assertEquals(serverGuid, UUID.fromString(resultMessage.getServerGuid()));
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheRemoveQueueScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheRemoveQueueScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messagesToCheckForMrmKeys = removeScript.execute(destinationUuid, deviceId,
Collections.emptyList())
.block(Duration.ofSeconds(1));
assertEquals(1, messagesToCheckForMrmKeys.size());
}
}

View File

@ -0,0 +1,124 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.lettuce.core.cluster.SlotHash;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuples;
class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ParameterizedTest
@MethodSource
void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet())
.flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId))))
.flatMap(aciServiceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey),
aciServiceIdentifierByteTuple.getT1(), aciServiceIdentifierByteTuple.getT2()))
.reduce(Long::sum)
.block(Duration.ofSeconds(35)));
assertEquals(1, keysRemoved);
final long keyExists = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmKey));
assertEquals(0, keyExists);
}
public static List<Map<AciServiceIdentifier, List<Byte>>> testUpdateSingleKey() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Map<AciServiceIdentifier, List<Byte>>> testCases = new ArrayList<>();
testCases.add(singleAccount);
// Generate a more, from smallish to very large
for (int j = 1000; j <= 81000; j *= 3) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
.mapToObj(v -> (byte) v)
.toList());
return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices);
})
.collect(Collectors.toMap(Pair::first, Pair::second));
testCases.add(manyAccounts);
}
return testCases;
}
@ParameterizedTest
@ValueSource(ints = {1, 10, 100, 1000, 10000})
void testUpdateManyKeys(int keyCount) throws Exception {
final List<byte[]> sharedMrmKeys = new ArrayList<>(keyCount);
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;
for (int i = 0; i < keyCount; i++) {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(aciServiceIdentifier, deviceId));
sharedMrmKeys.add(sharedMrmKey);
}
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, aciServiceIdentifier, deviceId))
.reduce(Long::sum)
.block(Duration.ofSeconds(5)));
assertEquals(sharedMrmKeys.size(), keysRemoved);
}
}

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
@ -25,6 +26,8 @@ import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.RedisCommand;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
@ -32,9 +35,12 @@ import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
@ -42,11 +48,13 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@ -57,7 +65,12 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
@ -83,6 +96,8 @@ class MessagesCacheTest {
private Scheduler messageDeliveryScheduler;
private MessagesCache messagesCache;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final byte DESTINATION_DEVICE_ID = 7;
@ -95,11 +110,16 @@ class MessagesCacheTest {
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
sharedExecutorService = Executors.newSingleThreadExecutor();
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesCache.start();
}
@ -148,10 +168,10 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<MessageProtos.Envelope> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
assertEquals(Optional.of(message), maybeRemovedMessage);
assertEquals(Optional.of(RemovedMessage.fromEnvelope(message)), maybeRemovedMessage);
}
@ParameterizedTest
@ -181,11 +201,11 @@ class MessagesCacheTest {
message);
}
final List<MessageProtos.Envelope> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid()))
.collect(Collectors.toList())).get(5, TimeUnit.SECONDS);
assertEquals(messagesToRemove, removedMessages);
assertEquals(messagesToRemove.stream().map(RemovedMessage::fromEnvelope).toList(), removedMessages);
assertEquals(messagesToPreserve,
messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@ -283,7 +303,8 @@ class MessagesCacheTest {
}
final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock);
sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock,
dynamicConfigurationManager);
final List<MessageProtos.Envelope> actualMessages = Flux.from(
messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID))
@ -320,7 +341,7 @@ class MessagesCacheTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
final int messageCount = 1000;
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
@ -340,7 +361,7 @@ class MessagesCacheTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
final int messageCount = 1000;
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
@ -542,6 +563,57 @@ class MessagesCacheTest {
});
}
@Test
void testMultiRecipientMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID mrmGuid = UUID.randomUUID();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(
new AciServiceIdentifier(destinationUuid), deviceId);
final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm);
final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(guid, true)
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 1: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid)))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(guid, destinationUuid, deviceId, message);
assertEquals(1L, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))));
final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1);
assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(new ServiceId.Aci(destinationUuid));
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationUuid, deviceId, guid)
.join();
assertTrue(removedMessage.isPresent());
assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString()));
assertTrue(get(destinationUuid, deviceId, 1).isEmpty());
// updating the shared MRM data is purely async, so we just wait for it
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
boolean exists;
do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)));
} while (exists);
});
}
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
final int messageCount) {
return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))
@ -573,7 +645,7 @@ class MessagesCacheTest {
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(mockCluster, mock(ExecutorService.class), messageDeliveryScheduler,
Executors.newSingleThreadExecutor(), Clock.systemUTC());
Executors.newSingleThreadExecutor(), Clock.systemUTC(), mock(DynamicConfigurationManager.class));
}
@AfterEach
@ -755,18 +827,85 @@ class MessagesCacheTest {
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender,
final long timestamp) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString());
.setDestinationServiceId(UUID.randomUUID().toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1)
.setSourceUuid(UUID.randomUUID().toString());
.setSourceServiceId(UUID.randomUUID().toString());
}
return envelopeBuilder.build();
}
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(
Map<AciServiceIdentifier, List<Byte>> destinations) {
try {
final ByteBuffer prefix = ByteBuffer.allocate(7);
prefix.put((byte) 0x23); // version
writeVarint(prefix, destinations.size()); // recipient count
prefix.flip();
List<ByteBuffer> recipients = new ArrayList<>(destinations.size());
for (Map.Entry<AciServiceIdentifier, List<Byte>> aciServiceIdentifierAndDeviceIds : destinations.entrySet()) {
final AciServiceIdentifier destination = aciServiceIdentifierAndDeviceIds.getKey();
final List<Byte> deviceIds = aciServiceIdentifierAndDeviceIds.getValue();
assert deviceIds.size() < 255;
final ByteBuffer recipient = ByteBuffer.allocate(17 + 3 * deviceIds.size() + 48);
recipient.put(destination.toFixedWidthByteArray());
for (int i = 0; i < deviceIds.size(); i++) {
final int hasMore = i == deviceIds.size() - 1 ? 0x0000 : 0x8000;
recipient.put(new byte[]{deviceIds.get(i)}); // device ID
recipient.putShort((short) ((100 + deviceIds.get(i)) | hasMore)); // registration ID
}
final byte[] keyMaterial = new byte[48];
ThreadLocalRandom.current().nextBytes(keyMaterial);
recipient.put(keyMaterial);
recipients.add(recipient);
}
final byte[] commonPayload = new byte[64];
ThreadLocalRandom.current().nextBytes(commonPayload);
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
baos.write(prefix.array(), 0, prefix.limit());
for (ByteBuffer recipient : recipients) {
baos.write(recipient.array());
}
baos.write(commonPayload);
return SealedSenderMultiRecipientMessage.parse(baos.toByteArray());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(AciServiceIdentifier destination,
byte... deviceIds) {
final Map<AciServiceIdentifier, List<Byte>> destinations = new HashMap<>();
destinations.put(destination, Arrays.asList(ArrayUtils.toObject(deviceIds)));
return generateRandomMrmMessage(destinations);
}
private static void writeVarint(ByteBuffer bb, long n) {
while (n >= 0x80) {
bb.put((byte) (n & 0x7F | 0x80));
n = n >> 7;
}
bb.put((byte) (n & 0x7F));
}
}

View File

@ -6,8 +6,6 @@
package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.time.Duration;
@ -31,7 +29,6 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.MessageHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
@ -47,31 +44,31 @@ class MessagesDynamoDbTest {
final long serverTimestamp = System.currentTimeMillis();
MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER);
builder.setTimestamp(123456789L);
builder.setClientTimestamp(123456789L);
builder.setContent(ByteString.copyFrom(new byte[]{(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}));
builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp);
builder.setDestinationUuid(UUID.randomUUID().toString());
builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE1 = builder.build();
builder.setType(MessageProtos.Envelope.Type.CIPHERTEXT);
builder.setSourceUuid(UUID.randomUUID().toString());
builder.setSourceServiceId(UUID.randomUUID().toString());
builder.setSourceDevice(1);
builder.setContent(ByteString.copyFromUtf8("MOO"));
builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp + 1);
builder.setDestinationUuid(UUID.randomUUID().toString());
builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE2 = builder.build();
builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER);
builder.clearSourceUuid();
builder.clearSourceDevice();
builder.clearSourceDevice();
builder.setContent(ByteString.copyFromUtf8("COW"));
builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp); // Test same millisecond arrival for two different messages
builder.setDestinationUuid(UUID.randomUUID().toString());
builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE3 = builder.build();
}

View File

@ -35,7 +35,7 @@ class MessagesManagerTest {
void insert() {
final UUID sourceAci = UUID.randomUUID();
final Envelope message = Envelope.newBuilder()
.setSourceUuid(sourceAci.toString())
.setSourceServiceId(sourceAci.toString())
.build();
final UUID destinationUuid = UUID.randomUUID();
@ -45,7 +45,7 @@ class MessagesManagerTest {
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
final Envelope syncMessage = Envelope.newBuilder(message)
.setSourceUuid(destinationUuid.toString())
.setSourceServiceId(destinationUuid.toString())
.build();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);

View File

@ -17,11 +17,11 @@ public class MessageHelper {
return MessageProtos.Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(0)
.setSourceUuid(senderUuid.toString())
.setSourceServiceId(senderUuid.toString())
.setSourceDevice(senderDeviceId)
.setDestinationUuid(destinationUuid.toString())
.setDestinationServiceId(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
}

View File

@ -44,6 +44,7 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
@ -55,6 +56,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
@ -85,6 +87,8 @@ class WebSocketConnectionIntegrationTest {
private Scheduler messageDeliveryScheduler;
private ClientReleaseManager clientReleaseManager;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private long serialTimestamp = System.currentTimeMillis();
@BeforeEach
@ -92,8 +96,10 @@ class WebSocketConnectionIntegrationTest {
sharedExecutorService = Executors.newSingleThreadExecutor();
scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7),
sharedExecutorService);
@ -381,12 +387,12 @@ class WebSocketConnectionIntegrationTest {
final long timestamp = serialTimestamp++;
return MessageProtos.Envelope.newBuilder()
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setDestinationServiceId(UUID.randomUUID().toString())
.build();
}

View File

@ -48,7 +48,6 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -297,19 +296,19 @@ class WebSocketConnectionTest {
final Envelope firstMessage = Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(UUID.randomUUID().toString())
.setDestinationUuid(accountUuid.toString())
.setSourceServiceId(UUID.randomUUID().toString())
.setDestinationServiceId(accountUuid.toString())
.setUpdatedPni(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setClientTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT)
.build();
final Envelope secondMessage = Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(senderTwoUuid.toString())
.setDestinationUuid(accountUuid.toString())
.setTimestamp(System.currentTimeMillis())
.setSourceServiceId(senderTwoUuid.toString())
.setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(System.currentTimeMillis())
.setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT)
.build();
@ -365,7 +364,7 @@ class WebSocketConnectionTest {
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getTimestamp()));
eq(secondMessage.getClientTimestamp()));
connection.stop();
verify(client).close(anyInt(), anyString());
@ -616,10 +615,10 @@ class WebSocketConnectionTest {
final byte[] body = argument.get();
try {
final Envelope envelope = Envelope.parseFrom(body);
if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) {
if (!envelope.hasSourceServiceId() || envelope.getSourceServiceId().length() == 0) {
return false;
}
return envelope.getSourceUuid().equals(senderUuid.toString());
return envelope.getSourceServiceId().equals(senderUuid.toString());
} catch (InvalidProtocolBufferException e) {
return false;
}
@ -627,7 +626,7 @@ class WebSocketConnectionTest {
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) {
private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
@ -933,11 +932,11 @@ class WebSocketConnectionTest {
return Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setType(Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(0)
.setSourceUuid(senderUuid.toString())
.setSourceServiceId(senderUuid.toString())
.setSourceDevice(SOURCE_DEVICE_ID)
.setDestinationUuid(destinationUuid.toString())
.setDestinationServiceId(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
}