From e8cebad27ee5872701bf7a896fc1efda93b9847d Mon Sep 17 00:00:00 2001
From: Jon Chambers <jon@signal.org>
Date: Thu, 19 Oct 2023 18:52:45 -0400
Subject: [PATCH] Avoid modifying original `Account` instances when
 constructing JSON for updates

---
 .../textsecuregcm/storage/AccountUtil.java    |  22 ++
 .../textsecuregcm/storage/Accounts.java       | 201 +++++++++---------
 .../storage/AccountsManager.java              |  20 +-
 3 files changed, 129 insertions(+), 114 deletions(-)
 create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java

diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java
new file mode 100644
index 000000000..cffdeea13
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java
@@ -0,0 +1,22 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.storage;
+
+import org.whispersystems.textsecuregcm.util.SystemMapper;
+import java.io.IOException;
+
+class AccountUtil {
+
+  static Account cloneAccountAsNotStale(final Account account) {
+    try {
+      return SystemMapper.jsonMapper().readValue(
+          SystemMapper.jsonMapper().writeValueAsBytes(account), Account.class);
+    } catch (final IOException e) {
+      // this should really, truly, never happen
+      throw new IllegalArgumentException(e);
+    }
+  }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java
index 2b67c1972..e8a2670df 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java
@@ -434,83 +434,19 @@ public class Accounts extends AbstractDynamoDbStore {
    */
   public CompletableFuture<Void> confirmUsernameHash(final Account account, final byte[] usernameHash, @Nullable final byte[] encryptedUsername) {
     final Timer.Sample sample = Timer.start();
-
-    final Optional<byte[]> maybeOriginalUsernameHash = account.getUsernameHash();
-    final Optional<byte[]> maybeOriginalReservationHash = account.getReservedUsernameHash();
-    final Optional<UUID> maybeOriginalUsernameLinkHandle = Optional.ofNullable(account.getUsernameLinkHandle());
-    final Optional<byte[]> maybeOriginalEncryptedUsername = account.getEncryptedUsername();
-
     final UUID newLinkHandle = UUID.randomUUID();
 
-    account.setUsernameHash(usernameHash);
-    account.setReservedUsernameHash(null);
-    account.setUsernameLinkDetails(encryptedUsername == null ? null : newLinkHandle, encryptedUsername);
-
     final TransactWriteItemsRequest request;
 
     try {
-      final List<TransactWriteItem> writeItems = new ArrayList<>();
-
-      // add the username hash to the constraint table, wiping out the ttl if we had already reserved the hash
-      writeItems.add(TransactWriteItem.builder()
-          .put(Put.builder()
-              .tableName(usernamesConstraintTableName)
-              .item(Map.of(
-                  KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()),
-                  ATTR_USERNAME_HASH, AttributeValues.fromByteArray(usernameHash),
-                  ATTR_CONFIRMED, AttributeValues.fromBool(true)))
-              // it's not in the constraint table OR it's expired OR it was reserved by us
-              .conditionExpression("attribute_not_exists(#username_hash) OR #ttl < :now OR (#aci = :aci AND #confirmed = :confirmed)")
-              .expressionAttributeNames(Map.of("#username_hash", ATTR_USERNAME_HASH, "#ttl", ATTR_TTL, "#aci", KEY_ACCOUNT_UUID, "#confirmed", ATTR_CONFIRMED))
-              .expressionAttributeValues(Map.of(
-                  ":now", AttributeValues.fromLong(clock.instant().getEpochSecond()),
-                  ":aci", AttributeValues.fromUUID(account.getUuid()),
-                  ":confirmed", AttributeValues.fromBool(false)))
-              .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
-              .build())
-          .build());
-
-      final StringBuilder updateExpr = new StringBuilder("SET #data = :data, #username_hash = :username_hash");
-      final Map<String, AttributeValue> expressionAttributeValues = new HashMap<>(Map.of(
-          ":data", accountDataAttributeValue(account),
-          ":username_hash", AttributeValues.fromByteArray(usernameHash),
-          ":version", AttributeValues.fromInt(account.getVersion()),
-          ":version_increment", AttributeValues.fromInt(1)));
-      if (account.getUsernameLinkHandle() != null) {
-        updateExpr.append(", #ul = :ul");
-        expressionAttributeValues.put(":ul", AttributeValues.fromUUID(account.getUsernameLinkHandle()));
-      } else {
-        updateExpr.append(" REMOVE #ul");
-      }
-      updateExpr.append(" ADD #version :version_increment");
-
-      writeItems.add(
-          TransactWriteItem.builder()
-              .update(Update.builder()
-                  .tableName(accountsTableName)
-                  .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
-                  .updateExpression(updateExpr.toString())
-                  .conditionExpression("#version = :version")
-                  .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA,
-                      "#username_hash", ATTR_USERNAME_HASH,
-                      "#ul", ATTR_USERNAME_LINK_UUID,
-                      "#version", ATTR_VERSION))
-                  .expressionAttributeValues(expressionAttributeValues)
-                  .build())
-              .build());
+      final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
+      updatedAccount.setUsernameHash(usernameHash);
+      updatedAccount.setReservedUsernameHash(null);
+      updatedAccount.setUsernameLinkDetails(encryptedUsername == null ? null : newLinkHandle, encryptedUsername);
 
-      maybeOriginalUsernameHash.ifPresent(originalUsernameHash -> writeItems.add(
-          buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, originalUsernameHash)));
-
-      request = TransactWriteItemsRequest.builder()
-          .transactItems(writeItems)
-          .build();
+      request = buildConfirmUsernameHashRequest(updatedAccount, account.getUsernameHash());
     } catch (final JsonProcessingException e) {
       throw new IllegalArgumentException(e);
-    } finally {
-      account.setUsernameLinkDetails(maybeOriginalUsernameLinkHandle.orElse(null), maybeOriginalEncryptedUsername.orElse(null));
-      account.setReservedUsernameHash(maybeOriginalReservationHash.orElse(null));
-      account.setUsernameHash(maybeOriginalUsernameHash.orElse(null));
     }
 
     return asyncClient.transactWriteItems(request)
