0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-19 19:42:18 +02:00

Multi-recipient message views

This adds support for storing multi-recipient message payloads and recipient views in Redis, and only fanning out on delivery or persistence. Phase 1: confirm storage and retrieval correctness.
This commit is contained in:
Chris Eager 2024-09-04 13:58:20 -05:00 committed by GitHub
parent d78c8370b6
commit 11601fd091
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 1544 additions and 328 deletions

View File

@ -632,7 +632,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
keyspaceNotificationDispatchExecutor); keyspaceNotificationDispatchExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock); messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager);
ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases, ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases,
recurringJobExecutor, recurringJobExecutor,
config.getClientReleaseConfiguration().refreshInterval(), config.getClientReleaseConfiguration().refreshInterval(),

View File

@ -27,5 +27,4 @@ public class MessageCacheConfiguration {
public int getPersistDelayMinutes() { public int getPersistDelayMinutes() {
return persistDelayMinutes; return persistDelayMinutes;
} }
} }

View File

@ -5,21 +5,9 @@
package org.whispersystems.textsecuregcm.configuration.dynamic; package org.whispersystems.textsecuregcm.configuration.dynamic;
import java.util.List; public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean mrmViewExperimentEnabled) {
import javax.validation.constraints.NotNull;
public record DynamicMessagesConfiguration(@NotNull List<DynamoKeyScheme> dynamoKeySchemes) {
public enum DynamoKeyScheme {
TRADITIONAL,
LAZY_DELETION;
}
public DynamicMessagesConfiguration() { public DynamicMessagesConfiguration() {
this(List.of(DynamoKeyScheme.TRADITIONAL)); this(false, false);
}
public DynamoKeyScheme writeKeyScheme() {
return dynamoKeySchemes().getLast();
} }
} }

View File

@ -24,7 +24,6 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
@ -73,8 +72,8 @@ import javax.ws.rs.core.Response.Status;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
@ -261,7 +260,7 @@ public class MessageController {
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ManagedAsync @ManagedAsync
@Operation( @Operation(
summary = "Send a message", summary = "Send a message",
description = """ description = """
Deliver a message to a single recipient. May be authenticated or unauthenticated; if unauthenticated, Deliver a message to a single recipient. May be authenticated or unauthenticated; if unauthenticated,
@ -309,9 +308,10 @@ public class MessageController {
if (groupSendToken != null) { if (groupSendToken != null) {
if (!source.isEmpty() || !accessKey.isEmpty()) { if (!source.isEmpty() || !accessKey.isEmpty()) {
throw new BadRequestException("Group send endorsement tokens should not be combined with other authentication"); throw new BadRequestException(
"Group send endorsement tokens should not be combined with other authentication");
} else if (isStory) { } else if (isStory) {
throw new BadRequestException("Group send endorsement tokens should not be sent for story messages"); throw new BadRequestException("Group send endorsement tokens should not be sent for story messages");
} }
} }
@ -346,8 +346,7 @@ public class MessageController {
} }
final Optional<byte[]> spamReportToken = switch (senderType) { final Optional<byte[]> spamReportToken = switch (senderType) {
case SENDER_TYPE_IDENTIFIED -> case SENDER_TYPE_IDENTIFIED -> reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination);
reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination);
default -> Optional.empty(); default -> Optional.empty();
}; };
@ -470,7 +469,7 @@ public class MessageController {
throw new WebApplicationException(Response.status(409) throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE) .type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(), .entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices())) e.getExtraDevices()))
.build()); .build());
} catch (StaleDevicesException e) { } catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410) throw new WebApplicationException(Response.status(410)
@ -621,27 +620,28 @@ public class MessageController {
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> { recipients.values().forEach(recipient -> {
final Account account = recipient.account(); final Account account = recipient.account();
try { try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet()); DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds( DestinationDeviceValidator.validateRegistrationIds(
account, account,
recipient.deviceIdToRegistrationId().entrySet(), recipient.deviceIdToRegistrationId().entrySet(),
Map.Entry<Byte, Short>::getKey, Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()), e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI); recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
accountMismatchedDevices.add( accountMismatchedDevices.add(
new AccountMismatchedDevices( new AccountMismatchedDevices(
recipient.serviceIdentifier(), recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) { } catch (StaleDevicesException e) {
accountStaleDevices.add( accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
} }
}); });
if (!accountMismatchedDevices.isEmpty()) { if (!accountMismatchedDevices.isEmpty()) {
return Response return Response
.status(409) .status(409)
@ -667,6 +667,11 @@ public class MessageController {
} }
try { try {
@Nullable final byte[] sharedMrmKey =
dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().storeSharedMrmData()
? messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage)
: null;
CompletableFuture.allOf( CompletableFuture.allOf(
recipients.values().stream() recipients.values().stream()
.flatMap(recipientData -> { .flatMap(recipientData -> {
@ -692,8 +697,7 @@ public class MessageController {
sentMessageCounter.increment(); sentMessageCounter.increment();
sendCommonPayloadMessage( sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp,
online, online, isStory, isUrgent, payload, sharedMrmKey);
isStory, isUrgent, payload);
}, },
multiRecipientMessageExecutor)); multiRecipientMessageExecutor));
}) })
@ -739,8 +743,8 @@ public class MessageController {
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) .filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
.map(account -> .map(account ->
account.getUnidentifiedAccessKey() account.getUnidentifiedAccessKey()
.filter(b -> b.length == keyLength) .filter(b -> b.length == keyLength)
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.reduce(new byte[keyLength], .reduce(new byte[keyLength],
(a, b) -> { (a, b) -> {
final byte[] xor = new byte[keyLength]; final byte[] xor = new byte[keyLength];
@ -828,23 +832,28 @@ public class MessageController {
auth.getAuthenticatedDevice(), auth.getAuthenticatedDevice(),
uuid, uuid,
null) null)
.thenAccept(maybeDeletedMessage -> { .thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> {
maybeDeletedMessage.ifPresent(deletedMessage -> {
WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getServerTimestamp(), WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(),
auth.getAuthenticatedDevice()); auth.getAuthenticatedDevice());
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { if (removedMessage.sourceServiceId().isPresent()
&& removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) {
try { try {
receiptSender.sendReceipt( receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(),
ServiceIdentifier.valueOf(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), aciServiceIdentifier, removedMessage.clientTimestamp());
AciServiceIdentifier.valueOf(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
} catch (Exception e) { } catch (Exception e) {
logger.warn("Failed to send delivery receipt", e); logger.warn("Failed to send delivery receipt", e);
} }
} else {
// If source service ID is present and the envelope type is not a server delivery receipt, then
// the source service ID *should always* be an ACI -- PNIs are receive-only, so they can only be the
// "source" via server delivery receipts
logger.warn("Source service ID unexpectedly a PNI service ID");
} }
}); }
}) }))
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }
@ -943,19 +952,25 @@ public class MessageController {
boolean online, boolean online,
boolean story, boolean story,
boolean urgent, boolean urgent,
byte[] payload) { byte[] payload,
@Nullable byte[] sharedMrmKey) {
final Envelope.Builder messageBuilder = Envelope.newBuilder(); final Envelope.Builder messageBuilder = Envelope.newBuilder();
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = System.currentTimeMillis();
messageBuilder messageBuilder
.setType(Type.UNIDENTIFIED_SENDER) .setType(Type.UNIDENTIFIED_SENDER)
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) .setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFrom(payload))
.setStory(story) .setStory(story)
.setUrgent(urgent) .setUrgent(urgent)
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString()); .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString());
if (sharedMrmKey != null) {
messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
}
// mrm views phase 1: always set content
messageBuilder.setContent(ByteString.copyFrom(payload));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} }

View File

@ -31,15 +31,15 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder(); final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
envelopeBuilder.setType(envelopeType) envelopeBuilder.setType(envelopeType)
.setTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setDestinationUuid(destinationIdentifier.toServiceIdentifierString()) .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setStory(story) .setStory(story)
.setUrgent(urgent); .setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) { if (sourceAccount != null && sourceDeviceId != null) {
envelopeBuilder envelopeBuilder
.setSourceUuid(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString()) .setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(sourceDeviceId.intValue()); .setSourceDevice(sourceDeviceId.intValue());
} }

View File

@ -40,15 +40,15 @@ public record OutgoingMessageEntity(UUID guid,
public MessageProtos.Envelope toEnvelope() { public MessageProtos.Envelope toEnvelope() {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type())) .setType(MessageProtos.Envelope.Type.forNumber(type()))
.setTimestamp(timestamp()) .setClientTimestamp(timestamp())
.setServerTimestamp(serverTimestamp()) .setServerTimestamp(serverTimestamp())
.setDestinationUuid(destinationUuid().toServiceIdentifierString()) .setDestinationServiceId(destinationUuid().toServiceIdentifierString())
.setServerGuid(guid().toString()) .setServerGuid(guid().toString())
.setStory(story) .setStory(story)
.setUrgent(urgent); .setUrgent(urgent);
if (sourceUuid() != null) { if (sourceUuid() != null) {
builder.setSourceUuid(sourceUuid().toServiceIdentifierString()); builder.setSourceServiceId(sourceUuid().toServiceIdentifierString());
builder.setSourceDevice(sourceDevice()); builder.setSourceDevice(sourceDevice());
} }
@ -72,10 +72,10 @@ public record OutgoingMessageEntity(UUID guid,
return new OutgoingMessageEntity( return new OutgoingMessageEntity(
UUID.fromString(envelope.getServerGuid()), UUID.fromString(envelope.getServerGuid()),
envelope.getType().getNumber(), envelope.getType().getNumber(),
envelope.getTimestamp(), envelope.getClientTimestamp(),
envelope.hasSourceUuid() ? ServiceIdentifier.valueOf(envelope.getSourceUuid()) : null, envelope.hasSourceServiceId() ? ServiceIdentifier.valueOf(envelope.getSourceServiceId()) : null,
envelope.getSourceDevice(), envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? ServiceIdentifier.valueOf(envelope.getDestinationUuid()) : null, envelope.hasDestinationServiceId() ? ServiceIdentifier.valueOf(envelope.getDestinationServiceId()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.getContent().toByteArray(), envelope.getContent().toByteArray(),
envelope.getServerTimestamp(), envelope.getServerTimestamp(),

View File

@ -50,11 +50,11 @@ public final class MessageMetrics {
public void measureAccountEnvelopeUuidMismatches(final Account account, public void measureAccountEnvelopeUuidMismatches(final Account account,
final MessageProtos.Envelope envelope) { final MessageProtos.Envelope envelope) {
if (envelope.hasDestinationUuid()) { if (envelope.hasDestinationServiceId()) {
try { try {
measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationUuid())); measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationServiceId()));
} catch (final IllegalArgumentException ignored) { } catch (final IllegalArgumentException ignored) {
logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid()); logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationServiceId());
} }
} }
} }

