0
0
mirror of https://github.com/signalapp/Signal-Server.git synced 2024-09-19 11:32:16 +02:00

Create separate key stores for different kinds of pre-keys

This commit is contained in:
Jon Chambers 2023-06-06 17:08:26 -04:00 committed by GitHub
parent cac04146de
commit 2b08742c0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1482 additions and 847 deletions

View File

@ -296,7 +296,7 @@ commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
procedures, authorization keysManager, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object

View File

@ -176,7 +176,7 @@ import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagePersister;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
@ -345,10 +345,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getPqKeys().getTableName(),
config.getDynamoDbTables().getPqLastResortKeys().getTableName());
config.getDynamoDbTables().getKemKeys().getTableName(),
config.getDynamoDbTables().getKemLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -51,8 +51,8 @@ public class DynamoDbTables {
private final Table deletedAccountsLock;
private final IssuedReceiptsTableConfiguration issuedReceipts;
private final Table ecKeys;
private final Table pqKeys;
private final Table pqLastResortKeys;
private final Table kemKeys;
private final Table kemLastResortKeys;
private final TableWithExpiration messages;
private final Table pendingAccounts;
private final Table pendingDevices;
@ -72,8 +72,8 @@ public class DynamoDbTables {
@JsonProperty("deletedAccountsLock") final Table deletedAccountsLock,
@JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts,
@JsonProperty("ecKeys") final Table ecKeys,
@JsonProperty("pqKeys") final Table pqKeys,
@JsonProperty("pqLastResortKeys") final Table pqLastResortKeys,
@JsonProperty("pqKeys") final Table kemKeys,
@JsonProperty("pqLastResortKeys") final Table kemLastResortKeys,
@JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("pendingAccounts") final Table pendingAccounts,
@JsonProperty("pendingDevices") final Table pendingDevices,
@ -92,8 +92,8 @@ public class DynamoDbTables {
this.deletedAccountsLock = deletedAccountsLock;
this.issuedReceipts = issuedReceipts;
this.ecKeys = ecKeys;
this.pqKeys = pqKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.kemKeys = kemKeys;
this.kemLastResortKeys = kemLastResortKeys;
this.messages = messages;
this.pendingAccounts = pendingAccounts;
this.pendingDevices = pendingDevices;
@ -140,14 +140,14 @@ public class DynamoDbTables {
@NotNull
@Valid
public Table getPqKeys() {
return pqKeys;
public Table getKemKeys() {
return kemKeys;
}
@NotNull
@Valid
public Table getPqLastResortKeys() {
return pqLastResortKeys;
public Table getKemLastResortKeys() {
return kemLastResortKeys;
}
@NotNull

View File

@ -51,7 +51,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.util.Pair;
@ -67,14 +67,14 @@ public class DeviceController {
private final StoredVerificationCodeManager pendingDevices;
private final AccountsManager accounts;
private final MessagesManager messages;
private final Keys keys;
private final KeysManager keys;
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;
public DeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts,
MessagesManager messages,
Keys keys,
KeysManager keys,
RateLimiters rateLimiters,
Map<String, Integer> maxDeviceConfiguration) {
this.pendingDevices = pendingDevices;

View File

@ -53,7 +53,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
@ -61,7 +61,7 @@ import org.whispersystems.textsecuregcm.storage.Keys;
public class KeysController {
private final RateLimiters rateLimiters;
private final Keys keys;
private final KeysManager keys;
private final AccountsManager accounts;
private static final String IDENTITY_KEY_CHANGE_COUNTER_NAME = name(KeysController.class, "identityKeyChange");
@ -70,7 +70,7 @@ public class KeysController {
private static final String IDENTITY_TYPE_TAG_NAME = "identityType";
private static final String HAS_IDENTITY_KEY_TAG_NAME = "hasIdentityKey";
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts) {
public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) {
this.rateLimiters = rateLimiters;
this.keys = keys;
this.accounts = accounts;

View File

@ -48,7 +48,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@ -74,18 +74,18 @@ public class RegistrationController {
private final AccountsManager accounts;
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final Keys keys;
private final KeysManager keysManager;
private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts,
final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager,
final Keys keys,
final KeysManager keysManager,
final RateLimiters rateLimiters) {
this.accounts = accounts;
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager;
this.keys = keys;
this.keysManager = keysManager;
this.rateLimiters = rateLimiters;
}
@ -176,8 +176,8 @@ public class RegistrationController {
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
keys.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get()));
keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get()));
keysManager.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get()));
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get()));
});
}

View File

