diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java
index 84ba4f8d4c52..e87f6403fc45 100644
--- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java
+++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java
@@ -269,6 +269,17 @@ public void testUpdate()
.hasMessage("Hive update is only supported for ACID transactional tables");
}
+ @Override
+ public void testUpdateRowConcurrently()
+ throws Exception
+ {
+ // TODO (https://github.com/trinodb/trino/issues/10518) test this with a TestHiveConnectorTest version that creates ACID tables by default, or in some other way
+ assertThatThrownBy(super::testUpdateRowConcurrently)
+ .hasMessage("Unexpected concurrent update failure")
+ .getCause()
+ .hasMessage("Hive update is only supported for ACID transactional tables");
+ }
+
@Override
public void testExplainAnalyzeWithDeleteWithSubquery()
{
diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml
index c1b42f0f1bb9..3955ecc4c8ac 100644
--- a/testing/trino-testing/pom.xml
+++ b/testing/trino-testing/pom.xml
@@ -53,6 +53,11 @@
tpch
+
+ io.airlift
+ concurrent
+
+
io.airlift
log
diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java
index a46cf38655fb..c0365a082481 100644
--- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java
+++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java
@@ -26,16 +26,24 @@
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
+import java.util.List;
import java.util.Optional;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import java.util.function.Consumer;
+import java.util.stream.IntStream;
import java.util.stream.Stream;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static io.airlift.concurrent.MoreFutures.tryGetFutureValue;
import static io.trino.SystemSessionProperties.IGNORE_STATS_CALCULATOR_FAILURES;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.sql.planner.planprinter.PlanPrinter.textLogicalPlan;
import static io.trino.testing.DataProviders.toDataProvider;
import static io.trino.testing.QueryAssertions.assertContains;
+import static io.trino.testing.QueryAssertions.getTrinoExceptionCause;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_ADD_COLUMN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_ARRAY;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_COMMENT_ON_COLUMN;
@@ -66,6 +74,9 @@
import static java.lang.String.join;
import static java.util.Collections.nCopies;
import static java.util.Locale.ENGLISH;
+import static java.util.concurrent.Executors.newFixedThreadPool;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static java.util.stream.Collectors.joining;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertFalse;
@@ -1280,6 +1291,72 @@ public void testUpdate()
}
}
+ // Repeat test with invocationCount for better test coverage, since the tested aspect is inherently non-deterministic.
+ @Test(timeOut = 60_000, invocationCount = 4)
+ public void testUpdateRowConcurrently()
+ throws Exception
+ {
+ if (!hasBehavior(SUPPORTS_UPDATE)) {
+ // Covered by testUpdate
+ return;
+ }
+
+ int threads = 4;
+ CyclicBarrier barrier = new CyclicBarrier(threads);
+ ExecutorService executor = newFixedThreadPool(threads);
+ try (TestTable table = new TestTable(
+ getQueryRunner()::execute,
+ "test_concurrent_update",
+ IntStream.range(0, threads)
+ .mapToObj(i -> format("col%s integer", i))
+ .collect(joining(", ", "(", ")")))) {
+ String tableName = table.getName();
+ assertUpdate(format("INSERT INTO %s VALUES (%s)", tableName, join(",", nCopies(threads, "0"))), 1);
+
+ List> futures = IntStream.range(0, threads)
+ .mapToObj(threadNumber -> executor.submit(() -> {
+ barrier.await(10, SECONDS);
+ try {
+ String columnName = "col" + threadNumber;
+ getQueryRunner().execute(format("UPDATE %s SET %s = %s + 1", tableName, columnName, columnName));
+ return true;
+ }
+ catch (Exception e) {
+ RuntimeException trinoException = getTrinoExceptionCause(e);
+ try {
+ verifyConcurrentUpdateFailurePermissible(trinoException);
+ }
+ catch (Throwable verifyFailure) {
+ if (trinoException != e && verifyFailure != e) {
+ verifyFailure.addSuppressed(e);
+ }
+ throw verifyFailure;
+ }
+ return false;
+ }
+ }))
+ .collect(toImmutableList());
+
+ String expected = futures.stream()
+ .map(future -> tryGetFutureValue(future, 10, SECONDS).orElseThrow(() -> new RuntimeException("Wait timed out")))
+ .map(success -> success ? "1" : "0")
+ .collect(joining(",", "VALUES (", ")"));
+
+ assertThat(query("TABLE " + tableName))
+ .matches(expected);
+ }
+ finally {
+ executor.shutdownNow();
+ executor.awaitTermination(10, SECONDS);
+ }
+ }
+
+ protected void verifyConcurrentUpdateFailurePermissible(Exception e)
+ {
+ // By default, do not expect UPDATE to fail in case of concurrent updates
+ throw new AssertionError("Unexpected concurrent update failure", e);
+ }
+
@Test
public void testTruncateTable()
{