@@ -533,6 +469,69 @@ public class Accounts extends AbstractDynamoDbStore {
         .whenComplete((ignored, throwable) -> sample.stop(SET_USERNAME_TIMER));
   }
 
+  private TransactWriteItemsRequest buildConfirmUsernameHashRequest(final Account updatedAccount, final Optional<byte[]> maybeOriginalUsernameHash)
+      throws JsonProcessingException {
+
+    final List<TransactWriteItem> writeItems = new ArrayList<>();
+    final byte[] usernameHash = updatedAccount.getUsernameHash()
+        .orElseThrow(() -> new IllegalArgumentException("Account must have a username hash"));
+
+    // add the username hash to the constraint table, wiping out the ttl if we had already reserved the hash
+    writeItems.add(TransactWriteItem.builder()
+        .put(Put.builder()
+            .tableName(usernamesConstraintTableName)
+            .item(Map.of(
+                KEY_ACCOUNT_UUID, AttributeValues.fromUUID(updatedAccount.getUuid()),
+                ATTR_USERNAME_HASH, AttributeValues.fromByteArray(usernameHash),
+                ATTR_CONFIRMED, AttributeValues.fromBool(true)))
+            // it's not in the constraint table OR it's expired OR it was reserved by us
+            .conditionExpression("attribute_not_exists(#username_hash) OR #ttl < :now OR (#aci = :aci AND #confirmed = :confirmed)")
+            .expressionAttributeNames(Map.of("#username_hash", ATTR_USERNAME_HASH, "#ttl", ATTR_TTL, "#aci", KEY_ACCOUNT_UUID, "#confirmed", ATTR_CONFIRMED))
+            .expressionAttributeValues(Map.of(
+                ":now", AttributeValues.fromLong(clock.instant().getEpochSecond()),
+                ":aci", AttributeValues.fromUUID(updatedAccount.getUuid()),
+                ":confirmed", AttributeValues.fromBool(false)))
+            .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD)
+            .build())
+        .build());
+
+    final StringBuilder updateExpr = new StringBuilder("SET #data = :data, #username_hash = :username_hash");
+    final Map<String, AttributeValue> expressionAttributeValues = new HashMap<>(Map.of(
+        ":data", accountDataAttributeValue(updatedAccount),
+        ":username_hash", AttributeValues.fromByteArray(usernameHash),
+        ":version", AttributeValues.fromInt(updatedAccount.getVersion()),
+        ":version_increment", AttributeValues.fromInt(1)));
+    if (updatedAccount.getUsernameLinkHandle() != null) {
+      updateExpr.append(", #ul = :ul");
+      expressionAttributeValues.put(":ul", AttributeValues.fromUUID(updatedAccount.getUsernameLinkHandle()));
+    } else {
+      updateExpr.append(" REMOVE #ul");
+    }
+    updateExpr.append(" ADD #version :version_increment");
+
+    writeItems.add(
+        TransactWriteItem.builder()
+            .update(Update.builder()
+                .tableName(accountsTableName)
+                .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(updatedAccount.getUuid())))
+                .updateExpression(updateExpr.toString())
+                .conditionExpression("#version = :version")
+                .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA,
+                    "#username_hash", ATTR_USERNAME_HASH,
+                    "#ul", ATTR_USERNAME_LINK_UUID,
+                    "#version", ATTR_VERSION))
+                .expressionAttributeValues(expressionAttributeValues)
+                .build())
+            .build());
+
+    maybeOriginalUsernameHash.ifPresent(originalUsernameHash -> writeItems.add(
+        buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, originalUsernameHash)));
+
+    return TransactWriteItemsRequest.builder()
+        .transactItems(writeItems)
+        .build();
+  }
+
   public CompletableFuture<Void> clearUsernameHash(final Account account) {
     return account.getUsernameHash().map(usernameHash -> {
       final Timer.Sample sample = Timer.start();
@@ -543,39 +542,13 @@ public class Accounts extends AbstractDynamoDbStore {
       final TransactWriteItemsRequest request;
 
       try {
-        final List<TransactWriteItem> writeItems = new ArrayList<>();
-
-        account.setUsernameHash(null);
-        account.setUsernameLinkDetails(null, null);
+        final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
+        updatedAccount.setUsernameHash(null);
+        updatedAccount.setUsernameLinkDetails(null, null);
 
-        writeItems.add(
-            TransactWriteItem.builder()
-                .update(Update.builder()
-                    .tableName(accountsTableName)
-                    .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
-                    .updateExpression("SET #data = :data REMOVE #username_hash, #username_link ADD #version :version_increment")
-                    .conditionExpression("#version = :version")
-                    .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA,
-                        "#username_hash", ATTR_USERNAME_HASH,
-                        "#username_link", ATTR_USERNAME_LINK_UUID,
-                        "#version", ATTR_VERSION))
-                    .expressionAttributeValues(Map.of(
-                        ":data", accountDataAttributeValue(account),
-                        ":version", AttributeValues.fromInt(account.getVersion()),
-                        ":version_increment", AttributeValues.fromInt(1)))
-                    .build())
-                .build());
-
-        writeItems.add(buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, usernameHash));
-
-        request = TransactWriteItemsRequest.builder()
-            .transactItems(writeItems)
-            .build();
+        request = buildClearUsernameHashRequest(updatedAccount, usernameHash);
       } catch (final JsonProcessingException e) {
         throw new IllegalArgumentException(e);
-      } finally {
-        account.setUsernameHash(usernameHash);
-        account.setUsernameLinkDetails(originalLinkHandle, originalEncryptedUsername);
       }
 
       return asyncClient.transactWriteItems(request)
@@ -598,6 +571,36 @@ public class Accounts extends AbstractDynamoDbStore {
     }).orElseGet(() -> CompletableFuture.completedFuture(null));
   }
 
