diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 9128f9fdd1a4..6703811a8d3d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -61,6 +61,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; import java.util.function.BiFunction; @@ -81,6 +82,7 @@ import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.getWriteBatchSize; +import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.getWriteParallelism; import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalInsert; import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharReadFunction; @@ -1342,6 +1344,12 @@ public void truncateTable(ConnectorSession session, JdbcTableHandle handle) execute(session, sql); } + @Override + public OptionalInt getMaxWriteParallelism(ConnectorSession session) + { + return OptionalInt.of(getWriteParallelism(session)); + } + protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) throws SQLException { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 165babb7316a..71ea97ac126a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -54,6 +54,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.Callable; @@ -565,6 +566,12 @@ public Optional getTableScanRedirection(Conn return delegate.getTableScanRedirection(session, tableHandle); } + @Override + public OptionalInt getMaxWriteParallelism(ConnectorSession session) + { + return delegate.getMaxWriteParallelism(session); + } + public void onDataChanged(SchemaTableName table) { invalidateAllIf(statisticsCache, key -> key.mayReference(table)); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 5f171d81bcd7..8d49c4f79861 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -72,6 +72,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; @@ -994,6 +995,12 @@ public void renameSchema(ConnectorSession session, String schemaName, String new jdbcClient.renameSchema(session, schemaName, newSchemaName); } + @Override + public OptionalInt getMaxWriterTasks(ConnectorSession session) + { + return jdbcClient.getMaxWriteParallelism(session); + } + private static boolean isTableHandleForProcedure(ConnectorTableHandle tableHandle) { return tableHandle instanceof JdbcProcedureHandle; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 794244cc9ec2..9c576d18fe73 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; import java.util.function.Supplier; @@ -431,4 +432,10 @@ public void truncateTable(ConnectorSession session, JdbcTableHandle handle) { delegate().truncateTable(session, handle); } + + @Override + public OptionalInt getMaxWriteParallelism(ConnectorSession session) + { + return delegate().getMaxWriteParallelism(session); + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index 5a8e2ff3df48..0fa5b42716d8 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -40,6 +40,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; @@ -229,4 +230,6 @@ default Optional getTableScanRedirection(Con OptionalLong delete(ConnectorSession session, JdbcTableHandle handle); void truncateTable(ConnectorSession session, JdbcTableHandle handle); + + OptionalInt getMaxWriteParallelism(ConnectorSession session); } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java index 84ef6eae998c..1a08c572e7f4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java @@ -22,8 +22,10 @@ public class JdbcWriteConfig { public static final int MAX_ALLOWED_WRITE_BATCH_SIZE = 10_000_000; + static final int DEFAULT_WRITE_PARALELLISM = 8; private int writeBatchSize = 1000; + private int writeParallelism = DEFAULT_WRITE_PARALELLISM; // Do not create temporary table during insert. // This means that the write operation can fail and leave the table in an inconsistent state. @@ -57,4 +59,19 @@ public JdbcWriteConfig setNonTransactionalInsert(boolean nonTransactionalInsert) this.nonTransactionalInsert = nonTransactionalInsert; return this; } + + @Min(1) + @Max(128) + public int getWriteParallelism() + { + return writeParallelism; + } + + @Config("write.parallelism") + @ConfigDescription("Maximum number of parallel write tasks") + public JdbcWriteConfig setWriteParallelism(int writeParallelism) + { + this.writeParallelism = writeParallelism; + return this; + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java index f1b30fd65360..78e6d12d2298 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java @@ -33,6 +33,7 @@ public class JdbcWriteSessionProperties { public static final String WRITE_BATCH_SIZE = "write_batch_size"; public static final String NON_TRANSACTIONAL_INSERT = "non_transactional_insert"; + public static final String WRITE_PARALLELISM = "write_parallelism"; private final List> properties; @@ -51,6 +52,11 @@ public JdbcWriteSessionProperties(JdbcWriteConfig writeConfig) "Do not use temporary table on insert to table", writeConfig.isNonTransactionalInsert(), false)) + .add(integerProperty( + WRITE_PARALLELISM, + "Maximum number of parallel write tasks", + writeConfig.getWriteParallelism(), + false)) .build(); } @@ -65,6 +71,11 @@ public static int getWriteBatchSize(ConnectorSession session) return session.getProperty(WRITE_BATCH_SIZE, Integer.class); } + public static int getWriteParallelism(ConnectorSession session) + { + return session.getProperty(WRITE_PARALLELISM, Integer.class); + } + public static boolean isNonTransactionalInsert(ConnectorSession session) { return session.getProperty(NON_TRANSACTIONAL_INSERT, Boolean.class); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index 62e70f57f36f..3441d62d40ad 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -56,6 +56,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; @@ -451,4 +452,10 @@ public void truncateTable(ConnectorSession session, JdbcTableHandle handle) { stats.getTruncateTable().wrap(() -> delegate().truncateTable(session, handle)); } + + @Override + public OptionalInt getMaxWriteParallelism(ConnectorSession session) + { + return delegate().getMaxWriteParallelism(session); + } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 6d757d5498a0..be061743f986 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -29,6 +29,7 @@ import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.ValuesNode; import io.trino.sql.query.QueryAssertions.QueryAssert; @@ -73,6 +74,7 @@ import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; @@ -1683,6 +1685,46 @@ public void testWriteBatchSizeSessionProperty(Integer batchSize, Integer numberO } } + @Test(dataProvider = "writeTaskParallelismDataProvider") + public void testWriteTaskParallelismSessionProperty(int parallelism, int numberOfRows) + { + if (!hasBehavior(SUPPORTS_CREATE_TABLE)) { + throw new SkipException("CREATE TABLE is required for write_parallelism test but is not supported"); + } + + Session session = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "write_parallelism", String.valueOf(parallelism)) + .build(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "write_parallelism", + "(a varchar(128), b bigint)")) { + assertUpdate(session, "INSERT INTO " + table.getName() + " (a, b) SELECT clerk, orderkey FROM tpch.sf100.orders LIMIT " + numberOfRows, numberOfRows, plan -> { + TableWriterNode.WriterTarget target = searchFrom(plan.getRoot()) + .where(node -> node instanceof TableWriterNode) + .findFirst() + .map(TableWriterNode.class::cast) + .map(TableWriterNode::getTarget) + .orElseThrow(); + + assertThat(target.getMaxWriterTasks(getQueryRunner().getMetadata(), getSession())) + .hasValue(parallelism); + }); + } + } + + @DataProvider + public static Object[][] writeTaskParallelismDataProvider() + { + return new Object[][]{ + {1, 10_000}, + {2, 10_000}, + {4, 10_000}, + {16, 10_000}, + {32, 10_000}}; + } + private static List buildRowsForInsert(int numberOfRows) { List result = new ArrayList<>(numberOfRows); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java index fa71f8beb9e3..6a3b46178d58 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java @@ -33,6 +33,7 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(JdbcWriteConfig.class) .setWriteBatchSize(1000) + .setWriteParallelism(8) .setNonTransactionalInsert(false)); } @@ -42,11 +43,13 @@ public void testExplicitPropertyMappings() Map properties = ImmutableMap.builder() .put("write.batch-size", "24") .put("insert.non-transactional-insert.enabled", "true") + .put("write.parallelism", "16") .buildOrThrow(); JdbcWriteConfig expected = new JdbcWriteConfig() .setWriteBatchSize(24) - .setNonTransactionalInsert(true); + .setNonTransactionalInsert(true) + .setWriteParallelism(16); assertFullMapping(properties, expected); }