View File

@ -92,7 +92,7 @@ public class MessageSender {
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent), CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()), URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()), STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceUuid())) SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment(); .increment();
} }
} }

View File

@ -45,10 +45,10 @@ public class ReceiptSender {
destinationAccount -> { destinationAccount -> {
final Envelope.Builder message = Envelope.newBuilder() final Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setSourceUuid(sourceIdentifier.toServiceIdentifierString()) .setSourceServiceId(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice((int) sourceDeviceId) .setSourceDevice(sourceDeviceId)
.setDestinationUuid(destinationIdentifier.toServiceIdentifierString()) .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setTimestamp(messageId) .setClientTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT) .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
.setUrgent(false); .setUrgent(false);

View File

@ -138,12 +138,13 @@ public class ChangeNumberManager {
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = System.currentTimeMillis();
final Envelope envelope = Envelope.newBuilder() final Envelope envelope = Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type())) .setType(Envelope.Type.forNumber(message.type()))
.setTimestamp(serverTimestamp) .setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setDestinationUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) .setDestinationServiceId(
new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setContent(ByteString.copyFrom(contents.get())) .setContent(ByteString.copyFrom(contents.get()))
.setSourceUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) .setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice((int) Device.PRIMARY_ID) .setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString()) .setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.setUrgent(true) .setUrgent(true)
.build(); .build();

View File

