From f9f93c77e2576e11e58b6276039df29455a11637 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 20 Jul 2020 12:18:10 -0400 Subject: [PATCH] Use UUIDs instead of phone numbers as account identifiers in clustered message cache --- .../textsecuregcm/WhisperServerService.java | 2 +- .../controllers/AccountController.java | 4 +- .../controllers/DeviceController.java | 4 +- .../controllers/MessageController.java | 3 + .../textsecuregcm/push/WebsocketSender.java | 2 +- .../textsecuregcm/storage/MessagesCache.java | 26 +++--- .../storage/MessagesManager.java | 28 +++---- .../storage/RedisClusterMessagesCache.java | 79 ++++++++++--------- .../storage/UserMessagesCache.java | 15 ++-- .../websocket/DeadLetterHandler.java | 19 ++++- .../websocket/WebSocketConnection.java | 4 +- .../storage/AbstractMessagesCacheTest.java | 41 +++++----- .../RedisClusterMessagesCacheTest.java | 11 ++- .../controllers/DeviceControllerTest.java | 3 +- .../controllers/MessageControllerTest.java | 10 +-- .../websocket/WebSocketConnectionTest.java | 12 ++- 16 files changed, 148 insertions(+), 115 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index b84b9760..47222da9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -314,7 +314,7 @@ public class WhisperServerService extends Application maybeExistingAccount = accounts.get(number); + Device device = new Device(); device.setId(Device.MASTER_ID); device.setAuthenticationCredentials(new AuthenticationCredentials(password)); @@ -643,7 +645,7 @@ public class AccountController { directoryQueue.deleteRegisteredUser(account.getUuid(), number); } - messagesManager.clear(number); + messagesManager.clear(number, maybeExistingAccount.map(Account::getUuid).orElse(null)); pendingAccounts.remove(number); return account; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index f3f0398a..b5fc3147 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -117,7 +117,7 @@ public class DeviceController { directoryQueue.deleteRegisteredUser(account.getUuid(), account.getNumber()); } - messages.clear(account.getNumber(), deviceId); + messages.clear(account.getNumber(), account.getUuid(), deviceId); } @Timed @@ -205,7 +205,7 @@ public class DeviceController { device.setCreated(System.currentTimeMillis()); account.get().addDevice(device); - messages.clear(account.get().getNumber(), device.getId()); + messages.clear(account.get().getNumber(), account.get().getUuid(), device.getId()); accounts.update(account.get()); pendingDevices.remove(number); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index a1a49fdc..4d9ea769 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -189,6 +189,7 @@ public class MessageController { } return messagesManager.getMessagesForDevice(account.getNumber(), + account.getUuid(), account.getAuthenticatedDevice().get().getId()); } @@ -203,6 +204,7 @@ public class MessageController { WebSocketConnection.messageTime.update(System.currentTimeMillis() - timestamp); Optional message = messagesManager.delete(account.getNumber(), + account.getUuid(), account.getAuthenticatedDevice().get().getId(), source, timestamp); @@ -222,6 +224,7 @@ public class MessageController { public void removePendingMessage(@Auth Account account, @PathParam("uuid") UUID uuid) { try { Optional message = messagesManager.delete(account.getNumber(), + account.getUuid(), account.getAuthenticatedDevice().get().getId(), uuid); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java index fe84ee61..fd2ca1cf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java @@ -96,7 +96,7 @@ public class WebsocketSender { WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); - messagesManager.insert(account.getNumber(), device.getId(), message); + messagesManager.insert(account.getNumber(), account.getUuid(), device.getId(), message); pubSubManager.publish(address, PubSubMessage.newBuilder() .setType(PubSubMessage.Type.QUERY_DB) .build()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 835453da..e7f1b85a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -89,14 +89,14 @@ public class MessagesCache implements Managed, UserMessagesCache { } @Override - public long insert(UUID guid, String destination, long destinationDevice, Envelope message) { + public long insert(UUID guid, String destination, final UUID destinationUuid, long destinationDevice, Envelope message) { final Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); Timer.Context timer = insertTimer.time(); try { final long messageId = insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), messageWithGuid); - insertExperiment.compareSupplierResultAsync(messageId, () -> clusterMessagesCache.insert(guid, destination, destinationDevice, message, messageId), experimentExecutor); + insertExperiment.compareSupplierResultAsync(messageId, () -> clusterMessagesCache.insert(guid, destination, destinationUuid, destinationDevice, message, messageId), experimentExecutor); return messageId; } finally { @@ -105,7 +105,7 @@ public class MessagesCache implements Managed, UserMessagesCache { } @Override - public Optional remove(String destination, long destinationDevice, long id) { + public Optional remove(String destination, final UUID destinationUuid, long destinationDevice, long id) { OutgoingMessageEntity removedMessageEntity = null; try (Jedis jedis = jedisPool.getWriteResource(); @@ -122,13 +122,13 @@ public class MessagesCache implements Managed, UserMessagesCache { final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity); - removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, id), experimentExecutor); + removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationUuid, destinationDevice, id), experimentExecutor); return maybeRemovedMessage; } @Override - public Optional remove(String destination, long destinationDevice, String sender, long timestamp) { + public Optional remove(String destination, final UUID destinationUuid, long destinationDevice, String sender, long timestamp) { OutgoingMessageEntity removedMessageEntity = null; Timer.Context timer = removeByNameTimer.time(); @@ -146,13 +146,13 @@ public class MessagesCache implements Managed, UserMessagesCache { final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity); - removeBySenderExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, sender, timestamp), experimentExecutor); + removeBySenderExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationUuid, destinationDevice, sender, timestamp), experimentExecutor); return maybeRemovedMessage; } @Override - public Optional remove(String destination, long destinationDevice, UUID guid) { + public Optional remove(String destination, final UUID destinationUuid, long destinationDevice, UUID guid) { OutgoingMessageEntity removedMessageEntity = null; Timer.Context timer = removeByGuidTimer.time(); @@ -170,13 +170,13 @@ public class MessagesCache implements Managed, UserMessagesCache { final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity); - removeByUuidExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, guid), experimentExecutor); + removeByUuidExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationUuid, destinationDevice, guid), experimentExecutor); return maybeRemovedMessage; } @Override - public List get(String destination, long destinationDevice, int limit) { + public List get(String destination, final UUID destinationUuid, long destinationDevice, int limit) { Timer.Context timer = getTimer.time(); try { @@ -194,7 +194,7 @@ public class MessagesCache implements Managed, UserMessagesCache { } } - getMessagesExperiment.compareSupplierResultAsync(results, () -> clusterMessagesCache.get(destination, destinationDevice, limit), experimentExecutor); + getMessagesExperiment.compareSupplierResultAsync(results, () -> clusterMessagesCache.get(destination, destinationUuid, destinationDevice, limit), experimentExecutor); return results; } finally { @@ -203,12 +203,12 @@ public class MessagesCache implements Managed, UserMessagesCache { } @Override - public void clear(String destination) { + public void clear(String destination, final UUID destinationUuid) { Timer.Context timer = clearAccountTimer.time(); try { for (int i = 1; i < 255; i++) { - clear(destination, i); + clear(destination, destinationUuid, i); } } finally { timer.stop(); @@ -216,7 +216,7 @@ public class MessagesCache implements Managed, UserMessagesCache { } @Override - public void clear(String destination, long deviceId) { + public void clear(String destination, final UUID destinationUuid, long deviceId) { Timer.Context timer = clearDeviceTimer.time(); try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index b87a254a..4a10fe88 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -34,34 +34,34 @@ public class MessagesManager { this.messagesCache = messagesCache; } - public void insert(String destination, long destinationDevice, Envelope message) { + public void insert(String destination, UUID destinationUuid, long destinationDevice, Envelope message) { UUID guid = UUID.randomUUID(); - messagesCache.insert(guid, destination, destinationDevice, message); + messagesCache.insert(guid, destination, destinationUuid, destinationDevice, message); } - public OutgoingMessageEntityList getMessagesForDevice(String destination, long destinationDevice) { + public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice) { List messages = this.messages.load(destination, destinationDevice); if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) { - messages.addAll(this.messagesCache.get(destination, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); + messages.addAll(this.messagesCache.get(destination, destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); } return new OutgoingMessageEntityList(messages, messages.size() >= Messages.RESULT_SET_CHUNK_SIZE); } - public void clear(String destination) { - this.messagesCache.clear(destination); + public void clear(String destination, UUID destinationUuid) { + this.messagesCache.clear(destination, destinationUuid); this.messages.clear(destination); } - public void clear(String destination, long deviceId) { - this.messagesCache.clear(destination, deviceId); + public void clear(String destination, UUID destinationUuid, long deviceId) { + this.messagesCache.clear(destination, destinationUuid, deviceId); this.messages.clear(destination, deviceId); } - public Optional delete(String destination, long destinationDevice, String source, long timestamp) + public Optional delete(String destination, UUID destinationUuid, long destinationDevice, String source, long timestamp) { - Optional removed = this.messagesCache.remove(destination, destinationDevice, source, timestamp); + Optional removed = this.messagesCache.remove(destination, destinationUuid, destinationDevice, source, timestamp); if (!removed.isPresent()) { removed = this.messages.remove(destination, destinationDevice, source, timestamp); @@ -73,8 +73,8 @@ public class MessagesManager { return removed; } - public Optional delete(String destination, long deviceId, UUID guid) { - Optional removed = this.messagesCache.remove(destination, deviceId, guid); + public Optional delete(String destination, UUID destinationUuid, long deviceId, UUID guid) { + Optional removed = this.messagesCache.remove(destination, destinationUuid, deviceId, guid); if (!removed.isPresent()) { removed = this.messages.remove(destination, guid); @@ -86,9 +86,9 @@ public class MessagesManager { return removed; } - public void delete(String destination, long deviceId, long id, boolean cached) { + public void delete(String destination, UUID destinationUuid, long deviceId, long id, boolean cached) { if (cached) { - this.messagesCache.remove(destination, deviceId, id); + this.messagesCache.remove(destination, destinationUuid, deviceId, id); cacheHitByIdMeter.mark(); } else { this.messages.remove(destination, id); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java index c6e2c885..a5c32325 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java @@ -53,28 +53,28 @@ public class RedisClusterMessagesCache implements UserMessagesCache { } @Override - public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message) { + public long insert(final UUID guid, final String destination, final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message) { final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil"; return (long)Metrics.timer(INSERT_TIMER_NAME).record(() -> - insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), - getMessageQueueMetadataKey(destination, destinationDevice), - getQueueIndexKey(destination, destinationDevice)), + insertScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice), + getQueueIndexKey(destinationUuid, destinationDevice)), List.of(messageWithGuid.toByteArray(), String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8), sender.getBytes(StandardCharsets.UTF_8), guid.toString().getBytes(StandardCharsets.UTF_8)))); } - public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message, final long messageId) { + public long insert(final UUID guid, final String destination, final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message, final long messageId) { final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil"; return (long)Metrics.timer(INSERT_TIMER_NAME).record(() -> - insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), - getMessageQueueMetadataKey(destination, destinationDevice), - getQueueIndexKey(destination, destinationDevice)), + insertScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice), + getQueueIndexKey(destinationUuid, destinationDevice)), List.of(messageWithGuid.toByteArray(), String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8), sender.getBytes(StandardCharsets.UTF_8), @@ -83,12 +83,12 @@ public class RedisClusterMessagesCache implements UserMessagesCache { } @Override - public Optional remove(final String destination, final long destinationDevice, final long id) { + public Optional remove(final String destination, final UUID destinationUuid, final long destinationDevice, final long id) { try { final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() -> - removeByIdScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), - getMessageQueueMetadataKey(destination, destinationDevice), - getQueueIndexKey(destination, destinationDevice)), + removeByIdScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice), + getQueueIndexKey(destinationUuid, destinationDevice)), List.of(String.valueOf(id).getBytes(StandardCharsets.UTF_8)))); @@ -103,12 +103,12 @@ public class RedisClusterMessagesCache implements UserMessagesCache { } @Override - public Optional remove(final String destination, final long destinationDevice, final String sender, final long timestamp) { + public Optional remove(final String destination, final UUID destinationUuid, final long destinationDevice, final String sender, final long timestamp) { try { final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_SENDER).record(() -> - removeBySenderScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), - getMessageQueueMetadataKey(destination, destinationDevice), - getQueueIndexKey(destination, destinationDevice)), + removeBySenderScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice), + getQueueIndexKey(destinationUuid, destinationDevice)), List.of((sender + "::" + timestamp).getBytes(StandardCharsets.UTF_8)))); if (serialized != null) { @@ -122,12 +122,12 @@ public class RedisClusterMessagesCache implements UserMessagesCache { } @Override - public Optional remove(final String destination, final long destinationDevice, final UUID guid) { + public Optional remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID guid) { try { final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() -> - removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), - getMessageQueueMetadataKey(destination, destinationDevice), - getQueueIndexKey(destination, destinationDevice)), + removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getMessageQueueMetadataKey(destinationUuid, destinationDevice), + getQueueIndexKey(destinationUuid, destinationDevice)), List.of(guid.toString().getBytes(StandardCharsets.UTF_8)))); if (serialized != null) { @@ -142,10 +142,10 @@ public class RedisClusterMessagesCache implements UserMessagesCache { @Override @SuppressWarnings("unchecked") - public List get(String destination, long destinationDevice, int limit) { + public List get(String destination, final UUID destinationUuid, long destinationDevice, int limit) { return Metrics.timer(GET_TIMER_NAME).record(() -> { - final List queueItems = (List)getItemsScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), - getPersistInProgressKey(destination, destinationDevice)), + final List queueItems = (List)getItemsScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice), + getPersistInProgressKey(destinationUuid, destinationDevice)), List.of(String.valueOf(limit).getBytes())); final List messageEntities; @@ -173,34 +173,37 @@ public class RedisClusterMessagesCache implements UserMessagesCache { } @Override - public void clear(final String destination) { - for (int i = 1; i < 256; i++) { - clear(destination, i); + public void clear(final String destination, final UUID destinationUuid) { + // TODO Remove null check in a fully UUID-based world + if (destinationUuid != null) { + for (int i = 1; i < 256; i++) { + clear(destination, destinationUuid, i); + } } } @Override - public void clear(final String destination, final long deviceId) { + public void clear(final String destination, final UUID destinationUuid, final long deviceId) { Metrics.timer(CLEAR_TIMER_NAME).record(() -> - removeQueueScript.executeBinary(List.of(getMessageQueueKey(destination, deviceId), - getMessageQueueMetadataKey(destination, deviceId), - getQueueIndexKey(destination, deviceId)), + removeQueueScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, deviceId), + getMessageQueueMetadataKey(destinationUuid, deviceId), + getQueueIndexKey(destinationUuid, deviceId)), Collections.emptyList())); } - private static byte[] getMessageQueueKey(final String address, final long deviceId) { - return ("user_queue::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + private static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) { + return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getMessageQueueMetadataKey(final String address, final long deviceId) { - return ("user_queue_metadata::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final long deviceId) { + return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private byte[] getQueueIndexKey(final String address, final long deviceId) { - return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(address + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8); + private byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) { + return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(accountUuid.toString() + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8); } - private byte[] getPersistInProgressKey(final String address, final long deviceId) { - return ("user_queue_persisting::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + private byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) { + return ("user_queue_persisting::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java index f3f81323..d70ebc00 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java @@ -3,7 +3,6 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.push.PushSender; import java.util.List; import java.util.Optional; @@ -25,17 +24,17 @@ public interface UserMessagesCache { envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0); } - long insert(UUID guid, String destination, long destinationDevice, MessageProtos.Envelope message); + long insert(UUID guid, String destination, UUID destinationUuid, long destinationDevice, MessageProtos.Envelope message); - Optional remove(String destination, long destinationDevice, long id); + Optional remove(String destination, UUID destinationUuid, long destinationDevice, long id); - Optional remove(String destination, long destinationDevice, String sender, long timestamp); + Optional remove(String destination, UUID destinationUuid, long destinationDevice, String sender, long timestamp); - Optional remove(String destination, long destinationDevice, UUID guid); + Optional remove(String destination, UUID destinationUuid, long destinationDevice, UUID guid); - List get(String destination, long destinationDevice, int limit); + List get(String destination, UUID destinationUuid, long destinationDevice, int limit); - void clear(String destination); + void clear(String destination, UUID destinationUuid); - void clear(String destination, long deviceId); + void clear(String destination, UUID destinationUuid, long deviceId); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java index 0151fea1..b15ce57c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java @@ -7,20 +7,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; +import java.util.Optional; + import static com.codahale.metrics.MetricRegistry.name; public class DeadLetterHandler implements DispatchChannel { private final Logger logger = LoggerFactory.getLogger(DeadLetterHandler.class); + private final AccountsManager accountsManager; private final MessagesManager messagesManager; private final Counter deadLetterCounter = Metrics.counter(name(getClass(), "deadLetterCounter")); - public DeadLetterHandler(MessagesManager messagesManager) { + public DeadLetterHandler(AccountsManager accountsManager, MessagesManager messagesManager) { + this.accountsManager = accountsManager; this.messagesManager = messagesManager; } @@ -35,8 +41,15 @@ public class DeadLetterHandler implements DispatchChannel { switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.DELIVER_VALUE: - Envelope message = Envelope.parseFrom(pubSubMessage.getContent()); - messagesManager.insert(address.getNumber(), address.getDeviceId(), message); + Envelope message = Envelope.parseFrom(pubSubMessage.getContent()); + Optional maybeAccount = accountsManager.get(address.getNumber()); + + if (maybeAccount.isPresent()) { + messagesManager.insert(address.getNumber(), maybeAccount.get().getUuid(), address.getDeviceId(), message); + } else { + logger.warn("Dead letter for account that no longer exists: {}", address); + } + break; } } catch (InvalidProtocolBufferException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index a7d916c7..fdeee7a5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -129,7 +129,7 @@ public class WebSocketConnection implements DispatchChannel { } if (isSuccessResponse(response)) { - if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached); + if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached); if (!isReceipt) sendDeliveryReceiptFor(message); if (requery) processStoredMessages(); } else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) { @@ -172,7 +172,7 @@ public class WebSocketConnection implements DispatchChannel { } private void processStoredMessages() { - OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), device.getId()); + OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId()); Iterator iterator = messages.getMessages().iterator(); while (iterator.hasNext()) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java index 73480349..30b655c2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertTrue; public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest { private static final String DESTINATION_ACCOUNT = "+18005551234"; + private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final int DESTINATION_DEVICE_ID = 7; private final Random random = new Random(); @@ -35,7 +36,7 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest @Parameters({"true", "false"}) public void testInsert(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); - assertTrue(getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0); + assertTrue(getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0); } @Test @@ -44,12 +45,12 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); - final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId); + final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageId); assertTrue(maybeRemovedMessage.isPresent()); assertEquals(UserMessagesCache.constructEntityFromEnvelope(messageId, message), maybeRemovedMessage.get()); - assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId)); + assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageId)); } @Test @@ -57,12 +58,12 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, false); - getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); - final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()); + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()); assertTrue(maybeRemovedMessage.isPresent()); assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get()); - assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp())); + assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp())); } @Test @@ -70,12 +71,12 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest public void testRemoveByUUID(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); - assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid)); + assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); - final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid); + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid); assertTrue(maybeRemovedMessage.isPresent()); assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get()); @@ -91,12 +92,12 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); + final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); expectedMessages.add(UserMessagesCache.constructEntityFromEnvelope(messageId, message)); } - assertEquals(expectedMessages, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(expectedMessages, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } @Test @@ -109,14 +110,14 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message); + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, deviceId, message); } } - getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID); + getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID); - assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount)); - assertEquals(messageCount, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount).size()); + assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(messageCount, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size()); } @Test @@ -129,14 +130,14 @@ public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message); + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, deviceId, message); } } - getMessagesCache().clear(DESTINATION_ACCOUNT); + getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_UUID); - assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount)); - assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount)); + assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount)); } protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java index dc167bf2..6efe93ef 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java @@ -12,6 +12,7 @@ import static org.junit.Assert.assertEquals; public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { private static final String DESTINATION_ACCOUNT = "+18005551234"; + private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final int DESTINATION_DEVICE_ID = 7; private RedisClusterMessagesCache messagesCache; @@ -42,7 +43,13 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { final UUID secondMessageGuid = UUID.randomUUID(); final long messageId = 74; - assertEquals(messageId, messagesCache.insert(firstMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(firstMessageGuid, sealedSender), messageId)); - assertEquals(messageId + 1, messagesCache.insert(secondMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(secondMessageGuid, sealedSender))); + assertEquals(messageId, messagesCache.insert(firstMessageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(firstMessageGuid, sealedSender), messageId)); + assertEquals(messageId + 1, messagesCache.insert(secondMessageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(secondMessageGuid, sealedSender))); + } + + @Test + public void testClearNullUuid() { + // We're happy as long as this doesn't throw an exception + messagesCache.clear(DESTINATION_ACCOUNT, null); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 68148f62..c9b1fe56 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -113,6 +113,7 @@ public class DeviceControllerTest { when(account.getNextDeviceId()).thenReturn(42L); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); + when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); // when(maxedAccount.getActiveDeviceCount()).thenReturn(6); when(account.getAuthenticatedDevice()).thenReturn(Optional.of(masterDevice)); when(account.isEnabled()).thenReturn(false); @@ -144,7 +145,7 @@ public class DeviceControllerTest { assertThat(response.getDeviceId()).isEqualTo(42L); verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); - verify(messagesManager).clear(eq(AuthHelper.VALID_NUMBER), eq(42L)); + verify(messagesManager).clear(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(42L)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index c3674a72..4a71d6b7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -257,7 +257,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L))).thenReturn(messagesList); OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") @@ -294,7 +294,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L))).thenReturn(messagesList); Response response = resources.getJerseyTest().target("/v1/messages/") @@ -312,20 +312,20 @@ public class MessageControllerTest { UUID sourceUuid = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31337)) + when(messagesManager.delete(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, 1, "+14152222222", 31337)) .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null, Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, "+14152222222", sourceUuid, 1, "hi".getBytes(), null, 0))); - when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31338)) + when(messagesManager.delete(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, 1, "+14152222222", 31338)) .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null, Envelope.Type.RECEIPT_VALUE, null, System.currentTimeMillis(), "+14152222222", sourceUuid, 1, null, null, 0))); - when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31339)) + when(messagesManager.delete(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, 1, "+14152222222", 31339)) .thenReturn(Optional.empty()); Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 034712d0..d50c45af 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -105,6 +105,7 @@ public class WebSocketConnectionTest { public void testOpen() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); + UUID accountUuid = UUID.randomUUID(); UUID senderOneUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID(); @@ -121,6 +122,7 @@ public class WebSocketConnectionTest { when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(accountUuid); final Device sender1device = mock(Device.class); @@ -134,7 +136,7 @@ public class WebSocketConnectionTest { when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.get("sender2")).thenReturn(Optional.empty()); - when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId())) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId())) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -166,7 +168,7 @@ public class WebSocketConnectionTest { futures.get(0).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException()); - verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L), eq(2L), eq(false)); + verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), eq(2L), eq(false)); verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L)); connection.onDispatchUnsubscribed(websocketAddress.serialize()); @@ -204,6 +206,7 @@ public class WebSocketConnectionTest { when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); final Device sender1device = mock(Device.class); @@ -217,7 +220,7 @@ public class WebSocketConnectionTest { when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.get("sender2")).thenReturn(Optional.empty()); - when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId())) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId())) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -311,6 +314,7 @@ public class WebSocketConnectionTest { when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); final Device sender1device = mock(Device.class); @@ -324,7 +328,7 @@ public class WebSocketConnectionTest { when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.get("sender2")).thenReturn(Optional.empty()); - when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId())) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId())) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>();