@ -43,7 +43,7 @@ public record ChangeNumberRequest(
@NotEmpty byte[] pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,

View File

@ -36,7 +36,7 @@ public record ChangePhoneNumberRequest(
@Nullable byte[] pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@Nullable List<IncomingMessage> deviceMessages,

View File

@ -30,7 +30,7 @@ public record PhoneNumberIdentityKeyDistributionRequest(
@NotNull
@Valid
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
A list of synchronization messages to send to companion devices to supply the private keysManager
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
List<@NotNull @Valid IncomingMessage> deviceMessages,

View File

@ -90,7 +90,7 @@ public class AccountsManager {
private final FaultTolerantRedisCluster cacheCluster;
private final AccountLockManager accountLockManager;
private final DeletedAccounts deletedAccounts;
private final Keys keys;
private final KeysManager keysManager;
private final MessagesManager messagesManager;
private final ProfilesManager profilesManager;
private final StoredVerificationCodeManager pendingAccounts;
@ -134,7 +134,7 @@ public class AccountsManager {
final FaultTolerantRedisCluster cacheCluster,
final AccountLockManager accountLockManager,
final DeletedAccounts deletedAccounts,
final Keys keys,
final KeysManager keysManager,
final MessagesManager messagesManager,
final ProfilesManager profilesManager,
final StoredVerificationCodeManager pendingAccounts,
@ -150,7 +150,7 @@ public class AccountsManager {
this.cacheCluster = cacheCluster;
this.accountLockManager = accountLockManager;
this.deletedAccounts = deletedAccounts;
this.keys = keys;
this.keysManager = keysManager;
this.messagesManager = messagesManager;
this.profilesManager = profilesManager;
this.pendingAccounts = pendingAccounts;
@ -223,8 +223,8 @@ public class AccountsManager {
// account and need to clear out messages and keys that may have been stored for the old account.
if (!originalUuid.equals(actualUuid)) {
messagesManager.clear(actualUuid);
keys.delete(actualUuid);
keys.delete(account.getPhoneNumberIdentifier());
keysManager.delete(actualUuid);
keysManager.delete(account.getPhoneNumberIdentifier());
profilesManager.deleteAll(actualUuid);
clientPresenceManager.disconnectAllPresencesForUuid(actualUuid);
}
@ -315,13 +315,13 @@ public class AccountsManager {
updatedAccount.set(numberChangedAccount);
keys.delete(phoneNumberIdentifier);
keys.delete(originalPhoneNumberIdentifier);
keysManager.delete(phoneNumberIdentifier);
keysManager.delete(originalPhoneNumberIdentifier);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(
keysManager.storePqLastResort(
phoneNumberIdentifier,
keys.getPqEnabledDevices(uuid).stream().collect(
keysManager.getPqEnabledDevices(uuid).stream().collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get)));
@ -356,10 +356,10 @@ public class AccountsManager {
final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni);
keys.delete(pni);
final List<Long> pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni);
keysManager.delete(pni);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
keysManager.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
}
return updatedAccount;
@ -740,8 +740,8 @@ public class AccountsManager {
account.getUuid());
profilesManager.deleteAll(account.getUuid());
keys.delete(account.getUuid());
keys.delete(account.getPhoneNumberIdentifier());
keysManager.delete(account.getUuid());
keysManager.delete(account.getPhoneNumberIdentifier());
messagesManager.clear(account.getUuid());
messagesManager.clear(account.getPhoneNumberIdentifier());
registrationRecoveryPasswordsManager.removeForNumber(account.getNumber());

View File

@ -1,417 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.Counter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.DeleteRequest;
import software.amazon.awssdk.services.dynamodb.model.PutRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.Select;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class Keys extends AbstractDynamoDbStore {
private final String ecTableName;
private final String pqTableName;
private final String pqLastResortTableName;
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String KEY_PUBLIC_KEY = "P";
static final String KEY_SIGNATURE = "S";
private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys"));
private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice"));
private static final Timer GET_KEY_COUNT_TIMER = Metrics.timer(name(Keys.class, "getKeyCount"));
private static final Timer DELETE_KEYS_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForDevice"));
private static final Timer DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForAccount"));
private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys"));
private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount"));
private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty"));
private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys"));
private static final Counter PARSE_BYTES_FROM_STRING_COUNTER = Metrics.counter(name(Keys.class, "parseByteArray"), "format", "string");
private static final Counter READ_BYTES_FROM_BYTE_ARRAY_COUNTER = Metrics.counter(name(Keys.class, "parseByteArray"), "format", "bytes");
public Keys(
final DynamoDbClient dynamoDB,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
super(dynamoDB);
this.ecTableName = ecTableName;
this.pqTableName = pqTableName;
this.pqLastResortTableName = pqLastResortTableName;
}
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
store(identifier, deviceId, keys, null, null);
}
public void store(
final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) {
Multimap<String, PreKey> keys = MultimapBuilder.hashKeys().arrayListValues().build();
List<String> tablesToClear = new ArrayList<>();
if (ecKeys != null && !ecKeys.isEmpty()) {
keys.putAll(ecTableName, ecKeys);
tablesToClear.add(ecTableName);
}
if (pqKeys != null && !pqKeys.isEmpty()) {
keys.putAll(pqTableName, pqKeys);
tablesToClear.add(pqTableName);
}
if (pqLastResortKey != null) {
keys.put(pqLastResortTableName, pqLastResortKey);
tablesToClear.add(pqLastResortTableName);
}
STORE_KEYS_TIMER.record(() -> {
delete(tablesToClear, identifier, deviceId);
writeInBatches(
keys.entries(),
batch -> {
Multimap<String, WriteRequest> writes = batch.stream()
.collect(
Multimaps.toMultimap(
Map.Entry<String, PreKey>::getKey,
entry -> WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, entry.getValue()))
.build())
.build(),
MultimapBuilder.hashKeys().arrayListValues()::build));
executeTableWriteItemsUntilComplete(writes.asMap());
});
});
}
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
final List<WriteRequest> writes = new ArrayList<>(2 * keys.size());
final Map<Long, Map<String, AttributeValue>> newItems = keys.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue())));
for (final Map<String, AttributeValue> item : db().query(queryRequest).items()) {
final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID);
final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong();
if (newItems.containsKey(oldDeviceId)) {
final Map<String, AttributeValue> replacement = newItems.get(oldDeviceId);
if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) {
writes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, oldSortKey))
.build())
.build());
}
}
}
newItems.forEach((unusedKey, item) ->
writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build()));
executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes));
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return take(ecTableName, identifier, deviceId);
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return take(pqTableName, identifier, deviceId)
.or(() -> getLastResort(identifier, deviceId))
.map(pk -> (SignedPreKey) pk);
}
private Optional<PreKey> take(final String tableName, final UUID identifier, final long deviceId) {
return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
int contestedKeys = 0;
try {
QueryResponse response = db().query(queryRequest);
for (Map<String, AttributeValue> candidate : response.items()) {
DeleteItemRequest deleteItemRequest = DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, candidate.get(KEY_DEVICE_ID_KEY_ID)))
.returnValues(ReturnValue.ALL_OLD)
.build();
DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest);
if (deleteItemResponse.hasAttributes()) {
return Optional.of(getPreKeyFromItem(deleteItemResponse.attributes()));
}
contestedKeys++;
}
KEYS_EMPTY_TAKE_COUNTER.increment();
return Optional.empty();
} finally {
CONTESTED_KEY_DISTRIBUTION.record(contestedKeys);
}
});
}
@VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.consistentRead(false)
.select(Select.ALL_ATTRIBUTES)
.build();
QueryResponse response = db().query(queryRequest);
if (response.count() > 1) {
TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment();
}
return response.items().stream().findFirst().map(this::getPreKeyFromItem);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
final QueryResponse response = db().query(queryRequest);
return response.items().stream()
.map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong())
.toList();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return getCount(ecTableName, identifier, deviceId);
}
public int getPqCount(final UUID identifier, final long deviceId) {
return getCount(pqTableName, identifier, deviceId);
}
private int getCount(final String tableName, final UUID identifier, final long deviceId) {
return GET_KEY_COUNT_TIMER.record(() -> {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.select(Select.COUNT)
.consistentRead(false)
.build();
int keyCount = 0;
// This is very confusing, but does appear to be the intended behavior. See:
//
// - https://github.com/aws/aws-sdk-java/issues/693
// - https://github.com/aws/aws-sdk-java/issues/915
// - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count
for (final QueryResponse page : db().queryPaginator(queryRequest)) {
keyCount += page.count();
}
KEY_COUNT_DISTRIBUTION.record(keyCount);
return keyCount;
});
}
public void delete(final UUID accountUuid) {
DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest);
});
}
public void delete(final UUID accountUuid, final long deviceId) {
delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId);
}
private void delete(final List<String> tableNames, final UUID accountUuid, final long deviceId) {
DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest);
});
}
private void deleteItemsForAccountMatchingQuery(final List<String> tableNames, final UUID accountUuid, final QueryRequest querySpec) {
final AttributeValue partitionKey = getPartitionKey(accountUuid);
Multimap<String, Map<String, AttributeValue>> itemStream = tableNames.stream()
.collect(
Multimaps.flatteningToMultimap(
Function.identity(),
tableName ->
db().query(querySpec.toBuilder().tableName(tableName).build())
.items()
.stream(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
writeInBatches(
itemStream.entries(),
batch -> {
Multimap<String, WriteRequest> deletes = batch.stream()
.collect(Multimaps.toMultimap(
Map.Entry<String, Map<String, AttributeValue>>::getKey,
entry -> WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
executeTableWriteItemsUntilComplete(deletes.asMap());
});
}
private static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid);
}
private static AttributeValue getSortKey(final long deviceId, final long keyId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
byteBuffer.putLong(deviceId);
byteBuffer.putLong(keyId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
@VisibleForTesting
static AttributeValue getSortKeyPrefix(final long deviceId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]);
byteBuffer.putLong(deviceId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) {
if (preKey instanceof final SignedPreKey spk) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromByteArray(spk.getPublicKey()),
KEY_SIGNATURE, AttributeValues.fromByteArray(spk.getSignature()));
}
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey()));
}
private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(KEY_PUBLIC_KEY));
if (item.containsKey(KEY_SIGNATURE)) {
// All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored
// in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys.
return new SignedPreKey(keyId, publicKey, extractByteArray(item.get(KEY_SIGNATURE)));
}
return new PreKey(keyId, publicKey);
}
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
static byte[] extractByteArray(final AttributeValue attributeValue) {
if (attributeValue.b() != null) {
READ_BYTES_FROM_BYTE_ARRAY_COUNTER.increment();
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
PARSE_BYTES_FROM_STRING_COUNTER.increment();
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
}

View File

@ -0,0 +1,111 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class KeysManager {
private final SingleUseECPreKeyStore ecPreKeys;
private final SingleUseKEMPreKeyStore pqPreKeys;
private final RepeatedUseSignedPreKeyStore pqLastResortKeys;
public KeysManager(
final DynamoDbAsyncClient dynamoDbAsyncClient,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName);
this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName);
this.pqLastResortKeys = new RepeatedUseSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName);
}
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
store(identifier, deviceId, keys, null, null);
}
public void store(
final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>();
if (ecKeys != null && !ecKeys.isEmpty()) {
storeFutures.add(ecPreKeys.store(identifier, deviceId, ecKeys));
}
if (pqKeys != null && !pqKeys.isEmpty()) {
storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys));
}
if (pqLastResortKey != null) {
storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey));
}
CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join();
}
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
pqLastResortKeys.store(identifier, keys).join();
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId).join();
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return pqPreKeys.take(identifier, deviceId)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))).join();
}
@VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
return pqLastResortKeys.find(identifier, deviceId).join()
.map(signedPreKey -> signedPreKey);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().block();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return ecPreKeys.getCount(identifier, deviceId).join();
}
public int getPqCount(final UUID identifier, final long deviceId) {
return pqPreKeys.getCount(identifier, deviceId).join();
}
public void delete(final UUID accountUuid) {
CompletableFuture.allOf(
ecPreKeys.delete(accountUuid),
pqPreKeys.delete(accountUuid),
pqLastResortKeys.delete(accountUuid))
.join();
}
public void delete(final UUID accountUuid, final long deviceId) {
CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId),
pqLastResortKeys.delete(accountUuid, deviceId))
.join();
}
}

