diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 5e450267..9d964453 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -164,7 +164,7 @@ public class KeysController { if (setKeysRequest.pqLastResortPreKey() != null) { storeFutures.add( - keys.storePqLastResort(identifier, Map.of(device.getId(), setKeysRequest.pqLastResortPreKey()))); + keys.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey())); } return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java index 05b92132..2a3c607d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -200,7 +200,7 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { final UUID identifier = account.getIdentifier(identityType); return Flux.merge( - Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, Map.of(authenticatedDevice.deviceId(), signedPreKey))), + Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey)), Mono.fromFuture(() -> accountsManager.updateDeviceAsync(account, authenticatedDevice.deviceId(), deviceUpdater))) .then(); })); @@ -217,7 +217,7 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { final UUID identifier = account.getIdentifier(IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType())); - return Mono.fromFuture(() -> keysManager.storePqLastResort(identifier, Map.of(authenticatedDevice.deviceId(), lastResortKey))); + return Mono.fromFuture(() -> keysManager.storePqLastResort(identifier, authenticatedDevice.deviceId(), lastResortKey)); })); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index dee58ecc..1f080f25 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -100,9 +99,9 @@ public class KeysManager { return writeItems; } - public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final Map keys) { + public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final byte deviceId, final ECSignedPreKey ecSignedPreKey) { if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { - return ecSignedPreKeys.store(identifier, keys); + return ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey); } else { return CompletableFuture.completedFuture(null); } @@ -113,8 +112,8 @@ public class KeysManager { return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey); } - public CompletableFuture storePqLastResort(final UUID identifier, final Map keys) { - return pqLastResortKeys.store(identifier, keys); + public CompletableFuture storePqLastResort(final UUID identifier, final byte deviceId, final KEMSignedPreKey lastResortKey) { + return pqLastResortKeys.store(identifier, deviceId, lastResortKey); } public CompletableFuture storeEcOneTimePreKeys(final UUID identifier, final byte deviceId, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index 07ac03cc..2b67cbbe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -23,7 +23,6 @@ import software.amazon.awssdk.services.dynamodb.model.Put; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; -import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; /** * A repeated-use signed pre-key store manages storage for pre-keys that may be used more than once. Generally, these @@ -45,7 +44,6 @@ public abstract class RepeatedUseSignedPreKeyStore> { static final String ATTR_SIGNATURE = "S"; private final Timer storeSingleKeyTimer = Metrics.timer(MetricsUtil.name(getClass(), "storeSingleKey")); - private final Timer storeKeyBatchTimer = Metrics.timer(MetricsUtil.name(getClass(), "storeKeyBatch")); private final String findKeyTimerName = MetricsUtil.name(getClass(), "findKey"); @@ -74,41 +72,6 @@ public abstract class RepeatedUseSignedPreKeyStore> { .thenRun(() -> sample.stop(storeSingleKeyTimer)); } - /** - * Stores repeated-use pre-keys for a collection of devices associated with a single account/identity, displacing any - * previously-stored repeated-use pre-keys for the targeted devices. Note that this method is transactional; either - * all keys will be stored or none will. - * - * @param identifier the identifier for the account/identity with which the target devices are associated - * @param signedPreKeysByDeviceId a map of device identifiers to pre-keys - * - * @return a future that completes once all keys have been stored - */ - public CompletableFuture store(final UUID identifier, final Map signedPreKeysByDeviceId) { - if (signedPreKeysByDeviceId.isEmpty()) { - return CompletableFuture.completedFuture(null); - } - - final Timer.Sample sample = Timer.start(); - - return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder() - .transactItems(signedPreKeysByDeviceId.entrySet().stream() - .map(entry -> { - final byte deviceId = entry.getKey(); - final K signedPreKey = entry.getValue(); - - return TransactWriteItem.builder() - .put(Put.builder() - .tableName(tableName) - .item(getItemFromPreKey(identifier, deviceId, signedPreKey)) - .build()) - .build(); - }) - .toList()) - .build()) - .thenRun(() -> sample.stop(storeKeyBatchTimer)); - } - TransactWriteItem buildTransactWriteItemForInsertion(final UUID identifier, final byte deviceId, final K preKey) { return TransactWriteItem.builder() .put(Put.builder() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 7eac5b2d..c56750fb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -254,11 +254,11 @@ class KeysControllerTest { when(KEYS.storeKemOneTimePreKeys(any(), anyByte(), any())) .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); - when(KEYS.storePqLastResort(any(), any())) + when(KEYS.storePqLastResort(any(), anyByte(), any())) .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + when(KEYS.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn( CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); @@ -827,7 +827,7 @@ class KeysControllerTest { ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture()); verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), pqCaptor.capture()); - verify(KEYS).storePqLastResort(AuthHelper.VALID_UUID, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey)); + verify(KEYS).storePqLastResort(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID, pqLastResortPreKey); assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); @@ -966,7 +966,7 @@ class KeysControllerTest { ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture()); verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), pqCaptor.capture()); - verify(KEYS).storePqLastResort(AuthHelper.VALID_PNI, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey)); + verify(KEYS).storePqLastResort(AuthHelper.VALID_PNI, SAMPLE_DEVICE_ID, pqLastResortPreKey); assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java index d7dc7436..497a0c85 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java @@ -304,7 +304,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest ACI_IDENTITY_KEY_PAIR; @@ -326,12 +326,12 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest { verify(authenticatedDevice).setSignedPreKey(signedPreKey); - verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey)); + verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, signedPreKey); } case IDENTITY_TYPE_PNI -> { verify(authenticatedDevice).setPhoneNumberIdentitySignedPreKey(signedPreKey); - verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey)); + verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID, signedPreKey); } } } @@ -387,7 +387,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest ACI_IDENTITY_KEY_PAIR; @@ -412,7 +412,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest throw new AssertionError("Bad identity type"); }; - verify(keysManager).storePqLastResort(expectedIdentifier, Map.of(AUTHENTICATED_DEVICE_ID, lastResortPreKey)); + verify(keysManager).storePqLastResort(expectedIdentifier, AUTHENTICATED_DEVICE_ID, lastResortPreKey); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 74eb76f3..9031c158 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -1284,7 +1284,7 @@ class AccountsManagerTest { final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3))); - when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); final List devices = List.of( DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), @@ -1381,7 +1381,7 @@ class AccountsManagerTest { final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); - when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds); @@ -1437,8 +1437,8 @@ class AccountsManagerTest { when(keysManager.getPqEnabledDevices(oldPni)).thenReturn( CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID))); - when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); @@ -1500,8 +1500,8 @@ class AccountsManagerTest { UUID oldPni = account.getPhoneNumberIdentifier(); when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of())); - when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index 9802222f..7d6e7787 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -97,7 +97,7 @@ class KeysManagerTest { final ECSignedPreKey signedPreKey = generateTestECSignedPreKey(1); - keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, signedPreKey)).join(); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, DEVICE_ID, signedPreKey).join(); assertEquals(Optional.of(signedPreKey), keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join()); } @@ -124,7 +124,7 @@ class KeysManagerTest { final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(preKey1, preKey2)).join(); - keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(DEVICE_ID, preKeyLast)).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, preKeyLast).join(); assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join()); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); @@ -146,8 +146,8 @@ class KeysManagerTest { for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); - keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join(); - keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join(); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, deviceId, generateTestECSignedPreKey(keyId++)).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, deviceId, generateTestKEMSignedPreKey(keyId++)).join(); } for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { @@ -174,8 +174,8 @@ class KeysManagerTest { for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); - keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join(); - keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join(); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, deviceId, generateTestECSignedPreKey(keyId++)).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, deviceId, generateTestKEMSignedPreKey(keyId++)).join(); } for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { @@ -207,19 +207,17 @@ class KeysManagerTest { final byte deviceId2 = 2; final byte deviceId3 = 3; - keysManager.storePqLastResort( - ACCOUNT_UUID, - Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair), (byte) 2, - KeysHelper.signedKEMPreKey(2, identityKeyPair))).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair)).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, (byte) 2, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); + assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId()); assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent()); - keysManager.storePqLastResort( - ACCOUNT_UUID, - Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), deviceId3, - KeysHelper.signedKEMPreKey(4, identityKeyPair))).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, deviceId3, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join(); + assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates"); assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(), "storing new last-resort keys should overwrite old ones"); @@ -227,19 +225,14 @@ class KeysManagerTest { "storing new last-resort keys should leave untouched ones alone"); assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(), "storing new last-resort keys should overwrite old ones"); - - keysManager.storePqLastResort(ACCOUNT_UUID, Map.of()).join(); - assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing zero last-resort keys should be a no-op"); } @Test void testGetPqEnabledDevices() { keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); - - keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2))).join(); - + keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2)).join(); keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), List.of(generateTestKEMSignedPreKey(3))).join(); - keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4))).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4)).join(); assertIterableEquals( Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)), @@ -250,7 +243,7 @@ class KeysManagerTest { void testStoreEcSignedPreKeyDisabled() { when(ecPreKeyMigrationConfiguration.storeEcSignedPreKeys()).thenReturn(false); - keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, generateTestECSignedPreKey(1))).join(); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, DEVICE_ID, generateTestECSignedPreKey(1)).join(); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java index 9cd7c198..263317e6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -30,27 +30,12 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), Device.PRIMARY_ID).join()); - { - final UUID identifier = UUID.randomUUID(); - final byte deviceId = 1; - final K signedPreKey = generateSignedPreKey(); + final UUID identifier = UUID.randomUUID(); + final byte deviceId = 1; + final K signedPreKey = generateSignedPreKey(); - assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join()); - assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join()); - } - - { - final UUID identifier = UUID.randomUUID(); - final byte deviceId2 = 2; - final Map signedPreKeys = Map.of( - Device.PRIMARY_ID, generateSignedPreKey(), - deviceId2, generateSignedPreKey() - ); - - assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join()); - assertEquals(Optional.of(signedPreKeys.get(Device.PRIMARY_ID)), keys.find(identifier, Device.PRIMARY_ID).join()); - assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); - } + assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join()); + assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join()); } @Test @@ -75,18 +60,16 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { final UUID identifier = UUID.randomUUID(); final byte deviceId2 = 2; - final Map signedPreKeys = Map.of( - Device.PRIMARY_ID, generateSignedPreKey(), - deviceId2, generateSignedPreKey() - ); + final K retainedPreKey = generateSignedPreKey(); - keys.store(identifier, signedPreKeys).join(); + keys.store(identifier, Device.PRIMARY_ID, generateSignedPreKey()).join(); + keys.store(identifier, deviceId2, retainedPreKey).join(); getDynamoDbClient().transactWriteItems(TransactWriteItemsRequest.builder() .transactItems(keys.buildTransactWriteItemForDeletion(identifier, Device.PRIMARY_ID)) .build()); assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); - assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); + assertEquals(Optional.of(retainedPreKey), keys.find(identifier, deviceId2).join()); } }