@ -8,10 +8,10 @@ package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ScoredValue; import io.lettuce.core.ScoredValue;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.ZAddArgs; import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
@ -20,6 +20,7 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
@ -38,14 +39,17 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -57,6 +61,62 @@ import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
/**
* Manages short-term storage of messages in Redis. Messages are frequently delivered to their destination and deleted
* shortly after they reach the server, and this cache acts as a low-latency holding area for new messages, reducing
* load on higher-latency, longer-term storage systems. Redis in particular provides keyspace notifications, which act
* as a form of pub-sub notifications to alert listeners when new messages arrive.
* <p>
* The following structures are used:
* <dl>
* <dt>{@code queueKey}</code></dt>
* <dd>A sorted set of messages in a devices queue. A messages score is its queue-local message ID. See
* <a href="https://redis.io/docs/latest/develop/use/patterns/twitter-clone/#the-sorted-set-data-type">Redis.io: The
* Sorted Set data type</a> for background on scores and this data structure.</dd>
* <dt>{@code queueMetadataKey}</dt>
* <dd>A hash containing message guids and their queue-local message ID. It also contains a {@code counter} key, which is
* incremented to supply the next message ID. This is used to remove a message by GUID from {@code queueKey} by its
* local messageId.</dd>
* <dt>{@code sharedMrmKey}</dt>
* <dd>A hash containing a single multi-recipient message pending delivery. It contains:
* <ul>
* <li>{@code data} - the serialized SealedSenderMultiRecipientMessage data</li>
* <li>fields with each recipient device's view into the payload ({@link SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)}</li>
* </ul>
* Note: this is shared among all of the message's recipients, and it may be located in any Redis shard. As each recipients
* message is delivered, its corresponding view is idempotently removed. When {@code data} is the only remaining
* field, the hash will be deleted.
* </dd>
* <dt>{@code queueLockKey}</dt>
* <dd>Used to indicate that a queue is being modified by the {@link MessagePersister} and that {@code get_items} should
* return an empty list.</dd>
* <dt>{@code queueTotalIndexKey}</dt>
* <dd>A sorted set of all queues in a shard. A queues score is the timestamp of its oldest message, which is used by
* the {@link MessagePersister} to prioritize queues to persist.</dd>
* </dl>
* <p>
* At a high level, the process is:
* <ol>
* <li>Insert: the queue metadata is queried for the next incremented message ID. The message data is inserted into
* the queue at that ID, and the message GUID is inserted in the queue metadata.</li>
* <li>Get: a batch of messages are retrieved from the queue, potentially with an after-message-ID offset.</li>
* <li>Remove: a set of messages are remove by GUID. For each GUID, the message ID is retrieved from the queue metadata,
* and then that single-value range is removed from the queue.</li>
* </ol>
* For multi-recipient messages (sometimes abbreviated MRM), there are similar operations on the common data during
* insert, get, and remove. MRM inserts must occur before individual queue inserts, while removal is considered
* best-effort, and uses key expiration as back-stop garbage collection.
* <p>
* For atomicity, many operations are implemented as Lua scripts that are executed on the Redis server using
* {@code EVAL}/{@code EVALSHA}.
*
* @see MessagesCacheInsertScript
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
* @see MessagesCacheGetItemsScript
* @see MessagesCacheRemoveByGuidScript
* @see MessagesCacheRemoveRecipientViewFromMrmDataScript
* @see MessagesCacheRemoveQueueScript
*/
public class MessagesCache extends RedisClusterPubSubAdapter<String, String> implements Managed { public class MessagesCache extends RedisClusterPubSubAdapter<String, String> implements Managed {
private final FaultTolerantRedisCluster redisCluster; private final FaultTolerantRedisCluster redisCluster;
@ -69,17 +129,22 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
// messageDeletionExecutorService wrapped into a reactor Scheduler // messageDeletionExecutorService wrapped into a reactor Scheduler
private final Scheduler messageDeletionScheduler; private final Scheduler messageDeletionScheduler;
private final ClusterLuaScript insertScript; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript; private final MessagesCacheInsertScript insertScript;
private final ClusterLuaScript removeQueueScript; private final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript;
private final ClusterLuaScript getQueuesToPersistScript; private final MessagesCacheRemoveByGuidScript removeByGuidScript;
private final MessagesCacheGetItemsScript getItemsScript;
private final MessagesCacheRemoveQueueScript removeQueueScript;
private final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript;
private final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript;
private final ReentrantLock messageListenersLock = new ReentrantLock(); private final ReentrantLock messageListenersLock = new ReentrantLock();
private final Map<String, MessageAvailabilityListener> messageListenersByQueueName = new HashMap<>(); private final Map<String, MessageAvailabilityListener> messageListenersByQueueName = new HashMap<>();
private final Map<MessageAvailabilityListener, String> queueNamesByMessageListener = new IdentityHashMap<>(); private final Map<MessageAvailabilityListener, String> queueNamesByMessageListener = new IdentityHashMap<>();
private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert")); private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert"));
private final Timer insertSharedMrmPayloadTimer = Metrics.timer(name(MessagesCache.class, "insertSharedMrmPayload"));
private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get"));
private final Timer getQueuesToPersistTimer = Metrics.timer(name(MessagesCache.class, "getQueuesToPersist")); private final Timer getQueuesToPersistTimer = Metrics.timer(name(MessagesCache.class, "getQueuesToPersist"));
private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid")); private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid"));
@ -95,6 +160,9 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
name(MessagesCache.class, "messageAvailabilityListenerRemovedAfterAdd")); name(MessagesCache.class, "messageAvailabilityListenerRemovedAfterAdd"));
private final Counter prunedStaleSubscriptionCounter = Metrics.counter( private final Counter prunedStaleSubscriptionCounter = Metrics.counter(
name(MessagesCache.class, "prunedStaleSubscription")); name(MessagesCache.class, "prunedStaleSubscription"));
private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved"));
private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter(
name(MessagesCache.class, "sharedMrmKeyRemoved"));
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
@ -102,16 +170,49 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::"; private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::";
private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::"; private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::";
private static final String MRM_VIEWS_EXPERIMENT_NAME = "mrmViews";
@VisibleForTesting @VisibleForTesting
static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
private static final String GET_FLUX_NAME = MetricsUtil.name(MessagesCache.class, "get"); private static final String GET_FLUX_NAME = MetricsUtil.name(MessagesCache.class, "get");
private static final int PAGE_SIZE = 100; private static final int PAGE_SIZE = 100;
private static final int REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY = 8;
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class); private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
public MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService, public MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService,
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock) final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager)
throws IOException {
this(
redisCluster,
notificationExecutorService,
messageDeliveryScheduler,
messageDeletionExecutorService,
clock,
dynamicConfigurationManager,
new MessagesCacheInsertScript(redisCluster),
new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(redisCluster),
new MessagesCacheGetItemsScript(redisCluster),
new MessagesCacheRemoveByGuidScript(redisCluster),
new MessagesCacheRemoveQueueScript(redisCluster),
new MessagesCacheGetQueuesToPersistScript(redisCluster),
new MessagesCacheRemoveRecipientViewFromMrmDataScript(redisCluster)
);
}
@VisibleForTesting
MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService,
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final MessagesCacheInsertScript insertScript,
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript,
final MessagesCacheGetItemsScript getItemsScript, final MessagesCacheRemoveByGuidScript removeByGuidScript,
final MessagesCacheRemoveQueueScript removeQueueScript,
final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript,
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript)
throws IOException { throws IOException {
this.redisCluster = redisCluster; this.redisCluster = redisCluster;
@ -123,14 +224,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
this.messageDeletionExecutorService = messageDeletionExecutorService; this.messageDeletionExecutorService = messageDeletionExecutorService;
this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion"); this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion");
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); this.dynamicConfigurationManager = dynamicConfigurationManager;
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua",
ScriptOutputType.MULTI); this.insertScript = insertScript;
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); this.insertMrmScript = insertMrmScript;
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", this.removeByGuidScript = removeByGuidScript;
ScriptOutputType.STATUS); this.getItemsScript = getItemsScript;
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", this.removeQueueScript = removeQueueScript;
ScriptOutputType.MULTI); this.getQueuesToPersistScript = getQueuesToPersistScript;
this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript;
} }
@Override @Override
@ -164,51 +266,51 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice, public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice,
final MessageProtos.Envelope message) { final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
return (long) insertTimer.record(() -> return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid));
insertScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getServerTimestamp()).getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8))));
} }
public CompletableFuture<Optional<MessageProtos.Envelope>> remove(final UUID destinationUuid, public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid,
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid);
insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage));
return sharedMrmKey;
}
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice, final byte destinationDevice,
final UUID messageGuid) { final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid)) return remove(destinationUuid, destinationDevice, List.of(messageGuid))
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.get(0))); .thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst()));
} }
@SuppressWarnings("unchecked") public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid,
public CompletableFuture<List<MessageProtos.Envelope>> remove(final UUID destinationUuid, final byte destinationDevice, final List<UUID> messageGuids) {
final byte destinationDevice,
final List<UUID> messageGuids) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice), return removeByGuidScript.execute(destinationUuid, destinationDevice, messageGuids)
getMessageQueueMetadataKey(destinationUuid, destinationDevice), .thenApplyAsync(serialized -> {
getQueueIndexKey(destinationUuid, destinationDevice)),
messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.collect(Collectors.toList()))
.thenApplyAsync(result -> {
List<byte[]> serialized = (List<byte[]>) result;
final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size()); final List<RemovedMessage> removedMessages = new ArrayList<>(serialized.size());
final List<byte[]> sharedMrmKeysToUpdate = new ArrayList<>();
for (final byte[] bytes : serialized) { for (final byte[] bytes : serialized) {
try { try {
removedMessages.add(MessageProtos.Envelope.parseFrom(bytes)); final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes);
removedMessages.add(RemovedMessage.fromEnvelope(envelope));
if (envelope.hasSharedMrmKey()) {
sharedMrmKeysToUpdate.add(envelope.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) { } catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e); logger.warn("Failed to parse envelope", e);
} }
} }
removeRecipientViewFromMrmData(sharedMrmKeysToUpdate, destinationUuid, destinationDevice);
return removedMessages; return removedMessages;
}, messageDeletionExecutorService) }, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
.whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
} }
public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) { public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) {
@ -251,7 +353,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message, private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message,
long earliestAllowableTimestamp) { long earliestAllowableTimestamp) {
return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp; return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp;
} }
private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice, private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice,
@ -283,37 +385,101 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
// we want to ensure we dont accidentally block the Lettuce/netty i/o executors // we want to ensure we dont accidentally block the Lettuce/netty i/o executors
.publishOn(messageDeliveryScheduler) .publishOn(messageDeliveryScheduler)
.map(Pair::first) .map(Pair::first)
.flatMapIterable(queueItems -> { .concatMap(queueItems -> {
final List<MessageProtos.Envelope> envelopes = new ArrayList<>(queueItems.size() / 2);
final List<Mono<MessageProtos.Envelope>> envelopes = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size() - 1; i += 2) { for (int i = 0; i < queueItems.size() - 1; i += 2) {
try { try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i)); final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
envelopes.add(message); final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) {
maybeRunMrmViewExperiment(message, destinationUuid, destinationDevice);
// mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content
messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build());
} else {
messageMono = Mono.just(message);
}
envelopes.add(messageMono);
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e); logger.warn("Failed to parse envelope", e);
} }
} }
return envelopes; return Flux.mergeSequential(envelopes);
}); });
} }
private Flux<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, /**
* Runs the fetch and compare logic for the MRM view experiment, if it is enabled.
*
* @see DynamicMessagesConfiguration#mrmViewExperimentEnabled()
*/
private void maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessage, final UUID destinationUuid,
final byte destinationDevice) {
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration()
.mrmViewExperimentEnabled()) {
final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);
final byte[] key = mrmMessage.getSharedMrmKey().toByteArray();
final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey(
new AciServiceIdentifier(destinationUuid), destinationDevice);
final Mono<MessageProtos.Envelope> mrmMessageMono = Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList()
.publishOn(messageDeliveryScheduler)
.handle((mrmDataAndView, sink) -> {
try {
assert mrmDataAndView.size() == 2;
final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient(
mrmDataAndView.getFirst().getValue(),
mrmDataAndView.getLast().getValue());
sink.next(mrmMessage.toBuilder()
.clearSharedMrmKey()
.setContent(ByteString.copyFrom(content))
.build());
mrmContentRetrievedCounter.increment();
} catch (Exception e) {
sink.error(e);
}
})));
experiment.compareFutureResult(mrmMessage.toBuilder().clearSharedMrmKey().build(),
mrmMessageMono.toFuture());
}
}
/**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/
private void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid,
final byte deviceId) {
Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(
keys -> removeRecipientViewFromMrmDataScript.execute(keys, new AciServiceIdentifier(accountUuid), deviceId),
REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY)
.subscribe(sharedMrmDataKeyRemovedCounter::increment, e -> logger.warn("Error removing recipient view", e));
}
private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid,
final byte destinationDevice,
long messageId) { long messageId) {
return getItemsScript.executeBinaryReactive( return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId)
List.of(getMessageQueueKey(destinationUuid, destinationDevice), .map(queueItems -> {
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8),
String.valueOf(messageId).getBytes(StandardCharsets.UTF_8)))
.map(result -> {
logger.trace("Processing page: {}", messageId); logger.trace("Processing page: {}", messageId);
@SuppressWarnings("unchecked")
List<byte[]> queueItems = (List<byte[]>) result;
if (queueItems.isEmpty()) { if (queueItems.isEmpty()) {
return new Pair<>(Collections.emptyList(), null); return new Pair<>(Collections.emptyList(), null);
} }
@ -324,7 +490,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
} }
final long lastMessageId = Long.parseLong( final long lastMessageId = Long.parseLong(
new String(queueItems.get(queueItems.size() - 1), StandardCharsets.UTF_8)); new String(queueItems.getLast(), StandardCharsets.UTF_8));
return new Pair<>(queueItems, lastMessageId); return new Pair<>(queueItems, lastMessageId);
}); });
@ -362,10 +528,35 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
public CompletableFuture<Void> clear(final UUID destinationUuid, final byte deviceId) { public CompletableFuture<Void> clear(final UUID destinationUuid, final byte deviceId) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
return removeQueueScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, deviceId), return removeQueueScript.execute(destinationUuid, deviceId, Collections.emptyList())
getMessageQueueMetadataKey(destinationUuid, deviceId), .publishOn(messageDeletionScheduler)
getQueueIndexKey(destinationUuid, deviceId)), .expand(messagesToProcess -> {
Collections.emptyList()) if (messagesToProcess.isEmpty()) {
return Mono.empty();
}
final List<byte[]> mrmKeys = new ArrayList<>(messagesToProcess.size());
final List<String> processedMessages = new ArrayList<>(messagesToProcess.size());
for (byte[] serialized : messagesToProcess) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(serialized);
processedMessages.add(message.getServerGuid());
if (message.hasSharedMrmKey()) {
mrmKeys.add(message.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
removeRecipientViewFromMrmData(mrmKeys, destinationUuid, deviceId);
return removeQueueScript.execute(destinationUuid, deviceId, processedMessages);
})
.then()
.toFuture()
.thenRun(() -> sample.stop(clearQueueTimer)); .thenRun(() -> sample.stop(clearQueueTimer));
} }
@ -375,11 +566,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
} }
List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) { List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) {
//noinspection unchecked return getQueuesToPersistTimer.record(() -> getQueuesToPersistScript.execute(slot, maxTime, limit));
return getQueuesToPersistTimer.record(() -> (List<String>) getQueuesToPersistScript.execute(
List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)),
List.of(String.valueOf(maxTime.toEpochMilli()),
String.valueOf(limit))));
} }
void addQueueToPersist(final UUID accountUuid, final byte deviceId) { void addQueueToPersist(final UUID accountUuid, final byte deviceId) {
@ -538,29 +725,36 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return channel.substring(startOfHashTag + 1, endOfHashTag); return channel.substring(startOfHashTag + 1, endOfHashTag);
} }
@VisibleForTesting
static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) { static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
} }
private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) { static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
} }
private static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) { static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) {
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId)); return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
} }
private static byte[] getQueueIndexKey(final int slot) { static byte[] getQueueIndexKey(final int slot) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8); return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
} }
private static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) { static byte[] getSharedMrmKey(final UUID mrmGuid) {
return ("mrm::{" + mrmGuid.toString() + "}").getBytes(StandardCharsets.UTF_8);
}
static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
} }
private static byte[] getUnlinkInProgressKey(final UUID accountUuid) { static byte[] getSharedMrmViewKey(final AciServiceIdentifier serviceIdentifier, final byte deviceId) {
return ("user_account_unlinking::{" + accountUuid + "}").getBytes(StandardCharsets.UTF_8); final ByteBuffer keyBb = ByteBuffer.allocate(18);
keyBb.put(serviceIdentifier.toFixedWidthByteArray());
keyBb.put(deviceId);
assert !keyBb.hasRemaining();
return keyBb.array();
} }
static UUID getAccountUuidFromQueueName(final String queueName) { static UUID getAccountUuidFromQueueName(final String queueName) {

View File

@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Retrieves a list of messages and their corresponding queue-local IDs for the device. To support streaming processing,
* the last queue-local message ID from a previous call may be used as the {@code afterMessageId}.
*/
class MessagesCacheGetItemsScript {
private final ClusterLuaScript getItemsScript;
MessagesCacheGetItemsScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.OBJECT);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
int limit, long afterMessageId) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getPersistInProgressKey(destinationUuid, destinationDevice) // queueLockKey
);
final List<byte[]> args = List.of(
String.valueOf(limit).getBytes(StandardCharsets.UTF_8), // limit
String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8) // afterMessageId
);
//noinspection unchecked
return getItemsScript.executeBinaryReactive(keys, args)
.map(result -> (List<byte[]>) result)
.next();
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.List;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Returns a list of queues that may be persisted. They will be sorted from oldest to more recent, limited by the
* {@code maxTime} argument.
*
* @see MessagePersister
*/
class MessagesCacheGetQueuesToPersistScript {
private final ClusterLuaScript getQueuesToPersistScript;
MessagesCacheGetQueuesToPersistScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua",
ScriptOutputType.MULTI);
}
List<String> execute(final int slot, final Instant maxTime, final int limit) {
final List<String> keys = List.of(
new String(MessagesCache.getQueueIndexKey(slot), StandardCharsets.UTF_8) // queueTotalIndexKey
);
final List<String> args = List.of(
String.valueOf(maxTime.toEpochMilli()), // maxTime
String.valueOf(limit) // limit
);
//noinspection unchecked
return (List<String>) getQueuesToPersistScript.execute(keys, args);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Inserts an envelope into the message queue for a destination device.
*/
class MessagesCacheInsertScript {
private final ClusterLuaScript insertScript;
MessagesCacheInsertScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
}
long execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
assert envelope.hasServerGuid();
assert envelope.hasServerTimestamp();
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = new ArrayList<>(Arrays.asList(
envelope.toByteArray(), // message
String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime
envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid
));
return (long) insertScript.executeBinary(keys, args);
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as
* fields in the hash.
*
* @see SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)
*/
class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
private final ClusterLuaScript script;
MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster)
throws IOException {
this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua",
ScriptOutputType.INTEGER);
}
void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) {
final List<byte[]> keys = List.of(
sharedMrmKey // sharedMrmKey
);
// Pre-allocate capacity for the most fields we expect -- 6 devices per recipient, plus the data field.
final List<byte[]> args = new ArrayList<>(message.getRecipients().size() * 6 + 1);
args.add(message.serialized());
message.getRecipients().forEach((serviceId, recipient) -> {
for (byte device : recipient.getDevices()) {
final byte[] key = new byte[18];
System.arraycopy(serviceId.toServiceIdFixedWidthBinary(), 0, key, 0, 17);
key[17] = device;
args.add(key);
args.add(message.serializedRecipientView(recipient));
}
});
script.executeBinary(keys, args);
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Removes a list of message GUIDs from the queue of a destination device.
*/
class MessagesCacheRemoveByGuidScript {
private final ClusterLuaScript removeByGuidScript;
MessagesCacheRemoveByGuidScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua",
ScriptOutputType.OBJECT);
}
CompletableFuture<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
final List<UUID> messageGuids) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.toList();
//noinspection unchecked
return removeByGuidScript.executeBinaryAsync(keys, args)
.thenApply(result -> (List<byte[]>) result);
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Removes a device's queue from the cache. For a non-empty queue, this script must be executed multiple times.
* <ol>
* <li>The first call will return a list of messages to check for {@code sharedMrmKeys}. If a {@code sharedMrmKey} is present, {@link MessagesCacheRemoveRecipientViewFromMrmDataScript} must be called.</li>
* <li>Once theses messages have been processed, this script should be called again, confirming that the messages have been processed.</li>
* <li>This should be repeated until the script returns an empty list, as the script only returns a page ({@value PAGE_SIZE}) of messages at a time.</li>
* </ol>
*/
class MessagesCacheRemoveQueueScript {
private static final int PAGE_SIZE = 100;
private final ClusterLuaScript removeQueueScript;
MessagesCacheRemoveQueueScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua",
ScriptOutputType.MULTI);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
final List<String> processedMessageGuids) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = new ArrayList<>();
args.addFirst(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8)); // limit
args.addAll(processedMessageGuids.stream().map(guid -> guid.getBytes(StandardCharsets.UTF_8))
.toList()); // processedMessageGuids
//noinspection unchecked
return removeQueueScript.executeBinaryReactive(keys, args)
.map(result -> (List<byte[]>) result)
.next();
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Removes the given destination device from the given {@code sharedMrmKeys}. If there are no devices remaining in the
* hash as a result, the shared payload is deleted.
* <p>
* NOTE: Callers are responsible for ensuring that all keys are in the same slot.
*/
class MessagesCacheRemoveRecipientViewFromMrmDataScript {
private final ClusterLuaScript removeRecipientViewFromMrmDataScript;
MessagesCacheRemoveRecipientViewFromMrmDataScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeRecipientViewFromMrmDataScript = ClusterLuaScript.fromResource(redisCluster,
"lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER);
}
Mono<Long> execute(final Collection<byte[]> keysCollection, final AciServiceIdentifier serviceIdentifier,
final byte deviceId) {
final List<byte[]> keys = keysCollection instanceof List<byte[]>
? (List<byte[]>) keysCollection
: new ArrayList<>(keysCollection);
return removeRecipientViewFromMrmDataScript.executeBinaryReactive(keys,
List.of(MessagesCache.getSharedMrmViewKey(serviceIdentifier, deviceId)))
.map(o -> (long) o)
.next();
}
}

