Skip to content
Snippets Groups Projects
Commit e8cebad2 authored by Jon Chambers's avatar Jon Chambers Committed by Jon Chambers
Browse files

Avoid modifying original `Account` instances when constructing JSON for updates

parent 6441d583
No related branches found
No related tags found
No related merge requests found
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
class AccountUtil {
static Account cloneAccountAsNotStale(final Account account) {
try {
return SystemMapper.jsonMapper().readValue(
SystemMapper.jsonMapper().writeValueAsBytes(account), Account.class);
} catch (final IOException e) {
// this should really, truly, never happen
throw new IllegalArgumentException(e);
}
}
}
...@@ -434,29 +434,54 @@ public class Accounts extends AbstractDynamoDbStore { ...@@ -434,29 +434,54 @@ public class Accounts extends AbstractDynamoDbStore {
*/ */
public CompletableFuture<Void> confirmUsernameHash(final Account account, final byte[] usernameHash, @Nullable final byte[] encryptedUsername) { public CompletableFuture<Void> confirmUsernameHash(final Account account, final byte[] usernameHash, @Nullable final byte[] encryptedUsername) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final UUID newLinkHandle = UUID.randomUUID();
final Optional<byte[]> maybeOriginalUsernameHash = account.getUsernameHash(); final TransactWriteItemsRequest request;
final Optional<byte[]> maybeOriginalReservationHash = account.getReservedUsernameHash();
final Optional<UUID> maybeOriginalUsernameLinkHandle = Optional.ofNullable(account.getUsernameLinkHandle());
final Optional<byte[]> maybeOriginalEncryptedUsername = account.getEncryptedUsername();
final UUID newLinkHandle = UUID.randomUUID(); try {
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
updatedAccount.setUsernameHash(usernameHash);
updatedAccount.setReservedUsernameHash(null);
updatedAccount.setUsernameLinkDetails(encryptedUsername == null ? null : newLinkHandle, encryptedUsername);
request = buildConfirmUsernameHashRequest(updatedAccount, account.getUsernameHash());
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
return asyncClient.transactWriteItems(request)
.thenRun(() -> {
account.setUsernameHash(usernameHash); account.setUsernameHash(usernameHash);
account.setReservedUsernameHash(null); account.setReservedUsernameHash(null);
account.setUsernameLinkDetails(encryptedUsername == null ? null : newLinkHandle, encryptedUsername); account.setUsernameLinkDetails(encryptedUsername == null ? null : newLinkHandle, encryptedUsername);
final TransactWriteItemsRequest request; account.setVersion(account.getVersion() + 1);
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof TransactionCanceledException transactionCanceledException) {
if (transactionCanceledException.cancellationReasons().stream().map(CancellationReason::code).anyMatch(CONDITIONAL_CHECK_FAILED::equals)) {
throw new ContestedOptimisticLockException();
}
}
throw ExceptionUtils.wrap(throwable);
})
.whenComplete((ignored, throwable) -> sample.stop(SET_USERNAME_TIMER));
}
private TransactWriteItemsRequest buildConfirmUsernameHashRequest(final Account updatedAccount, final Optional<byte[]> maybeOriginalUsernameHash)
throws JsonProcessingException {
try {
final List<TransactWriteItem> writeItems = new ArrayList<>(); final List<TransactWriteItem> writeItems = new ArrayList<>();
final byte[] usernameHash = updatedAccount.getUsernameHash()
.orElseThrow(() -> new IllegalArgumentException("Account must have a username hash"));
// add the username hash to the constraint table, wiping out the ttl if we had already reserved the hash // add the username hash to the constraint table, wiping out the ttl if we had already reserved the hash
writeItems.add(TransactWriteItem.builder() writeItems.add(TransactWriteItem.builder()
.put(Put.builder() .put(Put.builder()
.tableName(usernamesConstraintTableName) .tableName(usernamesConstraintTableName)
.item(Map.of( .item(Map.of(
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()), KEY_ACCOUNT_UUID, AttributeValues.fromUUID(updatedAccount.getUuid()),
ATTR_USERNAME_HASH, AttributeValues.fromByteArray(usernameHash), ATTR_USERNAME_HASH, AttributeValues.fromByteArray(usernameHash),
ATTR_CONFIRMED, AttributeValues.fromBool(true))) ATTR_CONFIRMED, AttributeValues.fromBool(true)))
// it's not in the constraint table OR it's expired OR it was reserved by us // it's not in the constraint table OR it's expired OR it was reserved by us
...@@ -464,7 +489,7 @@ public class Accounts extends AbstractDynamoDbStore { ...@@ -464,7 +489,7 @@ public class Accounts extends AbstractDynamoDbStore {
.expressionAttributeNames(Map.of("#username_hash", ATTR_USERNAME_HASH, "#ttl", ATTR_TTL, "#aci", KEY_ACCOUNT_UUID, "#confirmed", ATTR_CONFIRMED)) .expressionAttributeNames(Map.of("#username_hash", ATTR_USERNAME_HASH, "#ttl", ATTR_TTL, "#aci", KEY_ACCOUNT_UUID, "#confirmed", ATTR_CONFIRMED))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":now", AttributeValues.fromLong(clock.instant().getEpochSecond()), ":now", AttributeValues.fromLong(clock.instant().getEpochSecond()),
":aci", AttributeValues.fromUUID(account.getUuid()), ":aci", AttributeValues.fromUUID(updatedAccount.getUuid()),
":confirmed", AttributeValues.fromBool(false))) ":confirmed", AttributeValues.fromBool(false)))
.returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
.build()) .build())
...@@ -472,13 +497,13 @@ public class Accounts extends AbstractDynamoDbStore { ...@@ -472,13 +497,13 @@ public class Accounts extends AbstractDynamoDbStore {
final StringBuilder updateExpr = new StringBuilder("SET #data = :data, #username_hash = :username_hash"); final StringBuilder updateExpr = new StringBuilder("SET #data = :data, #username_hash = :username_hash");
final Map<String, AttributeValue> expressionAttributeValues = new HashMap<>(Map.of( final Map<String, AttributeValue> expressionAttributeValues = new HashMap<>(Map.of(
":data", accountDataAttributeValue(account), ":data", accountDataAttributeValue(updatedAccount),
":username_hash", AttributeValues.fromByteArray(usernameHash), ":username_hash", AttributeValues.fromByteArray(usernameHash),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(updatedAccount.getVersion()),
":version_increment", AttributeValues.fromInt(1))); ":version_increment", AttributeValues.fromInt(1)));
if (account.getUsernameLinkHandle() != null) { if (updatedAccount.getUsernameLinkHandle() != null) {
updateExpr.append(", #ul = :ul"); updateExpr.append(", #ul = :ul");
expressionAttributeValues.put(":ul", AttributeValues.fromUUID(account.getUsernameLinkHandle())); expressionAttributeValues.put(":ul", AttributeValues.fromUUID(updatedAccount.getUsernameLinkHandle()));
} else { } else {
updateExpr.append(" REMOVE #ul"); updateExpr.append(" REMOVE #ul");
} }
...@@ -488,7 +513,7 @@ public class Accounts extends AbstractDynamoDbStore { ...@@ -488,7 +513,7 @@ public class Accounts extends AbstractDynamoDbStore {
TransactWriteItem.builder() TransactWriteItem.builder()
.update(Update.builder() .update(Update.builder()
.tableName(accountsTableName) .tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(updatedAccount.getUuid())))
.updateExpression(updateExpr.toString()) .updateExpression(updateExpr.toString())
.conditionExpression("#version = :version") .conditionExpression("#version = :version")
.expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA, .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA,
...@@ -502,57 +527,60 @@ public class Accounts extends AbstractDynamoDbStore { ...@@ -502,57 +527,60 @@ public class Accounts extends AbstractDynamoDbStore {
maybeOriginalUsernameHash.ifPresent(originalUsernameHash -> writeItems.add( maybeOriginalUsernameHash.ifPresent(originalUsernameHash -> writeItems.add(
buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, originalUsernameHash))); buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, originalUsernameHash)));
request = TransactWriteItemsRequest.builder() return TransactWriteItemsRequest.builder()
.transactItems(writeItems) .transactItems(writeItems)
.build(); .build();
}
public CompletableFuture<Void> clearUsernameHash(final Account account) {
return account.getUsernameHash().map(usernameHash -> {
final Timer.Sample sample = Timer.start();
@Nullable final UUID originalLinkHandle = account.getUsernameLinkHandle();
@Nullable final byte[] originalEncryptedUsername = account.getEncryptedUsername().orElse(null);
final TransactWriteItemsRequest request;
try {
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
updatedAccount.setUsernameHash(null);
updatedAccount.setUsernameLinkDetails(null, null);
request = buildClearUsernameHashRequest(updatedAccount, usernameHash);
} catch (final JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} finally {
account.setUsernameLinkDetails(maybeOriginalUsernameLinkHandle.orElse(null), maybeOriginalEncryptedUsername.orElse(null));
account.setReservedUsernameHash(maybeOriginalReservationHash.orElse(null));
account.setUsernameHash(maybeOriginalUsernameHash.orElse(null));
} }
return asyncClient.transactWriteItems(request) return asyncClient.transactWriteItems(request)
.thenRun(() -> { .thenAccept(ignored -> {
account.setUsernameHash(usernameHash); account.setUsernameHash(null);
account.setReservedUsernameHash(null); account.setUsernameLinkDetails(null, null);
account.setUsernameLinkDetails(encryptedUsername == null ? null : newLinkHandle, encryptedUsername);
account.setVersion(account.getVersion() + 1); account.setVersion(account.getVersion() + 1);
}) })
.exceptionally(throwable -> { .exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof TransactionCanceledException transactionCanceledException) { if (ExceptionUtils.unwrap(throwable) instanceof TransactionCanceledException transactionCanceledException) {
if (transactionCanceledException.cancellationReasons().stream().map(CancellationReason::code).anyMatch(CONDITIONAL_CHECK_FAILED::equals)) { if (conditionalCheckFailed(transactionCanceledException.cancellationReasons().get(0))) {
throw new ContestedOptimisticLockException(); throw new ContestedOptimisticLockException();
} }
} }
throw ExceptionUtils.wrap(throwable); throw ExceptionUtils.wrap(throwable);
}) })
.whenComplete((ignored, throwable) -> sample.stop(SET_USERNAME_TIMER)); .whenComplete((ignored, throwable) -> sample.stop(CLEAR_USERNAME_HASH_TIMER));
}).orElseGet(() -> CompletableFuture.completedFuture(null));
} }
public CompletableFuture<Void> clearUsernameHash(final Account account) { private TransactWriteItemsRequest buildClearUsernameHashRequest(final Account updatedAccount, final byte[] originalUsernameHash)
return account.getUsernameHash().map(usernameHash -> { throws JsonProcessingException {
final Timer.Sample sample = Timer.start();
@Nullable final UUID originalLinkHandle = account.getUsernameLinkHandle();
@Nullable final byte[] originalEncryptedUsername = account.getEncryptedUsername().orElse(null);
final TransactWriteItemsRequest request;
try {
final List<TransactWriteItem> writeItems = new ArrayList<>(); final List<TransactWriteItem> writeItems = new ArrayList<>();
account.setUsernameHash(null);
account.setUsernameLinkDetails(null, null);
writeItems.add( writeItems.add(
TransactWriteItem.builder() TransactWriteItem.builder()
.update(Update.builder() .update(Update.builder()
.tableName(accountsTableName) .tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(updatedAccount.getUuid())))
.updateExpression("SET #data = :data REMOVE #username_hash, #username_link ADD #version :version_increment") .updateExpression("SET #data = :data REMOVE #username_hash, #username_link ADD #version :version_increment")
.conditionExpression("#version = :version") .conditionExpression("#version = :version")
.expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA, .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA,
...@@ -560,42 +588,17 @@ public class Accounts extends AbstractDynamoDbStore { ...@@ -560,42 +588,17 @@ public class Accounts extends AbstractDynamoDbStore {
"#username_link", ATTR_USERNAME_LINK_UUID, "#username_link", ATTR_USERNAME_LINK_UUID,
"#version", ATTR_VERSION)) "#version", ATTR_VERSION))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":data", accountDataAttributeValue(account), ":data", accountDataAttributeValue(updatedAccount),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(updatedAccount.getVersion()),
":version_increment", AttributeValues.fromInt(1))) ":version_increment", AttributeValues.fromInt(1)))
.build()) .build())
.build()); .build());
writeItems.add(buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, usernameHash)); writeItems.add(buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, originalUsernameHash));
request = TransactWriteItemsRequest.builder() return TransactWriteItemsRequest.builder()
.transactItems(writeItems) .transactItems(writeItems)
.build(); .build();
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
} finally {
account.setUsernameHash(usernameHash);
account.setUsernameLinkDetails(originalLinkHandle, originalEncryptedUsername);
}
return asyncClient.transactWriteItems(request)
.thenAccept(ignored -> {
account.setUsernameHash(null);
account.setUsernameLinkDetails(null, null);
account.setVersion(account.getVersion() + 1);
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof TransactionCanceledException transactionCanceledException) {
if (conditionalCheckFailed(transactionCanceledException.cancellationReasons().get(0))) {
throw new ContestedOptimisticLockException();
}
}
throw ExceptionUtils.wrap(throwable);
})
.whenComplete((ignored, throwable) -> sample.stop(CLEAR_USERNAME_HASH_TIMER));
}).orElseGet(() -> CompletableFuture.completedFuture(null));
} }
@Nonnull @Nonnull
......
...@@ -660,7 +660,7 @@ public class AccountsManager { ...@@ -660,7 +660,7 @@ public class AccountsManager {
final Supplier<Account> retriever, final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException { final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {
Account originalAccount = cloneAccountAsNotStale(account); Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
if (!updater.apply(account)) { if (!updater.apply(account)) {
return account; return account;
...@@ -674,7 +674,7 @@ public class AccountsManager { ...@@ -674,7 +674,7 @@ public class AccountsManager {
try { try {
persister.persistAccount(account); persister.persistAccount(account);
final Account updatedAccount = cloneAccountAsNotStale(account); final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale(); account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount); changeValidator.validateChange(originalAccount, updatedAccount);
...@@ -684,7 +684,7 @@ public class AccountsManager { ...@@ -684,7 +684,7 @@ public class AccountsManager {
tries++; tries++;
account = retriever.get(); account = retriever.get();
originalAccount = cloneAccountAsNotStale(account); originalAccount = AccountUtil.cloneAccountAsNotStale(account);
if (!updater.apply(account)) { if (!updater.apply(account)) {
return account; return account;
...@@ -702,7 +702,7 @@ public class AccountsManager { ...@@ -702,7 +702,7 @@ public class AccountsManager {
final AccountChangeValidator changeValidator, final AccountChangeValidator changeValidator,
final int remainingTries) { final int remainingTries) {
final Account originalAccount = cloneAccountAsNotStale(account); final Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
if (!updater.apply(account)) { if (!updater.apply(account)) {
return CompletableFuture.completedFuture(account); return CompletableFuture.completedFuture(account);
...@@ -711,7 +711,7 @@ public class AccountsManager { ...@@ -711,7 +711,7 @@ public class AccountsManager {
if (remainingTries > 0) { if (remainingTries > 0) {
return persister.apply(account) return persister.apply(account)
.thenApply(ignored -> { .thenApply(ignored -> {
final Account updatedAccount = cloneAccountAsNotStale(account); final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale(); account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount); changeValidator.validateChange(originalAccount, updatedAccount);
...@@ -731,16 +731,6 @@ public class AccountsManager { ...@@ -731,16 +731,6 @@ public class AccountsManager {
return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException()); return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
} }
private static Account cloneAccountAsNotStale(final Account account) {
try {
return SystemMapper.jsonMapper().readValue(
SystemMapper.jsonMapper().writeValueAsBytes(account), Account.class);
} catch (final IOException e) {
// this should really, truly, never happen
throw new IllegalArgumentException(e);
}
}
public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) { public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> { return update(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater); a.getDevice(deviceId).ifPresent(deviceUpdater);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment