From 9de6190ec4cf48b8a79cd91f99723fb8cae882ee Mon Sep 17 00:00:00 2001
From: Robin Appelman <robin@icewind.nl>
Date: Thu, 4 Jul 2024 17:18:22 +0200
Subject: [PATCH] feat: allow running QueryBuilder queries on different
 connections

Signed-off-by: Robin Appelman <robin@icewind.nl>
---
 .../DB/QueryBuilder/ExtendedQueryBuilder.php  | 15 ++--
 lib/private/DB/QueryBuilder/QueryBuilder.php  | 76 +++++++++-------
 lib/public/DB/QueryBuilder/IQueryBuilder.php  | 10 ++-
 .../lib/DB/QueryBuilder/QueryBuilderTest.php  | 89 ++++++++++++-------
 4 files changed, 115 insertions(+), 75 deletions(-)

diff --git a/lib/private/DB/QueryBuilder/ExtendedQueryBuilder.php b/lib/private/DB/QueryBuilder/ExtendedQueryBuilder.php
index ab58773dfd3..bde6523567f 100644
--- a/lib/private/DB/QueryBuilder/ExtendedQueryBuilder.php
+++ b/lib/private/DB/QueryBuilder/ExtendedQueryBuilder.php
@@ -11,6 +11,7 @@ namespace OC\DB\QueryBuilder;
 use OC\DB\Exceptions\DbalException;
 use OCP\DB\IResult;
 use OCP\DB\QueryBuilder\IQueryBuilder;