View File

@ -19,8 +19,10 @@ import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -62,8 +64,8 @@ public class MessagesManager {
messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message); messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message);
if (message.hasSourceUuid() && !destinationUuid.toString().equals(message.getSourceUuid())) { if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) {
reportMessageManager.store(message.getSourceUuid(), messageGuid); reportMessageManager.store(message.getSourceServiceId(), messageGuid);
} }
} }
@ -137,7 +139,7 @@ public class MessagesManager {
return messagesCache.clear(destinationUuid, deviceId); return messagesCache.clear(destinationUuid, deviceId);
} }
public CompletableFuture<Optional<Envelope>> delete(UUID destinationUuid, Device destinationDevice, UUID guid, public CompletableFuture<Optional<RemovedMessage>> delete(UUID destinationUuid, Device destinationDevice, UUID guid,
@Nullable Long serverTimestamp) { @Nullable Long serverTimestamp) {
return messagesCache.remove(destinationUuid, destinationDevice.getId(), guid) return messagesCache.remove(destinationUuid, destinationDevice.getId(), guid)
.thenComposeAsync(removed -> { .thenComposeAsync(removed -> {
@ -146,12 +148,16 @@ public class MessagesManager {
return CompletableFuture.completedFuture(removed); return CompletableFuture.completedFuture(removed);
} }
final CompletableFuture<Optional<MessageProtos.Envelope>> maybeDeletedEnvelope;
if (serverTimestamp == null) { if (serverTimestamp == null) {
return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, destinationDevice, guid); maybeDeletedEnvelope = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid,
destinationDevice, guid);
} else { } else {
return messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid, serverTimestamp); maybeDeletedEnvelope = messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid,
serverTimestamp);
} }
return maybeDeletedEnvelope.thenApply(maybeEnvelope -> maybeEnvelope.map(RemovedMessage::fromEnvelope));
}, messageDeletionExecutor); }, messageDeletionExecutor);
} }
@ -194,4 +200,14 @@ public class MessagesManager {
messagesCache.removeMessageAvailabilityListener(listener); messagesCache.removeMessageAvailabilityListener(listener);
} }
/**
* Inserts the shared multi-recipient message payload to storage.
*
* @return a key where the shared data is stored
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
*/
public byte[] insertSharedMultiRecipientMessagePayload(
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage);
}
} }

View File

@ -0,0 +1,30 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.Optional;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public record RemovedMessage(Optional<ServiceIdentifier> sourceServiceId, ServiceIdentifier destinationServiceId,
@VisibleForTesting UUID serverGuid, long serverTimestamp, long clientTimestamp,
MessageProtos.Envelope.Type envelopeType) {
public static RemovedMessage fromEnvelope(MessageProtos.Envelope envelope) {
return new RemovedMessage(
envelope.hasSourceServiceId()
? Optional.of(ServiceIdentifier.valueOf(envelope.getSourceServiceId()))
: Optional.empty(),
ServiceIdentifier.valueOf(envelope.getDestinationServiceId()),
UUID.fromString(envelope.getServerGuid()),
envelope.getServerTimestamp(),
envelope.getClientTimestamp(),
envelope.getType()
);
}
}

View File

@ -294,16 +294,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
} }
private void sendDeliveryReceiptFor(Envelope message) { private void sendDeliveryReceiptFor(Envelope message) {
if (!message.hasSourceUuid()) { if (!message.hasSourceServiceId()) {
return; return;
} }
try { try {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationUuid()), receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceUuid()), auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getTimestamp()); message.getClientTimestamp());
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceUuid()); logger.error("Could not parse UUID: {}", message.getSourceServiceId());
} catch (Exception e) { } catch (Exception e) {
logger.warn("Failed to send receipt", e); logger.warn("Failed to send receipt", e);
} }

