0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-20 03:52:16 +02:00

Use destination service ID from the envelope when removing views from shared MRM data

This commit is contained in:
Chris Eager 2024-09-13 17:38:18 -05:00 committed by Chris Eager
parent 11691c3122
commit 374fe087bc
8 changed files with 68 additions and 52 deletions

View File

@ -70,7 +70,7 @@ public class MessageSender {
if (clientPresent) { if (clientPresent) {
messagesManager.insert(account.getUuid(), device.getId(), message.toBuilder().setEphemeral(true).build()); messagesManager.insert(account.getUuid(), device.getId(), message.toBuilder().setEphemeral(true).build());
} else { } else {
messagesManager.removeRecipientViewFromMrmData(account.getUuid(), device.getId(), message); messagesManager.removeRecipientViewFromMrmData(device.getId(), message);
} }
} else { } else {
messagesManager.insert(account.getUuid(), device.getId(), message); messagesManager.insert(account.getUuid(), device.getId(), message);

View File

@ -48,7 +48,6 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
@ -296,21 +295,25 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
.thenApplyAsync(serialized -> { .thenApplyAsync(serialized -> {
final List<RemovedMessage> removedMessages = new ArrayList<>(serialized.size()); final List<RemovedMessage> removedMessages = new ArrayList<>(serialized.size());
final List<byte[]> sharedMrmKeysToUpdate = new ArrayList<>(); final Map<ServiceIdentifier, List<byte[]>> serviceIdentifierToMrmKeys = new HashMap<>();
for (final byte[] bytes : serialized) { for (final byte[] bytes : serialized) {
try { try {
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes); final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes);
removedMessages.add(RemovedMessage.fromEnvelope(envelope)); removedMessages.add(RemovedMessage.fromEnvelope(envelope));
if (envelope.hasSharedMrmKey()) { if (envelope.hasSharedMrmKey()) {
sharedMrmKeysToUpdate.add(envelope.getSharedMrmKey().toByteArray()); serviceIdentifierToMrmKeys.computeIfAbsent(
ServiceIdentifier.valueOf(envelope.getDestinationServiceId()), ignored -> new ArrayList<>())
.add(envelope.getSharedMrmKey().toByteArray());
} }
} catch (final InvalidProtocolBufferException e) { } catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e); logger.warn("Failed to parse envelope", e);
} }
} }
removeRecipientViewFromMrmData(sharedMrmKeysToUpdate, destinationUuid, destinationDevice); serviceIdentifierToMrmKeys.forEach(
(serviceId, keysToUpdate) -> removeRecipientViewFromMrmData(keysToUpdate, serviceId, destinationDevice));
return removedMessages; return removedMessages;
}, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer)); }, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
@ -472,7 +475,8 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
/** /**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure * Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/ */
void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid, final byte deviceId) { void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final ServiceIdentifier serviceIdentifier,
final byte deviceId) {
if (sharedMrmKeys.isEmpty()) { if (sharedMrmKeys.isEmpty()) {
return; return;
@ -483,7 +487,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
.collectMultimap(SlotHash::getSlot) .collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values())) .flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap( .flatMap(
keys -> removeRecipientViewFromMrmDataScript.execute(keys, new AciServiceIdentifier(accountUuid), deviceId), keys -> removeRecipientViewFromMrmDataScript.execute(keys, serviceIdentifier, deviceId),
REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY) REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY)
.doOnNext(sharedMrmDataKeyRemovedCounter::increment) .doOnNext(sharedMrmDataKeyRemovedCounter::increment)
.onErrorResume(e -> { .onErrorResume(e -> {
@ -575,7 +579,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return Mono.empty(); return Mono.empty();
} }
final List<byte[]> mrmKeys = new ArrayList<>(messagesToProcess.size()); final Map<ServiceIdentifier, List<byte[]>> serviceIdentifierToMrmKeys = new HashMap<>();
final List<String> processedMessages = new ArrayList<>(messagesToProcess.size()); final List<String> processedMessages = new ArrayList<>(messagesToProcess.size());
for (byte[] serialized : messagesToProcess) { for (byte[] serialized : messagesToProcess) {
try { try {
@ -584,14 +588,17 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
processedMessages.add(message.getServerGuid()); processedMessages.add(message.getServerGuid());
if (message.hasSharedMrmKey()) { if (message.hasSharedMrmKey()) {
mrmKeys.add(message.getSharedMrmKey().toByteArray()); serviceIdentifierToMrmKeys.computeIfAbsent(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
ignored -> new ArrayList<>())
.add(message.getSharedMrmKey().toByteArray());
} }
} catch (final InvalidProtocolBufferException e) { } catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e); logger.warn("Failed to parse envelope", e);
} }
} }
removeRecipientViewFromMrmData(mrmKeys, destinationUuid, deviceId); serviceIdentifierToMrmKeys.forEach((serviceId, keysToUpdate) ->
removeRecipientViewFromMrmData(keysToUpdate, serviceId, deviceId));
return removeQueueScript.execute(destinationUuid, deviceId, processedMessages); return removeQueueScript.execute(destinationUuid, deviceId, processedMessages);
}) })