+use OCP\IDBConnection;
 
 /**
  * Base class for creating classes that extend the builtin query builder
@@ -46,12 +47,12 @@ abstract class ExtendedQueryBuilder implements IQueryBuilder {
 		return $this->builder->getState();
 	}
 
-	public function execute() {
+	public function execute(?IDBConnection $connection = null) {
 		try {
 			if ($this->getType() === \Doctrine\DBAL\Query\QueryBuilder::SELECT) {
-				return $this->executeQuery();
+				return $this->executeQuery($connection);
 			} else {
-				return $this->executeStatement();
+				return $this->executeStatement($connection);
 			}
 		} catch (DBALException $e) {
 			// `IQueryBuilder->execute` never wrapped the exception, but `executeQuery` and `executeStatement` do
@@ -280,11 +281,11 @@ abstract class ExtendedQueryBuilder implements IQueryBuilder {
 		return $this->builder->getColumnName($column, $tableAlias);
 	}
 
-	public function executeQuery(): IResult {
-		return $this->builder->executeQuery();
+	public function executeQuery(?IDBConnection $connection = null): IResult {
+		return $this->builder->executeQuery($connection);
 	}
 
-	public function executeStatement(): int {
-		return $this->builder->executeStatement();
+	public function executeStatement(?IDBConnection $connection = null): int {
+		return $this->builder->executeStatement($connection);
 	}
 }
diff --git a/lib/private/DB/QueryBuilder/QueryBuilder.php b/lib/private/DB/QueryBuilder/QueryBuilder.php
index 0e7d8d2ff3e..82127078d06 100644
--- a/lib/private/DB/QueryBuilder/QueryBuilder.php
+++ b/lib/private/DB/QueryBuilder/QueryBuilder.php
@@ -13,6 +13,7 @@ use Doctrine\DBAL\Platforms\PostgreSQL94Platform;
 use Doctrine\DBAL\Platforms\SqlitePlatform;
 use Doctrine\DBAL\Query\QueryException;
 use OC\DB\ConnectionAdapter;
+use OC\DB\Exceptions\DbalException;
 use OC\DB\QueryBuilder\ExpressionBuilder\ExpressionBuilder;
 use OC\DB\QueryBuilder\ExpressionBuilder\MySqlExpressionBuilder;
 use OC\DB\QueryBuilder\ExpressionBuilder\OCIExpressionBuilder;
@@ -22,7 +23,6 @@ use OC\DB\QueryBuilder\FunctionBuilder\FunctionBuilder;
 use OC\DB\QueryBuilder\FunctionBuilder\OCIFunctionBuilder;
 use OC\DB\QueryBuilder\FunctionBuilder\PgSqlFunctionBuilder;
 use OC\DB\QueryBuilder\FunctionBuilder\SqliteFunctionBuilder;
-use OC\DB\ResultAdapter;
 use OC\SystemConfig;
 use OCP\DB\IResult;
 use OCP\DB\QueryBuilder\ICompositeExpression;
@@ -30,6 +30,7 @@ use OCP\DB\QueryBuilder\ILiteral;
 use OCP\DB\QueryBuilder\IParameter;
 use OCP\DB\QueryBuilder\IQueryBuilder;
 use OCP\DB\QueryBuilder\IQueryFunction;
+use OCP\IDBConnection;
 use Psr\Log\LoggerInterface;
 
 class QueryBuilder implements IQueryBuilder {
@@ -168,15 +169,7 @@ class QueryBuilder implements IQueryBuilder {
 		return $this->queryBuilder->getState();
 	}
 
-	/**
-	 * Executes this query using the bound parameters and their types.
-	 *
-	 * Uses {@see Connection::executeQuery} for select statements and {@see Connection::executeUpdate}
-	 * for insert, update and delete statements.
-	 *
-	 * @return IResult|int
-	 */
-	public function execute() {
+	private function prepareForExecute() {
 		if ($this->systemConfig->getValue('log_query', false)) {
 			try {
 				$params = [];
@@ -253,48 +246,63 @@ class QueryBuilder implements IQueryBuilder {
 				'exception' => $exception,
 			]);
 		}
+	}
 
-		$result = $this->queryBuilder->execute();
-		if (is_int($result)) {
-			return $result;
+	/**
+	 * Executes this query using the bound parameters and their types.
+	 *
+	 * Uses {@see Connection::executeQuery} for select statements and {@see Connection::executeUpdate}
+	 * for insert, update and delete statements.
+	 *
+	 * @return IResult|int
+	 */
+	public function execute(?IDBConnection $connection = null) {
+		try {
+			if ($this->getType() === \Doctrine\DBAL\Query\QueryBuilder::SELECT) {
+				return $this->executeQuery($connection);
+			} else {
+				return $this->executeStatement($connection);
+			}
+		} catch (DBALException $e) {
+			// `IQueryBuilder->execute` never wrapped the exception, but `executeQuery` and `executeStatement` do
+			/** @var \Doctrine\DBAL\Exception $previous */
+			$previous = $e->getPrevious();
+			throw $previous;
 		}
-		return new ResultAdapter($result);
 	}
 
-	public function executeQuery(): IResult {
+	public function executeQuery(?IDBConnection $connection = null): IResult {
 		if ($this->getType() !== \Doctrine\DBAL\Query\QueryBuilder::SELECT) {
 			throw new \RuntimeException('Invalid query type, expected SELECT query');
 		}
 
-		try {
-			$result = $this->execute();
-		} catch (\Doctrine\DBAL\Exception $e) {
-			throw \OC\DB\Exceptions\DbalException::wrap($e);
+		$this->prepareForExecute();
+		if (!$connection) {
+			$connection = $this->connection;
 		}
 
-		if ($result instanceof IResult) {
-			return $result;
-		}
-
-		throw new \RuntimeException('Invalid return type for query');
+		return $connection->executeQuery(
+			$this->getSQL(),
+			$this->getParameters(),
+			$this->getParameterTypes(),
+		);
 	}
 
-	public function executeStatement(): int {
+	public function executeStatement(?IDBConnection $connection = null): int {
 		if ($this->getType() === \Doctrine\DBAL\Query\QueryBuilder::SELECT) {
 			throw new \RuntimeException('Invalid query type, expected INSERT, DELETE or UPDATE statement');
 		}
 
-		try {
-			$result = $this->execute();
-		} catch (\Doctrine\DBAL\Exception $e) {
-			throw \OC\DB\Exceptions\DbalException::wrap($e);
-		}
-
-		if (!is_int($result)) {
-			throw new \RuntimeException('Invalid return type for statement');
+		$this->prepareForExecute();
+		if (!$connection) {
+			$connection = $this->connection;
 		}
 
-		return $result;
+		return $connection->executeStatement(
+			$this->getSQL(),
+			$this->getParameters(),
+			$this->getParameterTypes(),
+		);
 	}
 
 
diff --git a/lib/public/DB/QueryBuilder/IQueryBuilder.php b/lib/public/DB/QueryBuilder/IQueryBuilder.php
index 94ab796adf4..c736d3094e5 100644
--- a/lib/public/DB/QueryBuilder/IQueryBuilder.php
+++ b/lib/public/DB/QueryBuilder/IQueryBuilder.php
@@ -12,6 +12,7 @@ use Doctrine\DBAL\Connection;
 use Doctrine\DBAL\ParameterType;
 use OCP\DB\Exception;
 use OCP\DB\IResult;
+use OCP\IDBConnection;
 
 /**
  * This class provides a wrapper around Doctrine's QueryBuilder
@@ -146,34 +147,37 @@ interface IQueryBuilder {
 	 *          that interface changed in a breaking way the adapter \OCP\DB\QueryBuilder\IStatement is returned
 	 *          to bridge old code to the new API
 	 *
+	 * @param ?IDBConnection $connection (optional) the connection to run the query against. since 30.0
 	 * @return IResult|int
 	 * @throws Exception since 21.0.0
 	 * @since 8.2.0
 	 * @deprecated 22.0.0 Use executeQuery or executeStatement
 	 */
-	public function execute();
+	public function execute(?IDBConnection $connection = null);
 
 	/**
 	 * Execute for select statements
 	 *
+	 * @param ?IDBConnection $connection (optional) the connection to run the query against. since 30.0
 	 * @return IResult
 	 * @since 22.0.0
 	 *
 	 * @throws Exception
 	 * @throws \RuntimeException in case of usage with non select query
 	 */
-	public function executeQuery(): IResult;
+	public function executeQuery(?IDBConnection $connection = null): IResult;
 
 	/**
 	 * Execute insert, update and delete statements
 	 *
+	 * @param ?IDBConnection $connection (optional) the connection to run the query against. since 30.0
 	 * @return int the number of affected rows
 	 * @since 22.0.0
 	 *
 	 * @throws Exception
 	 * @throws \RuntimeException in case of usage with select query
 	 */
-	public function executeStatement(): int;
+	public function executeStatement(?IDBConnection $connection = null): int;
 
 	/**
 	 * Gets the complete SQL string formed by the current specifications of this QueryBuilder.
diff --git a/tests/lib/DB/QueryBuilder/QueryBuilderTest.php b/tests/lib/DB/QueryBuilder/QueryBuilderTest.php
index 96cde8ba1f9..335666b54fd 100644
--- a/tests/lib/DB/QueryBuilder/QueryBuilderTest.php
+++ b/tests/lib/DB/QueryBuilder/QueryBuilderTest.php
@@ -9,11 +9,12 @@ namespace Test\DB\QueryBuilder;
 
 use Doctrine\DBAL\Query\Expression\CompositeExpression;
 use Doctrine\DBAL\Query\QueryException;
-use Doctrine\DBAL\Result;
 use OC\DB\QueryBuilder\Literal;
 use OC\DB\QueryBuilder\Parameter;
 use OC\DB\QueryBuilder\QueryBuilder;
 use OC\SystemConfig;
+use OCP\DB\IResult;
+use OCP\DB\QueryBuilder\IQueryBuilder;
 use OCP\DB\QueryBuilder\IQueryFunction;
 use OCP\IDBConnection;
 use Psr\Log\LoggerInterface;
@@ -1253,16 +1254,29 @@ class QueryBuilderTest extends \Test\TestCase {
 		);
 	}
 
+	private function getConnection(): IDBConnection {
+		$connection = $this->createMock(IDBConnection::class);
+		$connection->method('executeStatement')
+			->willReturn(3);
+		$connection->method('executeQuery')
+			->willReturn($this->createMock(IResult::class));
+		return $connection;
+	}
+
 	public function testExecuteWithoutLogger() {
 		$queryBuilder = $this->createMock(\Doctrine\DBAL\Query\QueryBuilder::class);
 		$queryBuilder
-			->expects($this->once())
-			->method('execute')
-			->willReturn(3);
+			->method('getSQL')
+			->willReturn('');
 		$queryBuilder
-			->expects($this->any())
 			->method('getParameters')
 			->willReturn([]);
+		$queryBuilder
+			->method('getParameterTypes')
+			->willReturn([]);
+		$queryBuilder
+			->method('getType')
+			->willReturn(\Doctrine\DBAL\Query\QueryBuilder::UPDATE);
 		$this->logger
 			->expects($this->never())
 			->method('debug');
@@ -1273,6 +1287,7 @@ class QueryBuilderTest extends \Test\TestCase {
 			->willReturn(false);
 
 		$this->invokePrivate($this->queryBuilder, 'queryBuilder', [$queryBuilder]);
+		$this->invokePrivate($this->queryBuilder, 'connection', [$this->getConnection()]);
 		$this->assertEquals(3, $this->queryBuilder->execute());
 	}
 
@@ -1285,21 +1300,26 @@ class QueryBuilderTest extends \Test\TestCase {
 				'foo' => 'bar',
 				'key' => 'value',
 			]);
+		$queryBuilder
+			->method('getParameterTypes')
+			->willReturn([
+				'foo' => IQueryBuilder::PARAM_STR,
+				'key' => IQueryBuilder::PARAM_STR,
+			]);
+		$queryBuilder
+			->method('getType')
+			->willReturn(\Doctrine\DBAL\Query\QueryBuilder::UPDATE);
 		$queryBuilder
 			->expects($this->any())
 			->method('getSQL')
-			->willReturn('SELECT * FROM FOO WHERE BAR = ?');
-		$queryBuilder
-			->expects($this->once())
-			->method('execute')
-			->willReturn(3);
+			->willReturn('UPDATE FOO SET bar = 1 WHERE BAR = ?');
 		$this->logger
 			->expects($this->once())
 			->method('debug')
 			->with(
 				'DB QueryBuilder: \'{query}\' with parameters: {params}',
 				[
-					'query' => 'SELECT * FROM FOO WHERE BAR = ?',
+					'query' => 'UPDATE FOO SET bar = 1 WHERE BAR = ?',
 					'params' => 'foo => \'bar\', key => \'value\'',
 					'app' => 'core',
 				]
@@ -1311,6 +1331,7 @@ class QueryBuilderTest extends \Test\TestCase {
 			->willReturn(true);
 
 		$this->invokePrivate($this->queryBuilder, 'queryBuilder', [$queryBuilder]);
+		$this->invokePrivate($this->queryBuilder, 'connection', [$this->getConnection()]);
 		$this->assertEquals(3, $this->queryBuilder->execute());
 	}
 
@@ -1320,21 +1341,23 @@ class QueryBuilderTest extends \Test\TestCase {
 			->expects($this->any())
 			->method('getParameters')
 			->willReturn(['Bar']);
+		$queryBuilder
+			->method('getParameterTypes')
+			->willReturn([IQueryBuilder::PARAM_STR]);
+		$queryBuilder
+			->method('getType')
+			->willReturn(\Doctrine\DBAL\Query\QueryBuilder::UPDATE);
 		$queryBuilder
 			->expects($this->any())
 			->method('getSQL')
-			->willReturn('SELECT * FROM FOO WHERE BAR = ?');
-		$queryBuilder
-			->expects($this->once())
-			->method('execute')
-			->willReturn(3);
+			->willReturn('UPDATE FOO SET bar = false WHERE BAR = ?');
 		$this->logger
 			->expects($this->once())
 			->method('debug')
 			->with(
 				'DB QueryBuilder: \'{query}\' with parameters: {params}',
 				[
-					'query' => 'SELECT * FROM FOO WHERE BAR = ?',
+					'query' => 'UPDATE FOO SET bar = false WHERE BAR = ?',
 					'params' => '0 => \'Bar\'',
 					'app' => 'core',
 				]
@@ -1346,6 +1369,7 @@ class QueryBuilderTest extends \Test\TestCase {
 			->willReturn(true);
 
 		$this->invokePrivate($this->queryBuilder, 'queryBuilder', [$queryBuilder]);
+		$this->invokePrivate($this->queryBuilder, 'connection', [$this->getConnection()]);
 		$this->assertEquals(3, $this->queryBuilder->execute());
 	}
 
@@ -1355,21 +1379,23 @@ class QueryBuilderTest extends \Test\TestCase {
 			->expects($this->any())
 			->method('getParameters')
 			->willReturn([]);
+		$queryBuilder
+			->method('getParameterTypes')
+			->willReturn([]);
+		$queryBuilder
+			->method('getType')
+			->willReturn(\Doctrine\DBAL\Query\QueryBuilder::UPDATE);
 		$queryBuilder
 			->expects($this->any())
 			->method('getSQL')
-			->willReturn('SELECT * FROM FOO WHERE BAR = ?');
-		$queryBuilder
-			->expects($this->once())
-			->method('execute')
-			->willReturn(3);
+			->willReturn('UPDATE FOO SET bar = false WHERE BAR = ?');
 		$this->logger
 			->expects($this->once())
 			->method('debug')
 			->with(
 				'DB QueryBuilder: \'{query}\'',
 				[
-					'query' => 'SELECT * FROM FOO WHERE BAR = ?',
+					'query' => 'UPDATE FOO SET bar = false WHERE BAR = ?',
 					'app' => 'core',
 				]
 			);
@@ -1380,6 +1406,7 @@ class QueryBuilderTest extends \Test\TestCase {
 			->willReturn(true);
 
 		$this->invokePrivate($this->queryBuilder, 'queryBuilder', [$queryBuilder]);
+		$this->invokePrivate($this->queryBuilder, 'connection', [$this->getConnection()]);
 		$this->assertEquals(3, $this->queryBuilder->execute());
 	}
 
@@ -1390,14 +1417,13 @@ class QueryBuilderTest extends \Test\TestCase {
 			->expects($this->any())
 			->method('getParameters')
 			->willReturn([$p]);
+		$queryBuilder
+			->method('getParameterTypes')
+			->willReturn([IQueryBuilder::PARAM_STR_ARRAY]);
 		$queryBuilder
 			->expects($this->any())
 			->method('getSQL')
 			->willReturn('SELECT * FROM FOO WHERE BAR IN (?)');
-		$queryBuilder
-			->expects($this->once())
-			->method('execute')
-			->willReturn($this->createMock(Result::class));
 		$this->logger
 			->expects($this->once())
 			->method('error')
@@ -1415,6 +1441,7 @@ class QueryBuilderTest extends \Test\TestCase {
 			->willReturn(false);
 
 		$this->invokePrivate($this->queryBuilder, 'queryBuilder', [$queryBuilder]);
+		$this->invokePrivate($this->queryBuilder, 'connection', [$this->getConnection()]);
 		$this->queryBuilder->execute();
 	}
 
@@ -1425,14 +1452,13 @@ class QueryBuilderTest extends \Test\TestCase {
 			->expects($this->any())
 			->method('getParameters')
 			->willReturn(array_fill(0, 66, $p));
+		$queryBuilder
+			->method('getParameterTypes')
+			->willReturn([IQueryBuilder::PARAM_STR_ARRAY]);
 		$queryBuilder
 			->expects($this->any())
 			->method('getSQL')
 			->willReturn('SELECT * FROM FOO WHERE BAR IN (?) OR BAR IN (?)');
-		$queryBuilder
-			->expects($this->once())
-			->method('execute')
-			->willReturn($this->createMock(Result::class));
 		$this->logger
 			->expects($this->once())
 			->method('error')
@@ -1450,6 +1476,7 @@ class QueryBuilderTest extends \Test\TestCase {
 			->willReturn(false);
 
 		$this->invokePrivate($this->queryBuilder, 'queryBuilder', [$queryBuilder]);
+		$this->invokePrivate($this->queryBuilder, 'connection', [$this->getConnection()]);
 		$this->queryBuilder->execute();
 	}
 }
-- 
GitLab