View File

@ -205,7 +205,7 @@ record CommandDependencies(
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
recurringJobExecutor, keyspaceNotificationDispatchExecutor); recurringJobExecutor, keyspaceNotificationDispatchExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor, MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC()); messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient,
configuration.getDynamoDbTables().getReportMessage().getTableName(), configuration.getDynamoDbTables().getReportMessage().getTableName(),

View File

@ -11,30 +11,31 @@ option java_outer_classname = "MessageProtos";
message Envelope { message Envelope {
enum Type { enum Type {
UNKNOWN = 0; UNKNOWN = 0;
CIPHERTEXT = 1; CIPHERTEXT = 1;
KEY_EXCHANGE = 2; KEY_EXCHANGE = 2;
PREKEY_BUNDLE = 3; PREKEY_BUNDLE = 3;
SERVER_DELIVERY_RECEIPT = 5; SERVER_DELIVERY_RECEIPT = 5;
UNIDENTIFIED_SENDER = 6; UNIDENTIFIED_SENDER = 6;
reserved 7; reserved 7;
PLAINTEXT_CONTENT = 8; // for decryption error receipts PLAINTEXT_CONTENT = 8; // for decryption error receipts
} }
optional Type type = 1; optional Type type = 1;
optional string source_uuid = 11; optional string source_service_id = 11;
optional uint32 source_device = 7; optional uint32 source_device = 7;
optional uint64 timestamp = 5; optional uint64 client_timestamp = 5;
optional bytes content = 8; // Contains an encrypted Content optional bytes content = 8; // Contains an encrypted Content
optional string server_guid = 9; optional string server_guid = 9;
optional uint64 server_timestamp = 10; optional uint64 server_timestamp = 10;
optional bool ephemeral = 12; // indicates that the message should not be persisted if the recipient is offline optional bool ephemeral = 12; // indicates that the message should not be persisted if the recipient is offline
optional string destination_uuid = 13; optional string destination_service_id = 13;
optional bool urgent = 14 [default=true]; optional bool urgent = 14 [default=true];
optional string updated_pni = 15; optional string updated_pni = 15;
optional bool story = 16; // indicates that the content is a story. optional bool story = 16; // indicates that the content is a story.
optional bytes report_spam_token = 17; // token sent when reporting spam optional bytes report_spam_token = 17; // token sent when reporting spam
// next: 18 optional bytes shared_mrm_key = 18; // indicates content should be fetched from multi-recipient message datastore
// next: 19
} }
message ProvisioningUuid { message ProvisioningUuid {
@ -42,25 +43,25 @@ message ProvisioningUuid {
} }
message ServerCertificate { message ServerCertificate {
message Certificate { message Certificate {
optional uint32 id = 1; optional uint32 id = 1;
optional bytes key = 2; optional bytes key = 2;
} }
optional bytes certificate = 1; optional bytes certificate = 1;
optional bytes signature = 2; optional bytes signature = 2;
} }
message SenderCertificate { message SenderCertificate {
message Certificate { message Certificate {
optional string sender = 1; optional string sender = 1;
optional string sender_uuid = 6; optional string sender_uuid = 6;
optional uint32 sender_device = 2; optional uint32 sender_device = 2;
optional fixed64 expires = 3; optional fixed64 expires = 3;
optional bytes identity_key = 4; optional bytes identity_key = 4;
optional ServerCertificate signer = 5; optional ServerCertificate signer = 5;
} }
optional bytes certificate = 1; optional bytes certificate = 1;
optional bytes signature = 2; optional bytes signature = 2;
} }

View File

@ -46,7 +46,7 @@ local getNextInterval = function(interval)
end end
local results = redis.call("ZRANGEBYSCORE", pendingNotificationQueue, 0, maxTime, "LIMIT", 0, limit) local results = redis.call("ZRANGE", pendingNotificationQueue, 0, maxTime, "BYSCORE", "LIMIT", 0, limit)
local collated = {} local collated = {}
if results and next(results) then if results and next(results) then

View File

@ -1,7 +1,10 @@
local queueKey = KEYS[1] -- gets messages from a device's queue, up to a given limit
local queueLockKey = KEYS[2] -- returns a list of all envelopes and their queue-local IDs
local limit = ARGV[1]
local afterMessageId = ARGV[2] local queueKey = KEYS[1] -- sorted set of all Envelopes for a device, scored by queue-local ID
local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persistent and must not be read
local limit = ARGV[1] -- [number] the maximum number of messages to return
local afterMessageId = ARGV[2] -- [number] a queue-local ID to exclusively start after, to support pagination. Use -1 to start at the beginning
local locked = redis.call("GET", queueLockKey) local locked = redis.call("GET", queueLockKey)
@ -9,17 +12,8 @@ if locked then
return {} return {}
end end
if afterMessageId == "null" then if afterMessageId == "null" or afterMessageId == nil then
-- An index range is inclusive return redis.error_reply("ERR afterMessageId is required")
local min = 0
local max = limit - 1
if max < 0 then
return {}
end
return redis.call("ZRANGE", queueKey, min, max, "WITHSCORES")
else
-- note: this is deprecated in Redis 6.2, and should be migrated to zrange after the cluster is updated
return redis.call("ZRANGEBYSCORE", queueKey, "("..afterMessageId, "+inf", "WITHSCORES", "LIMIT", 0, limit)
end end
return redis.call("ZRANGE", queueKey, "("..afterMessageId, "+inf", "BYSCORE", "LIMIT", 0, limit, "WITHSCORES")

View File

@ -1,8 +1,10 @@
local queueTotalIndexKey = KEYS[1] -- returns a list of queues that meet persistence criteria
local maxTime = ARGV[1]
local limit = ARGV[2]
local results = redis.call("ZRANGEBYSCORE", queueTotalIndexKey, 0, maxTime, "LIMIT", 0, limit) local queueTotalIndexKey = KEYS[1] -- sorted set of all queues in the shard, by timestamp of oldest message
local maxTime = ARGV[1] -- [number] the most recent queue timestamp that may be fetched
local limit = ARGV[2] -- [number] the maximum number of queues to fetch
local results = redis.call("ZRANGE", queueTotalIndexKey, 0, maxTime, "BYSCORE", "LIMIT", 0, limit)
if results and next(results) then if results and next(results) then
redis.call("ZREM", queueTotalIndexKey, unpack(results)) redis.call("ZREM", queueTotalIndexKey, unpack(results))

View File

@ -1,9 +1,12 @@
local queueKey = KEYS[1] -- inserts a message into a device's queue, and updates relevant associated data
local queueMetadataKey = KEYS[2] -- returns a number, the queue-local message ID (useful for testing)
local queueTotalIndexKey = KEYS[3]
local message = ARGV[1] local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
local currentTime = ARGV[2] local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
local guid = ARGV[3] local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local message = ARGV[1] -- [bytes] the Envelope to insert
local currentTime = ARGV[2] -- [number] the message timestamp, to sort the queue in the queueTotalIndex
local guid = ARGV[3] -- [string] the message GUID
if redis.call("HEXISTS", queueMetadataKey, guid) == 1 then if redis.call("HEXISTS", queueMetadataKey, guid) == 1 then
return tonumber(redis.call("HGET", queueMetadataKey, guid)) return tonumber(redis.call("HGET", queueMetadataKey, guid))
@ -14,9 +17,8 @@ local messageId = redis.call("HINCRBY", queueMetadataKey, "counter", 1)
redis.call("ZADD", queueKey, "NX", messageId, message) redis.call("ZADD", queueKey, "NX", messageId, message)
redis.call("HSET", queueMetadataKey, guid, messageId) redis.call("HSET", queueMetadataKey, guid, messageId)
redis.call("EXPIRE", queueKey, 2678400) -- 31 days
redis.call("EXPIRE", queueKey, 7776000) -- 90 days redis.call("EXPIRE", queueMetadataKey, 2678400) -- 31 days
redis.call("EXPIRE", queueMetadataKey, 7776000) -- 90 days
redis.call("ZADD", queueTotalIndexKey, "NX", currentTime, queueKey) redis.call("ZADD", queueTotalIndexKey, "NX", currentTime, queueKey)
return messageId return messageId

View File

@ -0,0 +1,13 @@
-- inserts shared multi-recipient message data
local sharedMrmKey = KEYS[1] -- [string] the key containing the shared MRM data
local mrmData = ARGV[1] -- [bytes] the serialized multi-recipient message data
-- the remainder of ARGV is list of recipient keys and view data
redis.call("HSET", sharedMrmKey, "data", mrmData);
redis.call("EXPIRE", sharedMrmKey, 604800) -- 7 days
-- unpack() fails with "too many results" at very large table sizes, so we loop
for i = 2, #ARGV, 2 do
redis.call("HSET", sharedMrmKey, ARGV[i], ARGV[i + 1])
end

View File

@ -1,20 +1,26 @@
local queueKey = KEYS[1] -- removes a list of messages by ID from the cluster, returning the deleted messages
local queueMetadataKey = KEYS[2] -- returns a list of removed envelopes
local queueTotalIndexKey = KEYS[3] -- Note: content may be absent for MRM messages, and for these messages, the caller must update the sharedMrmKey
-- to remove the recipient's reference
local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local messageGuids = ARGV -- [list[string]] message GUIDs
local removedMessages = {} local removedMessages = {}
for _, guid in ipairs(ARGV) do for _, guid in ipairs(messageGuids) do
local messageId = redis.call("HGET", queueMetadataKey, guid) local messageId = redis.call("HGET", queueMetadataKey, guid)
if messageId then if messageId then
local envelope = redis.call("ZRANGEBYSCORE", queueKey, messageId, messageId, "LIMIT", 0, 1) local envelope = redis.call("ZRANGE", queueKey, messageId, messageId, "BYSCORE", "LIMIT", 0, 1)
redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId) redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId)
redis.call("HDEL", queueMetadataKey, guid) redis.call("HDEL", queueMetadataKey, guid)
if envelope and next(envelope) then if envelope and next(envelope) then
removedMessages[#removedMessages + 1] = envelope[1] table.insert(removedMessages, envelope[1])
end end
end end
end end

View File

@ -1,7 +1,29 @@
local queueKey = KEYS[1] -- incrementally removes a given device's queue and associated data
local queueMetadataKey = KEYS[2] -- returns: a page of messages and scores.
local queueTotalIndexKey = KEYS[3] -- The messages must be checked for mrmKeys to update. After updating MRM keys, this script must be called again
-- with processedMessageGuids. If the returned table is empty, then
-- the queue has been fully deleted.
redis.call("DEL", queueKey) local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
redis.call("DEL", queueMetadataKey) local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
redis.call("ZREM", queueTotalIndexKey, queueKey) local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local limit = ARGV[1] -- the maximum number of messages to return
local processedMessageGuids = { unpack(ARGV, 2) }
for _, guid in ipairs(processedMessageGuids) do
local messageId = redis.call("HGET", queueMetadataKey, guid)
if messageId then
redis.call("ZREMRANGEBYSCORE", queueKey, messageId, messageId)
redis.call("HDEL", queueMetadataKey, guid)
end
end
local messages = redis.call("ZRANGE", queueKey, 0, limit-1)
if #messages == 0 then
redis.call("DEL", queueKey)
redis.call("DEL", queueMetadataKey)
redis.call("ZREM", queueTotalIndexKey, queueKey)
end
return messages

View File

@ -0,0 +1,17 @@
-- Removes the given recipient view from the shared MRM data. If the only field remaining after the removal is the
-- `data` field, then the key will be deleted
local sharedMrmKeys = KEYS -- KEYS: list of all keys in a single slot to update
local recipientViewToRemove = ARGV[1] -- the recipient view to remove from the hash
local keysDeleted = 0
for _, sharedMrmKey in ipairs(sharedMrmKeys) do
redis.call("HDEL", sharedMrmKey, recipientViewToRemove)
if redis.call("HLEN", sharedMrmKey) == 1 then
redis.call("DEL", sharedMrmKey)
keysDeleted = keysDeleted + 1
end
end
return keysDeleted

View File

@ -87,6 +87,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
@ -121,6 +122,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.RemovedMessage;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -251,6 +253,7 @@ class MessageControllerTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration); when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
@ -311,7 +314,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class); ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid()); assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice()); assertTrue(captor.getValue().hasSourceDevice());
assertTrue(captor.getValue().getUrgent()); assertTrue(captor.getValue().getUrgent());
} }
@ -353,7 +356,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class); ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid()); assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice()); assertTrue(captor.getValue().hasSourceDevice());
assertFalse(captor.getValue().getUrgent()); assertFalse(captor.getValue().getUrgent());
} }
@ -375,7 +378,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class); ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid()); assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice()); assertTrue(captor.getValue().hasSourceDevice());
} }
} }
@ -410,7 +413,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class); ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertFalse(captor.getValue().hasSourceUuid()); assertFalse(captor.getValue().hasSourceServiceId());
assertFalse(captor.getValue().hasSourceDevice()); assertFalse(captor.getValue().hasSourceDevice());
} }
} }
@ -444,7 +447,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) { if (expectedResponse == 200) {
verify(messageSender).sendMessage( verify(messageSender).sendMessage(
any(Account.class), any(Device.class), argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()), any(Account.class), any(Device.class), argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(false)); eq(false));
} else { } else {
verifyNoMoreInteractions(messageSender); verifyNoMoreInteractions(messageSender);
@ -732,23 +735,27 @@ class MessageControllerTest {
@Test @Test
void testDeleteMessages() { void testDeleteMessages() {
long timestamp = System.currentTimeMillis(); long clientTimestamp = System.currentTimeMillis();
UUID sourceUuid = UUID.randomUUID(); UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID(); UUID uuid1 = UUID.randomUUID();
final long serverTimestamp = 0;
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid1, null)) when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid1, null))
.thenReturn( .thenReturn(
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE, CompletableFutureTestUtil.almostCompletedFuture(Optional.of(
timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)))); new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)),
new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid1, serverTimestamp, clientTimestamp,
Envelope.Type.CIPHERTEXT))));
UUID uuid2 = UUID.randomUUID(); UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid2, null)) when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid2, null))
.thenReturn( .thenReturn(
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope( CompletableFutureTestUtil.almostCompletedFuture(Optional.of(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)),
System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0)))); new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid2, serverTimestamp, clientTimestamp,
Envelope.Type.SERVER_DELIVERY_RECEIPT))));
UUID uuid3 = UUID.randomUUID(); UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid3, null)) when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid3, null))
@ -766,7 +773,7 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1), verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1),
eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp)); eq(new AciServiceIdentifier(sourceUuid)), eq(clientTimestamp));
} }
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
@ -1068,9 +1075,16 @@ class MessageControllerTest {
} }
private record Recipient(ServiceIdentifier uuid, private record Recipient(ServiceIdentifier uuid,
byte deviceId, Byte[] deviceId,
int registrationId, Integer[] registrationId,
byte[] perRecipientKeyMaterial) { byte[] perRecipientKeyMaterial) {
Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
this(uuid, new Byte[]{deviceId}, new Integer[]{registrationId}, perRecipientKeyMaterial);
}
} }
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
@ -1081,8 +1095,13 @@ class MessageControllerTest {
bb.put(UUIDUtil.toBytes(r.uuid().uuid())); bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
} }
bb.put(r.deviceId()); // device id (1 byte) assert (r.deviceId.length == r.registrationId.length);
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
for (int i = 0; i < r.deviceId.length; i++) {
final int hasMore = i == r.deviceId.length - 1 ? 0 : 0x8000;
bb.put(r.deviceId()[i]); // device id (1 byte)
bb.putShort((short) (r.registrationId()[i] | hasMore)); // registration id (2 bytes)
}
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
} }
@ -1157,7 +1176,7 @@ class MessageControllerTest {
.queryParam("story", true) .queryParam("story", true)
.queryParam("urgent", false) .queryParam("urgent", false)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.put(entity)) { .put(entity)) {
assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200))); assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200)));
@ -1206,7 +1225,7 @@ class MessageControllerTest {
.queryParam("story", isStory) .queryParam("story", isStory)
.queryParam("urgent", urgent) .queryParam("urgent", urgent)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader)
.put(entity)) { .put(entity)) {
@ -1216,7 +1235,7 @@ class MessageControllerTest {
.sendMessage( .sendMessage(
any(), any(),
any(), any(),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()), argThat(env -> env.getUrgent() == urgent && !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(true)); eq(true));
if (expectedStatus == 200) { if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
@ -1384,7 +1403,7 @@ class MessageControllerTest {
.queryParam("story", false) .queryParam("story", false)
.queryParam("urgent", false) .queryParam("urgent", false)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@ -1395,7 +1414,7 @@ class MessageControllerTest {
.sendMessage( .sendMessage(
any(), any(),
any(), any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()), argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(true)); eq(true));
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty())); assertThat(smrmr.uuids404(), is(empty()));
@ -1423,7 +1442,7 @@ class MessageControllerTest {
.queryParam("story", false) .queryParam("story", false)
.queryParam("urgent", false) .queryParam("urgent", false)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) serverSecretParams, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@ -1454,7 +1473,7 @@ class MessageControllerTest {
.queryParam("story", false) .queryParam("story", false)
.queryParam("urgent", false) .queryParam("urgent", false)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@ -1620,7 +1639,7 @@ class MessageControllerTest {
.queryParam("story", false) .queryParam("story", false)
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request // make the PUT request
@ -1663,7 +1682,7 @@ class MessageControllerTest {
.queryParam("story", false) .queryParam("story", false)
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request // make the PUT request
@ -1702,7 +1721,7 @@ class MessageControllerTest {
.queryParam("story", true) .queryParam("story", true)
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class))) when(rateLimiter.validateAsync(any(UUID.class)))
@ -1730,14 +1749,14 @@ class MessageControllerTest {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type)) .setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString()) .setDestinationServiceId(destinationUuid.toString())
.setStory(story) .setStory(story)
.setServerGuid(guid.toString()); .setServerGuid(guid.toString());
if (sourceUuid != null) { if (sourceUuid != null) {
builder.setSourceUuid(sourceUuid.toString()); builder.setSourceServiceId(sourceUuid.toString());
builder.setSourceDevice(sourceDevice); builder.setSourceDevice(sourceDevice);
} }

View File

@ -104,7 +104,7 @@ class MessageMetricsTest {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder(); final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
if (destinationIdentifier != null) { if (destinationIdentifier != null) {
builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString()); builder.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString());
} }
return builder.build(); return builder.build();