View File

@ -10,7 +10,7 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -30,8 +30,9 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScript {
"lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER); "lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER);
} }
Mono<Long> execute(final Collection<byte[]> keysCollection, final AciServiceIdentifier serviceIdentifier, Mono<Long> execute(final Collection<byte[]> keysCollection, final ServiceIdentifier serviceIdentifier,
final byte deviceId) { final byte deviceId) {
final List<byte[]> keys = keysCollection instanceof List<byte[]> final List<byte[]> keys = keysCollection instanceof List<byte[]>
? (List<byte[]>) keysCollection ? (List<byte[]>) keysCollection
: new ArrayList<>(keysCollection); : new ArrayList<>(keysCollection);

View File

@ -24,6 +24,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.observability.micrometer.Micrometer; import reactor.core.observability.micrometer.Micrometer;
@ -214,10 +215,10 @@ public class MessagesManager {
/** /**
* Removes the recipient's view from shared MRM data if necessary * Removes the recipient's view from shared MRM data if necessary
*/ */
public void removeRecipientViewFromMrmData(final UUID destinationUuid, final byte destinationDeviceId, public void removeRecipientViewFromMrmData(final byte destinationDeviceId, final Envelope message) {
final Envelope message) {
if (message.hasSharedMrmKey()) { if (message.hasSharedMrmKey()) {
messagesCache.removeRecipientViewFromMrmData(List.of(message.getSharedMrmKey().toByteArray()), destinationUuid, messagesCache.removeRecipientViewFromMrmData(List.of(message.getSharedMrmKey().toByteArray()),
ServiceIdentifier.valueOf(message.getDestinationServiceId()),
destinationDeviceId); destinationDeviceId);
} }
} }

View File

@ -73,8 +73,7 @@ class MessageSenderTest {
MessageProtos.Envelope.class); MessageProtos.Envelope.class);
verify(messagesManager).insert(any(), anyByte(), envelopeArgumentCaptor.capture()); verify(messagesManager).insert(any(), anyByte(), envelopeArgumentCaptor.capture());
verify(messagesManager, never()).removeRecipientViewFromMrmData(any(), anyByte(), verify(messagesManager, never()).removeRecipientViewFromMrmData(anyByte(), any(MessageProtos.Envelope.class));
any(MessageProtos.Envelope.class));
assertTrue(envelopeArgumentCaptor.getValue().getEphemeral()); assertTrue(envelopeArgumentCaptor.getValue().getEphemeral());
@ -96,7 +95,7 @@ class MessageSenderTest {
} }
verify(messagesManager, never()).insert(any(), anyByte(), any()); verify(messagesManager, never()).insert(any(), anyByte(), any());
verify(messagesManager).removeRecipientViewFromMrmData(any(), anyByte(), any(MessageProtos.Envelope.class)); verify(messagesManager).removeRecipientViewFromMrmData(anyByte(), any(MessageProtos.Envelope.class));
verifyNoInteractions(pushNotificationManager); verifyNoInteractions(pushNotificationManager);
} }

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import io.lettuce.core.RedisCommandExecutionException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -15,13 +16,13 @@ import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import io.lettuce.core.RedisCommandExecutionException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -32,7 +33,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testInsert(final int count, final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception { void testInsert(final int count, final Map<ServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript( final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
@ -49,7 +50,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
} }
public static List<Arguments> testInsert() { public static List<Arguments> testInsert() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of( final Map<ServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2)); new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Arguments> testCases = new ArrayList<>(); final List<Arguments> testCases = new ArrayList<>();
@ -58,7 +59,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
for (int j = 1000; j <= 30000; j += 1000) { for (int j = 1000; j <= 30000; j += 1000) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>(); final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j) final Map<ServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> { .mapToObj(i -> {
final int deviceCount = 1 + i % 5; final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count) final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)

View File

@ -22,6 +22,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -34,7 +35,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception { void testUpdateSingleKey(final Map<ServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript( final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
@ -48,8 +49,8 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet()) final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet())
.flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId)))) .flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId))))
.flatMap(aciServiceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey), .flatMap(serviceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey),
aciServiceIdentifierByteTuple.getT1(), aciServiceIdentifierByteTuple.getT2())) serviceIdentifierByteTuple.getT1(), serviceIdentifierByteTuple.getT2()))
.reduce(Long::sum) .reduce(Long::sum)
.block(Duration.ofSeconds(35))); .block(Duration.ofSeconds(35)));
@ -60,18 +61,18 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
assertEquals(0, keyExists); assertEquals(0, keyExists);
} }
public static List<Map<AciServiceIdentifier, List<Byte>>> testUpdateSingleKey() { public static List<Map<ServiceIdentifier, List<Byte>>> testUpdateSingleKey() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of( final Map<ServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2)); new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Map<AciServiceIdentifier, List<Byte>>> testCases = new ArrayList<>(); final List<Map<ServiceIdentifier, List<Byte>>> testCases = new ArrayList<>();
testCases.add(singleAccount); testCases.add(singleAccount);
// Generate a more, from smallish to very large // Generate a more, from smallish to very large
for (int j = 1000; j <= 81000; j *= 3) { for (int j = 1000; j <= 81000; j *= 3) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>(); final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j) final Map<ServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> { .mapToObj(i -> {
final int deviceCount = 1 + i % 5; final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count) final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
@ -93,7 +94,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
void testUpdateManyKeys(int keyCount) throws Exception { void testUpdateManyKeys(int keyCount) throws Exception {
final List<byte[]> sharedMrmKeys = new ArrayList<>(keyCount); final List<byte[]> sharedMrmKeys = new ArrayList<>(keyCount);
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1; final byte deviceId = 1;
for (int i = 0; i < keyCount; i++) { for (int i = 0; i < keyCount; i++) {
@ -103,7 +104,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey, insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(aciServiceIdentifier, deviceId)); MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId));
sharedMrmKeys.add(sharedMrmKey); sharedMrmKeys.add(sharedMrmKey);
} }
@ -114,7 +115,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys) final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot) .collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values())) .flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, aciServiceIdentifier, deviceId)) .flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, serviceIdentifier, deviceId))
.reduce(Long::sum) .reduce(Long::sum)
.block(Duration.ofSeconds(5))); .block(Duration.ofSeconds(5)));