View File

@ -0,0 +1,228 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
/**
* A repeated-use signed pre-key store manages storage for pre-keys that may be used more than once. Generally, these
* are considered "last resort" keys and should only be used when a device's supply of single-use pre-keys has been
* exhausted.
* <p/>
* Each {@link Account} may have one or more {@link Device devices}. Each "active" (i.e. those that have completed
* provisioning and are capable of sending and receiving messages) must have exactly one "last resort" pre-key.
*/
public class RepeatedUseSignedPreKeyStore {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID = "D";
static final String ATTR_KEY_ID = "I";
static final String ATTR_PUBLIC_KEY = "P";
static final String ATTR_SIGNATURE = "S";
private static final Timer STORE_SINGLE_KEY_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "storeSingleKey"));
private static final Timer STORE_KEY_BATCH_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "storeKeyBatch"));
private static final Timer DELETE_FOR_DEVICE_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "deleteForDevice"));
private static final Timer DELETE_FOR_ACCOUNT_TIMER =
Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "deleteForAccount"));
private static final String FIND_KEY_TIMER_NAME = MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "findKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
public RepeatedUseSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
}
/**
* Stores a repeated-use pre-key for a specific device, displacing any previously-stored repeated-use pre-key for that
* device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @param signedPreKey the key to store for the target device
*
* @return a future that completes once the key has been stored
*/
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, signedPreKey))
.build())
.thenRun(() -> sample.stop(STORE_SINGLE_KEY_TIMER));
}
/**
* Stores repeated-use pre-keys for a collection of devices associated with a single account/identity, displacing any
* previously-stored repeated-use pre-keys for the targeted devices. Note that this method is transactional; either
* all keys will be stored or none will.
*
* @param identifier the identifier for the account/identity with which the target devices are associated
* @param signedPreKeysByDeviceId a map of device identifiers to pre-keys
*
* @return a future that completes once all keys have been stored
*/
public CompletableFuture<Void> store(final UUID identifier, final Map<Long, SignedPreKey> signedPreKeysByDeviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(signedPreKeysByDeviceId.entrySet().stream()
.map(entry -> {
final long deviceId = entry.getKey();
final SignedPreKey signedPreKey = entry.getValue();
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, signedPreKey))
.build())
.build();
})
.toList())
.build())
.thenRun(() -> sample.stop(STORE_KEY_BATCH_TIMER));
}
/**
* Finds a repeated-use pre-key for a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
*
* @return a future that yields an optional signed pre-key if one is available for the target device or empty if no
* key could be found for the target device
*/
public CompletableFuture<Optional<SignedPreKey>> find(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
final CompletableFuture<Optional<SignedPreKey>> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))
.consistentRead(true)
.build())
.thenApply(response -> response.hasItem() ? Optional.of(getPreKeyFromItem(response.item())) : Optional.empty());
findFuture.whenComplete((maybeSignedPreKey, throwable) ->
sample.stop(Metrics.timer(FIND_KEY_TIMER_NAME, KEY_PRESENT_TAG_NAME, String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent()))));
return findFuture;
}
/**
* Clears all repeated-use pre-keys associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys
*
* @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the
* target account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
final Timer.Sample sample = Timer.start();
return getDeviceIdsWithKeys(identifier)
.map(deviceId -> DeleteItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)))
// Idiom: wait for everything to finish, but discard the results
.reduce(0, (a, b) -> 0)
.toFuture()
.thenRun(() -> sample.stop(DELETE_FOR_ACCOUNT_TIMER));
}
/**
* Removes the repeated-use pre-key associated with a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
*
* @return a future that completes once the repeated-use pre-key has been removed from the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))
.build())
.thenRun(() -> sample.stop(DELETE_FOR_DEVICE_TIMER));
}
public Flux<Long> getDeviceIdsWithKeys(final UUID identifier) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier)))
.projectionExpression(KEY_DEVICE_ID)
.consistentRead(true)
.build())
.items())
.map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n()));
}
private static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final long deviceId) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID, getSortKey(deviceId));
}
private static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid);
}
private static AttributeValue getSortKey(final long deviceId) {
return AttributeValues.fromLong(deviceId);
}
private static Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final SignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID, getSortKey(deviceId),
ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature()));
}
private static SignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
return new SignedPreKey(
Long.parseLong(item.get(ATTR_KEY_ID).n()),
item.get(ATTR_PUBLIC_KEY).b().asByteArray(),
item.get(ATTR_SIGNATURE).b().asByteArray());
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map;
import java.util.UUID;
public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<PreKey> {
protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final PreKey preKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey()));
}
@Override
protected PreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY));
return new PreKey(keyId, publicKey);
}
}

View File

@ -0,0 +1,38 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map;
import java.util.UUID;
public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<SignedPreKey> {
protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature()));
}
@Override
protected SignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY));
final byte[] signature = extractByteArray(item.get(ATTR_SIGNATURE));
return new SignedPreKey(keyId, publicKey, signature);
}
}

View File

@ -0,0 +1,312 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.Select;
/**
* A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key
* store's {@link #take(UUID, long)} method are guaranteed to be returned exactly once, and repeated calls will never
* yield the same key.
* <p/>
* Each {@link Account} may have one or more {@link Device devices}. Clients <em>should</em> regularly check their
* supply of single-use pre-keys (see {@link #getCount(UUID, long)}) and upload new keys when their supply runs low. In
* the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party
* may fall back to using the device's repeated-use ("last-resort") signed pre-key instead.
*/
public abstract class SingleUsePreKeyStore<K extends PreKey> {
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName;
private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey"));
private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch"));
private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount"));
private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice"));
private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount"));
final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary
.builder(name(getClass(), "keysConsideredForTake"))
.publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999)
.distributionStatisticExpiry(Duration.ofMinutes(10))
.register(Metrics.globalRegistry);
final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary
.builder(name(getClass(), "availableKeyCount"))
.publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999)
.distributionStatisticExpiry(Duration.ofMinutes(10))
.register(Metrics.globalRegistry);
private final String takeKeyTimerName = name(getClass(), "takeKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
private final Counter parseBytesFromStringCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "string");
private final Counter readBytesFromByteArrayCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "bytes");
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String ATTR_PUBLIC_KEY = "P";
static final String ATTR_SIGNATURE = "S";
protected SingleUsePreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName;
}
/**
* Stores a batch of single-use pre-keys for a specific device. All previously-stored keys for the device are cleared
* before storing new keys.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @param preKeys a collection of single-use pre-keys to store for the target device
*
* @return a future that completes when all previously-stored keys have been removed and the given collection of
* pre-keys has been stored in its place
*/
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final List<K> preKeys) {
final Timer.Sample sample = Timer.start();
return delete(identifier, deviceId)
.thenCompose(ignored -> CompletableFuture.allOf(preKeys.stream()
.map(preKey -> store(identifier, deviceId, preKey))
.toList()
.toArray(new CompletableFuture[0])))
.thenRun(() -> sample.stop(storeKeyBatchTimer));
}
private CompletableFuture<Void> store(final UUID identifier, final long deviceId, final K preKey) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, preKey))
.build())
.thenRun(() -> sample.stop(storeKeyTimer));
}
/**
* Attempts to retrieve a single-use pre-key for a specific device. Keys may only be returned by this method at most
* once; once the key is returned, it is removed from the key store and subsequent calls to this method will never
* return the same key.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
*
* @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are
* available for the target device
*/
public CompletableFuture<Optional<K>> take(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
final AttributeValue partitionKey = getPartitionKey(identifier);
final AtomicInteger keysConsidered = new AtomicInteger(0);
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build())
.items())
.map(item -> DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)))
.returnValues(ReturnValue.ALL_OLD)
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)), 1)
.doOnNext(deleteItemResponse -> keysConsidered.incrementAndGet())
.filter(DeleteItemResponse::hasAttributes)
.next()
.map(deleteItemResponse -> getPreKeyFromItem(deleteItemResponse.attributes()))
.toFuture()
.thenApply(Optional::ofNullable)
.whenComplete((maybeKey, throwable) -> {
sample.stop(Metrics.timer(takeKeyTimerName, KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent())));
keysConsideredForTakeDistributionSummary.record(keysConsidered.get());
});
}
/**
* Estimates the number of single-use pre-keys available for a given device.
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that yields the approximate number of single-use pre-keys currently available for the target
* device
*/
public CompletableFuture<Integer> getCount(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
// Getting an accurate count from DynamoDB can be very confusing. See:
//
// - https://github.com/aws/aws-sdk-java/issues/693
// - https://github.com/aws/aws-sdk-java/issues/915
// - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.select(Select.COUNT)
.consistentRead(false)
.build()))
.map(QueryResponse::count)
.reduce(0, Integer::sum)
.toFuture()
.whenComplete((keyCount, throwable) -> {
sample.stop(getKeyCountTimer);
if (throwable == null && keyCount != null) {
availableKeyCountDistributionSummary.record(keyCount);
}
});
}
/**
* Removes all single-use pre-keys for all devices associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to remove single-use pre-keys
*
* @return a future that completes when all single-use pre-keys have been removed for all devices associated with the
* given account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
final Timer.Sample sample = Timer.start();
return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", getPartitionKey(identifier)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build())
.items()))
.thenRun(() -> sample.stop(deleteForAccountTimer));
}
/**
* Removes all single-use pre-keys for a specific device.
*
* @param identifier the identifier for the account/identity with which the target device is associated
* @param deviceId the identifier for the device within the given account/identity
* @return a future that completes when all single-use pre-keys have been removed for the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start();
return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(identifier),
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build())
.items()))
.thenRun(() -> sample.stop(deleteForDeviceTimer));
}
private CompletableFuture<Void> deleteItems(final AttributeValue partitionKey, final Flux<Map<String, AttributeValue>> items) {
return items
.map(item -> DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)
))
.build())
.flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)))
// Idiom: wait for everything to finish, but discard the results
.reduce(0, (a, b) -> 0)
.toFuture()
.thenRun(Util.NOOP);
}
protected static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid);
}
protected static AttributeValue getSortKey(final long deviceId, final long keyId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
byteBuffer.putLong(deviceId);
byteBuffer.putLong(keyId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private static AttributeValue getSortKeyPrefix(final long deviceId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]);
byteBuffer.putLong(deviceId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId,
final K preKey);
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
byte[] extractByteArray(final AttributeValue attributeValue) {
if (attributeValue.b() != null) {
readBytesFromByteArrayCounter.increment();
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
parseBytesFromStringCounter.increment();
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
}

View File

@ -42,7 +42,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -171,10 +171,11 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
configuration.getDynamoDbTables().getKemKeys().getTableName(),
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -36,7 +36,7 @@ import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -66,7 +66,7 @@ record CommandDependencies(
MessagesManager messagesManager,
StoredVerificationCodeManager pendingAccountsManager,
ClientPresenceManager clientPresenceManager,
Keys keys,
KeysManager keysManager,
FaultTolerantRedisCluster cacheCluster,
ClientResources redisClusterClientResources) {
@ -153,10 +153,11 @@ record CommandDependencies(
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
configuration.getDynamoDbTables().getKemKeys().getTableName(),
configuration.getDynamoDbTables().getKemLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -65,7 +65,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand<WhisperServerConfigu
account = deps.accountsManager().update(account, a -> a.removeDevice(deviceId));
System.out.format("Removing keys for device %s::%d\n", aci, deviceId);
deps.keys().delete(account.getUuid(), deviceId);
deps.keysManager().delete(account.getUuid(), deviceId);
System.out.format("Clearing additional messages for %s::%d\n", aci, deviceId);
deps.messagesManager().clear(account.getUuid(), deviceId);

View File

@ -67,7 +67,7 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -90,7 +90,7 @@ class RegistrationControllerTest {
RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final Keys keys = mock(Keys.class);
private final KeysManager keysManager = mock(KeysManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@ -105,7 +105,7 @@ class RegistrationControllerTest {
.addResource(
new RegistrationController(accountsManager,
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
registrationLockVerificationManager, keys, rateLimiters))
registrationLockVerificationManager, keysManager, rateLimiters))
.build();
@BeforeEach
@ -669,8 +669,8 @@ class RegistrationControllerTest {
verify(device).setSignedPreKey(expectedAciSignedPreKey);
verify(device).setPhoneNumberIdentitySignedPreKey(expectedPniSignedPreKey);
verify(keys).storePqLastResort(accountIdentifier, Map.of(Device.MASTER_ID, expectedAciPqLastResortPreKey));
verify(keys).storePqLastResort(phoneNumberIdentifier, Map.of(Device.MASTER_ID, expectedPniPqLastResortPreKey));
verify(keysManager).storePqLastResort(accountIdentifier, Map.of(Device.MASTER_ID, expectedAciPqLastResortPreKey));
verify(keysManager).storePqLastResort(phoneNumberIdentifier, Map.of(Device.MASTER_ID, expectedPniPqLastResortPreKey));
expectedApnsToken.ifPresentOrElse(expectedToken -> verify(device).setApnId(expectedToken),
() -> verify(device, never()).setApnId(any()));

View File

@ -101,7 +101,9 @@ class AccountsManagerChangeNumberIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, deletedAccounts, mock(Keys.class),
accountLockManager,
deletedAccounts,
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@ -112,7 +112,9 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.builder().stringCommands(commands).build(),
accountLockManager, deletedAccounts, mock(Keys.class),
accountLockManager,
deletedAccounts,
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@ -71,7 +71,7 @@ class AccountsManagerTest {
private Accounts accounts;
private DeletedAccounts deletedAccounts;
private Keys keys;
private KeysManager keysManager;
private MessagesManager messagesManager;
private ProfilesManager profilesManager;
private ClientPresenceManager clientPresenceManager;
@ -94,7 +94,7 @@ class AccountsManagerTest {
void setup() throws InterruptedException {
accounts = mock(Accounts.class);
deletedAccounts = mock(DeletedAccounts.class);
keys = mock(Keys.class);
keysManager = mock(KeysManager.class);
messagesManager = mock(MessagesManager.class);
profilesManager = mock(ProfilesManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);
@ -157,7 +157,7 @@ class AccountsManagerTest {
RedisClusterHelper.builder().stringCommands(commands).build(),
accountLockManager,
deletedAccounts,
keys,
keysManager,
messagesManager,
profilesManager,
mock(StoredVerificationCodeManager.class),
@ -542,7 +542,7 @@ class AccountsManagerTest {
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())));
verifyNoInteractions(keys);
verifyNoInteractions(keysManager);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@ -565,8 +565,8 @@ class AccountsManagerTest {
verify(accounts)
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())));
verify(keys).delete(existingUuid);
verify(keys).delete(phoneNumberIdentifiersByE164.get(e164));
verify(keysManager).delete(existingUuid);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164));
verify(messagesManager).clear(existingUuid);
verify(profilesManager).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
@ -585,7 +585,7 @@ class AccountsManagerTest {
verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())));
verifyNoInteractions(keys);
verifyNoInteractions(keysManager);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@ -646,8 +646,8 @@ class AccountsManagerTest {
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
verify(keys).delete(originalPni);
verify(keys).delete(phoneNumberIdentifiersByE164.get(targetNumber));
verify(keysManager).delete(originalPni);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(targetNumber));
}
@Test
@ -659,7 +659,7 @@ class AccountsManagerTest {
assertEquals(number, account.getNumber());
verify(deletedAccounts, never()).put(any(), any());
verify(keys, never()).delete(any());
verify(keysManager, never()).delete(any());
}
@Test
@ -674,7 +674,7 @@ class AccountsManagerTest {
verify(accounts, never()).update(any());
verifyNoInteractions(deletedAccounts);
verifyNoInteractions(keys);
verifyNoInteractions(keysManager);
}
@Test
@ -697,11 +697,11 @@ class AccountsManagerTest {
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys).delete(originalPni);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verifyNoMoreInteractions(keys);
verify(keysManager).delete(existingAccountUuid);
verify(keysManager).delete(originalPni);
verify(keysManager, atLeastOnce()).delete(targetPni);
verify(keysManager).delete(newPni);
verifyNoMoreInteractions(keysManager);
}
@Test
@ -723,7 +723,7 @@ class AccountsManagerTest {
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keys.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
@ -735,13 +735,13 @@ class AccountsManagerTest {
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verify(keys).delete(originalPni);
verify(keys).getPqEnabledDevices(uuid);
verify(keys).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verifyNoMoreInteractions(keys);
verify(keysManager).delete(existingAccountUuid);
verify(keysManager, atLeastOnce()).delete(targetPni);
verify(keysManager).delete(newPni);
verify(keysManager).delete(originalPni);
verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verifyNoMoreInteractions(keysManager);
}
@Test
@ -792,7 +792,7 @@ class AccountsManagerTest {
verify(accounts).update(any());
verifyNoInteractions(deletedAccounts);
verify(keys).delete(oldPni);
verify(keysManager).delete(oldPni);
}
@Test
@ -813,7 +813,7 @@ class AccountsManagerTest {
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keys.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
@ -839,10 +839,10 @@ class AccountsManagerTest {
verify(accounts).update(any());
verifyNoInteractions(deletedAccounts);
verify(keys).delete(oldPni);
verify(keysManager).delete(oldPni);
// only the pq key for the already-pq-enabled device should be saved
verify(keys).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
}
@Test

View File

@ -116,7 +116,7 @@ class AccountsManagerUsernameIntegrationTest {
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
deletedAccounts,
mock(Keys.class),
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@ -156,7 +156,7 @@ class AccountsTest {
mock(FaultTolerantRedisCluster.class),
mock(AccountLockManager.class),
mock(DeletedAccounts.class),
mock(Keys.class),
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@ -88,44 +88,44 @@ public final class DynamoDbExtensionSchema {
List.of(), List.of()),
EC_KEYS("keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
SingleUsePreKeyStore.KEY_ACCOUNT_UUID,
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeName(SingleUsePreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeName(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_KEYS("pq_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
SingleUsePreKeyStore.KEY_ACCOUNT_UUID,
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeName(SingleUsePreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeName(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_LAST_RESORT_KEYS("pq_last_resort_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
REPEATED_USE_SIGNED_PRE_KEYS("repeated_use_signed_pre_keys_test",
RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID,
RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeName(RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.attributeName(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID)
.attributeType(ScalarAttributeType.N)
.build()),
List.of(), List.of()),

View File

@ -0,0 +1,257 @@
/*
* Copyright 2021-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class KeysManagerTest {
private KeysManager keysManager;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_SIGNED_PRE_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
@BeforeEach
void setup() {
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Test
void testStore() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)),
generateTestSignedPreKey(1002));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@Test
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = generateTestSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testGetCount() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keysManager.delete(ACCOUNT_UUID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keysManager.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
}
@Test
void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), KeysHelper.signedKEMPreKey(4, identityKeyPair));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null);
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID)));
}
private static PreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32];
new SecureRandom().nextBytes(key);
return new PreKey(keyId, key);
}
private static SignedPreKey generateTestSignedPreKey(final long keyId) {
final byte[] key = new byte[32];
final byte[] signature = new byte[32];
final SecureRandom secureRandom = new SecureRandom();
secureRandom.nextBytes(key);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, key, signature);
}
}

View File

@ -1,314 +0,0 @@
/*
* Copyright 2021-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
import static org.junit.jupiter.api.Assertions.*;
class KeysTest {
private Keys keys;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PQ_LAST_RESORT_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
@BeforeEach
void setup() {
keys = new Keys(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.PQ_LAST_RESORT_KEYS.tableName());
}
@Test
void testStore() {
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)),
generateTestSignedPreKey(1002));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@Test
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = generateTestPreKey(1);
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keys.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = generateTestSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001);
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
assertEquals(Optional.of(preKey1), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testGetCount() {
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testDeleteByAccount() {
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID);
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testDeleteByAccountAndDevice() {
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, getLastResortCount(ACCOUNT_UUID));
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, getLastResortCount(ACCOUNT_UUID));
assertEquals(1L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertFalse(keys.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, getLastResortCount(ACCOUNT_UUID), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keys.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
}
private int getLastResortCount(UUID uuid) {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(Tables.PQ_LAST_RESORT_KEYS.tableName())
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", Keys.KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(uuid)))
.select(Select.COUNT)
.build();
QueryResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().query(queryRequest);
return response.count();
}
@Test
void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null);
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair));
keys.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), KeysHelper.signedKEMPreKey(4, identityKeyPair));
keys.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null);
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keys.getPqEnabledDevices(ACCOUNT_UUID)));
}
@Test
void testSortKeyPrefix() {
AttributeValue got = Keys.getSortKeyPrefix(123);
assertArrayEquals(new byte[]{0, 0, 0, 0, 0, 0, 0, 123}, got.b().asByteArray());
}
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, Keys.extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> Keys.extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
private static PreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32];
new SecureRandom().nextBytes(key);
return new PreKey(keyId, key);
}
private static SignedPreKey generateTestSignedPreKey(final long keyId) {
final byte[] key = new byte[32];
final byte[] signature = new byte[32];
final SecureRandom secureRandom = new SecureRandom();
secureRandom.nextBytes(key);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, key, signature);
}
}

View File

@ -0,0 +1,149 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.reactivestreams.Subscriber;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class RepeatedUseSignedPreKeyStoreTest {
private RepeatedUseSignedPreKeyStore keys;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS);
@BeforeEach
void setUp() {
keys = new RepeatedUseSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Test
void storeFind() {
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join());
{
final UUID identifier = UUID.randomUUID();
final long deviceId = 1;
final SignedPreKey signedPreKey = generateSignedPreKey();
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join());
}
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join());
assertEquals(Optional.of(signedPreKeys.get(1L)), keys.find(identifier, 1).join());
assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join());
}
}
@Test
void delete() {
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join());
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, 1).join();
assertEquals(Optional.empty(), keys.find(identifier, 1).join());
assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join());
}
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier).join();
assertEquals(Optional.empty(), keys.find(identifier, 1).join());
assertEquals(Optional.empty(), keys.find(identifier, 2).join());
}
}
@Test
void deleteWithError() {
final DynamoDbAsyncClient mockClient = mock(DynamoDbAsyncClient.class);
final QueryPublisher queryPublisher = mock(QueryPublisher.class);
final SdkPublisher<Map<String, AttributeValue>> itemPublisher = new SdkPublisher<Map<String, AttributeValue>>() {
final Flux<Map<String, AttributeValue>> items = Flux.just(
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(1)),
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(2)));
@Override
public void subscribe(final Subscriber<? super Map<String, AttributeValue>> subscriber) {
items.subscribe(subscriber);
}
};
when(queryPublisher.items()).thenReturn(itemPublisher);
when(mockClient.queryPaginator(any(QueryRequest.class))).thenReturn(queryPublisher);
final Exception deleteItemException = new IllegalArgumentException("OH NO");
when(mockClient.deleteItem(any(DeleteItemRequest.class)))
.thenReturn(CompletableFuture.completedFuture(DeleteItemResponse.builder().build()))
.thenReturn(CompletableFuture.failedFuture(deleteItemException));
final RepeatedUseSignedPreKeyStore keyStore = new RepeatedUseSignedPreKeyStore(mockClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
final CompletionException completionException =
assertThrows(CompletionException.class, () -> keyStore.delete(UUID.randomUUID()).join());
assertEquals(deleteItemException, completionException.getCause());
}
private static SignedPreKey generateSignedPreKey() {
return KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR);
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.entities.PreKey;
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<PreKey> {
private SingleUseECPreKeyStore preKeyStore;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.EC_KEYS);
@BeforeEach
void setUp() {
preKeyStore = new SingleUseECPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName());
}
@Override
protected SingleUsePreKeyStore<PreKey> getPreKeyStore() {
return preKeyStore;
}
@Override
protected PreKey generatePreKey(final long keyId) {
return new PreKey(keyId, Curve.generateKeyPair().getPublicKey().serialize());
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<SignedPreKey> {
private SingleUseKEMPreKeyStore preKeyStore;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.PQ_KEYS);
@BeforeEach
void setUp() {
preKeyStore = new SingleUseKEMPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName());
}
@Override
protected SingleUsePreKeyStore<SignedPreKey> getPreKeyStore() {
return preKeyStore;
}
@Override
protected SignedPreKey generatePreKey(final long keyId) {
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
}
}

View File

@ -0,0 +1,155 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.PreKey;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
abstract class SingleUsePreKeyStoreTest<K extends PreKey> {
private static final int KEY_COUNT = 100;
protected abstract SingleUsePreKeyStore<K> getPreKeyStore();
protected abstract K generatePreKey(final long keyId);
@Test
void storeTake() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join());
assertEquals(Optional.of(preKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join());
assertEquals(Optional.of(preKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
}
@Test
void getCount() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join());
}
@Test
void deleteSingleDevice() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join();
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
}
@Test
void deleteAllDevices() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join();
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
}
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, getPreKeyStore().extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> getPreKeyStore().extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
}

View File

@ -23,7 +23,6 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -65,7 +64,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
@ -82,7 +81,7 @@ class DeviceControllerTest {
public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts,
MessagesManager messages,
Keys keys,
KeysManager keys,
RateLimiters rateLimiters,
Map<String, Integer> deviceConfiguration) {
super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration);
@ -97,7 +96,7 @@ class DeviceControllerTest {
private static StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class);
private static AccountsManager accountsManager = mock(AccountsManager.class);
private static MessagesManager messagesManager = mock(MessagesManager.class);
private static Keys keys = mock(Keys.class);
private static KeysManager keysManager = mock(KeysManager.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static Account account = mock(Account.class);
@ -117,7 +116,7 @@ class DeviceControllerTest {
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
keys,
keysManager,
rateLimiters,
deviceConfiguration))
.build();
@ -161,7 +160,7 @@ class DeviceControllerTest {
pendingDevicesManager,
accountsManager,
messagesManager,
keys,
keysManager,
rateLimiters,
rateLimiter,
account,
@ -314,8 +313,8 @@ class DeviceControllerTest {
verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER);
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID);
verify(keys).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keys).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
}
private static Stream<Arguments> linkDeviceAtomic() {
@ -822,7 +821,7 @@ class DeviceControllerTest {
verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId);
verify(keys).delete(AuthHelper.VALID_UUID, deviceId);
verify(keysManager).delete(AuthHelper.VALID_UUID, deviceId);
}
}

View File

@ -58,7 +58,7 @@ import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -107,7 +107,7 @@ class KeysControllerTest {
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR);
private final static Keys KEYS = mock(Keys.class );
private final static KeysManager KEYS = mock(KeysManager.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
private final static Account existsAccount = mock(Account.class );