View File

@ -151,7 +151,7 @@ class MessageSenderTest {
private MessageProtos.Envelope generateRandomMessage() { private MessageProtos.Envelope generateRandomMessage() {
return MessageProtos.Envelope.newBuilder() return MessageProtos.Envelope.newBuilder()
.setTimestamp(System.currentTimeMillis()) .setClientTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)

View File

@ -54,8 +54,8 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier; import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;

View File

@ -160,8 +160,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
} }
@ -208,8 +208,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
} }
@ -254,8 +254,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
} }
@ -296,8 +296,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
} }
@ -340,8 +340,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
} }

View File

@ -81,7 +81,7 @@ class MessagePersisterIntegrationTest {
notificationExecutorService = Executors.newSingleThreadExecutor(); notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService, messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService,
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC()); messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class), messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService); messageDeletionExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
@ -185,12 +185,12 @@ class MessagePersisterIntegrationTest {
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long serverTimestamp) { private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long serverTimestamp) {
return MessageProtos.Envelope.newBuilder() return MessageProtos.Envelope.newBuilder()
.setTimestamp(serverTimestamp * 2) // client timestamp may not be accurate .setClientTimestamp(serverTimestamp * 2) // client timestamp may not be accurate
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString()) .setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString()) .setDestinationServiceId(UUID.randomUUID().toString())
.build(); .build();
} }
} }