+  private TransactWriteItemsRequest buildClearUsernameHashRequest(final Account updatedAccount, final byte[] originalUsernameHash)
+      throws JsonProcessingException {
+
+    final List<TransactWriteItem> writeItems = new ArrayList<>();
+
+    writeItems.add(
+        TransactWriteItem.builder()
+            .update(Update.builder()
+                .tableName(accountsTableName)
+                .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(updatedAccount.getUuid())))
+                .updateExpression("SET #data = :data REMOVE #username_hash, #username_link ADD #version :version_increment")
+                .conditionExpression("#version = :version")
+                .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA,
+                    "#username_hash", ATTR_USERNAME_HASH,
+                    "#username_link", ATTR_USERNAME_LINK_UUID,
+                    "#version", ATTR_VERSION))
+                .expressionAttributeValues(Map.of(
+                    ":data", accountDataAttributeValue(updatedAccount),
+                    ":version", AttributeValues.fromInt(updatedAccount.getVersion()),
+                    ":version_increment", AttributeValues.fromInt(1)))
+                .build())
+            .build());
+
+    writeItems.add(buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, originalUsernameHash));
+
+    return TransactWriteItemsRequest.builder()
+        .transactItems(writeItems)
+        .build();
+  }
+
   @Nonnull
   public CompletionStage<Void> updateAsync(final Account account) {
     return AsyncTimerUtil.record(UPDATE_TIMER, () -> {
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java
index 0b2559589..1536c74b8 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java
@@ -660,7 +660,7 @@ public class AccountsManager {
       final Supplier<Account> retriever,
       final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {
 
-    Account originalAccount = cloneAccountAsNotStale(account);
+    Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
 
     if (!updater.apply(account)) {
       return account;
@@ -674,7 +674,7 @@ public class AccountsManager {
       try {
         persister.persistAccount(account);
 
-        final Account updatedAccount = cloneAccountAsNotStale(account);
+        final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
         account.markStale();
 
         changeValidator.validateChange(originalAccount, updatedAccount);
@@ -684,7 +684,7 @@ public class AccountsManager {
         tries++;
 
         account = retriever.get();
-        originalAccount = cloneAccountAsNotStale(account);
+        originalAccount = AccountUtil.cloneAccountAsNotStale(account);
 
         if (!updater.apply(account)) {
           return account;
@@ -702,7 +702,7 @@ public class AccountsManager {
       final AccountChangeValidator changeValidator,
       final int remainingTries) {
 
-    final Account originalAccount = cloneAccountAsNotStale(account);
+    final Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
 
     if (!updater.apply(account)) {
       return CompletableFuture.completedFuture(account);
@@ -711,7 +711,7 @@ public class AccountsManager {
     if (remainingTries > 0) {
       return persister.apply(account)
           .thenApply(ignored -> {
-            final Account updatedAccount = cloneAccountAsNotStale(account);
+            final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
             account.markStale();
 
             changeValidator.validateChange(originalAccount, updatedAccount);
@@ -731,16 +731,6 @@ public class AccountsManager {
     return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
   }
 
-  private static Account cloneAccountAsNotStale(final Account account) {
-    try {
-      return SystemMapper.jsonMapper().readValue(
-          SystemMapper.jsonMapper().writeValueAsBytes(account), Account.class);
-    } catch (final IOException e) {
-      // this should really, truly, never happen
-      throw new IllegalArgumentException(e);
-    }
-  }
-
   public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
     return update(account, a -> {
       a.getDevice(deviceId).ifPresent(deviceUpdater);
-- 
GitLab