diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 9115f719e2de..a20bf398fc5c 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -92,7 +92,7 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: rm -rf ~/.m2/repository/io/trino/trino-*
- check-commits:
+ check-commits-dispatcher:
runs-on: ubuntu-latest
if: github.event_name == 'pull_request'
outputs:
@@ -101,12 +101,12 @@ jobs:
- uses: actions/checkout@v3
with:
fetch-depth: 0 # checkout all commits to be able to determine merge base
- - name: Check Commits
+ - name: Block illegal commits
uses: trinodb/github-actions/block-commits@c2991972560c5219d9ae5fb68c0c9d687ffcdd10
with:
action-merge: fail
action-fixup: none
- - name: Set matrix
+ - name: Set matrix (dispatch commit checks)
id: set-matrix
run: |
# The output from rev-list ends with a newline, so we have to filter out index -1 in jq since it's an empty string
@@ -125,12 +125,13 @@ jobs:
echo "Commit matrix: $(jq '.' commit-matrix.json)"
echo "matrix=$(jq -c '.' commit-matrix.json)" >> $GITHUB_OUTPUT
- check-commits-dispatcher:
+ check-commit:
runs-on: ubuntu-latest
- needs: check-commits
- if: github.event_name == 'pull_request' && needs.check-commits.outputs.matrix != ''
+ needs: check-commits-dispatcher
+ if: github.event_name == 'pull_request' && needs.check-commits-dispatcher.outputs.matrix != ''
strategy:
- matrix: ${{ fromJson(needs.check-commits.outputs.matrix) }}
+ fail-fast: false
+ matrix: ${{ fromJson(needs.check-commits-dispatcher.outputs.matrix) }}
steps:
- uses: actions/checkout@v3
with:
diff --git a/client/trino-cli/pom.xml b/client/trino-cli/pom.xml
index 10d868304314..5146a7f34e92 100644
--- a/client/trino-cli/pom.xml
+++ b/client/trino-cli/pom.xml
@@ -5,7 +5,7 @@
io.trino
trino-root
- 406-SNAPSHOT
+ 407-SNAPSHOT
../../pom.xml
diff --git a/client/trino-client/pom.xml b/client/trino-client/pom.xml
index b14dcc92265a..07ab050377c6 100644
--- a/client/trino-client/pom.xml
+++ b/client/trino-client/pom.xml
@@ -5,7 +5,7 @@
io.trino
trino-root
- 406-SNAPSHOT
+ 407-SNAPSHOT
../../pom.xml
diff --git a/client/trino-jdbc/pom.xml b/client/trino-jdbc/pom.xml
index 08647ebe6bc7..893d5cc0049b 100644
--- a/client/trino-jdbc/pom.xml
+++ b/client/trino-jdbc/pom.xml
@@ -5,7 +5,7 @@
io.trino
trino-root
- 406-SNAPSHOT
+ 407-SNAPSHOT
../../pom.xml
diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java
index 5677816e5eaa..509e939a0481 100644
--- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java
+++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDatabaseMetaData.java
@@ -813,7 +813,7 @@ public void testGetColumns()
assertColumnSpec(rs, Types.TIME, 15L, null, 6L, null, createTimeType(6));
assertColumnSpec(rs, Types.TIME, 18L, null, 9L, null, createTimeType(9));
assertColumnSpec(rs, Types.TIME, 21L, null, 12L, null, createTimeType(12));
- assertColumnSpec(rs, Types.TIME_WITH_TIMEZONE, 18L, null, 3L, null, TimeWithTimeZoneType.TIME_WITH_TIME_ZONE);
+ assertColumnSpec(rs, Types.TIME_WITH_TIMEZONE, 18L, null, 3L, null, TimeWithTimeZoneType.TIME_TZ_MILLIS);
assertColumnSpec(rs, Types.TIME_WITH_TIMEZONE, 14L, null, 0L, null, createTimeWithTimeZoneType(0));
assertColumnSpec(rs, Types.TIME_WITH_TIMEZONE, 18L, null, 3L, null, createTimeWithTimeZoneType(3));
assertColumnSpec(rs, Types.TIME_WITH_TIMEZONE, 21L, null, 6L, null, createTimeWithTimeZoneType(6));
diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml
index 03619df45be0..073fb15da5f5 100644
--- a/core/trino-main/pom.xml
+++ b/core/trino-main/pom.xml
@@ -5,7 +5,7 @@
io.trino
trino-root
- 406-SNAPSHOT
+ 407-SNAPSHOT
../../pom.xml
diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java
index 9b3a2f90a6bd..49c7cfeb1825 100644
--- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java
+++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java
@@ -60,7 +60,8 @@ public final class SystemSessionProperties
public static final String JOIN_MAX_BROADCAST_TABLE_SIZE = "join_max_broadcast_table_size";
public static final String JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR = "join_multi_clause_independence_factor";
public static final String DISTRIBUTED_INDEX_JOIN = "distributed_index_join";
- public static final String HASH_PARTITION_COUNT = "hash_partition_count";
+ public static final String MAX_HASH_PARTITION_COUNT = "max_hash_partition_count";
+ public static final String MIN_HASH_PARTITION_COUNT = "min_hash_partition_count";
public static final String PREFER_STREAMING_OPERATORS = "prefer_streaming_operators";
public static final String TASK_WRITER_COUNT = "task_writer_count";
public static final String TASK_PARTITIONED_WRITER_COUNT = "task_partitioned_writer_count";
@@ -80,6 +81,7 @@ public final class SystemSessionProperties
public static final String PREFERRED_WRITE_PARTITIONING_MIN_NUMBER_OF_PARTITIONS = "preferred_write_partitioning_min_number_of_partitions";
public static final String SCALE_WRITERS = "scale_writers";
public static final String TASK_SCALE_WRITERS_ENABLED = "task_scale_writers_enabled";
+ public static final String MAX_WRITERS_NODES_COUNT = "max_writers_nodes_count";
public static final String TASK_SCALE_WRITERS_MAX_WRITER_COUNT = "task_scale_writers_max_writer_count";
public static final String WRITER_MIN_SIZE = "writer_min_size";
public static final String PUSH_TABLE_WRITE_THROUGH_UNION = "push_table_write_through_union";
@@ -173,10 +175,13 @@ public final class SystemSessionProperties
public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
public static final String JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT = "join_partitioned_build_min_row_count";
+ public static final String MIN_INPUT_SIZE_PER_TASK = "min_input_size_per_task";
+ public static final String MIN_INPUT_ROWS_PER_TASK = "min_input_rows_per_task";
public static final String USE_EXACT_PARTITIONING = "use_exact_partitioning";
public static final String FORCE_SPILLING_JOIN = "force_spilling_join";
public static final String FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED = "fault_tolerant_execution_event_driven_scheduler_enabled";
public static final String FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED = "fault_tolerant_execution_force_preferred_write_partitioning_enabled";
+ public static final String PAGE_PARTITIONING_BUFFER_POOL_SIZE = "page_partitioning_buffer_pool_size";
private final List> sessionProperties;
@@ -241,9 +246,14 @@ public SystemSessionProperties(
optimizerConfig.isDistributedIndexJoinsEnabled(),
false),
integerProperty(
- HASH_PARTITION_COUNT,
- "Number of partitions for distributed joins and aggregations",
- queryManagerConfig.getHashPartitionCount(),
+ MAX_HASH_PARTITION_COUNT,
+ "Maximum number of partitions for distributed joins and aggregations",
+ queryManagerConfig.getMaxHashPartitionCount(),
+ false),
+ integerProperty(
+ MIN_HASH_PARTITION_COUNT,
+ "Minimum number of partitions for distributed joins and aggregations",
+ queryManagerConfig.getMinHashPartitionCount(),
false),
booleanProperty(
PREFER_STREAMING_OPERATORS,
@@ -286,6 +296,11 @@ public SystemSessionProperties(
"Scale out writers based on throughput (use minimum necessary)",
featuresConfig.isScaleWriters(),
false),
+ integerProperty(
+ MAX_WRITERS_NODES_COUNT,
+ "Set upper limit on number of nodes that take part in writing if task.scale-writers.enabled is set",
+ queryManagerConfig.getMaxWritersNodesCount(),
+ false),
booleanProperty(
TASK_SCALE_WRITERS_ENABLED,
"Scale the number of concurrent table writers per task based on throughput",
@@ -860,6 +875,16 @@ public SystemSessionProperties(
optimizerConfig.getJoinPartitionedBuildMinRowCount(),
value -> validateNonNegativeLongValue(value, JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT),
false),
+ dataSizeProperty(
+ MIN_INPUT_SIZE_PER_TASK,
+ "Minimum input data size required per task. This will help optimizer determine hash partition count for joins and aggregations",
+ optimizerConfig.getMinInputSizePerTask(),
+ false),
+ longProperty(
+ MIN_INPUT_ROWS_PER_TASK,
+ "Minimum input rows required per task. This will help optimizer determine hash partition count for joins and aggregations",
+ optimizerConfig.getMinInputRowsPerTask(),
+ false),
booleanProperty(
USE_EXACT_PARTITIONING,
"When enabled this forces data repartitioning unless the partitioning of upstream stage matches exactly what downstream stage expects",
@@ -879,6 +904,10 @@ public SystemSessionProperties(
FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED,
"Force preferred write partitioning for fault tolerant execution",
queryManagerConfig.isFaultTolerantExecutionForcePreferredWritePartitioningEnabled(),
+ true),
+ integerProperty(PAGE_PARTITIONING_BUFFER_POOL_SIZE,
+ "Maximum number of free buffers in the per task partitioned page buffer pool. Setting this to zero effectively disables the pool",
+ taskManagerConfig.getPagePartitioningBufferPoolSize(),
true));
}
@@ -918,9 +947,14 @@ public static boolean isDistributedIndexJoinEnabled(Session session)
return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class);
}
- public static int getHashPartitionCount(Session session)
+ public static int getMaxHashPartitionCount(Session session)
+ {
+ return session.getSystemProperty(MAX_HASH_PARTITION_COUNT, Integer.class);
+ }
+
+ public static int getMinHashPartitionCount(Session session)
{
- return session.getSystemProperty(HASH_PARTITION_COUNT, Integer.class);
+ return session.getSystemProperty(MIN_HASH_PARTITION_COUNT, Integer.class);
}
public static boolean preferStreamingOperators(Session session)
@@ -968,6 +1002,11 @@ public static int getTaskScaleWritersMaxWriterCount(Session session)
return session.getSystemProperty(TASK_SCALE_WRITERS_MAX_WRITER_COUNT, Integer.class);
}
+ public static int getMaxWritersNodesCount(Session session)
+ {
+ return session.getSystemProperty(MAX_WRITERS_NODES_COUNT, Integer.class);
+ }
+
public static DataSize getWriterMinSize(Session session)
{
return session.getSystemProperty(WRITER_MIN_SIZE, DataSize.class);
@@ -1548,6 +1587,16 @@ public static long getJoinPartitionedBuildMinRowCount(Session session)
return session.getSystemProperty(JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT, Long.class);
}
+ public static DataSize getMinInputSizePerTask(Session session)
+ {
+ return session.getSystemProperty(MIN_INPUT_SIZE_PER_TASK, DataSize.class);
+ }
+
+ public static long getMinInputRowsPerTask(Session session)
+ {
+ return session.getSystemProperty(MIN_INPUT_ROWS_PER_TASK, Long.class);
+ }
+
public static boolean isUseExactPartitioning(Session session)
{
return session.getSystemProperty(USE_EXACT_PARTITIONING, Boolean.class);
@@ -1571,4 +1620,9 @@ public static boolean isFaultTolerantExecutionForcePreferredWritePartitioningEna
}
return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED, Boolean.class);
}
+
+ public static int getPagePartitioningBufferPoolSize(Session session)
+ {
+ return session.getSystemProperty(PAGE_PARTITIONING_BUFFER_POOL_SIZE, Integer.class);
+ }
}
diff --git a/core/trino-main/src/main/java/io/trino/connector/ConnectorContextInstance.java b/core/trino-main/src/main/java/io/trino/connector/ConnectorContextInstance.java
index cf7549f1c235..486cd3b23d2e 100644
--- a/core/trino-main/src/main/java/io/trino/connector/ConnectorContextInstance.java
+++ b/core/trino-main/src/main/java/io/trino/connector/ConnectorContextInstance.java
@@ -17,6 +17,7 @@
import io.trino.spi.PageIndexerFactory;
import io.trino.spi.PageSorter;
import io.trino.spi.VersionEmbedder;
+import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.ConnectorContext;
import io.trino.spi.connector.MetadataProvider;
import io.trino.spi.type.TypeManager;
@@ -38,8 +39,10 @@ public class ConnectorContextInstance
private final PageIndexerFactory pageIndexerFactory;
private final Supplier duplicatePluginClassLoaderFactory;
private final AtomicBoolean pluginClassLoaderDuplicated = new AtomicBoolean();
+ private final CatalogHandle catalogHandle;
public ConnectorContextInstance(
+ CatalogHandle catalogHandle,
NodeManager nodeManager,
VersionEmbedder versionEmbedder,
TypeManager typeManager,
@@ -55,6 +58,13 @@ public ConnectorContextInstance(
this.pageSorter = requireNonNull(pageSorter, "pageSorter is null");
this.pageIndexerFactory = requireNonNull(pageIndexerFactory, "pageIndexerFactory is null");
this.duplicatePluginClassLoaderFactory = requireNonNull(duplicatePluginClassLoaderFactory, "duplicatePluginClassLoaderFactory is null");
+ this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null");
+ }
+
+ @Override
+ public CatalogHandle getCatalogHandle()
+ {
+ return catalogHandle;
}
@Override
diff --git a/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java b/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java
index f16872bc4e9d..712ec147f958 100644
--- a/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java
+++ b/core/trino-main/src/main/java/io/trino/connector/DefaultCatalogFactory.java
@@ -190,6 +190,7 @@ private Connector createConnector(
Map properties)
{
ConnectorContext context = new ConnectorContextInstance(
+ catalogHandle,
new ConnectorAwareNodeManager(nodeManager, nodeInfo.getEnvironment(), catalogHandle, schedulerIncludeCoordinator),
versionEmbedder,
typeManager,
diff --git a/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java b/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java
index 2f76f9778429..5eeb0447bf74 100644
--- a/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java
+++ b/core/trino-main/src/main/java/io/trino/connector/InternalMetadataProvider.java
@@ -13,6 +13,7 @@
*/
package io.trino.connector;
+import com.google.common.collect.ImmutableList;
import io.trino.FullConnectorSession;
import io.trino.Session;
import io.trino.metadata.MaterializedViewDefinition;
@@ -54,12 +55,12 @@ public Optional getRelationMetadata(ConnectorSession conne
Optional materializedView = metadata.getMaterializedView(session, qualifiedName);
if (materializedView.isPresent()) {
- return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns())));
+ return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(materializedView.get().getColumns()), ImmutableList.of()));
}
Optional view = metadata.getView(session, qualifiedName);
if (view.isPresent()) {
- return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns())));
+ return Optional.of(new ConnectorTableSchema(tableName.getSchemaTableName(), toColumnSchema(view.get().getColumns()), ImmutableList.of()));
}
Optional tableHandle = metadata.getTableHandle(session, qualifiedName);
diff --git a/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java b/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java
index 29998ce48fad..c1c5ca958a99 100644
--- a/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java
+++ b/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java
@@ -54,6 +54,7 @@
import static io.trino.connector.informationschema.InformationSchemaMetadata.defaultPrefixes;
import static io.trino.connector.informationschema.InformationSchemaMetadata.isTablesEnumeratingTable;
import static io.trino.metadata.MetadataListing.getViews;
+import static io.trino.metadata.MetadataListing.listMaterializedViews;
import static io.trino.metadata.MetadataListing.listSchemas;
import static io.trino.metadata.MetadataListing.listTableColumns;
import static io.trino.metadata.MetadataListing.listTablePrivileges;
@@ -271,12 +272,18 @@ private void addColumnsRecords(QualifiedTablePrefix prefix)
private void addTablesRecords(QualifiedTablePrefix prefix)
{
Set tables = listTables(session, metadata, accessControl, prefix);
+ Set materializedViews = listMaterializedViews(session, metadata, accessControl, prefix);
Set views = listViews(session, metadata, accessControl, prefix);
- // TODO (https://github.com/trinodb/trino/issues/8207) define a type for materialized views
- for (SchemaTableName name : union(tables, views)) {
+ for (SchemaTableName name : union(union(tables, materializedViews), views)) {
// if table and view names overlap, the view wins
- String type = views.contains(name) ? "VIEW" : "BASE TABLE";
+ String type = "BASE TABLE";
+ if (materializedViews.contains(name)) {
+ type = "MATERIALIZED VIEW";
+ }
+ else if (views.contains(name)) {
+ type = "VIEW";
+ }
addRecord(
prefix.getCatalogName(),
name.getSchemaName(),
diff --git a/core/trino-main/src/main/java/io/trino/cost/TaskCountEstimator.java b/core/trino-main/src/main/java/io/trino/cost/TaskCountEstimator.java
index 3eaa86ecac2f..488bb5749281 100644
--- a/core/trino-main/src/main/java/io/trino/cost/TaskCountEstimator.java
+++ b/core/trino-main/src/main/java/io/trino/cost/TaskCountEstimator.java
@@ -26,7 +26,7 @@
import static io.trino.SystemSessionProperties.getCostEstimationWorkerCount;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionPartitionCount;
-import static io.trino.SystemSessionProperties.getHashPartitionCount;
+import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
@@ -73,7 +73,7 @@ public int estimateHashedTaskCount(Session session)
partitionCount = getFaultTolerantExecutionPartitionCount(session);
}
else {
- partitionCount = getHashPartitionCount(session);
+ partitionCount = getMaxHashPartitionCount(session);
}
return min(estimateSourceDistributedTaskCount(session), partitionCount);
}
diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java
index e391706a4c45..5ffe096f1b96 100644
--- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java
+++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java
@@ -53,7 +53,10 @@ public class QueryManagerConfig
private int maxConcurrentQueries = 1000;
private int maxQueuedQueries = 5000;
- private int hashPartitionCount = 100;
+ private int maxHashPartitionCount = 100;
+ private int minHashPartitionCount = 4;
+ private int maxWritersNodesCount = 100;
+
private Duration minQueryExpireAge = new Duration(15, TimeUnit.MINUTES);
private int maxQueryHistory = 100;
private int maxQueryLength = 1_000_000;
@@ -159,17 +162,46 @@ public QueryManagerConfig setMaxQueuedQueries(int maxQueuedQueries)
}
@Min(1)
- public int getHashPartitionCount()
+ public int getMaxHashPartitionCount()
+ {
+ return maxHashPartitionCount;
+ }
+
+ @Config("query.max-hash-partition-count")
+ @LegacyConfig({"query.initial-hash-partitions", "query.hash-partition-count"})
+ @ConfigDescription("Maximum number of partitions for distributed joins and aggregations")
+ public QueryManagerConfig setMaxHashPartitionCount(int maxHashPartitionCount)
+ {
+ this.maxHashPartitionCount = maxHashPartitionCount;
+ return this;
+ }
+
+ @Min(1)
+ public int getMinHashPartitionCount()
+ {
+ return minHashPartitionCount;
+ }
+
+ @Config("query.min-hash-partition-count")
+ @ConfigDescription("Minimum number of partitions for distributed joins and aggregations")
+ public QueryManagerConfig setMinHashPartitionCount(int minHashPartitionCount)
+ {
+ this.minHashPartitionCount = minHashPartitionCount;
+ return this;
+ }
+
+ @Min(1)
+ public int getMaxWritersNodesCount()
{
- return hashPartitionCount;
+ return maxWritersNodesCount;
}
- @Config("query.hash-partition-count")
- @LegacyConfig("query.initial-hash-partitions")
- @ConfigDescription("Number of partitions for distributed joins and aggregations")
- public QueryManagerConfig setHashPartitionCount(int hashPartitionCount)
+ @Config("query.max-writer-node-count")
+ @ConfigDescription("Maximum number of nodes that will take part in writer tasks. It is an upper bound on scaling of writers " +
+ "and works only if task.scale-writers.enabled is set")
+ public QueryManagerConfig setMaxWritersNodesCount(int maxWritersNodesCount)
{
- this.hashPartitionCount = hashPartitionCount;
+ this.maxWritersNodesCount = maxWritersNodesCount;
return this;
}
diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java
index afcbc76d6997..975c6d310d8f 100644
--- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java
+++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java
@@ -278,6 +278,7 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder)
long runningPartitionedSplitsWeight = 0L;
DataSize outputDataSize = DataSize.ofBytes(0);
DataSize physicalWrittenDataSize = DataSize.ofBytes(0);
+ Optional writerCount = Optional.empty();
DataSize userMemoryReservation = DataSize.ofBytes(0);
DataSize peakUserMemoryReservation = DataSize.ofBytes(0);
DataSize revocableMemoryReservation = DataSize.ofBytes(0);
@@ -292,6 +293,7 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder)
runningPartitionedDrivers = taskStats.getRunningPartitionedDrivers();
runningPartitionedSplitsWeight = taskStats.getRunningPartitionedSplitsWeight();
physicalWrittenDataSize = taskStats.getPhysicalWrittenDataSize();
+ writerCount = taskStats.getMaxWriterCount();
userMemoryReservation = taskStats.getUserMemoryReservation();
peakUserMemoryReservation = taskStats.getPeakUserMemoryReservation();
revocableMemoryReservation = taskStats.getRevocableMemoryReservation();
@@ -312,6 +314,7 @@ else if (taskHolder.getTaskExecution() != null) {
physicalWrittenBytes += pipelineContext.getPhysicalWrittenDataSize();
}
physicalWrittenDataSize = succinctBytes(physicalWrittenBytes);
+ writerCount = taskContext.getMaxWriterCount();
userMemoryReservation = taskContext.getMemoryReservation();
peakUserMemoryReservation = taskContext.getPeakMemoryReservation();
revocableMemoryReservation = taskContext.getRevocableMemoryReservation();
@@ -334,6 +337,7 @@ else if (taskHolder.getTaskExecution() != null) {
outputBuffer.getStatus(),
outputDataSize,
physicalWrittenDataSize,
+ writerCount,
userMemoryReservation,
peakUserMemoryReservation,
revocableMemoryReservation,
diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java
index 846e17136931..dc27dea56ca3 100644
--- a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java
+++ b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java
@@ -64,6 +64,7 @@ public class TaskManagerConfig
private DataSize sinkMaxBufferSize = DataSize.of(32, Unit.MEGABYTE);
private DataSize sinkMaxBroadcastBufferSize = DataSize.of(200, Unit.MEGABYTE);
private DataSize maxPagePartitioningBufferSize = DataSize.of(32, Unit.MEGABYTE);
+ private int pagePartitioningBufferPoolSize = 8;
private Duration clientTimeout = new Duration(2, TimeUnit.MINUTES);
private Duration infoMaxAge = new Duration(15, TimeUnit.MINUTES);
@@ -77,10 +78,11 @@ public class TaskManagerConfig
private Duration interruptStuckSplitTasksDetectionInterval = new Duration(2, TimeUnit.MINUTES);
private boolean scaleWritersEnabled = true;
- // The default value is 8 because it is better in performance compare to 2 or 4
- // and acceptable in terms of resource utilization since values like 32 or higher could take
- // more resources, hence potentially affect the other concurrent queries in the cluster.
- private int scaleWritersMaxWriterCount = 8;
+ // Set the value of default max writer count to the number of processors and cap it to 32. We can do this
+ // because preferred write partitioning is always enabled for local exchange thus partitioned inserts will never
+ // use this property. Hence, there is no risk in terms of more numbers of physical writers which can cause high
+ // resource utilization.
+ private int scaleWritersMaxWriterCount = min(getAvailablePhysicalProcessorCount(), 32);
private int writerCount = 1;
// Default value of partitioned task writer count should be above 1, otherwise it can create a plan
// with a single gather exchange node on the coordinator due to a single available processor. Whereas,
@@ -377,6 +379,20 @@ public TaskManagerConfig setMaxPagePartitioningBufferSize(DataSize size)
return this;
}
+ @Min(0)
+ public int getPagePartitioningBufferPoolSize()
+ {
+ return pagePartitioningBufferPoolSize;
+ }
+
+ @Config("driver.page-partitioning-buffer-pool-size")
+ @ConfigDescription("Maximum number of free buffers in the per task partitioned page buffer pool. Setting this to zero effectively disables the pool")
+ public TaskManagerConfig setPagePartitioningBufferPoolSize(int pagePartitioningBufferPoolSize)
+ {
+ this.pagePartitioningBufferPoolSize = pagePartitioningBufferPoolSize;
+ return this;
+ }
+
@MinDuration("5s")
@NotNull
public Duration getClientTimeout()
diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java b/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java
index 839154c85a39..50f1aac49b0d 100644
--- a/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java
+++ b/core/trino-main/src/main/java/io/trino/execution/TaskStatus.java
@@ -22,6 +22,7 @@
import java.net.URI;
import java.util.List;
+import java.util.Optional;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
@@ -59,6 +60,7 @@ public class TaskStatus
private final OutputBufferStatus outputBufferStatus;
private final DataSize outputDataSize;
private final DataSize physicalWrittenDataSize;
+ private final Optional maxWriterCount;
private final DataSize memoryReservation;
private final DataSize peakMemoryReservation;
private final DataSize revocableMemoryReservation;
@@ -84,6 +86,7 @@ public TaskStatus(
@JsonProperty("outputBufferStatus") OutputBufferStatus outputBufferStatus,
@JsonProperty("outputDataSize") DataSize outputDataSize,
@JsonProperty("physicalWrittenDataSize") DataSize physicalWrittenDataSize,
+ @JsonProperty("writerCount") Optional maxWriterCount,
@JsonProperty("memoryReservation") DataSize memoryReservation,
@JsonProperty("peakMemoryReservation") DataSize peakMemoryReservation,
@JsonProperty("revocableMemoryReservation") DataSize revocableMemoryReservation,
@@ -116,6 +119,7 @@ public TaskStatus(
this.outputDataSize = requireNonNull(outputDataSize, "outputDataSize is null");
this.physicalWrittenDataSize = requireNonNull(physicalWrittenDataSize, "physicalWrittenDataSize is null");
+ this.maxWriterCount = requireNonNull(maxWriterCount, "maxWriterCount is null");
this.memoryReservation = requireNonNull(memoryReservation, "memoryReservation is null");
this.peakMemoryReservation = requireNonNull(peakMemoryReservation, "peakMemoryReservation is null");
@@ -189,6 +193,12 @@ public DataSize getPhysicalWrittenDataSize()
return physicalWrittenDataSize;
}
+ @JsonProperty
+ public Optional getMaxWriterCount()
+ {
+ return maxWriterCount;
+ }
+
@JsonProperty
public OutputBufferStatus getOutputBufferStatus()
{
@@ -273,6 +283,7 @@ public static TaskStatus initialTaskStatus(TaskId taskId, URI location, String n
OutputBufferStatus.initial(),
DataSize.ofBytes(0),
DataSize.ofBytes(0),
+ Optional.empty(),
DataSize.ofBytes(0),
DataSize.ofBytes(0),
DataSize.ofBytes(0),
@@ -298,6 +309,7 @@ public static TaskStatus failWith(TaskStatus taskStatus, TaskState state, List partitioningCacheMap = new HashMap<>();
- Function partitioningCache = partitioningHandle ->
+ BiFunction, NodePartitionMap> partitioningCache = (partitioningHandle, partitionCount) ->
partitioningCacheMap.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(
queryStateMachine.getSession(),
// TODO: support hash distributed writer scaling (https://github.com/trinodb/trino/issues/10791)
- handle.equals(SCALED_WRITER_HASH_DISTRIBUTION) ? FIXED_HASH_DISTRIBUTION : handle));
+ handle.equals(SCALED_WRITER_HASH_DISTRIBUTION) ? FIXED_HASH_DISTRIBUTION : handle,
+ partitionCount));
Map> bucketToPartitionMap = createBucketToPartitionMap(
coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(),
@@ -937,13 +936,18 @@ public static DistributedStagesScheduler create(
private static Map> createBucketToPartitionMap(
Map> bucketToPartitionForStagesConsumedByCoordinator,
StageManager stageManager,
- Function partitioningCache)
+ BiFunction, NodePartitionMap> partitioningCache)
{
ImmutableMap.Builder> result = ImmutableMap.builder();
result.putAll(bucketToPartitionForStagesConsumedByCoordinator);
for (SqlStage stage : stageManager.getDistributedStagesInTopologicalOrder()) {
PlanFragment fragment = stage.getFragment();
- Optional bucketToPartition = getBucketToPartition(fragment.getPartitioning(), partitioningCache, fragment.getRoot(), fragment.getRemoteSourceNodes());
+ Optional bucketToPartition = getBucketToPartition(
+ fragment.getPartitioning(),
+ partitioningCache,
+ fragment.getRoot(),
+ fragment.getRemoteSourceNodes(),
+ fragment.getPartitioningScheme().getPartitionCount());
for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) {
result.put(childStage.getFragment().getId(), bucketToPartition);
}
@@ -953,9 +957,10 @@ private static Map> createBucketToPartitionMap(
private static Optional getBucketToPartition(
PartitioningHandle partitioningHandle,
- Function partitioningCache,
+ BiFunction, NodePartitionMap> partitioningCache,
PlanNode fragmentRoot,
- List remoteSourceNodes)
+ List remoteSourceNodes,
+ Optional partitionCount)
{
if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
return Optional.of(new int[1]);
@@ -965,10 +970,10 @@ private static Optional getBucketToPartition(
return Optional.empty();
}
// remote source requires nodePartitionMap
- NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle);
+ NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle, partitionCount);
return Optional.of(nodePartitionMap.getBucketToPartition());
}
- NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle);
+ NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle, partitionCount);
List partitionToNode = nodePartitionMap.getPartitionToNode();
// todo this should asynchronously wait a standard timeout period before failing
checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available");
@@ -1011,7 +1016,7 @@ private static StageScheduler createStageScheduler(
StageExecution stageExecution,
SplitSourceFactory splitSourceFactory,
List childStageExecutions,
- Function partitioningCache,
+ BiFunction, NodePartitionMap> partitioningCache,
NodeScheduler nodeScheduler,
NodePartitioningManager nodePartitioningManager,
int splitBatchSize,
@@ -1022,6 +1027,7 @@ private static StageScheduler createStageScheduler(
Session session = queryStateMachine.getSession();
PlanFragment fragment = stageExecution.getFragment();
PartitioningHandle partitioningHandle = fragment.getPartitioning();
+ Optional partitionCount = fragment.getPartitioningScheme().getPartitionCount();
Map splitSources = splitSourceFactory.createSplitSources(session, fragment);
if (!splitSources.isEmpty()) {
queryStateMachine.addStateChangeListener(new StateChangeListener<>()
@@ -1077,7 +1083,7 @@ public void stateChanged(QueryState newState)
nodeScheduler.createNodeSelector(session, Optional.empty()),
executor,
getWriterMinSize(session),
- isTaskScaleWritersEnabled(session) ? getTaskScaleWritersMaxWriterCount(session) : getTaskWriterCount(session));
+ getMaxWritersNodesCount(session));
whenAllStages(childStageExecutions, StageExecution.State::isDone)
.addListener(scheduler::finish, directExecutor());
@@ -1087,7 +1093,7 @@ public void stateChanged(QueryState newState)
if (splitSources.isEmpty()) {
// all sources are remote
- NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle);
+ NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle, partitionCount);
List partitionToNode = nodePartitionMap.getPartitionToNode();
// todo this should asynchronously wait a standard timeout period before failing
checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available");
@@ -1109,7 +1115,7 @@ public void stateChanged(QueryState newState)
}
else {
// remote source requires nodePartitionMap
- NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle);
+ NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle, partitionCount);
stageNodeList = nodePartitionMap.getPartitionToNode();
bucketNodeMap = nodePartitionMap.asBucketNodeMap();
}
diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java
index 828cdba1bf79..39dd15a8a281 100644
--- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java
+++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java
@@ -48,9 +48,9 @@ public class ScaledWriterScheduler
private final NodeSelector nodeSelector;
private final ScheduledExecutorService executor;
private final long writerMinSizeBytes;
- private final int maxTaskWriterCount;
private final Set scheduledNodes = new HashSet<>();
private final AtomicBoolean done = new AtomicBoolean();
+ private final int maxWritersNodesCount;
private volatile SettableFuture future = SettableFuture.create();
public ScaledWriterScheduler(
@@ -60,7 +60,7 @@ public ScaledWriterScheduler(
NodeSelector nodeSelector,
ScheduledExecutorService executor,
DataSize writerMinSize,
- int maxTaskWriterCount)
+ int maxWritersNodesCount)
{
this.stage = requireNonNull(stage, "stage is null");
this.sourceTasksProvider = requireNonNull(sourceTasksProvider, "sourceTasksProvider is null");
@@ -68,7 +68,7 @@ public ScaledWriterScheduler(
this.nodeSelector = requireNonNull(nodeSelector, "nodeSelector is null");
this.executor = requireNonNull(executor, "executor is null");
this.writerMinSizeBytes = writerMinSize.toBytes();
- this.maxTaskWriterCount = maxTaskWriterCount;
+ this.maxWritersNodesCount = maxWritersNodesCount;
}
public void finish()
@@ -95,21 +95,44 @@ private int getNewTaskCount()
return 1;
}
- long writtenBytes = writerTasksProvider.get().stream()
- .map(TaskStatus::getPhysicalWrittenDataSize)
- .mapToLong(DataSize::toBytes)
- .sum();
+ Collection writerTasks = writerTasksProvider.get();
+ // Do not scale tasks until all existing writer tasks are initialized with maxWriterCount
+ if (writerTasks.size() != scheduledNodes.size()
+ || writerTasks.stream().map(TaskStatus::getMaxWriterCount).anyMatch(Optional::isEmpty)) {
+ return 0;
+ }
// When there is a big data skewness, there could be a bottleneck due to the skewed workers even if most of the workers are not over-utilized.
// Check both, weighted output buffer over-utilization rate and average output buffer over-utilization rate, in case when there are many over-utilized small tasks
// due to fewer not-over-utilized big skewed tasks.
- if ((isWeightedBufferFull() || isAverageBufferFull()) && (writtenBytes >= (writerMinSizeBytes * maxTaskWriterCount * scheduledNodes.size()))) {
+ if (isSourceTasksBufferFull() && isWriteThroughputSufficient() && scheduledNodes.size() < maxWritersNodesCount) {
return 1;
}
return 0;
}
+ private boolean isSourceTasksBufferFull()
+ {
+ return isAverageBufferFull() || isWeightedBufferFull();
+ }
+
+ private boolean isWriteThroughputSufficient()
+ {
+ Collection writerTasks = writerTasksProvider.get();
+ long writtenBytes = writerTasks.stream()
+ .map(TaskStatus::getPhysicalWrittenDataSize)
+ .mapToLong(DataSize::toBytes)
+ .sum();
+
+ long minWrittenBytesToScaleUp = writerTasks.stream()
+ .map(TaskStatus::getMaxWriterCount)
+ .map(Optional::get)
+ .mapToLong(writerCount -> writerMinSizeBytes * writerCount)
+ .sum();
+ return writtenBytes >= minWrittenBytesToScaleUp;
+ }
+
private boolean isWeightedBufferFull()
{
double totalOutputSize = 0.0;
diff --git a/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java b/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java
index efedd3b9e545..26c5ca2dc8a3 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/DisabledSystemSecurityMetadata.java
@@ -184,6 +184,15 @@ public void tableRenamed(Session session, CatalogSchemaTableName sourceTable, Ca
@Override
public void tableDropped(Session session, CatalogSchemaTableName table) {}
+ @Override
+ public void columnCreated(Session session, CatalogSchemaTableName table, String column) {}
+
+ @Override
+ public void columnRenamed(Session session, CatalogSchemaTableName table, String oldColumn, String newColumn) {}
+
+ @Override
+ public void columnDropped(Session session, CatalogSchemaTableName table, String column) {}
+
private static TrinoException notSupportedException(String catalogName)
{
return new TrinoException(NOT_SUPPORTED, "Catalog does not support permission management: " + catalogName);
diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
index 5d3396433e95..42efa512ece5 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java
@@ -726,6 +726,14 @@ public void renameColumn(Session session, TableHandle tableHandle, ColumnHandle
CatalogHandle catalogHandle = tableHandle.getCatalogHandle();
ConnectorMetadata metadata = getMetadataForWrite(session, catalogHandle);
metadata.renameColumn(session.toConnectorSession(catalogHandle), tableHandle.getConnectorHandle(), source, target.toLowerCase(ENGLISH));
+
+ CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogHandle);
+ ColumnMetadata sourceColumnMetadata = getColumnMetadata(session, tableHandle, source);
+ if (catalogMetadata.getSecurityManagement() != CONNECTOR) {
+ TableMetadata tableMetadata = getTableMetadata(session, tableHandle);
+ CatalogSchemaTableName sourceTableName = new CatalogSchemaTableName(catalogHandle.getCatalogName(), tableMetadata.getTable());
+ systemSecurityMetadata.columnRenamed(session, sourceTableName, sourceColumnMetadata.getName(), target);
+ }
}
@Override
@@ -734,6 +742,13 @@ public void addColumn(Session session, TableHandle tableHandle, ColumnMetadata c
CatalogHandle catalogHandle = tableHandle.getCatalogHandle();
ConnectorMetadata metadata = getMetadataForWrite(session, catalogHandle);
metadata.addColumn(session.toConnectorSession(catalogHandle), tableHandle.getConnectorHandle(), column);
+
+ CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogHandle);
+ if (catalogMetadata.getSecurityManagement() != CONNECTOR) {
+ TableMetadata tableMetadata = getTableMetadata(session, tableHandle);
+ CatalogSchemaTableName sourceTableName = new CatalogSchemaTableName(catalogHandle.getCatalogName(), tableMetadata.getTable());
+ systemSecurityMetadata.columnCreated(session, sourceTableName, column.getName());
+ }
}
@Override
@@ -741,7 +756,14 @@ public void dropColumn(Session session, TableHandle tableHandle, ColumnHandle co
{
CatalogHandle catalogHandle = tableHandle.getCatalogHandle();
ConnectorMetadata metadata = getMetadataForWrite(session, catalogHandle);
+ CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogHandle);
metadata.dropColumn(session.toConnectorSession(catalogHandle), tableHandle.getConnectorHandle(), column);
+ if (catalogMetadata.getSecurityManagement() != CONNECTOR) {
+ String columnName = getColumnMetadata(session, tableHandle, column).getName();
+ TableMetadata tableMetadata = getTableMetadata(session, tableHandle);
+ CatalogSchemaTableName sourceTableName = new CatalogSchemaTableName(catalogHandle.getCatalogName(), tableMetadata.getTable());
+ systemSecurityMetadata.columnDropped(session, sourceTableName, columnName);
+ }
}
@Override
diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java b/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java
index 7faa8d06d795..b8ba4de03796 100644
--- a/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java
+++ b/core/trino-main/src/main/java/io/trino/metadata/SystemSecurityMetadata.java
@@ -167,4 +167,19 @@ public interface SystemSecurityMetadata
* A table or view was dropped
*/
void tableDropped(Session session, CatalogSchemaTableName table);
+
+ /**
+ * A column was created
+ */
+ void columnCreated(Session session, CatalogSchemaTableName table, String column);
+
+ /**
+ * A column was renamed
+ */
+ void columnRenamed(Session session, CatalogSchemaTableName table, String oldColumn, String newColumn);
+
+ /**
+ * A column was dropped
+ */
+ void columnDropped(Session session, CatalogSchemaTableName table, String column);
}
diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java
index 5eed6d8fe2d6..85c726045e13 100644
--- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java
+++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java
@@ -19,6 +19,7 @@
import io.airlift.http.client.HttpClient;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
+import io.airlift.stats.TDigest;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.FeaturesConfig.DataIntegrityVerification;
@@ -27,6 +28,7 @@
import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.HttpPageBufferClient.ClientCallback;
import io.trino.operator.WorkProcessor.ProcessState;
+import io.trino.plugin.base.metrics.TDigestHistogram;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
@@ -83,6 +85,8 @@ public class DirectExchangeClient
private long averageBytesPerRequest;
@GuardedBy("this")
private boolean closed;
+ @GuardedBy("this")
+ private final TDigest requestDuration = new TDigest();
@GuardedBy("memoryContextLock")
@Nullable
@@ -143,7 +147,8 @@ public DirectExchangeClientStatus getStatus()
buffer.getSpilledPageCount(),
buffer.getSpilledBytes(),
noMoreLocations,
- pageBufferClientStatus);
+ pageBufferClientStatus,
+ new TDigestHistogram(TDigest.copyOf(requestDuration)));
}
}
@@ -369,6 +374,7 @@ private void releaseMemoryContext()
private synchronized void requestComplete(HttpPageBufferClient client)
{
+ requestDuration.add(client.getLastRequestDurationMillis());
if (!completedClients.contains(client) && !queuedClients.contains(client)) {
queuedClients.add(client);
}
diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java
index 6b130814e9f0..4d1bf5efd568 100644
--- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java
+++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClientStatus.java
@@ -16,6 +16,7 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
+import io.trino.plugin.base.metrics.TDigestHistogram;
import io.trino.spi.Mergeable;
import java.util.List;
@@ -35,6 +36,7 @@ public class DirectExchangeClientStatus
private final long spilledBytes;
private final boolean noMoreLocations;
private final List pageBufferClientStatuses;
+ private final TDigestHistogram requestDuration;
@JsonCreator
public DirectExchangeClientStatus(
@@ -46,7 +48,8 @@ public DirectExchangeClientStatus(
@JsonProperty("spilledPages") int spilledPages,
@JsonProperty("spilledBytes") long spilledBytes,
@JsonProperty("noMoreLocations") boolean noMoreLocations,
- @JsonProperty("pageBufferClientStatuses") List pageBufferClientStatuses)
+ @JsonProperty("pageBufferClientStatuses") List pageBufferClientStatuses,
+ @JsonProperty("requestDuration") TDigestHistogram requestDuration)
{
this.bufferedBytes = bufferedBytes;
this.maxBufferedBytes = maxBufferedBytes;
@@ -57,6 +60,7 @@ public DirectExchangeClientStatus(
this.spilledBytes = spilledBytes;
this.noMoreLocations = noMoreLocations;
this.pageBufferClientStatuses = ImmutableList.copyOf(requireNonNull(pageBufferClientStatuses, "pageBufferClientStatuses is null"));
+ this.requestDuration = requireNonNull(requestDuration, "requestsDuration is null");
}
@JsonProperty
@@ -113,6 +117,12 @@ public List getPageBufferClientStatuses()
return pageBufferClientStatuses;
}
+ @JsonProperty
+ public TDigestHistogram getRequestDuration()
+ {
+ return requestDuration;
+ }
+
@Override
public boolean isFinal()
{
@@ -132,6 +142,7 @@ public String toString()
.add("spilledBytes", spilledBytes)
.add("noMoreLocations", noMoreLocations)
.add("pageBufferClientStatuses", pageBufferClientStatuses)
+ .add("requestDuration", requestDuration)
.toString();
}
@@ -147,7 +158,8 @@ public DirectExchangeClientStatus mergeWith(DirectExchangeClientStatus other)
spilledPages + other.spilledPages,
spilledBytes + other.spilledBytes,
noMoreLocations && other.noMoreLocations, // if at least one has some locations, mergee has some too
- ImmutableList.of()); // pageBufferClientStatuses may be long, so we don't want to combine the lists
+ ImmutableList.of(), // pageBufferClientStatuses may be long, so we don't want to combine the lists
+ requestDuration.mergeWith(other.requestDuration)); // this is correct as long as all clients have the same shape of histogram
}
private static long mergeAvgs(long value1, long count1, long value2, long count2)
diff --git a/core/trino-main/src/main/java/io/trino/operator/Driver.java b/core/trino-main/src/main/java/io/trino/operator/Driver.java
index 03ae061225c7..c525f2d8c523 100644
--- a/core/trino-main/src/main/java/io/trino/operator/Driver.java
+++ b/core/trino-main/src/main/java/io/trino/operator/Driver.java
@@ -336,6 +336,7 @@ private OperationTimer createTimer()
driverContext.isCpuTimerEnabled() && driverContext.isPerOperatorCpuTimerEnabled());
}
+ // sourceBlockedFuture rezprezentuje blokade
private ListenableFuture updateDriverBlockedFuture(ListenableFuture sourceBlockedFuture)
{
// driverBlockedFuture will be completed as soon as the sourceBlockedFuture is completed
diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java
index 34b1b5d58948..c5a6cf95682d 100644
--- a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java
+++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java
@@ -232,6 +232,8 @@ public void addInput(Page page)
@Override
public Page getOutput()
{
+ System.out.println("ExchangeOperator::isBlocked. My id is %s, exchangeDataSource id is %s".formatted(this, exchangeDataSource));
+
Slice page = exchangeDataSource.pollPage();
if (page == null) {
return null;
diff --git a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java
index 3e9b0dcfd72c..66004b83a462 100644
--- a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java
+++ b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java
@@ -145,7 +145,8 @@ public interface ClientCallback
private boolean completed;
@GuardedBy("this")
private String taskInstanceId;
-
+ private volatile long lastRequestStartNanos;
+ private volatile long lastRequestDurationMillis;
// it is synchronized on `this` for update
private volatile long averageRequestSizeInBytes;
@@ -161,6 +162,7 @@ public interface ClientCallback
private final AtomicInteger requestsFailed = new AtomicInteger();
private final Executor pageBufferClientCallbackExecutor;
+ private final Ticker ticker;
public HttpPageBufferClient(
String selfAddress,
@@ -217,6 +219,7 @@ public HttpPageBufferClient(
requireNonNull(maxErrorDuration, "maxErrorDuration is null");
requireNonNull(ticker, "ticker is null");
this.backoff = new Backoff(maxErrorDuration, ticker);
+ this.ticker = ticker;
}
public synchronized PageBufferClientStatus getStatus()
@@ -327,6 +330,11 @@ public synchronized void scheduleRequest()
requestsScheduled.incrementAndGet();
}
+ public long getLastRequestDurationMillis()
+ {
+ return lastRequestDurationMillis;
+ }
+
private synchronized void initiateRequest()
{
scheduled = false;
@@ -347,6 +355,7 @@ private synchronized void initiateRequest()
private synchronized void sendGetResults()
{
URI uri = HttpUriBuilder.uriBuilderFrom(location).appendPath(String.valueOf(token)).build();
+ lastRequestStartNanos = ticker.read();
HttpResponseFuture resultFuture = httpClient.executeAsync(
prepareGet()
.setHeader(TRINO_MAX_SIZE, maxResponseSize.toString())
@@ -360,7 +369,7 @@ private synchronized void sendGetResults()
public void onSuccess(PagesResponse result)
{
assertNotHoldsLock(this);
-
+ lastRequestDurationMillis = (ticker.read() - lastRequestStartNanos) / 1_000_000;
backoff.success();
List pages;
@@ -467,6 +476,8 @@ public void onFailure(Throwable t)
log.debug("Request to %s failed %s", uri, t);
assertNotHoldsLock(this);
+ lastRequestDurationMillis = (ticker.read() - lastRequestStartNanos) / 1_000_000;
+
if (t instanceof ChecksumVerificationException) {
switch (dataIntegrityVerification) {
case NONE:
diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java
index a55111fdfed8..b8380ad23afa 100644
--- a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java
+++ b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java
@@ -31,6 +31,7 @@
import io.trino.execution.buffer.LazyOutputBuffer;
import io.trino.memory.QueryContext;
import io.trino.memory.QueryContextVisitor;
+import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.memory.context.MemoryTrackingContext;
import io.trino.spi.predicate.Domain;
@@ -43,9 +44,11 @@
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
@@ -92,6 +95,8 @@ public class TaskContext
private final Object cumulativeMemoryLock = new Object();
private final AtomicDouble cumulativeUserMemory = new AtomicDouble(0.0);
+ private final AtomicInteger maxWriterCount = new AtomicInteger(-1);
+
@GuardedBy("cumulativeMemoryLock")
private long lastUserMemoryReservation;
@@ -285,6 +290,11 @@ public LocalMemoryContext localMemoryContext()
return taskMemoryContext.localUserMemoryContext();
}
+ public AggregatedMemoryContext newAggregateMemoryContext()
+ {
+ return taskMemoryContext.newAggregateUserMemoryContext();
+ }
+
public boolean isPerOperatorCpuTimerEnabled()
{
return perOperatorCpuTimerEnabled;
@@ -349,6 +359,20 @@ public long getPhysicalWrittenDataSize()
return physicalWrittenBytes;
}
+ public void setMaxWriterCount(int maxWriterCount)
+ {
+ checkArgument(maxWriterCount > 0, "maxWriterCount must be > 0");
+
+ int oldMaxWriterCount = this.maxWriterCount.getAndSet(maxWriterCount);
+ checkArgument(oldMaxWriterCount == -1 || oldMaxWriterCount == maxWriterCount, "maxWriterCount already set to " + oldMaxWriterCount);
+ }
+
+ public Optional getMaxWriterCount()
+ {
+ int value = maxWriterCount.get();
+ return value == -1 ? Optional.empty() : Optional.of(value);
+ }
+
public Duration getFullGcTime()
{
long startFullGcTimeNanos = this.startFullGcTimeNanos.get();
@@ -570,6 +594,7 @@ public TaskStats getTaskStats()
outputPositions,
new Duration(outputBlockedTime, NANOSECONDS).convertToMostSuccinctTimeUnit(),
succinctBytes(physicalWrittenDataSize),
+ getMaxWriterCount(),
fullGcCount,
fullGcTime,
pipelineStats);
diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskStats.java b/core/trino-main/src/main/java/io/trino/operator/TaskStats.java
index 1cd7b01c1a53..be7bcb0067f4 100644
--- a/core/trino-main/src/main/java/io/trino/operator/TaskStats.java
+++ b/core/trino-main/src/main/java/io/trino/operator/TaskStats.java
@@ -24,6 +24,7 @@
import javax.annotation.Nullable;
import java.util.List;
+import java.util.Optional;
import java.util.Set;
import static com.google.common.base.Preconditions.checkArgument;
@@ -83,6 +84,7 @@ public class TaskStats
private final Duration outputBlockedTime;
private final DataSize physicalWrittenDataSize;
+ private final Optional maxWriterCount;
private final int fullGcCount;
private final Duration fullGcTime;
@@ -130,6 +132,7 @@ public TaskStats(DateTime createTime, DateTime endTime)
0,
new Duration(0, MILLISECONDS),
DataSize.ofBytes(0),
+ Optional.empty(),
0,
new Duration(0, MILLISECONDS),
ImmutableList.of());
@@ -187,6 +190,7 @@ public TaskStats(
@JsonProperty("outputBlockedTime") Duration outputBlockedTime,
@JsonProperty("physicalWrittenDataSize") DataSize physicalWrittenDataSize,
+ @JsonProperty("writerCount") Optional writerCount,
@JsonProperty("fullGcCount") int fullGcCount,
@JsonProperty("fullGcTime") Duration fullGcTime,
@@ -260,6 +264,7 @@ public TaskStats(
this.outputBlockedTime = requireNonNull(outputBlockedTime, "outputBlockedTime is null");
this.physicalWrittenDataSize = requireNonNull(physicalWrittenDataSize, "physicalWrittenDataSize is null");
+ this.maxWriterCount = requireNonNull(writerCount, "writerCount is null");
checkArgument(fullGcCount >= 0, "fullGcCount is negative");
this.fullGcCount = fullGcCount;
@@ -482,6 +487,12 @@ public DataSize getPhysicalWrittenDataSize()
return physicalWrittenDataSize;
}
+ @JsonProperty
+ public Optional getMaxWriterCount()
+ {
+ return maxWriterCount;
+ }
+
@JsonProperty
public List getPipelines()
{
@@ -566,6 +577,7 @@ public TaskStats summarize()
outputPositions,
outputBlockedTime,
physicalWrittenDataSize,
+ maxWriterCount,
fullGcCount,
fullGcTime,
ImmutableList.of());
@@ -613,6 +625,7 @@ public TaskStats summarizeFinal()
outputPositions,
outputBlockedTime,
physicalWrittenDataSize,
+ maxWriterCount,
fullGcCount,
fullGcTime,
summarizePipelineStats(pipelines));
diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
index 5915f30399e9..2e94a11f746f 100644
--- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
+++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java
@@ -285,7 +285,7 @@ private static PartitionFunction createPartitionFunction(
// The same bucket function (with the same bucket count) as for node
// partitioning must be used. This way rows within a single bucket
// will be being processed by single thread.
- int bucketCount = nodePartitioningManager.getBucketCount(session, partitioning);
+ int bucketCount = getBucketCount(session, nodePartitioningManager, partitioning);
int[] bucketToPartition = new int[bucketCount];
for (int bucket = 0; bucket < bucketCount; bucket++) {
@@ -306,6 +306,15 @@ private static PartitionFunction createPartitionFunction(
bucketToPartition);
}
+ public static int getBucketCount(Session session, NodePartitioningManager nodePartitioningManager, PartitioningHandle partitioning)
+ {
+ if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) {
+ // TODO: can we always use this code path?
+ return nodePartitioningManager.getNodePartitioningMap(session, partitioning).getBucketToPartition().length;
+ }
+ return nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount();
+ }
+
private static boolean isSystemPartitioning(PartitioningHandle partitioning)
{
return partitioning.getConnectorHandle() instanceof SystemPartitioningHandle;
diff --git a/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java
index 20ff30a1ad35..534c932d47a3 100644
--- a/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java
+++ b/core/trino-main/src/main/java/io/trino/operator/join/NestedLoopJoinOperator.java
@@ -279,7 +279,8 @@ static NestedLoopOutputIterator createNestedLoopOutputIterator(Page probePage, P
// bi-morphic parent class for the two implementations allowed. Adding a third implementation will make getOutput megamorphic and
// should be avoided
@VisibleForTesting
- abstract static class NestedLoopOutputIterator
+ abstract static sealed class NestedLoopOutputIterator
+ permits PageRepeatingIterator, NestedLoopPageBuilder
{
public abstract boolean hasNext();
diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java
index aba8b1fd62a6..0a002032654c 100644
--- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java
+++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java
@@ -15,12 +15,13 @@
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
-import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.PageSerializer;
import io.trino.execution.buffer.PagesSerdeFactory;
+import io.trino.memory.context.AggregatedMemoryContext;
+import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.OperatorContext;
import io.trino.operator.PartitionFunction;
import io.trino.spi.Page;
@@ -62,6 +63,7 @@ public class PagePartitioner
private final Type[] sourceTypes;
private final PartitionFunction partitionFunction;
private final int[] partitionChannels;
+ private final LocalMemoryContext memoryContext;
@Nullable
private final Block[] partitionConstantBlocks; // when null, no constants are present. Only non-null elements are constants
private final PageSerializer serializer;
@@ -69,9 +71,8 @@ public class PagePartitioner
private final PositionsAppenderPageBuilder[] positionsAppenders;
private final boolean replicatesAnyRow;
private final int nullChannel; // when >= 0, send the position to every partition if this channel is null
- private final AtomicLong rowsAdded = new AtomicLong();
- private final AtomicLong pagesAdded = new AtomicLong();
- private final OperatorContext operatorContext;
+ private final long partitionsInitialRetainedSize;
+ private PartitionedOutputInfoSupplier partitionedOutputInfoSupplier;
private boolean hasAnyRowBeenReplicated;
@@ -85,9 +86,9 @@ public PagePartitioner(
PagesSerdeFactory serdeFactory,
List sourceTypes,
DataSize maxMemory,
- OperatorContext operatorContext,
PositionsAppenderFactory positionsAppenderFactory,
- Optional exchangeEncryptionKey)
+ Optional exchangeEncryptionKey,
+ AggregatedMemoryContext aggregatedMemoryContext)
{
this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null");
this.partitionChannels = Ints.toArray(requireNonNull(partitionChannels, "partitionChannels is null"));
@@ -106,7 +107,6 @@ public PagePartitioner(
this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
this.sourceTypes = sourceTypes.toArray(new Type[0]);
this.serializer = serdeFactory.createSerializer(exchangeEncryptionKey.map(Ciphers::deserializeAesEncryptionKey));
- this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
// Ensure partition channels align with constant arguments provided
for (int i = 0; i < this.partitionChannels.length; i++) {
@@ -128,55 +128,17 @@ public PagePartitioner(
for (int i = 0; i < partitionCount; i++) {
pageBuilders[i] = PageBuilder.withMaxPageSize(pageSize, sourceTypes);
}
+ this.memoryContext = aggregatedMemoryContext.newLocalMemoryContext(PagePartitioner.class.getSimpleName());
+ this.partitionsInitialRetainedSize = getRetainedSizeInBytes();
+ this.memoryContext.setBytes(partitionsInitialRetainedSize);
}
- public ListenableFuture isFull()
+ // sets up this partitioner for the new operator
+ public void setupOperator(OperatorContext operatorContext)
{
- return outputBuffer.isFull();
- }
-
- public long getSizeInBytes()
- {
- // We use a foreach loop instead of streams
- // as it has much better performance.
- long sizeInBytes = 0;
- for (PositionsAppenderPageBuilder pageBuilder : positionsAppenders) {
- sizeInBytes += pageBuilder.getSizeInBytes();
- }
- for (PageBuilder pageBuilder : pageBuilders) {
- sizeInBytes += pageBuilder.getSizeInBytes();
- }
- return sizeInBytes;
- }
-
- /**
- * This method can be expensive for complex types.
- */
- public long getRetainedSizeInBytes()
- {
- long sizeInBytes = 0;
- for (PositionsAppenderPageBuilder pageBuilder : positionsAppenders) {
- sizeInBytes += pageBuilder.getRetainedSizeInBytes();
- }
- for (PageBuilder pageBuilder : pageBuilders) {
- sizeInBytes += pageBuilder.getRetainedSizeInBytes();
- }
- sizeInBytes += serializer.getRetainedSizeInBytes();
- return sizeInBytes;
- }
-
- public Supplier getOperatorInfoSupplier()
- {
- return createPartitionedOutputOperatorInfoSupplier(rowsAdded, pagesAdded, outputBuffer);
- }
-
- private static Supplier createPartitionedOutputOperatorInfoSupplier(AtomicLong rowsAdded, AtomicLong pagesAdded, OutputBuffer outputBuffer)
- {
- // Must be a separate static method to avoid embedding references to "this" in the supplier
- requireNonNull(rowsAdded, "rowsAdded is null");
- requireNonNull(pagesAdded, "pagesAdded is null");
- requireNonNull(outputBuffer, "outputBuffer is null");
- return () -> new PartitionedOutputInfo(rowsAdded.get(), pagesAdded.get(), outputBuffer.getPeakMemoryUsage());
+ // for new operator we need to reset the stats gathered by this PagePartitioner
+ partitionedOutputInfoSupplier = new PartitionedOutputInfoSupplier(outputBuffer, operatorContext);
+ operatorContext.setInfoSupplier(partitionedOutputInfoSupplier);
}
public void partitionPage(Page page)
@@ -195,6 +157,18 @@ public void partitionPage(Page page)
else {
partitionPageByColumn(page);
}
+ updateMemoryUsage();
+ }
+
+ private void updateMemoryUsage()
+ {
+ // We use getSizeInBytes() here instead of getRetainedSizeInBytes() for an approximation of
+ // the amount of memory used by the pageBuilders, because calculating the retained
+ // size can be expensive especially for complex types.
+ long partitionsSizeInBytes = getSizeInBytes();
+
+ // We also add partitionsInitialRetainedSize as an approximation of the object overhead of the partitions.
+ memoryContext.setBytes(partitionsSizeInBytes + partitionsInitialRetainedSize);
}
public void partitionPageByRow(Page page)
@@ -471,10 +445,11 @@ private Page getPartitionFunctionArguments(Page page)
return new Page(page.getPositionCount(), blocks);
}
- public void forceFlush()
+ public void close()
{
flushPositionsAppenders(true);
flushPageBuilders(true);
+ memoryContext.close();
}
private void flushPageBuilders(boolean force)
@@ -505,11 +480,8 @@ private void flushPositionsAppenders(boolean force)
private void enqueuePage(Page pagePartition, int partition)
{
- operatorContext.recordOutput(pagePartition.getSizeInBytes(), pagePartition.getPositionCount());
-
outputBuffer.enqueue(partition, splitAndSerializePage(pagePartition));
- pagesAdded.incrementAndGet();
- rowsAdded.addAndGet(pagePartition.getPositionCount());
+ partitionedOutputInfoSupplier.recordPage(pagePartition);
}
private List splitAndSerializePage(Page page)
@@ -521,4 +493,67 @@ private List splitAndSerializePage(Page page)
}
return builder.build();
}
+
+ private long getSizeInBytes()
+ {
+ // We use a foreach loop instead of streams
+ // as it has much better performance.
+ long sizeInBytes = 0;
+ for (PositionsAppenderPageBuilder pageBuilder : positionsAppenders) {
+ sizeInBytes += pageBuilder.getSizeInBytes();
+ }
+ for (PageBuilder pageBuilder : pageBuilders) {
+ sizeInBytes += pageBuilder.getSizeInBytes();
+ }
+ return sizeInBytes;
+ }
+
+ /**
+ * This method can be expensive for complex types.
+ */
+ private long getRetainedSizeInBytes()
+ {
+ long sizeInBytes = 0;
+ for (PositionsAppenderPageBuilder pageBuilder : positionsAppenders) {
+ sizeInBytes += pageBuilder.getRetainedSizeInBytes();
+ }
+ for (PageBuilder pageBuilder : pageBuilders) {
+ sizeInBytes += pageBuilder.getRetainedSizeInBytes();
+ }
+ sizeInBytes += serializer.getRetainedSizeInBytes();
+ return sizeInBytes;
+ }
+
+ /**
+ * Keeps statistics about output pages produced by the partitioner + updates the stats in the operatorContext.
+ */
+ private static class PartitionedOutputInfoSupplier
+ implements Supplier
+ {
+ private final AtomicLong rowsAdded = new AtomicLong();
+ private final AtomicLong pagesAdded = new AtomicLong();
+ private final OutputBuffer outputBuffer;
+ private final OperatorContext operatorContext;
+
+ private PartitionedOutputInfoSupplier(OutputBuffer outputBuffer, OperatorContext operatorContext)
+ {
+ this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
+ this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
+ }
+
+ @Override
+ public PartitionedOutputInfo get()
+ {
+ // note that outputBuffer.getPeakMemoryUsage() will produce peak across many operators
+ // this is suboptimal but hard to fix properly
+ return new PartitionedOutputInfo(rowsAdded.get(), pagesAdded.get(), outputBuffer.getPeakMemoryUsage());
+ }
+
+ public void recordPage(Page pagePartition)
+ {
+ operatorContext.recordOutput(pagePartition.getSizeInBytes(), pagePartition.getPositionCount());
+ pagesAdded.incrementAndGet();
+ rowsAdded.addAndGet(pagePartition.getPositionCount());
+ }
+ }
}
diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java
new file mode 100644
index 000000000000..f50efbc46f8f
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerPool.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.operator.output;
+
+import com.google.common.collect.ImmutableList;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import java.util.ArrayDeque;
+import java.util.Collection;
+import java.util.Queue;
+import java.util.function.Supplier;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
+public class PagePartitionerPool
+{
+ private final Supplier pagePartitionerSupplier;
+ /**
+ * Maximum number of free {@link PagePartitioner}s.
+ * In normal conditions, in the steady state,
+ * the number of free {@link PagePartitioner}s is going to be close to 0.
+ * There is a possible case though, where initially big number of concurrent drivers, say 128,
+ * drops to a small number e.g. 32 in a steady state. This could cause a lot of memory
+ * to be retained by the unused buffers.
+ * To defend against that, {@link #maxFree} limits the number of free buffers,
+ * thus limiting unused memory.
+ */
+ private final int maxFree;
+ @GuardedBy("this")
+ private final Queue free = new ArrayDeque<>();
+ @GuardedBy("this")
+ private boolean closed;
+
+ public PagePartitionerPool(int maxFree, Supplier pagePartitionerSupplier)
+ {
+ this.maxFree = maxFree;
+ this.pagePartitionerSupplier = requireNonNull(pagePartitionerSupplier, "pagePartitionerSupplier is null");
+ }
+
+ public synchronized PagePartitioner poll()
+ {
+ checkArgument(!closed, "The pool is already closed");
+ return free.isEmpty() ? pagePartitionerSupplier.get() : free.poll();
+ }
+
+ public void release(PagePartitioner pagePartitioner)
+ {
+ // pagePartitioner.close can take a long time (flush->serialization), we want to keep it out of the synchronized block
+ boolean shouldRetain;
+ synchronized (this) {
+ shouldRetain = !closed && free.size() < maxFree;
+ if (shouldRetain) {
+ free.add(pagePartitioner);
+ }
+ }
+ if (!shouldRetain) {
+ pagePartitioner.close();
+ }
+ }
+
+ public void close()
+ {
+ // pagePartitioner.close can take a long time (flush->serialization), we want to keep it out of the synchronized block
+ markClosed().forEach(PagePartitioner::close);
+ }
+
+ private synchronized Collection markClosed()
+ {
+ closed = true;
+ return ImmutableList.copyOf(free);
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java b/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java
index c1964bfda65a..5b5aa956bcfc 100644
--- a/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java
+++ b/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java
@@ -16,10 +16,11 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.util.concurrent.ListenableFuture;
+import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.PagesSerdeFactory;
-import io.trino.memory.context.LocalMemoryContext;
+import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.operator.DriverContext;
import io.trino.operator.Operator;
import io.trino.operator.OperatorContext;
@@ -39,6 +40,7 @@
import java.util.function.Function;
import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
public class PartitionedOutputOperator
@@ -55,6 +57,9 @@ public static class PartitionedOutputFactory
private final OptionalInt nullChannel;
private final DataSize maxMemory;
private final PositionsAppenderFactory positionsAppenderFactory;
+ private final Optional exchangeEncryptionKey;
+ private final AggregatedMemoryContext memoryContext;
+ private final int pagePartitionerPoolSize;
public PartitionedOutputFactory(
PartitionFunction partitionFunction,
@@ -64,7 +69,10 @@ public PartitionedOutputFactory(
OptionalInt nullChannel,
OutputBuffer outputBuffer,
DataSize maxMemory,
- PositionsAppenderFactory positionsAppenderFactory)
+ PositionsAppenderFactory positionsAppenderFactory,
+ Optional exchangeEncryptionKey,
+ AggregatedMemoryContext memoryContext,
+ int pagePartitionerPoolSize)
{
this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null");
this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null");
@@ -74,6 +82,9 @@ public PartitionedOutputFactory(
this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
this.maxMemory = requireNonNull(maxMemory, "maxMemory is null");
this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null");
+ this.exchangeEncryptionKey = requireNonNull(exchangeEncryptionKey, "exchangeEncryptionKey is null");
+ this.memoryContext = requireNonNull(memoryContext, "memoryContext is null");
+ this.pagePartitionerPoolSize = pagePartitionerPoolSize;
}
@Override
@@ -97,7 +108,10 @@ public OperatorFactory createOutputOperator(
outputBuffer,
serdeFactory,
maxMemory,
- positionsAppenderFactory);
+ positionsAppenderFactory,
+ exchangeEncryptionKey,
+ memoryContext,
+ pagePartitionerPoolSize);
}
}
@@ -117,6 +131,10 @@ public static class PartitionedOutputOperatorFactory
private final PagesSerdeFactory serdeFactory;
private final DataSize maxMemory;
private final PositionsAppenderFactory positionsAppenderFactory;
+ private final Optional exchangeEncryptionKey;
+ private final AggregatedMemoryContext memoryContext;
+ private final int pagePartitionerPoolSize;
+ private final PagePartitionerPool pagePartitionerPool;
public PartitionedOutputOperatorFactory(
int operatorId,
@@ -131,7 +149,10 @@ public PartitionedOutputOperatorFactory(
OutputBuffer outputBuffer,
PagesSerdeFactory serdeFactory,
DataSize maxMemory,
- PositionsAppenderFactory positionsAppenderFactory)
+ PositionsAppenderFactory positionsAppenderFactory,
+ Optional exchangeEncryptionKey,
+ AggregatedMemoryContext memoryContext,
+ int pagePartitionerPoolSize)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
@@ -146,6 +167,24 @@ public PartitionedOutputOperatorFactory(
this.serdeFactory = requireNonNull(serdeFactory, "serdeFactory is null");
this.maxMemory = requireNonNull(maxMemory, "maxMemory is null");
this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null");
+ this.exchangeEncryptionKey = requireNonNull(exchangeEncryptionKey, "exchangeEncryptionKey is null");
+ this.memoryContext = requireNonNull(memoryContext, "memoryContext is null");
+ this.pagePartitionerPoolSize = pagePartitionerPoolSize;
+ this.pagePartitionerPool = new PagePartitionerPool(
+ pagePartitionerPoolSize,
+ () -> new PagePartitioner(
+ partitionFunction,
+ partitionChannels,
+ partitionConstants,
+ replicatesAnyRow,
+ nullChannel,
+ outputBuffer,
+ serdeFactory,
+ sourceTypes,
+ maxMemory,
+ positionsAppenderFactory,
+ exchangeEncryptionKey,
+ memoryContext));
}
@Override
@@ -154,22 +193,15 @@ public Operator createOperator(DriverContext driverContext)
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, PartitionedOutputOperator.class.getSimpleName());
return new PartitionedOutputOperator(
operatorContext,
- sourceTypes,
pagePreprocessor,
- partitionFunction,
- partitionChannels,
- partitionConstants,
- replicatesAnyRow,
- nullChannel,
outputBuffer,
- serdeFactory,
- maxMemory,
- positionsAppenderFactory);
+ pagePartitionerPool);
}
@Override
public void noMoreOperators()
{
+ pagePartitionerPool.close();
}
@Override
@@ -188,52 +220,34 @@ public OperatorFactory duplicate()
outputBuffer,
serdeFactory,
maxMemory,
- positionsAppenderFactory);
+ positionsAppenderFactory,
+ exchangeEncryptionKey,
+ memoryContext,
+ pagePartitionerPoolSize);
}
}
private final OperatorContext operatorContext;
private final Function pagePreprocessor;
+ private final PagePartitionerPool pagePartitionerPool;
private final PagePartitioner partitionFunction;
- private final LocalMemoryContext memoryContext;
- private final long partitionsInitialRetainedSize;
+ // outputBuffer is used only to block the operator from finishing if the outputBuffer is full
+ private final OutputBuffer outputBuffer;
private ListenableFuture isBlocked = NOT_BLOCKED;
private boolean finished;
public PartitionedOutputOperator(
OperatorContext operatorContext,
- List sourceTypes,
Function pagePreprocessor,
- PartitionFunction partitionFunction,
- List partitionChannels,
- List> partitionConstants,
- boolean replicatesAnyRow,
- OptionalInt nullChannel,
OutputBuffer outputBuffer,
- PagesSerdeFactory serdeFactory,
- DataSize maxMemory,
- PositionsAppenderFactory positionsAppenderFactory)
+ PagePartitionerPool pagePartitionerPool)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null");
- this.partitionFunction = new PagePartitioner(
- partitionFunction,
- partitionChannels,
- partitionConstants,
- replicatesAnyRow,
- nullChannel,
- outputBuffer,
- serdeFactory,
- sourceTypes,
- maxMemory,
- operatorContext,
- positionsAppenderFactory,
- operatorContext.getSession().getExchangeEncryptionKey());
-
- operatorContext.setInfoSupplier(this.partitionFunction.getOperatorInfoSupplier());
- this.memoryContext = operatorContext.newLocalUserMemoryContext(PartitionedOutputOperator.class.getSimpleName());
- this.partitionsInitialRetainedSize = this.partitionFunction.getRetainedSizeInBytes();
- this.memoryContext.setBytes(partitionsInitialRetainedSize);
+ this.pagePartitionerPool = requireNonNull(pagePartitionerPool, "pagePartitionerPool is null");
+ this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
+ this.partitionFunction = requireNonNull(pagePartitionerPool.poll(), "partitionFunction is null");
+ this.partitionFunction.setupOperator(operatorContext);
}
@Override
@@ -245,8 +259,10 @@ public OperatorContext getOperatorContext()
@Override
public void finish()
{
- finished = true;
- partitionFunction.forceFlush();
+ if (!finished) {
+ pagePartitionerPool.release(partitionFunction);
+ finished = true;
+ }
}
@Override
@@ -255,12 +271,20 @@ public boolean isFinished()
return finished && isBlocked().isDone();
}
+ @Override
+ public void close()
+ throws Exception
+ {
+ // make sure the operator is finished and partitionFunction released
+ finish();
+ }
+
@Override
public ListenableFuture isBlocked()
{
// Avoid re-synchronizing on the output buffer when operator is already blocked
if (isBlocked.isDone()) {
- isBlocked = partitionFunction.isFull();
+ isBlocked = outputBuffer.isFull();
if (isBlocked.isDone()) {
isBlocked = NOT_BLOCKED;
}
@@ -278,6 +302,7 @@ public boolean needsInput()
public void addInput(Page page)
{
requireNonNull(page, "page is null");
+ checkState(!finished);
if (page.getPositionCount() == 0) {
return;
@@ -285,14 +310,6 @@ public void addInput(Page page)
page = pagePreprocessor.apply(page);
partitionFunction.partitionPage(page);
-
- // We use getSizeInBytes() here instead of getRetainedSizeInBytes() for an approximation of
- // the amount of memory used by the pageBuilders, because calculating the retained
- // size can be expensive especially for complex types.
- long partitionsSizeInBytes = partitionFunction.getSizeInBytes();
-
- // We also add partitionsInitialRetainedSize as an approximation of the object overhead of the partitions.
- memoryContext.setBytes(partitionsSizeInBytes + partitionsInitialRetainedSize);
}
@Override
@@ -301,12 +318,6 @@ public Page getOutput()
return null;
}
- @Override
- public void close()
- {
- memoryContext.close();
- }
-
public static class PartitionedOutputInfo
implements Mergeable, OperatorInfo
{
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
index 9920ac65bf80..2ca69e48f9c1 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
@@ -89,7 +89,6 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
-import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.LinkedHashMap;
@@ -112,6 +111,7 @@
import static java.lang.Boolean.FALSE;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
+import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
import static java.util.Collections.unmodifiableSet;
import static java.util.Objects.requireNonNull;
@@ -212,6 +212,7 @@ public class Analysis
private final Multiset rowFilterScopes = HashMultiset.create();
private final Map, List> rowFilters = new LinkedHashMap<>();
+ private final Map, List> checkConstraints = new LinkedHashMap<>();
private final Multiset columnMaskScopes = HashMultiset.create();
private final Map, Map> columnMasks = new LinkedHashMap<>();
@@ -1071,9 +1072,20 @@ public void addRowFilter(Table table, Expression filter)
.add(filter);
}
+ public void addCheckConstraints(Table table, Expression constraint)
+ {
+ checkConstraints.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>())
+ .add(constraint);
+ }
+
public List getRowFilters(Table node)
{
- return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of());
+ return unmodifiableList(rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of()));
+ }
+
+ public List getCheckConstraints(Table node)
+ {
+ return unmodifiableList(checkConstraints.getOrDefault(NodeRef.of(node), ImmutableList.of()));
}
public boolean hasColumnMask(QualifiedObjectName table, String column, String identity)
@@ -1101,7 +1113,7 @@ public void addColumnMask(Table table, String column, Expression mask)
public Map getColumnMasks(Table table)
{
- return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of());
+ return unmodifiableMap(columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()));
}
public List getReferencedTables()
@@ -1571,22 +1583,22 @@ public void addQuantifiedComparisons(List expres
public List getInPredicatesSubqueries()
{
- return Collections.unmodifiableList(inPredicatesSubqueries);
+ return unmodifiableList(inPredicatesSubqueries);
}
public List getSubqueries()
{
- return Collections.unmodifiableList(subqueries);
+ return unmodifiableList(subqueries);
}
public List getExistsSubqueries()
{
- return Collections.unmodifiableList(existsSubqueries);
+ return unmodifiableList(existsSubqueries);
}
public List getQuantifiedComparisonSubqueries()
{
- return Collections.unmodifiableList(quantifiedComparisonSubqueries);
+ return unmodifiableList(quantifiedComparisonSubqueries);
}
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java
index 4a4924604e45..28bdf9a97af0 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java
@@ -226,7 +226,7 @@
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimeType.TIME_MILLIS;
import static io.trino.spi.type.TimeType.createTimeType;
-import static io.trino.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
+import static io.trino.spi.type.TimeWithTimeZoneType.TIME_TZ_MILLIS;
import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampType.createTimestampType;
@@ -627,7 +627,7 @@ protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext {
if (node.getPrecision() != null) {
diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java
index d1c77fe68921..854edf086328 100644
--- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java
+++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java
@@ -115,7 +115,6 @@
import io.trino.sql.analyzer.Scope.AsteriskedIdentifierChainBasis;
import io.trino.sql.parser.ParsingException;
import io.trino.sql.parser.SqlParser;
-import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.ScopeAware;
@@ -296,6 +295,7 @@
import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_WINDOW;
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS;
+import static io.trino.spi.StandardErrorCode.INVALID_CHECK_CONSTRAINT;
import static io.trino.spi.StandardErrorCode.INVALID_COLUMN_REFERENCE;
import static io.trino.spi.StandardErrorCode.INVALID_COPARTITIONING;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
@@ -364,6 +364,8 @@
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature;
+import static io.trino.sql.planner.DeterminismEvaluator.containsCurrentTimeFunctions;
+import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.trino.sql.tree.DereferenceExpression.getQualifiedName;
@@ -541,6 +543,7 @@ protected Scope visitInsert(Insert insert, Optional scope)
List columns = tableSchema.getColumns().stream()
.filter(column -> !column.isHidden())
.collect(toImmutableList());
+ List checkConstraints = tableSchema.getTableSchema().getCheckConstraints();
for (ColumnSchema column : columns) {
if (accessControl.getColumnMask(session.toSecurityContext(), targetTable, column.getName(), column.getType()).isPresent()) {
@@ -550,7 +553,12 @@ protected Scope visitInsert(Insert insert, Optional scope)
Map columnHandles = metadata.getColumnHandles(session, targetTableHandle.get());
List tableFields = analyzeTableOutputFields(insert.getTable(), targetTable, tableSchema, columnHandles);
- analyzeFiltersAndMasks(insert.getTable(), targetTable, targetTableHandle, tableFields, session.getIdentity().getUser());
+ Scope accessControlScope = Scope.builder()
+ .withRelationType(RelationId.anonymous(), new RelationType(tableFields))
+ .build();
+ analyzeFiltersAndMasks(insert.getTable(), targetTable, new RelationType(tableFields), accessControlScope);
+ analyzeCheckConstraints(insert.getTable(), targetTable, accessControlScope, checkConstraints);
+ analysis.registerTable(insert.getTable(), targetTableHandle, targetTable, session.getIdentity().getUser(), accessControlScope);
List tableColumns = columns.stream()
.map(ColumnSchema::getName)
@@ -618,9 +626,9 @@ protected Scope visitInsert(Insert insert, Optional scope)
targetTable,
Optional.empty(),
Optional.of(Streams.zip(
- columnStream,
- queryScope.getRelationType().getVisibleFields().stream(),
- (column, field) -> new OutputColumn(column, analysis.getSourceColumns(field)))
+ columnStream,
+ queryScope.getRelationType().getVisibleFields().stream(),
+ (column, field) -> new OutputColumn(column, analysis.getSourceColumns(field)))
.collect(toImmutableList())));
return createAndAssignScope(insert, scope, Field.newUnqualified("rows", BIGINT));
@@ -701,9 +709,9 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView refreshMate
targetTable,
Optional.empty(),
Optional.of(Streams.zip(
- columns,
- queryScope.getRelationType().getVisibleFields().stream(),
- (column, field) -> new OutputColumn(column, analysis.getSourceColumns(field)))
+ columns,
+ queryScope.getRelationType().getVisibleFields().stream(),
+ (column, field) -> new OutputColumn(column, analysis.getSourceColumns(field)))
.collect(toImmutableList())));
return createAndAssignScope(refreshMaterializedView, scope, Field.newUnqualified("rows", BIGINT));
@@ -800,7 +808,12 @@ protected Scope visitDelete(Delete node, Optional scope)
analysis.setUpdateType("DELETE");
analysis.setUpdateTarget(tableName, Optional.of(table), Optional.empty());
- analyzeFiltersAndMasks(table, tableName, Optional.of(handle), analysis.getScope(table).getRelationType(), session.getIdentity().getUser());
+ Scope accessControlScope = Scope.builder()
+ .withRelationType(RelationId.anonymous(), analysis.getScope(table).getRelationType())
+ .build();
+ analyzeFiltersAndMasks(table, tableName, analysis.getScope(table).getRelationType(), accessControlScope);
+ analyzeCheckConstraints(table, tableName, accessControlScope, tableSchema.getTableSchema().getCheckConstraints());
+ analysis.registerTable(table, Optional.of(handle), tableName, session.getIdentity().getUser(), accessControlScope);
createMergeAnalysis(table, handle, tableSchema, tableScope, tableScope, ImmutableList.of());
@@ -1180,10 +1193,10 @@ protected Scope visitTableExecute(TableExecute node, Optional scope)
TableExecuteHandle executeHandle =
metadata.getTableHandleForExecute(
- session,
- tableHandle,
- procedureName,
- tableProperties)
+ session,
+ tableHandle,
+ procedureName,
+ tableProperties)
.orElseThrow(() -> semanticException(NOT_SUPPORTED, node, "Procedure '%s' cannot be executed on table '%s'", procedureName, tableName));
analysis.setTableExecuteReadsData(procedureMetadata.getExecutionMode().isReadsData());
@@ -2164,7 +2177,12 @@ protected Scope visitTable(Table table, Optional scope)
List outputFields = fields.build();
- analyzeFiltersAndMasks(table, targetTableName, tableHandle, outputFields, session.getIdentity().getUser());
+ Scope accessControlScope = Scope.builder()
+ .withRelationType(RelationId.anonymous(), new RelationType(outputFields))
+ .build();
+ analyzeFiltersAndMasks(table, targetTableName, new RelationType(outputFields), accessControlScope);
+ analyzeCheckConstraints(table, targetTableName, accessControlScope, tableSchema.getTableSchema().getCheckConstraints());
+ analysis.registerTable(table, tableHandle, targetTableName, session.getIdentity().getUser(), accessControlScope);
Scope tableScope = createAndAssignScope(table, scope, outputFields);
@@ -2184,17 +2202,8 @@ private void checkStorageTableNotRedirected(QualifiedObjectName source)
});
}
- private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, List fields, String authorization)
- {
- analyzeFiltersAndMasks(table, name, tableHandle, new RelationType(fields), authorization);
- }
-
- private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optional tableHandle, RelationType relationType, String authorization)
+ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, RelationType relationType, Scope accessControlScope)
{
- Scope accessControlScope = Scope.builder()
- .withRelationType(RelationId.anonymous(), relationType)
- .build();
-
for (int index = 0; index < relationType.getAllFieldCount(); index++) {
Field field = relationType.getFieldByIndex(index);
if (field.getName().isPresent()) {
@@ -2208,8 +2217,14 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Optio
accessControl.getRowFilters(session.toSecurityContext(), name)
.forEach(filter -> analyzeRowFilter(session.getIdentity().getUser(), table, name, accessControlScope, filter));
+ }
- analysis.registerTable(table, tableHandle, name, authorization, accessControlScope);
+ private void analyzeCheckConstraints(Table table, QualifiedObjectName name, Scope accessControlScope, List constraints)
+ {
+ for (String constraint : constraints) {
+ ViewExpression expression = new ViewExpression(session.getIdentity().getUser(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), constraint);
+ analyzeCheckConstraint(table, name, accessControlScope, expression);
+ }
}
private boolean checkCanSelectFromColumn(QualifiedObjectName name, String column)
@@ -2351,13 +2366,21 @@ private Scope createScopeForView(
if (storageTable.isPresent()) {
List storageTableFields = analyzeStorageTable(table, viewFields, storageTable.get());
- analyzeFiltersAndMasks(table, name, storageTable, viewFields, session.getIdentity().getUser());
+ Scope accessControlScope = Scope.builder()
+ .withRelationType(RelationId.anonymous(), new RelationType(viewFields))
+ .build();
+ analyzeFiltersAndMasks(table, name, new RelationType(viewFields), accessControlScope);
+ analysis.registerTable(table, storageTable, name, session.getIdentity().getUser(), accessControlScope);
analysis.addRelationCoercion(table, viewFields.stream().map(Field::getType).toArray(Type[]::new));
// use storage table output fields as they contain ColumnHandles
return createAndAssignScope(table, scope, storageTableFields);
}
- analyzeFiltersAndMasks(table, name, storageTable, viewFields, session.getIdentity().getUser());
+ Scope accessControlScope = Scope.builder()
+ .withRelationType(RelationId.anonymous(), new RelationType(viewFields))
+ .build();
+ analyzeFiltersAndMasks(table, name, new RelationType(viewFields), accessControlScope);
+ analysis.registerTable(table, storageTable, name, session.getIdentity().getUser(), accessControlScope);
viewFields.forEach(field -> analysis.addSourceColumns(field, ImmutableSet.of(new SourceColumn(name, field.getName().orElseThrow()))));
analysis.registerNamedQuery(table, query);
return createAndAssignScope(table, scope, viewFields);
@@ -2775,15 +2798,15 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional s
}
Map, Type> expressionTypes = ExpressionAnalyzer.analyzeExpressions(
- session,
- plannerContext,
- statementAnalyzerFactory,
- accessControl,
- TypeProvider.empty(),
- ImmutableList.of(samplePercentage),
- analysis.getParameters(),
- WarningCollector.NOOP,
- analysis.getQueryType())
+ session,
+ plannerContext,
+ statementAnalyzerFactory,
+ accessControl,
+ TypeProvider.empty(),
+ ImmutableList.of(samplePercentage),
+ analysis.getParameters(),
+ WarningCollector.NOOP,
+ analysis.getQueryType())
.getExpressionTypes();
Type samplePercentageType = expressionTypes.get(NodeRef.of(samplePercentage));
@@ -3150,6 +3173,10 @@ protected Scope visitUpdate(Update update, Optional scope)
if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) {
throw semanticException(NOT_SUPPORTED, update, "Updating a table with a row filter is not supported");
}
+ if (!tableSchema.getTableSchema().getCheckConstraints().isEmpty()) {
+ // TODO https://github.com/trinodb/trino/issues/15411 Add support for CHECK constraint to UPDATE statement
+ throw semanticException(NOT_SUPPORTED, update, "Updating a table with a check constraint is not supported");
+ }
// TODO: how to deal with connectors that need to see the pre-image of rows to perform the update without
// flowing that data through the masking logic
@@ -3277,6 +3304,10 @@ protected Scope visitMerge(Merge merge, Optional scope)
if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) {
throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with row filters");
}
+ if (!tableSchema.getTableSchema().getCheckConstraints().isEmpty()) {
+ // TODO https://github.com/trinodb/trino/issues/15411 Add support for CHECK constraint to MERGE statement
+ throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with check constraints");
+ }
Scope targetTableScope = analyzer.analyzeForUpdate(relation, scope, UpdateKind.MERGE);
Scope sourceTableScope = process(merge.getSource(), scope);
@@ -4622,6 +4653,62 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje
analysis.addRowFilter(table, expression);
}
+ private void analyzeCheckConstraint(Table table, QualifiedObjectName name, Scope scope, ViewExpression constraint)
+ {
+ Expression expression;
+ try {
+ expression = sqlParser.createExpression(constraint.getExpression(), createParsingOptions(session));
+ }
+ catch (ParsingException e) {
+ throw new TrinoException(INVALID_CHECK_CONSTRAINT, extractLocation(table), format("Invalid check constraint for '%s': %s", name, e.getErrorMessage()), e);
+ }
+
+ verifyNoAggregateWindowOrGroupingFunctions(session, metadata, expression, format("Check constraint for '%s'", name));
+
+ ExpressionAnalysis expressionAnalysis;
+ try {
+ Identity filterIdentity = Identity.forUser(constraint.getIdentity())
+ .withGroups(groupProvider.getGroups(constraint.getIdentity()))
+ .build();
+ expressionAnalysis = ExpressionAnalyzer.analyzeExpression(
+ createViewSession(constraint.getCatalog(), constraint.getSchema(), filterIdentity, session.getPath()),
+ plannerContext,
+ statementAnalyzerFactory,
+ accessControl,
+ scope,
+ analysis,
+ expression,
+ warningCollector,
+ correlationSupport);
+ }
+ catch (TrinoException e) {
+ throw new TrinoException(e::getErrorCode, extractLocation(table), format("Invalid check constraint for '%s': %s", name, e.getRawMessage()), e);
+ }
+
+ // Ensure that the expression doesn't contain non-deterministic functions. This should be "retrospectively deterministic" per SQL standard.
+ if (!isDeterministic(expression, this::getResolvedFunction)) {
+ throw semanticException(INVALID_CHECK_CONSTRAINT, expression, "Check constraint expression should be deterministic");
+ }
+ if (containsCurrentTimeFunctions(expression)) {
+ throw semanticException(INVALID_CHECK_CONSTRAINT, expression, "Check constraint expression should not contain temporal expression");
+ }
+
+ analysis.recordSubqueries(expression, expressionAnalysis);
+
+ Type actualType = expressionAnalysis.getType(expression);
+ if (!actualType.equals(BOOLEAN)) {
+ TypeCoercion coercion = new TypeCoercion(plannerContext.getTypeManager()::getType);
+
+ if (!coercion.canCoerce(actualType, BOOLEAN)) {
+ throw new TrinoException(TYPE_MISMATCH, extractLocation(table), format("Expected check constraint for '%s' to be of type BOOLEAN, but was %s", name, actualType), null);
+ }
+
+ analysis.addCoercion(expression, BOOLEAN, coercion.isTypeOnlyCoercion(actualType, BOOLEAN));
+ }
+
+ analysis.addCheckConstraints(table, expression);
+ }
+
private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, Field field, Scope scope, ViewExpression mask)
{
String column = field.getName().orElseThrow();
@@ -5007,7 +5094,7 @@ private void verifySelectDistinct(QuerySpecification node, List orde
}
for (Expression expression : orderByExpressions) {
- if (!DeterminismEvaluator.isDeterministic(expression, this::getResolvedFunction)) {
+ if (!isDeterministic(expression, this::getResolvedFunction)) {
throw semanticException(EXPRESSION_NOT_IN_DISTINCT, expression, "Non deterministic ORDER BY expression is not supported with SELECT DISTINCT");
}
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java
index 25a10a35cb82..ef00948c3acc 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/DeterminismEvaluator.java
@@ -15,6 +15,7 @@
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
+import io.trino.sql.tree.CurrentTime;
import io.trino.sql.tree.DefaultExpressionTraversalVisitor;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
@@ -65,4 +66,24 @@ protected Void visitFunctionCall(FunctionCall node, AtomicBoolean deterministic)
return super.visitFunctionCall(node, deterministic);
}
}
+
+ public static boolean containsCurrentTimeFunctions(Expression expression)
+ {
+ requireNonNull(expression, "expression is null");
+
+ AtomicBoolean currentTime = new AtomicBoolean(false);
+ new CurrentTimeVisitor().process(expression, currentTime);
+ return currentTime.get();
+ }
+
+ private static class CurrentTimeVisitor
+ extends DefaultExpressionTraversalVisitor
+ {
+ @Override
+ protected Void visitCurrentTime(CurrentTime node, AtomicBoolean currentTime)
+ {
+ currentTime.set(true);
+ return super.visitCurrentTime(node, currentTime);
+ }
+ }
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
index ff7f39936223..f125ad9aaa01 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java
@@ -292,6 +292,7 @@
import static io.trino.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit;
import static io.trino.SystemSessionProperties.getFilterAndProjectMinOutputPageRowCount;
import static io.trino.SystemSessionProperties.getFilterAndProjectMinOutputPageSize;
+import static io.trino.SystemSessionProperties.getPagePartitioningBufferPoolSize;
import static io.trino.SystemSessionProperties.getTaskConcurrency;
import static io.trino.SystemSessionProperties.getTaskPartitionedWriterCount;
import static io.trino.SystemSessionProperties.getTaskScaleWritersMaxWriterCount;
@@ -580,7 +581,10 @@ public LocalExecutionPlan plan(
nullChannel,
outputBuffer,
maxPagePartitioningBufferSize,
- positionsAppenderFactory));
+ positionsAppenderFactory,
+ taskContext.getSession().getExchangeEncryptionKey(),
+ taskContext.newAggregateMemoryContext(),
+ getPagePartitioningBufferPoolSize(taskContext.getSession())));
}
public LocalExecutionPlan plan(
@@ -3171,11 +3175,9 @@ public PhysicalOperation visitRefreshMaterializedView(RefreshMaterializedViewNod
public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPlanContext context)
{
// Set table writer count
- context.setDriverInstanceCount(getWriterCount(
- session,
- node.getPartitioningScheme(),
- node.getPreferredPartitioningScheme(),
- node.getSource()));
+ int maxWriterCount = getWriterCount(session, node.getPartitioningScheme(), node.getPreferredPartitioningScheme(), node.getSource());
+ context.setDriverInstanceCount(maxWriterCount);
+ context.taskContext.setMaxWriterCount(maxWriterCount);
PhysicalOperation source = node.getSource().accept(this, context);
@@ -3331,11 +3333,9 @@ public PhysicalOperation visitSimpleTableExecuteNode(SimpleTableExecuteNode node
public PhysicalOperation visitTableExecute(TableExecuteNode node, LocalExecutionPlanContext context)
{
// Set table writer count
- context.setDriverInstanceCount(getWriterCount(
- session,
- node.getPartitioningScheme(),
- node.getPreferredPartitioningScheme(),
- node.getSource()));
+ int maxWriterCount = getWriterCount(session, node.getPartitioningScheme(), node.getPreferredPartitioningScheme(), node.getSource());
+ context.setDriverInstanceCount(maxWriterCount);
+ context.taskContext.setMaxWriterCount(maxWriterCount);
PhysicalOperation source = node.getSource().accept(this, context);
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java
index 3befb4cc93c2..7e201215047c 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java
@@ -496,6 +496,18 @@ private RelationPlan getInsertPlan(
.withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields())
.build();
});
+ plan = planner.addCheckConstraints(
+ analysis.getCheckConstraints(table),
+ table,
+ plan,
+ node -> {
+ Scope accessControlScope = analysis.getAccessControlScope(table);
+ // hidden fields are not accessible in insert
+ return Scope.builder()
+ .like(accessControlScope)
+ .withRelationType(accessControlScope.getRelationId(), accessControlScope.getRelationType().withOnlyVisibleFields())
+ .build();
+ });
List insertedTableColumnNames = insertedColumns.stream()
.map(ColumnMetadata::getName)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java
index 1244586b624b..7fd019e203ea 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java
@@ -51,7 +51,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
-import static io.trino.SystemSessionProperties.getHashPartitionCount;
+import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static io.trino.util.Failures.checkCondition;
@@ -134,7 +134,12 @@ public BucketFunction getBucketFunction(Session session, PartitioningHandle part
public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle)
{
- return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>());
+ return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), Optional.empty());
+ }
+
+ public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, Optional partitionCount)
+ {
+ return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), partitionCount);
}
/**
@@ -145,22 +150,24 @@ private NodePartitionMap getNodePartitioningMap(
Session session,
PartitioningHandle partitioningHandle,
Map> bucketToNodeCache,
- AtomicReference> systemPartitioningCache)
+ AtomicReference> systemPartitioningCache,
+ Optional partitionCount)
{
requireNonNull(session, "session is null");
requireNonNull(partitioningHandle, "partitioningHandle is null");
if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) {
- return systemNodePartitionMap(session, partitioningHandle, systemPartitioningCache);
+ return systemNodePartitionMap(session, partitioningHandle, systemPartitioningCache, partitionCount);
}
if (partitioningHandle.getConnectorHandle() instanceof MergePartitioningHandle mergeHandle) {
- return mergeHandle.getNodePartitioningMap(handle -> getNodePartitioningMap(session, handle, bucketToNodeCache, systemPartitioningCache));
+ return mergeHandle.getNodePartitioningMap(handle ->
+ getNodePartitioningMap(session, handle, bucketToNodeCache, systemPartitioningCache, partitionCount));
}
Optional optionalMap = getConnectorBucketNodeMap(session, partitioningHandle);
if (optionalMap.isEmpty()) {
- return systemNodePartitionMap(session, FIXED_HASH_DISTRIBUTION, systemPartitioningCache);
+ return systemNodePartitionMap(session, FIXED_HASH_DISTRIBUTION, systemPartitioningCache, partitionCount);
}
ConnectorBucketNodeMap connectorBucketNodeMap = optionalMap.get();
@@ -199,7 +206,11 @@ private NodePartitionMap getNodePartitioningMap(
return new NodePartitionMap(partitionToNode, bucketToPartition, getSplitToBucket(session, partitioningHandle));
}
- private NodePartitionMap systemNodePartitionMap(Session session, PartitioningHandle partitioningHandle, AtomicReference> nodesCache)
+ private NodePartitionMap systemNodePartitionMap(
+ Session session,
+ PartitioningHandle partitioningHandle,
+ AtomicReference> nodesCache,
+ Optional partitionCount)
{
SystemPartitioning partitioning = ((SystemPartitioningHandle) partitioningHandle.getConnectorHandle()).getPartitioning();
@@ -211,7 +222,7 @@ private NodePartitionMap systemNodePartitionMap(Session session, PartitioningHan
case FIXED -> {
List value = nodesCache.get();
if (value == null) {
- value = nodeSelector.selectRandomNodes(getHashPartitionCount(session));
+ value = nodeSelector.selectRandomNodes(partitionCount.orElse(getMaxHashPartitionCount(session)));
nodesCache.set(value);
}
yield value;
@@ -239,15 +250,6 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit
return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount));
}
- public int getBucketCount(Session session, PartitioningHandle partitioning)
- {
- if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) {
- // TODO: can we always use this code path?
- return getNodePartitioningMap(session, partitioning).getBucketToPartition().length;
- }
- return getBucketNodeMap(session, partitioning).getBucketCount();
- }
-
public int getNodeCount(Session session, PartitioningHandle partitioningHandle)
{
return getAllNodes(session, requiredCatalogHandle(partitioningHandle)).size();
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java
index faf821a80da4..01339bb064d7 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java
@@ -23,6 +23,7 @@
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
+import static io.airlift.units.DataSize.Unit.GIGABYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MINUTES;
@@ -88,6 +89,8 @@ public class OptimizerConfig
private long adaptivePartialAggregationMinRows = 100_000;
private double adaptivePartialAggregationUniqueRowsRatioThreshold = 0.8;
private long joinPartitionedBuildMinRowCount = 1_000_000L;
+ private DataSize minInputSizePerTask = DataSize.of(5, GIGABYTE);
+ private long minInputRowsPerTask = 10_000_000L;
public enum JoinReorderingStrategy
{
@@ -744,6 +747,34 @@ public OptimizerConfig setJoinPartitionedBuildMinRowCount(long joinPartitionedBu
return this;
}
+ @NotNull
+ public DataSize getMinInputSizePerTask()
+ {
+ return minInputSizePerTask;
+ }
+
+ @Config("optimizer.min-input-size-per-task")
+ @ConfigDescription("Minimum input data size required per task. This will help optimizer determine hash partition count for joins and aggregations")
+ public OptimizerConfig setMinInputSizePerTask(DataSize minInputSizePerTask)
+ {
+ this.minInputSizePerTask = minInputSizePerTask;
+ return this;
+ }
+
+ @Min(0)
+ public long getMinInputRowsPerTask()
+ {
+ return minInputRowsPerTask;
+ }
+
+ @Config("optimizer.min-input-rows-per-task")
+ @ConfigDescription("Minimum input rows required per task. This will help optimizer determine hash partition count for joins and aggregations")
+ public OptimizerConfig setMinInputRowsPerTask(long minInputRowsPerTask)
+ {
+ this.minInputRowsPerTask = minInputRowsPerTask;
+ return this;
+ }
+
public boolean isUseExactPartitioning()
{
return useExactPartitioning;
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java
index fad8e028ab08..76c5d00124af 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java
@@ -34,6 +34,7 @@ public class PartitioningScheme
private final Optional hashColumn;
private final boolean replicateNullsAndAny;
private final Optional bucketToPartition;
+ private final Optional partitionCount;
public PartitioningScheme(Partitioning partitioning, List outputLayout)
{
@@ -42,6 +43,7 @@ public PartitioningScheme(Partitioning partitioning, List outputLayout)
outputLayout,
Optional.empty(),
false,
+ Optional.empty(),
Optional.empty());
}
@@ -52,6 +54,7 @@ public PartitioningScheme(Partitioning partitioning, List outputLayout,
outputLayout,
hashColumn,
false,
+ Optional.empty(),
Optional.empty());
}
@@ -61,7 +64,8 @@ public PartitioningScheme(
@JsonProperty("outputLayout") List outputLayout,
@JsonProperty("hashColumn") Optional hashColumn,
@JsonProperty("replicateNullsAndAny") boolean replicateNullsAndAny,
- @JsonProperty("bucketToPartition") Optional bucketToPartition)
+ @JsonProperty("bucketToPartition") Optional bucketToPartition,
+ @JsonProperty("partitionCount") Optional partitionCount)
{
this.partitioning = requireNonNull(partitioning, "partitioning is null");
this.outputLayout = ImmutableList.copyOf(requireNonNull(outputLayout, "outputLayout is null"));
@@ -77,6 +81,10 @@ public PartitioningScheme(
checkArgument(!replicateNullsAndAny || columns.size() <= 1, "Must have at most one partitioning column when nullPartition is REPLICATE.");
this.replicateNullsAndAny = replicateNullsAndAny;
this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null");
+ this.partitionCount = requireNonNull(partitionCount, "partitionCount is null");
+ checkArgument(
+ partitionCount.isEmpty() || partitioning.getHandle().getConnectorHandle() instanceof SystemPartitioningHandle,
+ "Connector partitioning handle should be of type system partitioning when partitionCount is present");
}
@JsonProperty
@@ -109,15 +117,26 @@ public Optional getBucketToPartition()
return bucketToPartition;
}
+ @JsonProperty
+ public Optional getPartitionCount()
+ {
+ return partitionCount;
+ }
+
public PartitioningScheme withBucketToPartition(Optional bucketToPartition)
{
- return new PartitioningScheme(partitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition);
+ return new PartitioningScheme(partitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition, partitionCount);
}
public PartitioningScheme withPartitioningHandle(PartitioningHandle partitioningHandle)
{
Partitioning newPartitioning = partitioning.withAlternativePartitioningHandle(partitioningHandle);
- return new PartitioningScheme(newPartitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition);
+ return new PartitioningScheme(newPartitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition, partitionCount);
+ }
+
+ public PartitioningScheme withPartitionCount(int partitionCount)
+ {
+ return new PartitioningScheme(partitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition, Optional.of(partitionCount));
}
public PartitioningScheme translateOutputLayout(List newOutputLayout)
@@ -132,7 +151,7 @@ public PartitioningScheme translateOutputLayout(List newOutputLayout)
.map(outputLayout::indexOf)
.map(newOutputLayout::get);
- return new PartitioningScheme(newPartitioning, newOutputLayout, newHashSymbol, replicateNullsAndAny, bucketToPartition);
+ return new PartitioningScheme(newPartitioning, newOutputLayout, newHashSymbol, replicateNullsAndAny, bucketToPartition, partitionCount);
}
@Override
@@ -148,13 +167,14 @@ public boolean equals(Object o)
return Objects.equals(partitioning, that.partitioning) &&
Objects.equals(outputLayout, that.outputLayout) &&
replicateNullsAndAny == that.replicateNullsAndAny &&
- Objects.equals(bucketToPartition, that.bucketToPartition);
+ Objects.equals(bucketToPartition, that.bucketToPartition) &&
+ Objects.equals(partitionCount, that.partitionCount);
}
@Override
public int hashCode()
{
- return Objects.hash(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition);
+ return Objects.hash(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, partitionCount);
}
@Override
@@ -166,6 +186,7 @@ public String toString()
.add("hashChannel", hashColumn)
.add("replicateNullsAndAny", replicateNullsAndAny)
.add("bucketToPartition", bucketToPartition)
+ .add("partitionCount", partitionCount)
.toString();
}
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java
index b28cf94c3689..512a4ac4ae07 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java
@@ -188,7 +188,8 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub
outputPartitioningScheme.getOutputLayout(),
outputPartitioningScheme.getHashColumn(),
outputPartitioningScheme.isReplicateNullsAndAny(),
- outputPartitioningScheme.getBucketToPartition()),
+ outputPartitioningScheme.getBucketToPartition(),
+ outputPartitioningScheme.getPartitionCount()),
fragment.getStatsAndCosts(),
fragment.getActiveCatalogs(),
fragment.getJsonRepresentation());
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
index b50ee5f15513..71e27a4b3a39 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java
@@ -236,6 +236,8 @@
import io.trino.sql.planner.optimizations.AddLocalExchanges;
import io.trino.sql.planner.optimizations.BeginTableWrite;
import io.trino.sql.planner.optimizations.CheckSubqueryNodesAreRewritten;
+import io.trino.sql.planner.optimizations.DeterminePartitionCount;
+import io.trino.sql.planner.optimizations.DetermineWritersNodesCount;
import io.trino.sql.planner.optimizations.HashGenerationOptimizer;
import io.trino.sql.planner.optimizations.IndexJoinOptimizer;
import io.trino.sql.planner.optimizations.LimitPushDown;
@@ -841,6 +843,9 @@ public PlanOptimizers(
// operators that require node partitioning
builder.add(new UnaliasSymbolReferences(metadata));
builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(plannerContext, typeAnalyzer, statsCalculator)));
+ // It can only run after AddExchanges since it estimates the hash partition count for all remote exchanges
+ builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new DeterminePartitionCount(statsCalculator)));
+ builder.add(new DetermineWritersNodesCount());
}
// use cost calculator without estimated exchanges after AddExchanges
@@ -859,7 +864,6 @@ public PlanOptimizers(
.build()));
// Run predicate push down one more time in case we can leverage new information from layouts' effective predicate
- // and to pushdown dynamic filters
builder.add(new StatsRecordingPlanOptimizer(
optimizerStats,
new PredicatePushDown(plannerContext, typeAnalyzer, true, false)));
@@ -881,7 +885,7 @@ public PlanOptimizers(
ImmutableSet.copyOf(new PushInequalityFilterExpressionBelowJoinRuleSet(metadata, typeAnalyzer).rules())));
// Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential
// pushdown into the connectors. Invoke PredicatePushdown and PushPredicateIntoTableScan after this
- // to leverage predicate pushdown on projected columns.
+ // to leverage predicate pushdown on projected columns and to pushdown dynamic filters.
builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(plannerContext, typeAnalyzer, true, true)));
builder.add(new RemoveUnsupportedDynamicFilters(plannerContext)); // Remove unsupported dynamic filters introduced by PredicatePushdown
builder.add(new IterativeOptimizer(
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java
index b1fbf93b1e42..70bc48c193ff 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java
@@ -1033,7 +1033,7 @@ private PlanBuilder filter(PlanBuilder subPlan, Expression predicate, Node node)
subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, analysis.getSubqueries(node));
- return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), subPlan.rewrite(predicate)));
+ return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), coerceIfNecessary(analysis, predicate, subPlan.rewrite(predicate))));
}
private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java
index d3e7c56fb3dd..756841691a25 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java
@@ -65,6 +65,7 @@
import io.trino.sql.tree.Except;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Identifier;
+import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.Intersect;
import io.trino.sql.tree.Join;
import io.trino.sql.tree.JoinCriteria;
@@ -112,10 +113,13 @@
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
+import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
+import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy;
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
+import static io.trino.sql.planner.LogicalPlanner.failFunction;
import static io.trino.sql.planner.PlanBuilder.newPlanBuilder;
import static io.trino.sql.planner.QueryPlanner.coerce;
import static io.trino.sql.planner.QueryPlanner.coerceIfNecessary;
@@ -279,7 +283,7 @@ public RelationPlan addRowFilters(Table node, RelationPlan plan, Function constraints, Table node, RelationPlan plan, Function accessControlScope)
+ {
+ if (constraints.isEmpty()) {
+ return plan;
+ }
+
+ PlanBuilder planBuilder = newPlanBuilder(plan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext)
+ .withScope(accessControlScope.apply(node), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope
+
+ for (Expression constraint : constraints) {
+ planBuilder = subqueryPlanner.handleSubqueries(planBuilder, constraint, analysis.getSubqueries(constraint));
+
+ Expression predicate = new IfExpression(
+ // When predicate evaluates to UNKNOWN (e.g. NULL > 100), it should not violate the check constraint.
+ new CoalesceExpression(coerceIfNecessary(analysis, constraint, planBuilder.rewrite(constraint)), TRUE_LITERAL),
+ TRUE_LITERAL,
+ new Cast(failFunction(plannerContext.getMetadata(), session, CONSTRAINT_VIOLATION, "Check constraint violation: " + constraint), toSqlType(BOOLEAN)));
+
+ planBuilder = planBuilder.withNewRoot(new FilterNode(
+ idAllocator.getNextId(),
+ planBuilder.getRoot(),
+ predicate));
+ }
+
+ return new RelationPlan(planBuilder.getRoot(), plan.getScope(), plan.getFieldMappings(), outerContext);
+ }
+
private RelationPlan addColumnMasks(Table table, RelationPlan plan)
{
Map columnMasks = analysis.getColumnMasks(table);
@@ -809,7 +840,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende
rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, subqueries);
for (Expression expression : complexJoinExpressions) {
- postInnerJoinConditions.add(rootPlanBuilder.rewrite(expression));
+ postInnerJoinConditions.add(coerceIfNecessary(analysis, expression, rootPlanBuilder.rewrite(expression)));
}
root = rootPlanBuilder.getRoot();
@@ -994,7 +1025,7 @@ private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Latera
.withAdditionalMappings(leftPlanBuilder.getTranslations().getMappings())
.withAdditionalMappings(rightPlanBuilder.getTranslations().getMappings());
- Expression rewrittenFilterCondition = translationMap.rewrite(filterExpression);
+ Expression rewrittenFilterCondition = coerceIfNecessary(analysis, filterExpression, translationMap.rewrite(filterExpression));
PlanBuilder planBuilder = subqueryPlanner.appendCorrelatedJoin(
leftPlanBuilder,
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
index a666f6daeec5..f0bbd4088452 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
@@ -98,11 +98,12 @@ public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet
private static final Capture PROJECTION = newCapture();
private static final Capture AGGREGATION = newCapture();
private static final Capture GROUP_ID = newCapture();
+ private static final Capture REMOTE_EXCHANGE = newCapture();
private static final Pattern WITH_PROJECTION =
// If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges
typeOf(ExchangeNode.class)
- .with(scope().equalTo(REMOTE))
+ .with(scope().equalTo(REMOTE)).capturedAs(REMOTE_EXCHANGE)
.with(source().matching(
// PushPartialAggregationThroughExchange adds a projection. However, it can be removed if RemoveRedundantIdentityProjections is run in the mean-time.
typeOf(ProjectNode.class).capturedAs(PROJECTION)
@@ -116,7 +117,7 @@ public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet
private static final Pattern WITHOUT_PROJECTION =
// If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges
typeOf(ExchangeNode.class)
- .with(scope().equalTo(REMOTE))
+ .with(scope().equalTo(REMOTE)).capturedAs(REMOTE_EXCHANGE)
.with(source().matching(
typeOf(AggregationNode.class).capturedAs(AGGREGATION)
.with(step().equalTo(AggregationNode.Step.PARTIAL))
@@ -166,7 +167,8 @@ public Result apply(ExchangeNode exchange, Captures captures, Context context)
ProjectNode project = captures.get(PROJECTION);
AggregationNode aggregation = captures.get(AGGREGATION);
GroupIdNode groupId = captures.get(GROUP_ID);
- return transform(aggregation, groupId, context)
+ ExchangeNode remoteExchange = captures.get(REMOTE_EXCHANGE);
+ return transform(aggregation, groupId, remoteExchange.getPartitioningScheme().getPartitionCount(), context)
.map(newAggregation -> Result.ofPlanNode(
exchange.replaceChildren(ImmutableList.of(
project.replaceChildren(ImmutableList.of(
@@ -189,7 +191,8 @@ public Result apply(ExchangeNode exchange, Captures captures, Context context)
{
AggregationNode aggregation = captures.get(AGGREGATION);
GroupIdNode groupId = captures.get(GROUP_ID);
- return transform(aggregation, groupId, context)
+ ExchangeNode remoteExchange = captures.get(REMOTE_EXCHANGE);
+ return transform(aggregation, groupId, remoteExchange.getPartitioningScheme().getPartitionCount(), context)
.map(newAggregation -> {
PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newAggregation));
return Result.ofPlanNode(newExchange);
@@ -212,7 +215,7 @@ public boolean isEnabled(Session session)
return isEnableForcedExchangeBelowGroupId(session);
}
- protected Optional transform(AggregationNode aggregation, GroupIdNode groupId, Context context)
+ protected Optional transform(AggregationNode aggregation, GroupIdNode groupId, Optional partitionCount, Context context)
{
if (groupId.getGroupingSets().size() < 2) {
return Optional.empty();
@@ -276,7 +279,12 @@ protected Optional transform(AggregationNode aggregation, GroupIdNode
source,
new PartitioningScheme(
Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashSymbols),
- source.getOutputSymbols()));
+ source.getOutputSymbols(),
+ Optional.empty(),
+ false,
+ Optional.empty(),
+ // It's fine to reuse partitionCount since that is computed by considering all the expanding nodes and table scans in a query
+ partitionCount));
source = partitionedExchange(
context.getIdAllocator().getNextId(),
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java
index 285929c303a8..e3105bac822b 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableExecutePartitioning.java
@@ -61,11 +61,6 @@ public Result apply(TableExecuteNode node, Captures captures, Context context)
return enable(node);
}
int minimumNumberOfPartitions = getPreferredWritePartitioningMinNumberOfPartitions(context.getSession());
- if (minimumNumberOfPartitions <= 1) {
- // Force 'preferred write partitioning' even if stats are missing or broken
- return enable(node);
- }
-
double expectedNumberOfPartitions = getRowsCount(
context.getStatsProvider().getStats(node.getSource()),
node.getPreferredPartitioningScheme().get().getPartitioning().getColumns());
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java
index 9745122240fe..e2147052880d 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ApplyPreferredTableWriterPartitioning.java
@@ -67,11 +67,6 @@ public Result apply(TableWriterNode node, Captures captures, Context context)
}
int minimumNumberOfPartitions = getPreferredWritePartitioningMinNumberOfPartitions(context.getSession());
- if (minimumNumberOfPartitions <= 1) {
- // Force 'preferred write partitioning' even if stats are missing or broken
- return enable(node);
- }
-
double expectedNumberOfPartitions = getRowsCount(
context.getStatsProvider().getStats(node.getSource()),
node.getPreferredPartitioningScheme().get().getPartitioning().getColumns());
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java
index fc6d1ca38f9f..2122ec834764 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java
@@ -105,7 +105,8 @@ protected Optional pushDownProjectOff(Context context, ExchangeNode ex
newOutputs.build(),
exchangeNode.getPartitioningScheme().getHashColumn(),
exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(),
- exchangeNode.getPartitioningScheme().getBucketToPartition());
+ exchangeNode.getPartitioningScheme().getBucketToPartition(),
+ exchangeNode.getPartitioningScheme().getPartitionCount());
return Optional.of(new ExchangeNode(
exchangeNode.getId(),
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java
index 97c28495a32b..5fc335b8a6c5 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java
@@ -178,7 +178,8 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange,
aggregation.getOutputSymbols(),
exchange.getPartitioningScheme().getHashColumn(),
exchange.getPartitioningScheme().isReplicateNullsAndAny(),
- exchange.getPartitioningScheme().getBucketToPartition());
+ exchange.getPartitioningScheme().getBucketToPartition(),
+ exchange.getPartitioningScheme().getPartitionCount());
return new ExchangeNode(
context.getIdAllocator().getNextId(),
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java
index 26edf6900b8b..57d4ff82ea77 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java
@@ -172,7 +172,8 @@ public Result apply(ProjectNode project, Captures captures, Context context)
outputBuilder.build(),
exchange.getPartitioningScheme().getHashColumn(),
exchange.getPartitioningScheme().isReplicateNullsAndAny(),
- exchange.getPartitioningScheme().getBucketToPartition());
+ exchange.getPartitioningScheme().getBucketToPartition(),
+ exchange.getPartitioningScheme().getPartitionCount());
PlanNode result = new ExchangeNode(
exchange.getId(),
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java
index 64eb1a387672..44f5dcd887ff 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java
@@ -80,7 +80,8 @@ public Result apply(ExchangeNode node, Captures captures, Context context)
removeSymbol(partitioningScheme.getOutputLayout(), assignUniqueId.getIdColumn()),
partitioningScheme.getHashColumn(),
partitioningScheme.isReplicateNullsAndAny(),
- partitioningScheme.getBucketToPartition()),
+ partitioningScheme.getBucketToPartition(),
+ partitioningScheme.getPartitionCount()),
ImmutableList.of(assignUniqueId.getSource()),
ImmutableList.of(removeSymbol(getOnlyElement(node.getInputs()), assignUniqueId.getIdColumn())),
Optional.empty()),
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java
index 1277f932c688..29e88d5f5820 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveTrivialFilters.java
@@ -18,9 +18,11 @@
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.ValuesNode;
+import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NullLiteral;
+import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.sql.planner.plan.Patterns.filter;
import static io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
@@ -41,12 +43,14 @@ public Pattern getPattern()
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
Expression predicate = filterNode.getPredicate();
+ checkArgument(!(predicate instanceof NullLiteral), "Unexpected null literal without a cast to boolean");
if (predicate.equals(TRUE_LITERAL)) {
return Result.ofPlanNode(filterNode.getSource());
}
- if (predicate.equals(FALSE_LITERAL) || predicate instanceof NullLiteral) {
+ if (predicate.equals(FALSE_LITERAL) ||
+ (predicate instanceof Cast cast && cast.getExpression() instanceof NullLiteral)) {
return Result.ofPlanNode(new ValuesNode(context.getIdAllocator().getNextId(), filterNode.getOutputSymbols(), emptyList()));
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java
index 0e64c07fdfd5..0bd95781c443 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java
@@ -38,7 +38,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionPartitionCount;
-import static io.trino.SystemSessionProperties.getHashPartitionCount;
+import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
@@ -130,7 +130,7 @@ public Result apply(AggregationNode node, Captures captures, Context context)
partitionCount = getFaultTolerantExecutionPartitionCount(context.getSession());
}
else {
- partitionCount = getHashPartitionCount(context.getSession());
+ partitionCount = getMaxHashPartitionCount(context.getSession());
}
return Result.ofPlanNode(
AggregationNode.builderFrom(node)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java
index 667ab168b37c..5ac1f938156d 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java
@@ -975,6 +975,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p
filteringSource.getNode().getOutputSymbols(),
Optional.empty(),
true,
+ Optional.empty(),
Optional.empty())),
filteringSource.getProperties());
}
@@ -1009,6 +1010,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, PreferredProperties p
filteringSource.getNode().getOutputSymbols(),
Optional.empty(),
true,
+ Optional.empty(),
Optional.empty())),
filteringSource.getProperties());
}
@@ -1179,6 +1181,7 @@ public PlanWithProperties visitUnion(UnionNode node, PreferredProperties parentP
source.getNode().getOutputSymbols(),
Optional.empty(),
nullsAndAnyReplicated,
+ Optional.empty(),
Optional.empty())),
source.getProperties());
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java
new file mode 100644
index 000000000000..3aa4593cc2ca
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java
@@ -0,0 +1,321 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.sql.planner.optimizations;
+
+import com.google.common.collect.ImmutableList;
+import io.airlift.log.Logger;
+import io.trino.Session;
+import io.trino.cost.CachingStatsProvider;
+import io.trino.cost.StatsCalculator;
+import io.trino.cost.StatsProvider;
+import io.trino.cost.TableStatsProvider;
+import io.trino.execution.warnings.WarningCollector;
+import io.trino.operator.RetryPolicy;
+import io.trino.sql.planner.PartitioningHandle;
+import io.trino.sql.planner.PlanNodeIdAllocator;
+import io.trino.sql.planner.SymbolAllocator;
+import io.trino.sql.planner.SystemPartitioningHandle;
+import io.trino.sql.planner.TypeProvider;
+import io.trino.sql.planner.plan.ExchangeNode;
+import io.trino.sql.planner.plan.JoinNode;
+import io.trino.sql.planner.plan.MergeWriterNode;
+import io.trino.sql.planner.plan.PlanNode;
+import io.trino.sql.planner.plan.SimplePlanRewriter;
+import io.trino.sql.planner.plan.TableExecuteNode;
+import io.trino.sql.planner.plan.TableScanNode;
+import io.trino.sql.planner.plan.TableWriterNode;
+import io.trino.sql.planner.plan.UnionNode;
+import io.trino.sql.planner.plan.UnnestNode;
+import io.trino.sql.planner.plan.ValuesNode;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.ToDoubleFunction;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static io.trino.SystemSessionProperties.MAX_WRITERS_NODES_COUNT;
+import static io.trino.SystemSessionProperties.getMaxHashPartitionCount;
+import static io.trino.SystemSessionProperties.getMinHashPartitionCount;
+import static io.trino.SystemSessionProperties.getMinInputRowsPerTask;
+import static io.trino.SystemSessionProperties.getMinInputSizePerTask;
+import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode;
+import static io.trino.SystemSessionProperties.getRetryPolicy;
+import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
+import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
+import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith;
+import static java.lang.Double.isNaN;
+import static java.lang.Math.incrementExact;
+import static java.lang.Math.max;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This rule looks at the amount of data read and processed by the query to determine the value of partition count
+ * used for remote exchanges. It helps to increase the concurrency of the engine in the case of large cluster.
+ * This rule is also cautious about lack of or incorrect statistics therefore it skips for input multiplying nodes like
+ * CROSS JOIN or UNNEST.
+ *
+ * E.g. 1:
+ * Given query: SELECT count(column_a) FROM table_with_stats_a
+ * config:
+ * MIN_INPUT_SIZE_PER_TASK: 500 MB
+ * Input table data size: 1000 MB
+ * Estimated partition count: Input table data size / MIN_INPUT_SIZE_PER_TASK => 2
+ *
+ * E.g. 2:
+ * Given query: SELECT * FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_b = b.column_b
+ * config:
+ * MIN_INPUT_SIZE_PER_TASK: 500 MB
+ * Input tables data size: 1000 MB
+ * Join output data size: 5000 MB
+ * Estimated partition count: max((Input table data size / MIN_INPUT_SIZE_PER_TASK), (Join output data size / MIN_INPUT_SIZE_PER_TASK)) => 10
+ */
+public class DeterminePartitionCount
+ implements PlanOptimizer
+{
+ private static final Logger log = Logger.get(DeterminePartitionCount.class);
+ private static final List> SKIP_PLAN_NODES = ImmutableList.of(TableExecuteNode.class, MergeWriterNode.class);
+
+ private final StatsCalculator statsCalculator;
+
+ public DeterminePartitionCount(StatsCalculator statsCalculator)
+ {
+ this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
+ }
+
+ @Override
+ public PlanNode optimize(
+ PlanNode plan,
+ Session session,
+ TypeProvider types,
+ SymbolAllocator symbolAllocator,
+ PlanNodeIdAllocator idAllocator,
+ WarningCollector warningCollector,
+ TableStatsProvider tableStatsProvider)
+ {
+ requireNonNull(plan, "plan is null");
+ requireNonNull(session, "session is null");
+ requireNonNull(types, "types is null");
+ requireNonNull(tableStatsProvider, "tableStatsProvider is null");
+
+ // Skip for write nodes since writing partitioned data with small amount of nodes could cause
+ // memory related issues even when the amount of data is small. Additionally, skip for FTE mode since we
+ // are not using estimated partitionCount in FTE scheduler.
+ if (PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(SKIP_PLAN_NODES).matches()
+ || getRetryPolicy(session) == RetryPolicy.TASK) {
+ return plan;
+ }
+
+ List tableWriterSources = PlanNodeSearcher
+ .searchFrom(plan)
+ .recurseOnlyWhen(planNode -> planNode instanceof TableWriterNode)
+ .where(planNode -> planNode instanceof ExchangeNode)
+ .findAll();
+
+ Optional partitionCount = tableWriterSources.isEmpty()
+ ? determinePartitionCount(plan, session, types, tableStatsProvider)
+ : Optional.of(session.getSystemProperty(MAX_WRITERS_NODES_COUNT, Integer.class));
+ try {
+ return partitionCount
+ .map(count -> rewriteWith(new Rewriter(count), plan))
+ .orElse(plan);
+ }
+
+ catch (RuntimeException e) {
+ log.warn(e, "Error occurred when determining hash partition count for query %s", session.getQueryId());
+ }
+
+ return plan;
+ }
+
+ private Optional determinePartitionCount(PlanNode plan, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
+ {
+ long minInputSizePerTask = getMinInputSizePerTask(session).toBytes();
+ long minInputRowsPerTask = getMinInputRowsPerTask(session);
+ if (minInputSizePerTask == 0 || minInputRowsPerTask == 0) {
+ return Optional.empty();
+ }
+
+ // Skip for expanding plan nodes like CROSS JOIN or UNNEST which can substantially increase the amount of data.
+ if (isInputMultiplyingPlanNodePresent(plan)) {
+ return Optional.empty();
+ }
+
+ StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types, tableStatsProvider);
+ long queryMaxMemoryPerNode = getQueryMaxMemoryPerNode(session).toBytes();
+
+ // Calculate partition count based on nodes output data size and rows
+ Optional partitionCountBasedOnOutputSize = getPartitionCountBasedOnOutputSize(
+ plan,
+ statsProvider,
+ types,
+ minInputSizePerTask,
+ queryMaxMemoryPerNode);
+ Optional partitionCountBasedOnRows = getPartitionCountBasedOnRows(plan, statsProvider, minInputRowsPerTask);
+
+ if (partitionCountBasedOnOutputSize.isEmpty() || partitionCountBasedOnRows.isEmpty()) {
+ return Optional.empty();
+ }
+
+ int partitionCount = max(
+ // Consider both output size and rows count to estimate the value of partition count. This is essential
+ // because huge number of small size rows can be cpu intensive for some operators. On the other
+ // hand, small number of rows with considerable size in bytes can be memory intensive.
+ max(partitionCountBasedOnOutputSize.get(), partitionCountBasedOnRows.get()),
+ getMinHashPartitionCount(session));
+
+ int maxHashPartitionCount = getMaxHashPartitionCount(session);
+ if (partitionCount >= maxHashPartitionCount) {
+ return Optional.empty();
+ }
+
+ log.debug("Estimated remote exchange partition count for query %s is %s", session.getQueryId(), partitionCount);
+ return Optional.of(partitionCount);
+ }
+
+ private static Optional getPartitionCountBasedOnOutputSize(
+ PlanNode plan,
+ StatsProvider statsProvider,
+ TypeProvider types,
+ long minInputSizePerTask,
+ long queryMaxMemoryPerNode)
+ {
+ double sourceTablesOutputSize = getSourceNodesOutputStats(
+ plan,
+ node -> statsProvider.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
+ double expandingNodesMaxOutputSize = getExpandingNodesMaxOutputStats(
+ plan,
+ node -> statsProvider.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
+ if (isNaN(sourceTablesOutputSize) || isNaN(expandingNodesMaxOutputSize)) {
+ return Optional.empty();
+ }
+ int partitionCountBasedOnOutputSize = getPartitionCount(
+ max(sourceTablesOutputSize, expandingNodesMaxOutputSize), minInputSizePerTask);
+
+ // Calculate partition count based on maximum memory usage. This is based on the assumption that
+ // generally operators won't keep data in memory more than the size of input data.
+ int partitionCountBasedOnMemory = (int) ((max(sourceTablesOutputSize, expandingNodesMaxOutputSize) * 2) / queryMaxMemoryPerNode);
+
+ return Optional.of(max(partitionCountBasedOnOutputSize, partitionCountBasedOnMemory));
+ }
+
+ private static Optional getPartitionCountBasedOnRows(PlanNode plan, StatsProvider statsProvider, long minInputRowsPerTask)
+ {
+ double sourceTablesRowCount = getSourceNodesOutputStats(plan, node -> statsProvider.getStats(node).getOutputRowCount());
+ double expandingNodesMaxRowCount = getExpandingNodesMaxOutputStats(plan, node -> statsProvider.getStats(node).getOutputRowCount());
+ if (isNaN(sourceTablesRowCount) || isNaN(expandingNodesMaxRowCount)) {
+ return Optional.empty();
+ }
+
+ return Optional.of(getPartitionCount(
+ max(sourceTablesRowCount, expandingNodesMaxRowCount), minInputRowsPerTask));
+ }
+
+ private static int getPartitionCount(double outputStats, long minInputStatsPerTask)
+ {
+ return max((int) (outputStats / minInputStatsPerTask), 1);
+ }
+
+ private static boolean isInputMultiplyingPlanNodePresent(PlanNode root)
+ {
+ return PlanNodeSearcher.searchFrom(root)
+ .where(DeterminePartitionCount::isInputMultiplyingPlanNode)
+ .matches();
+ }
+
+ private static boolean isInputMultiplyingPlanNode(PlanNode node)
+ {
+ if (node instanceof UnnestNode) {
+ return true;
+ }
+
+ if (node instanceof JoinNode joinNode) {
+ // Skip for cross join
+ if (joinNode.isCrossJoin()) {
+ // If any of the input node is scalar then there's no need to skip cross join
+ return !isAtMostScalar(joinNode.getRight()) && !isAtMostScalar(joinNode.getLeft());
+ }
+
+ // Skip for joins with multi keys since output row count stats estimation can wrong due to
+ // low correlation between multiple join keys.
+ return joinNode.getCriteria().size() > 1;
+ }
+
+ return false;
+ }
+
+ private static double getExpandingNodesMaxOutputStats(PlanNode root, ToDoubleFunction statsMapper)
+ {
+ List expandingNodes = PlanNodeSearcher.searchFrom(root)
+ .where(DeterminePartitionCount::isExpandingPlanNode)
+ .findAll();
+
+ return expandingNodes.stream()
+ .mapToDouble(statsMapper)
+ .max()
+ .orElse(0);
+ }
+
+ private static boolean isExpandingPlanNode(PlanNode node)
+ {
+ return node instanceof JoinNode
+ // consider union node and exchange node with multiple sources as expanding since it merge the rows
+ // from two different sources, thus more data is transferred over the network.
+ || node instanceof UnionNode
+ || (node instanceof ExchangeNode && node.getSources().size() > 1);
+ }
+
+ private static double getSourceNodesOutputStats(PlanNode root, ToDoubleFunction statsMapper)
+ {
+ List sourceNodes = PlanNodeSearcher.searchFrom(root)
+ .whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class)
+ .findAll();
+
+ return sourceNodes.stream()
+ .mapToDouble(statsMapper)
+ .sum();
+ }
+
+ private static class Rewriter
+ extends SimplePlanRewriter
+ {
+ private final int partitionCount;
+
+ private Rewriter(int partitionCount)
+ {
+ this.partitionCount = partitionCount;
+ }
+
+ @Override
+ public PlanNode visitExchange(ExchangeNode node, RewriteContext context)
+ {
+ PartitioningHandle handle = node.getPartitioningScheme().getPartitioning().getHandle();
+ if (!(node.getScope() == REMOTE && handle.getConnectorHandle() instanceof SystemPartitioningHandle)) {
+ return node;
+ }
+
+ List sources = node.getSources().stream()
+ .map(context::rewrite)
+ .collect(toImmutableList());
+
+ return new ExchangeNode(
+ node.getId(),
+ node.getType(),
+ node.getScope(),
+ node.getPartitioningScheme().withPartitionCount(partitionCount),
+ sources,
+ node.getInputs(),
+ node.getOrderingScheme());
+ }
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DetermineWritersNodesCount.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DetermineWritersNodesCount.java
new file mode 100644
index 000000000000..128b117527ff
--- /dev/null
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DetermineWritersNodesCount.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.sql.planner.optimizations;
+
+import io.trino.Session;
+import io.trino.cost.TableStatsProvider;
+import io.trino.execution.warnings.WarningCollector;
+import io.trino.operator.RetryPolicy;
+import io.trino.sql.planner.PartitioningHandle;
+import io.trino.sql.planner.PlanNodeIdAllocator;
+import io.trino.sql.planner.SymbolAllocator;
+import io.trino.sql.planner.SystemPartitioningHandle;
+import io.trino.sql.planner.TypeProvider;
+import io.trino.sql.planner.plan.ExchangeNode;
+import io.trino.sql.planner.plan.PlanNode;
+import io.trino.sql.planner.plan.SimplePlanRewriter;
+import io.trino.sql.planner.plan.TableWriterNode;
+
+import static io.trino.SystemSessionProperties.MAX_HASH_PARTITION_COUNT;
+import static io.trino.SystemSessionProperties.MAX_WRITERS_NODES_COUNT;
+import static io.trino.SystemSessionProperties.getRetryPolicy;
+import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION;
+import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
+import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith;
+import static java.util.Objects.requireNonNull;
+
+public class DetermineWritersNodesCount
+ implements PlanOptimizer
+{
+ @Override
+ public PlanNode optimize(
+ PlanNode plan,
+ Session session,
+ TypeProvider types,
+ SymbolAllocator symbolAllocator,
+ PlanNodeIdAllocator idAllocator,
+ WarningCollector warningCollector,
+ TableStatsProvider tableStatsProvider)
+ {
+ requireNonNull(plan, "plan is null");
+ requireNonNull(session, "session is null");
+
+ // Skip for plans where there is not writing stages Additionally, skip for FTE mode since we
+ // are not using estimated partitionCount in FTE scheduler.
+
+ if (!PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(TableWriterNode.class).matches() || getRetryPolicy(session) == RetryPolicy.TASK) {
+ return plan;
+ }
+
+ return rewriteWith(new Rewriter(
+ session.getSystemProperty(MAX_WRITERS_NODES_COUNT, Integer.class),
+ session.getSystemProperty(MAX_HASH_PARTITION_COUNT, Integer.class)), plan);
+ }
+
+ private static class Rewriter
+ extends SimplePlanRewriter
+ {
+ private final int maxWriterNodesCount;
+ private final int maxHashPartitionCount;
+
+ private Rewriter(int maxWriterNodesCount, int maxHashPartitionCount)
+ {
+ this.maxWriterNodesCount = maxWriterNodesCount;
+ this.maxHashPartitionCount = maxHashPartitionCount;
+ }
+
+ @Override
+ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context)
+ {
+ if (!(node.getSource() instanceof ExchangeNode)) {
+ return node;
+ }
+
+ ExchangeNode source = (ExchangeNode) node.getSource();
+ PartitioningHandle handle = source.getPartitioningScheme().getPartitioning().getHandle();
+
+ if (source.getScope() != REMOTE || !(handle.getConnectorHandle() instanceof SystemPartitioningHandle) || handle != SCALED_WRITER_HASH_DISTRIBUTION) {
+ return node;
+ }
+
+ // For TableWriterNode's sources (exchanges) there is no adaptive hash partition count. Then max-partition-hash-count is used.
+ // We limit that value with maxWriterNodesCount for writing stages, and only for them.
+
+ return new TableWriterNode(
+ node.getId(),
+ new ExchangeNode(
+ source.getId(),
+ source.getType(),
+ source.getScope(),
+ source.getPartitioningScheme().withPartitionCount(Math.min(maxWriterNodesCount, maxHashPartitionCount)),
+ source.getSources(),
+ source.getInputs(),
+ source.getOrderingScheme()),
+ node.getTarget(),
+ node.getRowCountSymbol(),
+ node.getFragmentSymbol(),
+ node.getColumns(),
+ node.getColumnNames(),
+ node.getPartitioningScheme(),
+ node.getPreferredPartitioningScheme(),
+ node.getStatisticsAggregation(),
+ node.getStatisticsAggregationDescriptor());
+ }
+ }
+}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java
index d1146bb91035..e4f9cadf04ec 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java
@@ -539,7 +539,8 @@ public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet pa
.build(),
partitionSymbols.map(newHashSymbols::get),
partitioningScheme.isReplicateNullsAndAny(),
- partitioningScheme.getBucketToPartition());
+ partitioningScheme.getBucketToPartition(),
+ partitioningScheme.getPartitionCount());
// add hash symbols to sources
ImmutableList.Builder> newInputs = ImmutableList.builder();
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java
index a8f379b57d2d..94f66ffb6daf 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java
@@ -482,7 +482,8 @@ public PartitioningScheme map(PartitioningScheme scheme, List sourceLayo
mapAndDistinct(sourceLayout),
scheme.getHashColumn().map(this::map),
scheme.isReplicateNullsAndAny(),
- scheme.getBucketToPartition());
+ scheme.getBucketToPartition(),
+ scheme.getPartitionCount());
}
public TableFinishNode map(TableFinishNode node, PlanNode source)
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java
index ed935944f519..60441272eacf 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java
@@ -20,6 +20,7 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Join;
import io.trino.sql.tree.Node;
+import io.trino.sql.tree.NullLiteral;
import javax.annotation.concurrent.Immutable;
@@ -100,6 +101,9 @@ public CorrelatedJoinNode(
requireNonNull(subquery, "subquery is null");
requireNonNull(correlation, "correlation is null");
requireNonNull(filter, "filter is null");
+ // The condition doesn't guarantee that filter is of type boolean, but was found to be a practical way to identify
+ // places where CorrelatedJoinNode could be created without appropriate coercions.
+ checkArgument(!(filter instanceof NullLiteral), "Filter must be an expression of boolean type: %s", filter);
requireNonNull(originSubquery, "originSubquery is null");
checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation");
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java
index f5b8fe0fe3b2..52f7e10432b8 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java
@@ -132,6 +132,7 @@ public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanN
child.getOutputSymbols(),
hashColumns,
replicateNullsAndAny,
+ Optional.empty(),
Optional.empty()));
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java
index a589aa5f5634..e7e8d935829d 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/FilterNode.java
@@ -19,11 +19,15 @@
import com.google.common.collect.Iterables;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.Expression;
+import io.trino.sql.tree.NullLiteral;
import javax.annotation.concurrent.Immutable;
import java.util.List;
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
@Immutable
public class FilterNode
extends PlanNode
@@ -39,6 +43,10 @@ public FilterNode(@JsonProperty("id") PlanNodeId id,
super(id);
this.source = source;
+ requireNonNull(predicate, "predicate is null");
+ // The condition doesn't guarantee that predicate is of type boolean, but was found to be a practical way to identify
+ // places where FilterNode was created without appropriate coercions.
+ checkArgument(!(predicate instanceof NullLiteral), "Predicate must be an expression of boolean type: %s", predicate);
this.predicate = predicate;
}
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java
index 6b8c00f7f49c..e470e443d4e3 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/JoinNode.java
@@ -23,6 +23,7 @@
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Join;
+import io.trino.sql.tree.NullLiteral;
import javax.annotation.concurrent.Immutable;
@@ -89,6 +90,9 @@ public JoinNode(
requireNonNull(leftOutputSymbols, "leftOutputSymbols is null");
requireNonNull(rightOutputSymbols, "rightOutputSymbols is null");
requireNonNull(filter, "filter is null");
+ // The condition doesn't guarantee that filter is of type boolean, but was found to be a practical way to identify
+ // places where JoinNode could be created without appropriate coercions.
+ checkArgument(filter.isEmpty() || !(filter.get() instanceof NullLiteral), "Filter must be an expression of boolean type: %s", filter);
requireNonNull(leftHashSymbol, "leftHashSymbol is null");
requireNonNull(rightHashSymbol, "rightHashSymbol is null");
requireNonNull(distributionType, "distributionType is null");
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java
index dcfa3fc2767d..1ab5824bbc65 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java
@@ -583,6 +583,8 @@ private static String formatFragment(
hashColumn));
}
+ partitioningScheme.getPartitionCount().ifPresent(partitionCount -> builder.append(format("Partition count: %s\n", partitionCount)));
+
builder.append(
new PlanPrinter(
fragment.getRoot(),
@@ -1639,6 +1641,7 @@ else if (node.getScope() == Scope.LOCAL) {
addNode(node,
format("%sExchange", UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, node.getScope().toString())),
ImmutableMap.of(
+ "partitionCount", node.getPartitioningScheme().getPartitionCount().map(String::valueOf).orElse(""),
"type", node.getType().name(),
"isReplicateNullsAndAny", formatBoolean(node.getPartitioningScheme().isReplicateNullsAndAny()),
"hashColumn", formatHash(node.getPartitioningScheme().getHashColumn())),
diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java
index f4c74af43983..46320df3c19d 100644
--- a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java
+++ b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java
@@ -60,6 +60,8 @@
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS;
+import static io.trino.spi.connector.MaterializedViewFreshness.Freshness.FRESH;
+import static io.trino.spi.connector.MaterializedViewFreshness.Freshness.STALE;
import static java.util.Collections.synchronizedSet;
import static java.util.Objects.requireNonNull;
@@ -269,7 +271,7 @@ public void dropMaterializedView(ConnectorSession session, SchemaTableName viewN
@Override
public MaterializedViewFreshness getMaterializedViewFreshness(ConnectorSession session, SchemaTableName name)
{
- return new MaterializedViewFreshness(freshMaterializedViews.contains(name));
+ return new MaterializedViewFreshness(freshMaterializedViews.contains(name) ? FRESH : STALE);
}
public void markMaterializedViewIsFresh(SchemaTableName name)
diff --git a/core/trino-main/src/main/java/io/trino/type/TypeCoercion.java b/core/trino-main/src/main/java/io/trino/type/TypeCoercion.java
index 4be3958939d7..b1455c76d18a 100644
--- a/core/trino-main/src/main/java/io/trino/type/TypeCoercion.java
+++ b/core/trino-main/src/main/java/io/trino/type/TypeCoercion.java
@@ -46,10 +46,8 @@
import static io.trino.spi.type.RowType.Field;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimeType.createTimeType;
-import static io.trino.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType;
import static io.trino.spi.type.TimestampType.createTimestampType;
-import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType;
import static io.trino.spi.type.VarcharType.createUnboundedVarcharType;
import static io.trino.spi.type.VarcharType.createVarcharType;
@@ -453,7 +451,7 @@ public Optional coerceTypeBase(Type sourceType, String resultTypeBase)
case StandardTypes.TIMESTAMP:
return Optional.of(createTimestampType(0));
case StandardTypes.TIMESTAMP_WITH_TIME_ZONE:
- return Optional.of(TIMESTAMP_TZ_MILLIS);
+ return Optional.of(createTimestampWithTimeZoneType(0));
default:
return Optional.empty();
}
@@ -461,7 +459,7 @@ public Optional coerceTypeBase(Type sourceType, String resultTypeBase)
case StandardTypes.TIME: {
switch (resultTypeBase) {
case StandardTypes.TIME_WITH_TIME_ZONE:
- return Optional.of(TIME_WITH_TIME_ZONE);
+ return Optional.of(createTimeWithTimeZoneType(((TimeType) sourceType).getPrecision()));
default:
return Optional.empty();
}
diff --git a/core/trino-main/src/test/java/io/trino/block/TestRowBlock.java b/core/trino-main/src/test/java/io/trino/block/TestRowBlock.java
index 224e3eef9a7a..bdc07dd4d62d 100644
--- a/core/trino-main/src/test/java/io/trino/block/TestRowBlock.java
+++ b/core/trino-main/src/test/java/io/trino/block/TestRowBlock.java
@@ -108,8 +108,8 @@ public void testCompactBlock()
Block emptyBlock = new ByteArrayBlock(0, Optional.empty(), new byte[0]);
Block compactFieldBlock1 = new ByteArrayBlock(5, Optional.empty(), createExpectedValue(5).getBytes());
Block compactFieldBlock2 = new ByteArrayBlock(5, Optional.empty(), createExpectedValue(5).getBytes());
- Block incompactFiledBlock1 = new ByteArrayBlock(5, Optional.empty(), createExpectedValue(6).getBytes());
- Block incompactFiledBlock2 = new ByteArrayBlock(5, Optional.empty(), createExpectedValue(6).getBytes());
+ Block incompactFieldBlock1 = new ByteArrayBlock(5, Optional.empty(), createExpectedValue(6).getBytes());
+ Block incompactFieldBlock2 = new ByteArrayBlock(5, Optional.empty(), createExpectedValue(6).getBytes());
boolean[] rowIsNull = {false, true, false, false, false, false};
assertCompact(fromFieldBlocks(0, Optional.empty(), new Block[] {emptyBlock, emptyBlock}));
@@ -117,8 +117,8 @@ public void testCompactBlock()
// TODO: add test case for a sliced RowBlock
// underlying field blocks are not compact
- testIncompactBlock(fromFieldBlocks(rowIsNull.length, Optional.of(rowIsNull), new Block[] {incompactFiledBlock1, incompactFiledBlock2}));
- testIncompactBlock(fromFieldBlocks(rowIsNull.length, Optional.of(rowIsNull), new Block[] {incompactFiledBlock1, incompactFiledBlock2}));
+ testIncompactBlock(fromFieldBlocks(rowIsNull.length, Optional.of(rowIsNull), new Block[] {incompactFieldBlock1, incompactFieldBlock2}));
+ testIncompactBlock(fromFieldBlocks(rowIsNull.length, Optional.of(rowIsNull), new Block[] {incompactFieldBlock1, incompactFieldBlock2}));
}
private void testWith(List fieldTypes, List