View File

@ -40,6 +40,7 @@ import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -48,12 +49,11 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class MessagePersisterTest { class MessagePersisterTest {
@RegisterExtension @RegisterExtension
@ -104,7 +104,7 @@ class MessagePersisterTest {
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager, messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager,
keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1); keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1);
@ -356,7 +356,8 @@ class MessagePersisterTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i) .setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i) .setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)

View File

@ -0,0 +1,74 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.lettuce.core.RedisCommandExecutionException;
import io.lettuce.core.ScriptOutputType;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheGetItemsScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheGetItemsScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final String serverGuid = UUID.randomUUID().toString();
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(serverGuid)
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1)
.block(Duration.ofSeconds(1));
assertNotNull(messageAndScores);
assertEquals(2, messageAndScores.size());
final MessageProtos.Envelope resultEnvelope = MessageProtos.Envelope.parseFrom(
messageAndScores.getFirst());
assertEquals(serverGuid, resultEnvelope.getServerGuid());
}
@Test
void testCacheGetItemsInvalidParameter() throws Exception {
final ClusterLuaScript getItemsScript = ClusterLuaScript.fromResource(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
"lua/get_items.lua", ScriptOutputType.OBJECT);
final byte[] fakeKey = new byte[]{1};
final Exception e = assertThrows(RedisCommandExecutionException.class,
() -> getItemsScript.executeBinaryReactive(List.of(fakeKey, fakeKey),
List.of("1".getBytes(StandardCharsets.UTF_8)))
.next()
.block(Duration.ofSeconds(1)));
assertEquals("ERR afterMessageId is required", e.getMessage());
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Instant;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheInsertScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheInsertScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1));
final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
assertEquals(2, insertScript.execute(destinationUuid, deviceId, envelope2));
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1),
"Repeated with same guid should have same message ID");
}
}

View File

@ -0,0 +1,74 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ParameterizedTest
@MethodSource
void testInsert(final int count, final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey));
assertEquals(totalDevices + 1, hashFieldCount);
}
public static List<Arguments> testInsert() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Arguments> testCases = new ArrayList<>();
testCases.add(Arguments.of(1, singleAccount));
for (int j = 1000; j <= 30000; j += 1000) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
.mapToObj(v -> (byte) v)
.toList());
return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices);
})
.collect(Collectors.toMap(Pair::first, Pair::second));
testCases.add(Arguments.of(j, manyAccounts));
}
return testCases;
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheRemoveByGuidScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheRemoveByGuid() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID serverGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(serverGuid.toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> removedMessages = removeByGuidScript.execute(destinationUuid, deviceId,
List.of(serverGuid)).get(1, TimeUnit.SECONDS);
assertEquals(1, removedMessages.size());
final MessageProtos.Envelope resultMessage = MessageProtos.Envelope.parseFrom(
removedMessages.getFirst());
assertEquals(serverGuid, UUID.fromString(resultMessage.getServerGuid()));
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheRemoveQueueScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheRemoveQueueScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messagesToCheckForMrmKeys = removeScript.execute(destinationUuid, deviceId,
Collections.emptyList())
.block(Duration.ofSeconds(1));
assertEquals(1, messagesToCheckForMrmKeys.size());
}
}

View File

@ -0,0 +1,124 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.lettuce.core.cluster.SlotHash;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuples;
class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ParameterizedTest
@MethodSource
void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet())
.flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId))))
.flatMap(aciServiceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey),
aciServiceIdentifierByteTuple.getT1(), aciServiceIdentifierByteTuple.getT2()))
.reduce(Long::sum)
.block(Duration.ofSeconds(35)));
assertEquals(1, keysRemoved);
final long keyExists = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmKey));
assertEquals(0, keyExists);
}
public static List<Map<AciServiceIdentifier, List<Byte>>> testUpdateSingleKey() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Map<AciServiceIdentifier, List<Byte>>> testCases = new ArrayList<>();
testCases.add(singleAccount);
// Generate a more, from smallish to very large
for (int j = 1000; j <= 81000; j *= 3) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
.mapToObj(v -> (byte) v)
.toList());
return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices);
})
.collect(Collectors.toMap(Pair::first, Pair::second));
testCases.add(manyAccounts);
}
return testCases;
}
@ParameterizedTest
@ValueSource(ints = {1, 10, 100, 1000, 10000})
void testUpdateManyKeys(int keyCount) throws Exception {
final List<byte[]> sharedMrmKeys = new ArrayList<>(keyCount);
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;
for (int i = 0; i < keyCount; i++) {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(aciServiceIdentifier, deviceId));
sharedMrmKeys.add(sharedMrmKey);
}
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, aciServiceIdentifier, deviceId))
.reduce(Long::sum)
.block(Duration.ofSeconds(5)));
assertEquals(sharedMrmKeys.size(), keysRemoved);
}
}

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
@ -25,6 +26,8 @@ import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands; import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.protocol.AsyncCommand; import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.RedisCommand; import io.lettuce.core.protocol.RedisCommand;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
@ -32,9 +35,12 @@ import java.time.Instant;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Deque; import java.util.Deque;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.UUID; import java.util.UUID;
@ -42,11 +48,13 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -57,7 +65,12 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
@ -83,6 +96,8 @@ class MessagesCacheTest {
private Scheduler messageDeliveryScheduler; private Scheduler messageDeliveryScheduler;
private MessagesCache messagesCache; private MessagesCache messagesCache;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final byte DESTINATION_DEVICE_ID = 7; private static final byte DESTINATION_DEVICE_ID = 7;
@ -95,11 +110,16 @@ class MessagesCacheTest {
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
}); });
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
sharedExecutorService = Executors.newSingleThreadExecutor(); sharedExecutorService = Executors.newSingleThreadExecutor();
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor(); resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesCache.start(); messagesCache.start();
} }
@ -148,10 +168,10 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<MessageProtos.Envelope> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS); DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
assertEquals(Optional.of(message), maybeRemovedMessage); assertEquals(Optional.of(RemovedMessage.fromEnvelope(message)), maybeRemovedMessage);
} }
@ParameterizedTest @ParameterizedTest
@ -181,11 +201,11 @@ class MessagesCacheTest {
message); message);
} }
final List<MessageProtos.Envelope> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid()))
.collect(Collectors.toList())).get(5, TimeUnit.SECONDS); .collect(Collectors.toList())).get(5, TimeUnit.SECONDS);
assertEquals(messagesToRemove, removedMessages); assertEquals(messagesToRemove.stream().map(RemovedMessage::fromEnvelope).toList(), removedMessages);
assertEquals(messagesToPreserve, assertEquals(messagesToPreserve,
messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
} }
@ -283,7 +303,8 @@ class MessagesCacheTest {
} }
final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock); sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock,
dynamicConfigurationManager);
final List<MessageProtos.Envelope> actualMessages = Flux.from( final List<MessageProtos.Envelope> actualMessages = Flux.from(
messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID)) messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID))
@ -320,7 +341,7 @@ class MessagesCacheTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testClearQueueForDevice(final boolean sealedSender) { void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100; final int messageCount = 1000;
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) { for (int i = 0; i < messageCount; i++) {
@ -340,7 +361,7 @@ class MessagesCacheTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testClearQueueForAccount(final boolean sealedSender) { void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100; final int messageCount = 1000;
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) { for (int i = 0; i < messageCount; i++) {
@ -542,6 +563,57 @@ class MessagesCacheTest {
}); });
} }
@Test
void testMultiRecipientMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID mrmGuid = UUID.randomUUID();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(
new AciServiceIdentifier(destinationUuid), deviceId);
final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm);
final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(guid, true)
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 1: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid)))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(guid, destinationUuid, deviceId, message);
assertEquals(1L, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))));
final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1);
assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(new ServiceId.Aci(destinationUuid));
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationUuid, deviceId, guid)
.join();
assertTrue(removedMessage.isPresent());
assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString()));
assertTrue(get(destinationUuid, deviceId, 1).isEmpty());
// updating the shared MRM data is purely async, so we just wait for it
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
boolean exists;
do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)));
} while (exists);
});
}
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId, private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
final int messageCount) { final int messageCount) {
return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId)) return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))
@ -573,7 +645,7 @@ class MessagesCacheTest {
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(mockCluster, mock(ExecutorService.class), messageDeliveryScheduler, messagesCache = new MessagesCache(mockCluster, mock(ExecutorService.class), messageDeliveryScheduler,
Executors.newSingleThreadExecutor(), Clock.systemUTC()); Executors.newSingleThreadExecutor(), Clock.systemUTC(), mock(DynamicConfigurationManager.class));
} }
@AfterEach @AfterEach
@ -755,18 +827,85 @@ class MessagesCacheTest {
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender, private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender,
final long timestamp) { final long timestamp) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(timestamp) .setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString()) .setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString()); .setDestinationServiceId(UUID.randomUUID().toString());
if (!sealedSender) { if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1) envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1)
.setSourceUuid(UUID.randomUUID().toString()); .setSourceServiceId(UUID.randomUUID().toString());
} }
return envelopeBuilder.build(); return envelopeBuilder.build();
} }
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(
Map<AciServiceIdentifier, List<Byte>> destinations) {
try {
final ByteBuffer prefix = ByteBuffer.allocate(7);
prefix.put((byte) 0x23); // version
writeVarint(prefix, destinations.size()); // recipient count
prefix.flip();
List<ByteBuffer> recipients = new ArrayList<>(destinations.size());
for (Map.Entry<AciServiceIdentifier, List<Byte>> aciServiceIdentifierAndDeviceIds : destinations.entrySet()) {
final AciServiceIdentifier destination = aciServiceIdentifierAndDeviceIds.getKey();
final List<Byte> deviceIds = aciServiceIdentifierAndDeviceIds.getValue();
assert deviceIds.size() < 255;
final ByteBuffer recipient = ByteBuffer.allocate(17 + 3 * deviceIds.size() + 48);
recipient.put(destination.toFixedWidthByteArray());
for (int i = 0; i < deviceIds.size(); i++) {
final int hasMore = i == deviceIds.size() - 1 ? 0x0000 : 0x8000;
recipient.put(new byte[]{deviceIds.get(i)}); // device ID
recipient.putShort((short) ((100 + deviceIds.get(i)) | hasMore)); // registration ID
}
final byte[] keyMaterial = new byte[48];
ThreadLocalRandom.current().nextBytes(keyMaterial);
recipient.put(keyMaterial);
recipients.add(recipient);
}
final byte[] commonPayload = new byte[64];
ThreadLocalRandom.current().nextBytes(commonPayload);
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
baos.write(prefix.array(), 0, prefix.limit());
for (ByteBuffer recipient : recipients) {
baos.write(recipient.array());
}
baos.write(commonPayload);
return SealedSenderMultiRecipientMessage.parse(baos.toByteArray());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(AciServiceIdentifier destination,
byte... deviceIds) {
final Map<AciServiceIdentifier, List<Byte>> destinations = new HashMap<>();
destinations.put(destination, Arrays.asList(ArrayUtils.toObject(deviceIds)));
return generateRandomMrmMessage(destinations);
}
private static void writeVarint(ByteBuffer bb, long n) {
while (n >= 0x80) {
bb.put((byte) (n & 0x7F | 0x80));
n = n >> 7;
}
bb.put((byte) (n & 0x7F));
}
} }