View File

@ -71,6 +71,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
@ -565,11 +566,10 @@ class MessagesCacheTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Exception { void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) throws Exception {
final UUID destinationUuid = UUID.randomUUID(); final ServiceIdentifier destinationServiceId = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1; final byte deviceId = 1;
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage( final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
new AciServiceIdentifier(destinationUuid), deviceId);
final byte[] sharedMrmDataKey; final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) { if (sharedMrmKeyPresent) {
@ -579,35 +579,35 @@ class MessagesCacheTest {
} }
final UUID guid = UUID.randomUUID(); final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(guid, true) final MessageProtos.Envelope message = generateRandomMessage(guid, destinationServiceId, true)
.toBuilder() .toBuilder()
// clear some things added by the helper // clear some things added by the helper
.clearServerGuid() .clearServerGuid()
// mrm views phase 1: messages have content // mrm views phase 1: messages have content
.setContent( .setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid))))) ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(destinationServiceId.toLibsignal()))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build(); .build();
messagesCache.insert(guid, destinationUuid, deviceId, message); messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1); final List<MessageProtos.Envelope> messages = get(destinationServiceId.uuid(), deviceId, 1);
assertEquals(1, messages.size()); assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid())); assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey()); assertFalse(messages.getFirst().hasSharedMrmKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients() final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(new ServiceId.Aci(destinationUuid)); .get(destinationServiceId.toLibsignal());
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray()); assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationUuid, deviceId, guid) final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationServiceId.uuid(), deviceId, guid)
.join(); .join();
assertTrue(removedMessage.isPresent()); assertTrue(removedMessage.isPresent());
assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString())); assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString()));
assertTrue(get(destinationUuid, deviceId, 1).isEmpty()); assertTrue(get(destinationServiceId.uuid(), deviceId, 1).isEmpty());
// updating the shared MRM data is purely async, so we just wait for it // updating the shared MRM data is purely async, so we just wait for it
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> { assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
@ -874,10 +874,17 @@ class MessagesCacheTest {
} }
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) { private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) {
return generateRandomMessage(messageGuid, sealedSender, serialTimestamp++); return generateRandomMessage(messageGuid, new AciServiceIdentifier(UUID.randomUUID()), sealedSender,
serialTimestamp++);
} }
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender, private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid,
final ServiceIdentifier destinationServiceId, final boolean sealedSender) {
return generateRandomMessage(messageGuid, destinationServiceId, sealedSender, serialTimestamp++);
}
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid,
final ServiceIdentifier destinationServiceId, final boolean sealedSender,
final long timestamp) { final long timestamp) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setClientTimestamp(timestamp) .setClientTimestamp(timestamp)
@ -885,7 +892,7 @@ class MessagesCacheTest {
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString()) .setServerGuid(messageGuid.toString())
.setDestinationServiceId(UUID.randomUUID().toString()); .setDestinationServiceId(destinationServiceId.toServiceIdentifierString());
if (!sealedSender) { if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1) envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1)
@ -896,8 +903,7 @@ class MessagesCacheTest {
} }
static SealedSenderMultiRecipientMessage generateRandomMrmMessage( static SealedSenderMultiRecipientMessage generateRandomMrmMessage(
Map<AciServiceIdentifier, List<Byte>> destinations) { Map<ServiceIdentifier, List<Byte>> destinations) {
try { try {
final ByteBuffer prefix = ByteBuffer.allocate(7); final ByteBuffer prefix = ByteBuffer.allocate(7);
@ -907,10 +913,10 @@ class MessagesCacheTest {
List<ByteBuffer> recipients = new ArrayList<>(destinations.size()); List<ByteBuffer> recipients = new ArrayList<>(destinations.size());
for (Map.Entry<AciServiceIdentifier, List<Byte>> aciServiceIdentifierAndDeviceIds : destinations.entrySet()) { for (Map.Entry<ServiceIdentifier, List<Byte>> serviceIdentifierAndDeviceIds : destinations.entrySet()) {
final AciServiceIdentifier destination = aciServiceIdentifierAndDeviceIds.getKey(); final ServiceIdentifier destination = serviceIdentifierAndDeviceIds.getKey();
final List<Byte> deviceIds = aciServiceIdentifierAndDeviceIds.getValue(); final List<Byte> deviceIds = serviceIdentifierAndDeviceIds.getValue();
assert deviceIds.size() < 255; assert deviceIds.size() < 255;
@ -946,10 +952,10 @@ class MessagesCacheTest {
} }
} }
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(AciServiceIdentifier destination, static SealedSenderMultiRecipientMessage generateRandomMrmMessage(ServiceIdentifier destination,
byte... deviceIds) { byte... deviceIds) {
final Map<AciServiceIdentifier, List<Byte>> destinations = new HashMap<>(); final Map<ServiceIdentifier, List<Byte>> destinations = new HashMap<>();
destinations.put(destination, Arrays.asList(ArrayUtils.toObject(deviceIds))); destinations.put(destination, Arrays.asList(ArrayUtils.toObject(deviceIds)));
return generateRandomMrmMessage(destinations); return generateRandomMrmMessage(destinations);
} }