View File

@ -6,8 +6,6 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.time.Duration; import java.time.Duration;
@ -31,7 +29,6 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.MessageHelper; import org.whispersystems.textsecuregcm.tests.util.MessageHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
@ -47,31 +44,31 @@ class MessagesDynamoDbTest {
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = System.currentTimeMillis();
MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder(); MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER); builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER);
builder.setTimestamp(123456789L); builder.setClientTimestamp(123456789L);
builder.setContent(ByteString.copyFrom(new byte[]{(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF})); builder.setContent(ByteString.copyFrom(new byte[]{(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}));
builder.setServerGuid(UUID.randomUUID().toString()); builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp); builder.setServerTimestamp(serverTimestamp);
builder.setDestinationUuid(UUID.randomUUID().toString()); builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE1 = builder.build(); MESSAGE1 = builder.build();
builder.setType(MessageProtos.Envelope.Type.CIPHERTEXT); builder.setType(MessageProtos.Envelope.Type.CIPHERTEXT);
builder.setSourceUuid(UUID.randomUUID().toString()); builder.setSourceServiceId(UUID.randomUUID().toString());
builder.setSourceDevice(1); builder.setSourceDevice(1);
builder.setContent(ByteString.copyFromUtf8("MOO")); builder.setContent(ByteString.copyFromUtf8("MOO"));
builder.setServerGuid(UUID.randomUUID().toString()); builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp + 1); builder.setServerTimestamp(serverTimestamp + 1);
builder.setDestinationUuid(UUID.randomUUID().toString()); builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE2 = builder.build(); MESSAGE2 = builder.build();
builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER); builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER);
builder.clearSourceUuid(); builder.clearSourceDevice();
builder.clearSourceDevice(); builder.clearSourceDevice();
builder.setContent(ByteString.copyFromUtf8("COW")); builder.setContent(ByteString.copyFromUtf8("COW"));
builder.setServerGuid(UUID.randomUUID().toString()); builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp); // Test same millisecond arrival for two different messages builder.setServerTimestamp(serverTimestamp); // Test same millisecond arrival for two different messages
builder.setDestinationUuid(UUID.randomUUID().toString()); builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE3 = builder.build(); MESSAGE3 = builder.build();
} }

View File

@ -35,7 +35,7 @@ class MessagesManagerTest {
void insert() { void insert() {
final UUID sourceAci = UUID.randomUUID(); final UUID sourceAci = UUID.randomUUID();
final Envelope message = Envelope.newBuilder() final Envelope message = Envelope.newBuilder()
.setSourceUuid(sourceAci.toString()) .setSourceServiceId(sourceAci.toString())
.build(); .build();
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
@ -45,7 +45,7 @@ class MessagesManagerTest {
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class)); verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
final Envelope syncMessage = Envelope.newBuilder(message) final Envelope syncMessage = Envelope.newBuilder(message)
.setSourceUuid(destinationUuid.toString()) .setSourceServiceId(destinationUuid.toString())
.build(); .build();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage); messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);

View File

@ -17,11 +17,11 @@ public class MessageHelper {
return MessageProtos.Envelope.newBuilder() return MessageProtos.Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(0) .setServerTimestamp(0)
.setSourceUuid(senderUuid.toString()) .setSourceServiceId(senderUuid.toString())
.setSourceDevice(senderDeviceId) .setSourceDevice(senderDeviceId)
.setDestinationUuid(destinationUuid.toString()) .setDestinationServiceId(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8))) .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build(); .build();
} }

View File

@ -44,6 +44,7 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
@ -55,6 +56,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesCache;
@ -85,6 +87,8 @@ class WebSocketConnectionIntegrationTest {
private Scheduler messageDeliveryScheduler; private Scheduler messageDeliveryScheduler;
private ClientReleaseManager clientReleaseManager; private ClientReleaseManager clientReleaseManager;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private long serialTimestamp = System.currentTimeMillis(); private long serialTimestamp = System.currentTimeMillis();
@BeforeEach @BeforeEach
@ -92,8 +96,10 @@ class WebSocketConnectionIntegrationTest {
sharedExecutorService = Executors.newSingleThreadExecutor(); sharedExecutorService = Executors.newSingleThreadExecutor();
scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7),
sharedExecutorService); sharedExecutorService);
@ -381,12 +387,12 @@ class WebSocketConnectionIntegrationTest {
final long timestamp = serialTimestamp++; final long timestamp = serialTimestamp++;
return MessageProtos.Envelope.newBuilder() return MessageProtos.Envelope.newBuilder()
.setTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(timestamp) .setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString()) .setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString()) .setDestinationServiceId(UUID.randomUUID().toString())
.build(); .build();
} }

View File

@ -48,7 +48,6 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -297,19 +296,19 @@ class WebSocketConnectionTest {
final Envelope firstMessage = Envelope.newBuilder() final Envelope firstMessage = Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(UUID.randomUUID().toString()) .setSourceServiceId(UUID.randomUUID().toString())
.setDestinationUuid(accountUuid.toString()) .setDestinationServiceId(accountUuid.toString())
.setUpdatedPni(UUID.randomUUID().toString()) .setUpdatedPni(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis()) .setClientTimestamp(System.currentTimeMillis())
.setSourceDevice(1) .setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT) .setType(Envelope.Type.CIPHERTEXT)
.build(); .build();
final Envelope secondMessage = Envelope.newBuilder() final Envelope secondMessage = Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(senderTwoUuid.toString()) .setSourceServiceId(senderTwoUuid.toString())
.setDestinationUuid(accountUuid.toString()) .setDestinationServiceId(accountUuid.toString())
.setTimestamp(System.currentTimeMillis()) .setClientTimestamp(System.currentTimeMillis())
.setSourceDevice(2) .setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT) .setType(Envelope.Type.CIPHERTEXT)
.build(); .build();
@ -365,7 +364,7 @@ class WebSocketConnectionTest {
futures.get(0).completeExceptionally(new IOException()); futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getTimestamp())); eq(secondMessage.getClientTimestamp()));
connection.stop(); connection.stop();
verify(client).close(anyInt(), anyString()); verify(client).close(anyInt(), anyString());
@ -616,10 +615,10 @@ class WebSocketConnectionTest {
final byte[] body = argument.get(); final byte[] body = argument.get();
try { try {
final Envelope envelope = Envelope.parseFrom(body); final Envelope envelope = Envelope.parseFrom(body);
if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) { if (!envelope.hasSourceServiceId() || envelope.getSourceServiceId().length() == 0) {
return false; return false;
} }
return envelope.getSourceUuid().equals(senderUuid.toString()); return envelope.getSourceServiceId().equals(senderUuid.toString());
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
return false; return false;
} }
@ -627,7 +626,7 @@ class WebSocketConnectionTest {
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
} }
private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) { private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client, mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class)); retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
@ -933,11 +932,11 @@ class WebSocketConnectionTest {
return Envelope.newBuilder() return Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.setType(Envelope.Type.CIPHERTEXT) .setType(Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp) .setClientTimestamp(timestamp)
.setServerTimestamp(0) .setServerTimestamp(0)
.setSourceUuid(senderUuid.toString()) .setSourceServiceId(senderUuid.toString())
.setSourceDevice(SOURCE_DEVICE_ID) .setSourceDevice(SOURCE_DEVICE_ID)
.setDestinationUuid(destinationUuid.toString()) .setDestinationServiceId(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8))) .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build(); .build();
} }