diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml index 7f9062077538..cbef1e8ca8d3 100644 --- a/core/trino-main/pom.xml +++ b/core/trino-main/pom.xml @@ -43,6 +43,12 @@ jackson-databind + + com.github.ishugaliy + allgood-consistent-hash + 1.0.0 + + com.github.oshi oshi-core @@ -269,6 +275,11 @@ trino-matching + + io.trino + trino-memory-cache + + io.trino trino-memory-context diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 59be671b8ce1..acb22ba2198b 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -17,6 +17,7 @@ import com.google.inject.Inject; import io.airlift.units.DataSize; import io.airlift.units.Duration; +import io.trino.cache.CacheConfig; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.QueryManagerConfig; import io.trino.execution.TaskManagerConfig; @@ -215,6 +216,12 @@ public final class SystemSessionProperties public static final String USE_COST_BASED_PARTITIONING = "use_cost_based_partitioning"; public static final String PUSH_FILTER_INTO_VALUES_MAX_ROW_COUNT = "push_filter_into_values_max_row_count"; public static final String FORCE_SPILLING_JOIN = "force_spilling_join"; + public static final String CACHE_ENABLED = "cache_enabled"; + public static final String CACHE_COMMON_SUBQUERIES_ENABLED = "cache_common_subqueries_enabled"; + public static final String CACHE_AGGREGATIONS_ENABLED = "cache_aggregations_enabled"; + public static final String CACHE_PROJECTIONS_ENABLED = "cache_projections_enabled"; + public static final String CACHE_MAX_SPLIT_SIZE = "cache_max_split_size"; + public static final String CACHE_MIN_WORKER_SPLIT_SEPARATION = "cache_min_worker_split_separation"; public static final String PAGE_PARTITIONING_BUFFER_POOL_SIZE = "page_partitioning_buffer_pool_size"; public static final String IDLE_WRITER_MIN_DATA_SIZE_THRESHOLD = "idle_writer_min_data_size_threshold"; public static final String CLOSE_IDLE_WRITERS_TRIGGER_DURATION = "close_idle_writers_trigger_duration"; @@ -232,6 +239,7 @@ public SystemSessionProperties() new OptimizerConfig(), new NodeMemoryConfig(), new DynamicFilterConfig(), + new CacheConfig(), new NodeSchedulerConfig()); } @@ -244,6 +252,7 @@ public SystemSessionProperties( OptimizerConfig optimizerConfig, NodeMemoryConfig nodeMemoryConfig, DynamicFilterConfig dynamicFilterConfig, + CacheConfig cacheConfig, NodeSchedulerConfig nodeSchedulerConfig) { sessionProperties = ImmutableList.of( @@ -1112,6 +1121,41 @@ public SystemSessionProperties( "Enables columnar evaluation of filters", featuresConfig.isColumnarFilterEvaluationEnabled(), false), + booleanProperty( + CACHE_ENABLED, + "Enables subquery caching", + cacheConfig.isEnabled(), + enabled -> { + if (enabled && !cacheConfig.isEnabled()) { + throw new TrinoException(INVALID_SESSION_PROPERTY, "Subquery cache must be enabled via feature config"); + } + }, + true), + booleanProperty( + CACHE_COMMON_SUBQUERIES_ENABLED, + "Enables caching of common subqueries when running a single query", + cacheConfig.isEnabled() && cacheConfig.isCacheCommonSubqueriesEnabled(), + true), + booleanProperty( + CACHE_AGGREGATIONS_ENABLED, + "Enables caching of aggregations", + cacheConfig.isEnabled() && cacheConfig.isCacheAggregationsEnabled(), + true), + booleanProperty( + CACHE_PROJECTIONS_ENABLED, + "Enables caching of projections", + cacheConfig.isEnabled() && cacheConfig.isCacheProjectionsEnabled(), + true), + dataSizeProperty( + CACHE_MAX_SPLIT_SIZE, + "Max size of cached split", + cacheConfig.getMaxSplitSize(), + true), + integerProperty( + CACHE_MIN_WORKER_SPLIT_SEPARATION, + "The minimum separation (in terms of processed splits) between two splits with same cache split id being scheduled on the single worker", + cacheConfig.getCacheMinWorkerSplitSeparation(), + true), integerProperty(PAGE_PARTITIONING_BUFFER_POOL_SIZE, "Maximum number of free buffers in the per task partitioned page buffer pool. Setting this to zero effectively disables the pool", taskManagerConfig.getPagePartitioningBufferPoolSize(), @@ -1998,6 +2042,36 @@ public static boolean isForceSpillingOperator(Session session) return session.getSystemProperty(FORCE_SPILLING_JOIN, Boolean.class); } + public static boolean isCacheEnabled(Session session) + { + return session.getSystemProperty(CACHE_ENABLED, Boolean.class); + } + + public static boolean isCacheCommonSubqueriesEnabled(Session session) + { + return session.getSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, Boolean.class); + } + + public static boolean isCacheAggregationsEnabled(Session session) + { + return session.getSystemProperty(CACHE_AGGREGATIONS_ENABLED, Boolean.class); + } + + public static boolean isCacheProjectionsEnabled(Session session) + { + return session.getSystemProperty(CACHE_PROJECTIONS_ENABLED, Boolean.class); + } + + public static DataSize getCacheMaxSplitSize(Session session) + { + return session.getSystemProperty(CACHE_MAX_SPLIT_SIZE, DataSize.class); + } + + public static int getCacheMinWorkerSplitSeparation(Session session) + { + return session.getSystemProperty(CACHE_MIN_WORKER_SPLIT_SEPARATION, Integer.class); + } + public static int getPagePartitioningBufferPoolSize(Session session) { return session.getSystemProperty(PAGE_PARTITIONING_BUFFER_POOL_SIZE, Integer.class); diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheCommonSubqueries.java b/core/trino-main/src/main/java/io/trino/cache/CacheCommonSubqueries.java new file mode 100644 index 000000000000..30c838a4611f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheCommonSubqueries.java @@ -0,0 +1,173 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.spi.cache.CacheManager; +import io.trino.sql.PlannerContext; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.iterative.Lookup; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; +import io.trino.sql.planner.plan.CacheDataPlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.SimplePlanRewriter; + +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.SystemSessionProperties.isCacheEnabled; +import static io.trino.cache.CommonSubqueriesExtractor.extractCommonSubqueries; +import static io.trino.sql.planner.iterative.Lookup.noLookup; +import static java.util.Objects.requireNonNull; + +/** + * Extracts common subqueries and substitutes each subquery with {@link ChooseAlternativeNode} + * consisting of 3 alternatives: + * * original subplan + * * subplan that caches data with {@link CacheManager} + * * subplan that reads data from {@link CacheManager} + */ +public class CacheCommonSubqueries +{ + public static final int ORIGINAL_PLAN_ALTERNATIVE = 0; + public static final int STORE_PAGES_ALTERNATIVE = 1; + public static final int LOAD_PAGES_ALTERNATIVE = 2; + + private final boolean cacheEnabled; + private final CacheController cacheController; + private final PlannerContext plannerContext; + private final Session session; + private final PlanNodeIdAllocator idAllocator; + private final SymbolAllocator symbolAllocator; + + public CacheCommonSubqueries( + CacheController cacheController, + PlannerContext plannerContext, + Session session, + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator) + { + this.cacheController = requireNonNull(cacheController, "cacheController is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.session = requireNonNull(session, "session is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); + this.cacheEnabled = isCacheEnabled(session); + } + + public PlanNode cacheSubqueries(PlanNode node) + { + if (!cacheEnabled) { + return node; + } + + Map adaptations = extractCommonSubqueries( + cacheController, + plannerContext, + session, + idAllocator, + symbolAllocator, + node); + + // add alternatives for each adaptation + ImmutableMap.Builder nodeMapping = ImmutableMap.builder(); + for (Map.Entry entry : adaptations.entrySet()) { + CommonPlanAdaptation adaptation = entry.getValue(); + + PlanNode storePagesAlternative = + adaptation.adaptCommonSubplan( + new CacheDataPlanNode( + idAllocator.getNextId(), + adaptation.getCommonSubplan()), + idAllocator); + + PlanNode loadPagesAlternative = + adaptation.adaptCommonSubplan( + new LoadCachedDataPlanNode( + idAllocator.getNextId(), + adaptation.getCommonSubplanSignature(), + adaptation.getCommonDynamicFilterDisjuncts(), + adaptation.getCommonColumnHandles(), + adaptation.getCommonSubplan().getOutputSymbols()), + idAllocator); + + PlanNode[] alternatives = new PlanNode[3]; + // use static indexes explicitly to make ensure code stays consistent with static indexes + alternatives[ORIGINAL_PLAN_ALTERNATIVE] = entry.getKey(); + alternatives[STORE_PAGES_ALTERNATIVE] = storePagesAlternative; + alternatives[LOAD_PAGES_ALTERNATIVE] = loadPagesAlternative; + + nodeMapping.put(entry.getKey(), new ChooseAlternativeNode( + idAllocator.getNextId(), + ImmutableList.copyOf(alternatives), + adaptation.getCommonSubplanFilteredTableScan())); + } + + return SimplePlanRewriter.rewriteWith(new PlanReplacer(nodeMapping.buildOrThrow()), node); + } + + public static boolean isCacheChooseAlternativeNode(PlanNode node) + { + return isCacheChooseAlternativeNode(node, noLookup()); + } + + public static boolean isCacheChooseAlternativeNode(PlanNode node, Lookup lookup) + { + if (!(node instanceof ChooseAlternativeNode chooseAlternativeNode)) { + return false; + } + + if (chooseAlternativeNode.getSources().size() != 3) { + return false; + } + + return PlanNodeSearcher.searchFrom(chooseAlternativeNode.getSources().get(LOAD_PAGES_ALTERNATIVE), lookup) + .whereIsInstanceOfAny(LoadCachedDataPlanNode.class) + .matches(); + } + + public static LoadCachedDataPlanNode getLoadCachedDataPlanNode(ChooseAlternativeNode node) + { + checkArgument(isCacheChooseAlternativeNode(node), "ChooseAlternativeNode should contain cache alternatives"); + return (LoadCachedDataPlanNode) PlanNodeSearcher.searchFrom(node.getSources().get(LOAD_PAGES_ALTERNATIVE)) + .whereIsInstanceOfAny(LoadCachedDataPlanNode.class) + .findOnlyElement(); + } + + private static class PlanReplacer + extends SimplePlanRewriter + { + private final Map mapping; + + public PlanReplacer(Map mapping) + { + this.mapping = requireNonNull(mapping, "mapping is null"); + } + + @Override + protected PlanNode visitPlan(PlanNode node, RewriteContext context) + { + if (mapping.containsKey(node)) { + return mapping.get(node); + } + + return context.defaultRewrite(node, context.get()); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheConfig.java b/core/trino-main/src/main/java/io/trino/cache/CacheConfig.java new file mode 100644 index 000000000000..fe78359407c5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheConfig.java @@ -0,0 +1,148 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.LegacyConfig; +import io.airlift.units.DataSize; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Min; + +public class CacheConfig +{ + private boolean enabled; + private double revokingThreshold = 0.9; + private double revokingTarget = 0.7; + private boolean cacheCommonSubqueriesEnabled = true; + private boolean cacheAggregationsEnabled = true; + private boolean cacheProjectionsEnabled = true; + private DataSize maxSplitSize = DataSize.of(256, DataSize.Unit.MEGABYTE); + // The minimum number of splits with distinct CacheSplitID that should be processed by a worker + // before scheduling splits with the same CacheSplitID on the same worker again. + // Since split scheduling is not fully deterministic, the default value is set to 500 + // which keeps cache collisions to a minimum, but avoids excessive fetching of splits. + private int cacheMinWorkerSplitSeparation = 500; + + public boolean isEnabled() + { + return enabled; + } + + @Config("cache.enabled") + @ConfigDescription("Enables pipeline level cache") + public CacheConfig setEnabled(boolean enabled) + { + this.enabled = enabled; + return this; + } + + @DecimalMin("0.0") + @DecimalMax("1.0") + public double getRevokingThreshold() + { + return revokingThreshold; + } + + @Config("cache.revoking-threshold") + @ConfigDescription("Revoke cache memory when memory pool is filled over threshold") + public CacheConfig setRevokingThreshold(double revokingThreshold) + { + this.revokingThreshold = revokingThreshold; + return this; + } + + @DecimalMin("0.0") + @DecimalMax("1.0") + public double getRevokingTarget() + { + return revokingTarget; + } + + @Config("cache.revoking-target") + @ConfigDescription("When revoking cache memory, revoke so much that cache memory reservation is below target at the end") + public CacheConfig setRevokingTarget(double revokingTarget) + { + this.revokingTarget = revokingTarget; + return this; + } + + public boolean isCacheCommonSubqueriesEnabled() + { + return cacheCommonSubqueriesEnabled; + } + + @Config("cache.common-subqueries.enabled") + @LegacyConfig("cache.subqueries.enabled") + @ConfigDescription("Enables caching of common subqueries when running a single query") + public CacheConfig setCacheCommonSubqueriesEnabled(boolean cacheCommonSubqueriesEnabled) + { + this.cacheCommonSubqueriesEnabled = cacheCommonSubqueriesEnabled; + return this; + } + + public boolean isCacheAggregationsEnabled() + { + return cacheAggregationsEnabled; + } + + @Config("cache.aggregations.enabled") + @ConfigDescription("Enables caching of aggregations") + public CacheConfig setCacheAggregationsEnabled(boolean cacheAggregationsEnabled) + { + this.cacheAggregationsEnabled = cacheAggregationsEnabled; + return this; + } + + public boolean isCacheProjectionsEnabled() + { + return cacheProjectionsEnabled; + } + + @Config("cache.projections.enabled") + @ConfigDescription("Enables caching of projections") + public CacheConfig setCacheProjectionsEnabled(boolean cacheProjectionsEnabled) + { + this.cacheProjectionsEnabled = cacheProjectionsEnabled; + return this; + } + + public DataSize getMaxSplitSize() + { + return maxSplitSize; + } + + @Config("cache.max-split-size") + @ConfigDescription("Upper bound for size of cached split") + public CacheConfig setMaxSplitSize(DataSize cacheSubqueriesSize) + { + this.maxSplitSize = cacheSubqueriesSize; + return this; + } + + @Min(0) + public int getCacheMinWorkerSplitSeparation() + { + return cacheMinWorkerSplitSeparation; + } + + @Config("cache.min-worker-split-separation") + @ConfigDescription("The minimum separation (in terms of processed splits) between two splits with same cache split id being scheduled on the single worker") + public CacheConfig setCacheMinWorkerSplitSeparation(int cacheMinWorkerSplitSeparation) + { + this.cacheMinWorkerSplitSeparation = cacheMinWorkerSplitSeparation; + return this; + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheController.java b/core/trino-main/src/main/java/io/trino/cache/CacheController.java new file mode 100644 index 000000000000..7dcd1837b188 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheController.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Multimap; +import io.trino.Session; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.predicate.TupleDomain; + +import java.util.AbstractMap.SimpleEntry; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; +import static io.trino.SystemSessionProperties.isCacheAggregationsEnabled; +import static io.trino.SystemSessionProperties.isCacheCommonSubqueriesEnabled; +import static io.trino.SystemSessionProperties.isCacheProjectionsEnabled; +import static io.trino.cache.CanonicalSubplan.Key; +import static io.trino.cache.CanonicalSubplan.TopNKey; +import static io.trino.cache.CanonicalSubplan.TopNRankingKey; + +public class CacheController +{ + /** + * Logic for cache decision (what to cache, order or caching candidates). + */ + public List getCachingCandidates(Session session, List canonicalSubplans) + { + Multimap groupedSubplans = canonicalSubplans.stream() + .map(subplan -> new SimpleEntry<>(new SubplanKey(subplan), subplan)) + .sorted(Comparator.comparing(entry -> entry.getKey().getPriority())) + .collect(toImmutableListMultimap(SimpleEntry::getKey, SimpleEntry::getValue)); + + List commonSubplans = groupedSubplans.asMap().values().stream() + .filter(subplans -> subplans.size() > 1) + // split grouped subplans by intersection of enforced constraints + .map(this::splitByIntersection) + .flatMap(Collection::stream) + // filter out cache candidates which don't share common subplan + .map(subplans -> new CacheCandidate(ImmutableList.copyOf(subplans), 2)) + .collect(toImmutableList()); + List aggregationSubplans = groupedSubplans.entries().stream() + .filter(entry -> entry.getKey().aggregation()) + .map(entry -> new CacheCandidate(ImmutableList.of(entry.getValue()), 1)) + .collect(toImmutableList()); + List projectionSubplans = groupedSubplans.entries().stream() + .filter(entry -> !entry.getKey().aggregation()) + .map(entry -> new CacheCandidate(ImmutableList.of(entry.getValue()), 1)) + .collect(toImmutableList()); + + ImmutableList.Builder cacheCandidates = ImmutableList.builder(); + + if (isCacheCommonSubqueriesEnabled(session)) { + cacheCandidates.addAll(commonSubplans); + } + + if (isCacheAggregationsEnabled(session)) { + cacheCandidates.addAll(aggregationSubplans); + } + + if (isCacheProjectionsEnabled(session)) { + cacheCandidates.addAll(projectionSubplans); + } + + return cacheCandidates.build(); + } + + private List> splitByIntersection(Collection subplans) + { + ImmutableList.Builder intersectionBuilder = ImmutableList.builder(); + ImmutableList.Builder excludingBuilder = ImmutableList.builder(); + + TupleDomain currentConstraint = TupleDomain.all(); + for (CanonicalSubplan subplan : subplans) { + TupleDomain testConstraint = currentConstraint.intersect(subplan.getEnforcedConstraint()); + if (testConstraint.isNone()) { + excludingBuilder.add(subplan); + continue; + } + currentConstraint = testConstraint; + intersectionBuilder.add(subplan); + } + List excludingSubplans = excludingBuilder.build(); + List intersectedSubplans = intersectionBuilder.build(); + + ImmutableList.Builder> intersectedSubplansBuilder = ImmutableList.builder(); + + if (intersectedSubplans.size() > 1) { + intersectedSubplansBuilder.add(intersectedSubplans); + } + if (excludingSubplans.size() > 1 && excludingSubplans.size() != subplans.size()) { + intersectedSubplansBuilder.addAll(splitByIntersection(excludingSubplans)); + } + + return intersectedSubplansBuilder.build(); + } + + record CacheCandidate(List subplans, int minSubplans) {} + + record SubplanKey(List keyChain, boolean aggregation) + { + public SubplanKey(CanonicalSubplan subplan) + { + this( + subplan.getKeyChain(), + // TopN and TopNRanking are treated as aggregations because of an assumption of a significant reduction of output rows + subplan.getGroupByColumns().isPresent() || subplan.getKey() instanceof TopNKey || subplan.getKey() instanceof TopNRankingKey); + } + + public int getPriority() + { + // prefer deeper plans to be cached first + return -keyChain.size(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheDataOperator.java b/core/trino-main/src/main/java/io/trino/cache/CacheDataOperator.java new file mode 100644 index 000000000000..28e797156dd7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheDataOperator.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import io.trino.memory.context.LocalMemoryContext; +import io.trino.operator.DriverContext; +import io.trino.operator.Operator; +import io.trino.operator.OperatorContext; +import io.trino.operator.OperatorFactory; +import io.trino.spi.Page; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.sql.planner.plan.PlanNodeId; +import jakarta.annotation.Nullable; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CacheDataOperator + implements Operator +{ + public static class CacheDataOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private boolean closed; + private final long maxSplitSizeInBytes; + + public CacheDataOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + long maxSplitSizeInBytes) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.maxSplitSizeInBytes = maxSplitSizeInBytes; + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkArgument(driverContext.getCacheDriverContext().isPresent(), "cacheDriverContext is empty"); + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, CacheDataOperator.class.getSimpleName()); + CacheDriverContext cacheDriverContext = driverContext.getCacheDriverContext().get(); + return new CacheDataOperator(operatorContext, maxSplitSizeInBytes, cacheDriverContext.cacheMetrics(), cacheDriverContext.cacheStats()); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new CacheDataOperatorFactory(operatorId, planNodeId, maxSplitSizeInBytes); + } + } + + private final OperatorContext operatorContext; + private final CacheMetrics cacheMetrics; + private final CacheStats cacheStats; + private final LocalMemoryContext memoryContext; + private final long maxCacheSizeInBytes; + + @Nullable + private ConnectorPageSink pageSink; + @Nullable + private Page page; + private long cachedDataSize; + private boolean finishing; + + private CacheDataOperator(OperatorContext operatorContext, long maxCacheSizeInBytes, CacheMetrics cacheMetrics, CacheStats cacheStats) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.memoryContext = operatorContext.newLocalUserMemoryContext(CacheDataOperator.class.getSimpleName()); + CacheDriverContext cacheContext = operatorContext.getDriverContext().getCacheDriverContext() + .orElseThrow(() -> new IllegalArgumentException("Cache context is not present")); + this.pageSink = cacheContext + .pageSink() + .orElseThrow(() -> new IllegalArgumentException("Cache page sink is not present")); + memoryContext.setBytes(pageSink.getMemoryUsage()); + this.maxCacheSizeInBytes = maxCacheSizeInBytes; + this.cacheMetrics = requireNonNull(cacheMetrics, "cacheMetrics is null"); + this.cacheStats = requireNonNull(cacheStats, "cacheStats is null"); + operatorContext.setLatestMetrics(cacheContext.metrics()); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public boolean needsInput() + { + return !finishing && page == null; + } + + @Override + public void addInput(Page page) + { + checkState(needsInput()); + this.page = page; + + if (pageSink == null) { + // caching was aborted + return; + } + + checkState(pageSink.appendPage(page).isDone(), "appendPage future must be done"); + cachedDataSize += page.getSizeInBytes(); + memoryContext.setBytes(pageSink.getMemoryUsage()); + + // If there is no space for a page in a cache, stop caching this split and abort pageSink + if (pageSink.getMemoryUsage() > maxCacheSizeInBytes) { + abort(); + cacheMetrics.incrementSplitsNotCached(); + } + } + + @Override + public Page getOutput() + { + Page page = this.page; + this.page = null; + return page; + } + + @Override + public void finish() + { + finishing = true; + if (pageSink != null) { + checkState(pageSink.finish().isDone(), "finish future must be done"); + pageSink = null; + memoryContext.close(); + + cacheMetrics.incrementSplitsCached(); + cacheStats.recordCacheData(cachedDataSize); + } + } + + @Override + public boolean isFinished() + { + return finishing && page == null; + } + + @Override + public void close() + throws Exception + { + if (pageSink != null) { + abort(); + } + } + + private void abort() + { + requireNonNull(pageSink, "pageSink is null"); + pageSink.abort(); + pageSink = null; + memoryContext.close(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheDriverContext.java b/core/trino-main/src/main/java/io/trino/cache/CacheDriverContext.java new file mode 100644 index 000000000000..c4b3cf001140 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheDriverContext.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import io.trino.operator.OperatorContext; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.metrics.Metrics; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record CacheDriverContext( + Optional pageSource, + Optional pageSink, + DynamicFilter dynamicFilter, + CacheMetrics cacheMetrics, + CacheStats cacheStats, + Metrics metrics) +{ + public CacheDriverContext( + Optional pageSource, + Optional pageSink, + DynamicFilter dynamicFilter, + CacheMetrics cacheMetrics, + CacheStats cacheStats, + Metrics metrics) + { + this.pageSource = requireNonNull(pageSource, "pageSource is null"); + this.pageSink = requireNonNull(pageSink, "pageSink is null"); + this.dynamicFilter = requireNonNull(dynamicFilter, "dynamicFilter is null"); + this.cacheMetrics = requireNonNull(cacheMetrics, "cacheMetrics is null"); + this.cacheStats = requireNonNull(cacheStats, "cacheStats is null"); + this.metrics = requireNonNull(metrics, "metrics is null"); + } + + public CacheDriverContext withMetrics(Metrics metrics) + { + return new CacheDriverContext(pageSource, pageSink, dynamicFilter, cacheMetrics, cacheStats, metrics); + } + + public static DynamicFilter getDynamicFilter(OperatorContext context, DynamicFilter originalDynamicFilter) + { + return context.getDriverContext().getCacheDriverContext() + .map(CacheDriverContext::dynamicFilter) + .orElse(originalDynamicFilter); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheDriverFactory.java b/core/trino-main/src/main/java/io/trino/cache/CacheDriverFactory.java new file mode 100644 index 000000000000..58418e8a3f78 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheDriverFactory.java @@ -0,0 +1,415 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.json.JsonCodec; +import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.execution.ScheduledSplit; +import io.trino.metadata.Split; +import io.trino.metadata.TableHandle; +import io.trino.operator.Driver; +import io.trino.operator.DriverContext; +import io.trino.operator.DriverFactory; +import io.trino.plugin.base.cache.CacheUtils; +import io.trino.plugin.base.metrics.TDigestHistogram; +import io.trino.spi.TrinoException; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager.SplitCache; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.metrics.Metrics; +import io.trino.spi.predicate.DiscreteValues; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Ranges; +import io.trino.spi.predicate.TupleDomain; +import io.trino.split.PageSourceProvider; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SystemSessionProperties.isEnableDynamicRowFiltering; +import static io.trino.cache.CacheCommonSubqueries.LOAD_PAGES_ALTERNATIVE; +import static io.trino.cache.CacheCommonSubqueries.ORIGINAL_PLAN_ALTERNATIVE; +import static io.trino.cache.CacheCommonSubqueries.STORE_PAGES_ALTERNATIVE; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.function.Function.identity; + +public class CacheDriverFactory + implements DriverFactory +{ + static final int MAX_UNENFORCED_PREDICATE_VALUE_COUNT = 1_000_000; + static final double DYNAMIC_FILTER_VALUES_HEURISTIC = 0.05; + + public static final float THRASHING_CACHE_THRESHOLD = 0.7f; + public static final int MIN_PROCESSED_SPLITS = 16; + + private final int pipelineId; + private final boolean inputDriver; + private final boolean outputDriver; + private final OptionalInt driverInstances; + private final PlanNodeId alternativeSourceId; + private final Session session; + private final PageSourceProvider pageSourceProvider; + private final SplitCache splitCache; + private final JsonCodec tupleDomainCodec; + private final TableHandle originalTableHandle; + private final TupleDomain enforcedPredicate; + private final BiMap commonColumnHandles; + private final Map projectedColumns; + private final Supplier commonDynamicFilterSupplier; + private final Supplier originalDynamicFilterSupplier; + private final List alternatives; + private final CacheMetrics cacheMetrics = new CacheMetrics(); + private final CacheStats cacheStats; + private final Ticker ticker = Ticker.systemTicker(); + + public CacheDriverFactory( + int pipelineId, + boolean inputDriver, + boolean outputDriver, + OptionalInt driverInstances, + PlanNodeId alternativeSourceId, + Session session, + PageSourceProvider pageSourceProvider, + CacheManagerRegistry cacheManagerRegistry, + JsonCodec tupleDomainCodec, + TableHandle originalTableHandle, + PlanSignatureWithPredicate planSignature, + Map commonColumnHandles, + Supplier commonDynamicFilterSupplier, + Supplier originalDynamicFilterSupplier, + List alternatives, + CacheStats cacheStats) + { + requireNonNull(planSignature, "planSignature is null"); + this.pipelineId = pipelineId; + this.inputDriver = inputDriver; + this.outputDriver = outputDriver; + this.driverInstances = requireNonNull(driverInstances, "driverInstances is null"); + this.alternativeSourceId = requireNonNull(alternativeSourceId, "alternativeSourceId is null"); + this.session = requireNonNull(session, "session is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); + this.splitCache = requireNonNull(cacheManagerRegistry, "cacheManagerRegistry is null").getCacheManager().getSplitCache(planSignature.signature()); + this.tupleDomainCodec = requireNonNull(tupleDomainCodec, "tupleDomainCodec is null"); + this.originalTableHandle = requireNonNull(originalTableHandle, "originalTableHandle is null"); + this.enforcedPredicate = planSignature.predicate(); + this.commonColumnHandles = ImmutableBiMap.copyOf(requireNonNull(commonColumnHandles, "commonColumnHandles is null")).inverse(); + List columns = planSignature.signature().getColumns(); + this.projectedColumns = IntStream.range(0, columns.size()).boxed() + .collect(toImmutableMap(columns::get, identity())); + this.commonDynamicFilterSupplier = requireNonNull(commonDynamicFilterSupplier, "commonDynamicFilterSupplier is null"); + this.originalDynamicFilterSupplier = requireNonNull(originalDynamicFilterSupplier, "originalDynamicFilterSupplier is null"); + this.alternatives = requireNonNull(alternatives, "alternatives is null"); + this.cacheStats = requireNonNull(cacheStats, "cacheStats is null"); + } + + @Override + public int getPipelineId() + { + return pipelineId; + } + + @Override + public boolean isInputDriver() + { + return inputDriver; + } + + @Override + public boolean isOutputDriver() + { + return outputDriver; + } + + @Override + public OptionalInt getDriverInstances() + { + return driverInstances; + } + + @Override + public Driver createDriver(DriverContext driverContext, Optional optionalSplit) + { + checkArgument(optionalSplit.isPresent()); + ScheduledSplit split = optionalSplit.get(); + + DriverFactoryWithCacheContext driverFactory; + long lookupStartNanos = ticker.read(); + try { + driverFactory = chooseDriverFactory(split); + } + catch (Throwable t) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "SUBQUERY CACHE: create driver exception", t); + } + long lookupDurationNanos = ticker.read() - lookupStartNanos; + cacheStats.getCacheLookupTime().addNanos(lookupDurationNanos); + + driverFactory.context() + .map(context -> context.withMetrics(new Metrics(ImmutableMap.of( + "Cache lookup time (ms)", TDigestHistogram.fromValue(new Duration(lookupDurationNanos, NANOSECONDS).convertTo(MILLISECONDS).getValue()))))) + .ifPresent(driverContext::setCacheDriverContext); + return driverFactory.factory().createDriver(driverContext, optionalSplit); + } + + @Override + public void noMoreDrivers() + { + alternatives.forEach(DriverFactory::noMoreDrivers); + closeSplitCache(); + } + + @Override + public boolean isNoMoreDrivers() + { + // noMoreDrivers is called on each alternative, so we can use any alternative here + return alternatives.iterator().next().isNoMoreDrivers(); + } + + @Override + public void localPlannerComplete() + { + alternatives.forEach(DriverFactory::localPlannerComplete); + } + + @Override + public Optional getSourceId() + { + return Optional.of(alternativeSourceId); + } + + private DriverFactoryWithCacheContext chooseDriverFactory(ScheduledSplit split) + { + Optional cacheSplitIdOptional = split.getSplit().getCacheSplitId(); + if (cacheSplitIdOptional.isEmpty()) { + // no split id, fallback to original plan + cacheStats.recordMissingSplitId(); + return new DriverFactoryWithCacheContext(alternatives.get(ORIGINAL_PLAN_ALTERNATIVE), Optional.empty()); + } + if (!split.getSplit().isSplitAddressEnforced()) { + // failed to schedule split on the preferred node, fallback to original plan + cacheStats.recordSplitFailoverHappened(); + return new DriverFactoryWithCacheContext(alternatives.get(ORIGINAL_PLAN_ALTERNATIVE), Optional.empty()); + } + CacheSplitId splitId = cacheSplitIdOptional.get(); + + StaticDynamicFilter originalDynamicFilter = originalDynamicFilterSupplier.get(); + StaticDynamicFilter commonDynamicFilter = commonDynamicFilterSupplier.get(); + StaticDynamicFilter dynamicFilter = resolveDynamicFilter(originalDynamicFilter, commonDynamicFilter); + + TupleDomain enforcedPredicate = pruneEnforcedPredicate(split); + TupleDomain unenforcedPredicate = getDynamicRowFilteringUnenforcedPredicate( + pageSourceProvider, + session, + split.getSplit(), + originalTableHandle, + dynamicFilter.getCurrentPredicate()) + .transformKeys(handle -> requireNonNull(commonColumnHandles.get(handle))); + + // skip caching of completely filtered out splits + if (enforcedPredicate.isNone() || unenforcedPredicate.isNone()) { + return new DriverFactoryWithCacheContext(alternatives.get(ORIGINAL_PLAN_ALTERNATIVE), Optional.empty()); + } + + // skip caching if unenforced predicate becomes too big, + // because large predicates are not likely to be reused in other subqueries + if (getTupleDomainValueCount(unenforcedPredicate) > MAX_UNENFORCED_PREDICATE_VALUE_COUNT) { + cacheStats.recordPredicateTooBig(); + return new DriverFactoryWithCacheContext(alternatives.get(ORIGINAL_PLAN_ALTERNATIVE), Optional.empty()); + } + + ProjectPredicate projectedEnforcedPredicate = projectPredicate(enforcedPredicate); + ProjectPredicate projectedUnenforcedPredicate = projectPredicate(unenforcedPredicate); + CacheSplitId splitIdWithPredicates = appendRemainingPredicates(splitId, projectedEnforcedPredicate, projectedUnenforcedPredicate); + + // load data from cache + Optional pageSource = splitCache.loadPages(splitIdWithPredicates, projectedEnforcedPredicate.predicate(), projectedUnenforcedPredicate.predicate()); + if (pageSource.isPresent()) { + cacheStats.recordCacheHit(); + return new DriverFactoryWithCacheContext( + alternatives.get(LOAD_PAGES_ALTERNATIVE), + Optional.of(new CacheDriverContext(pageSource, Optional.empty(), dynamicFilter, cacheMetrics, cacheStats, Metrics.EMPTY))); + } + else { + cacheStats.recordCacheMiss(); + } + + int processedSplitCount = cacheMetrics.getSplitNotCachedCount() + cacheMetrics.getSplitCachedCount(); + float cachingRatio = processedSplitCount > MIN_PROCESSED_SPLITS ? cacheMetrics.getSplitCachedCount() / (float) processedSplitCount : 1.0f; + // try storing results instead + // if splits are too large to be cached then do not try caching data as it adds extra computational cost + if (cachingRatio > THRASHING_CACHE_THRESHOLD) { + Optional pageSink = splitCache.storePages(splitIdWithPredicates, projectedEnforcedPredicate.predicate(), projectedUnenforcedPredicate.predicate()); + if (pageSink.isPresent()) { + return new DriverFactoryWithCacheContext( + alternatives.get(STORE_PAGES_ALTERNATIVE), + Optional.of(new CacheDriverContext(Optional.empty(), pageSink, dynamicFilter, cacheMetrics, cacheStats, Metrics.EMPTY))); + } + else { + cacheStats.recordSplitRejected(); + } + } + else { + cacheStats.recordSplitsTooBig(); + } + + // fallback to original subplan + return new DriverFactoryWithCacheContext(alternatives.get(ORIGINAL_PLAN_ALTERNATIVE), Optional.empty()); + } + + private record DriverFactoryWithCacheContext(DriverFactory factory, Optional context) {} + + private TupleDomain pruneEnforcedPredicate(ScheduledSplit split) + { + return TupleDomain.intersect(ImmutableList.of( + // prune scan domains of enforced predicate + pageSourceProvider.prunePredicate( + session, + split.getSplit(), + originalTableHandle, + enforcedPredicate + .filter((columnId, domain) -> commonColumnHandles.containsValue(columnId)) + .transformKeys(columnId -> commonColumnHandles.inverse().get(columnId))) + .transformKeys(commonColumnHandles::get), + enforcedPredicate.filter((columnId, domain) -> !commonColumnHandles.containsValue(columnId)))); + } + + private ProjectPredicate projectPredicate(TupleDomain predicate) + { + return new ProjectPredicate( + predicate.filter((columnId, domain) -> projectedColumns.containsKey(columnId)), + Optional.of(predicate.filter((columnId, domain) -> !projectedColumns.containsKey(columnId))) + .filter(domain -> !domain.isAll()) + .map(CacheUtils::normalizeTupleDomain) + .map(tupleDomainCodec::toJson)); + } + + private record ProjectPredicate(TupleDomain predicate, Optional remainingPredicate) {} + + private static CacheSplitId appendRemainingPredicates(CacheSplitId splitId, ProjectPredicate enforcedPredicate, ProjectPredicate unenforcedPredicate) + { + return appendRemainingPredicates(splitId, enforcedPredicate.remainingPredicate(), unenforcedPredicate.remainingPredicate()); + } + + @VisibleForTesting + static CacheSplitId appendRemainingPredicates(CacheSplitId splitId, Optional remainingEnforcedPredicate, Optional remainingUnenforcedPredicate) + { + if (remainingEnforcedPredicate.isEmpty() && remainingUnenforcedPredicate.isEmpty()) { + return splitId; + } + return new CacheSplitId(toStringHelper("SplitId") + .add("splitId", splitId) + .add("enforcedPredicate", remainingEnforcedPredicate.orElse("all")) + .add("unenforcedPredicate", remainingUnenforcedPredicate.orElse("all")) + .toString()); + } + + public void closeSplitCache() + { + try { + splitCache.close(); + } + catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private StaticDynamicFilter resolveDynamicFilter(StaticDynamicFilter originalDynamicFilter, StaticDynamicFilter commonDynamicFilter) + { + TupleDomain originalPredicate = originalDynamicFilter.getCurrentPredicate(); + TupleDomain commonPredicate = commonDynamicFilter.getCurrentPredicate(); + + if (commonPredicate.isNone() || originalPredicate.isNone()) { + // prefer original DF when common DF is absent + return originalDynamicFilter; + } + + if (originalPredicate.getDomains().get().size() > commonPredicate.getDomains().get().size() || + getTupleDomainValueCount(originalPredicate) < getTupleDomainValueCount(commonPredicate) * DYNAMIC_FILTER_VALUES_HEURISTIC) { + // prefer original DF when it contains more domains or original DF size is much smaller + return originalDynamicFilter; + } + + return commonDynamicFilter; + } + + @VisibleForTesting + public static TupleDomain getDynamicRowFilteringUnenforcedPredicate( + PageSourceProvider delegatePageSourceProvider, + Session session, + Split split, + TableHandle table, + TupleDomain dynamicFilter) + { + if (!isEnableDynamicRowFiltering(session)) { + return delegatePageSourceProvider.getUnenforcedPredicate(session, split, table, dynamicFilter); + } + + TupleDomain unenforcedPredicate = delegatePageSourceProvider.getUnenforcedPredicate(session, split, table, dynamicFilter); + if (unenforcedPredicate.isNone()) { + // split is fully filtered out + return TupleDomain.none(); + } + + // DynamicRowFilteringPageSourceProvider doesn't simplify dynamic predicate, + // but we can still prune columns from dynamic filter, which are ineffective + // in filtering split data + return unenforcedPredicate.intersect(delegatePageSourceProvider.prunePredicate(session, split, table, dynamicFilter)); + } + + @VisibleForTesting + public CacheMetrics getCacheMetrics() + { + return cacheMetrics; + } + + private static int getTupleDomainValueCount(TupleDomain tupleDomain) + { + return tupleDomain.getDomains() + .map(domains -> domains.values().stream() + .mapToInt(CacheDriverFactory::getDomainValueCount) + .sum()) + .orElse(0); + } + + private static int getDomainValueCount(Domain domain) + { + return domain.getValues().getValuesProcessor().transform( + Ranges::getRangeCount, + DiscreteValues::getValuesCount, + ignored -> 0); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheExpression.java b/core/trino-main/src/main/java/io/trino/cache/CacheExpression.java new file mode 100644 index 000000000000..1996f693df0e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheExpression.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import io.trino.sql.ir.Expression; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; + +public record CacheExpression(Optional projection, Optional aggregation) +{ + public static CacheExpression ofProjection(Expression projection) + { + return new CacheExpression(Optional.of(projection), Optional.empty()); + } + + public static CacheExpression ofAggregation(CanonicalAggregation aggregation) + { + return new CacheExpression(Optional.empty(), Optional.of(aggregation)); + } + + public CacheExpression + { + checkArgument(projection.isPresent() != aggregation.isPresent(), "Expected exactly one to be present, got %s and %s", projection, aggregation); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheManagerModule.java b/core/trino-main/src/main/java/io/trino/cache/CacheManagerModule.java new file mode 100644 index 000000000000..8e985a0d9d1c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheManagerModule.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static io.airlift.configuration.ConfigBinder.configBinder; +import static org.weakref.jmx.guice.ExportBinder.newExporter; + +public class CacheManagerModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(CacheConfig.class); + binder.bind(CacheStats.class).in(Scopes.SINGLETON); + newExporter(binder).export(CacheStats.class).withGeneratedName(); + binder.bind(CacheManagerRegistry.class).in(Scopes.SINGLETON); + binder.bind(ConnectorAwareAddressProvider.class).in(Scopes.SINGLETON); + binder.bind(CacheController.class).in(Scopes.SINGLETON); + newExporter(binder).export(CacheManagerRegistry.class).withGeneratedName(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheManagerRegistry.java b/core/trino-main/src/main/java/io/trino/cache/CacheManagerRegistry.java new file mode 100644 index 000000000000..6e9c93ae32fa --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheManagerRegistry.java @@ -0,0 +1,315 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.airlift.stats.Distribution; +import io.airlift.stats.TimeStat; +import io.trino.memory.LocalMemoryManager; +import io.trino.memory.MemoryPool; +import io.trino.memory.context.LocalMemoryContext; +import io.trino.memory.context.MemoryReservationHandler; +import io.trino.plugin.memory.MemoryCacheManagerFactory; +import io.trino.spi.TrinoException; +import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheManagerFactory; +import io.trino.spi.cache.MemoryAllocator; +import io.trino.spi.classloader.ThreadContextClassLoader; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom; +import static io.trino.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; +import static io.trino.spi.StandardErrorCode.CACHE_MANAGER_NOT_CONFIGURED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadExecutor; + +/** + * {@link CacheManagerRegistry} is responsible for instantiation of {@link CacheManager}. + * Additionally {@link CacheManagerRegistry} manages revoking of {@link CacheManager} + * memory whenever necessary. + */ +public class CacheManagerRegistry +{ + private static final Logger log = Logger.get(CacheManagerRegistry.class); + + static final File CONFIG_FILE = new File("etc/cache-manager.properties"); + private static final String CACHE_MANAGER_NAME_PROPERTY = "cache-manager.name"; + + private final MemoryPool memoryPool; + private final boolean enabled; + private final double revokingThreshold; + private final double revokingTarget; + private final Map cacheManagerFactories = new ConcurrentHashMap<>(); + private final ExecutorService executor; + private final BlockEncodingSerde blockEncodingSerde; + private final AtomicBoolean revokeRequested = new AtomicBoolean(); + private final Distribution sizeOfRevokedMemoryDistribution = new Distribution(); + private final AtomicInteger nonEmptyRevokeCount = new AtomicInteger(); + private final CacheStats cacheStats; + + private volatile CacheManager cacheManager; + private volatile LocalMemoryContext revocableMemoryContext; + + @Inject + public CacheManagerRegistry(CacheConfig cacheConfig, LocalMemoryManager localMemoryManager, BlockEncodingSerde blockEncodingSerde, CacheStats cacheStats) + { + this(cacheConfig, localMemoryManager, newSingleThreadExecutor(daemonThreadsNamed("cache-manager-registry")), blockEncodingSerde, cacheStats); + } + + @VisibleForTesting + CacheManagerRegistry(CacheConfig cacheConfig, LocalMemoryManager localMemoryManager, ExecutorService executor, BlockEncodingSerde blockEncodingSerde, CacheStats cacheStats) + { + requireNonNull(cacheConfig, "cacheConfig is null"); + requireNonNull(localMemoryManager, "localMemoryManager is null"); + requireNonNull(executor, "executor is null"); + requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); + this.enabled = cacheConfig.isEnabled(); + this.revokingThreshold = cacheConfig.getRevokingThreshold(); + this.revokingTarget = cacheConfig.getRevokingTarget(); + this.memoryPool = localMemoryManager.getMemoryPool(); + this.executor = executor; + this.blockEncodingSerde = blockEncodingSerde; + this.cacheStats = cacheStats; + } + + public void addCacheManagerFactory(CacheManagerFactory factory) + { + requireNonNull(factory, "factory is null"); + if (cacheManagerFactories.putIfAbsent(factory.getName(), factory) != null) { + throw new IllegalArgumentException(format("Cache manager factory '%s' is already registered", factory.getName())); + } + } + + public void loadCacheManager() + { + if (!enabled) { + // don't load CacheManager when caching is not enabled + return; + } + + if (!CONFIG_FILE.exists()) { + // use MemoryCacheManager by default + loadCacheManager(new MemoryCacheManagerFactory(), ImmutableMap.of()); + return; + } + + Map properties = loadProperties(CONFIG_FILE); + String name = properties.remove(CACHE_MANAGER_NAME_PROPERTY); + checkArgument(!isNullOrEmpty(name), "Cache manager configuration %s does not contain %s", CONFIG_FILE, CACHE_MANAGER_NAME_PROPERTY); + loadCacheManager(name, properties); + } + + public synchronized void loadCacheManager(String name, Map properties) + { + CacheManagerFactory factory = cacheManagerFactories.get(name); + checkArgument(factory != null, "Cache manager factory '%s' is not registered. Available factories: %s", name, cacheManagerFactories.keySet()); + loadCacheManager(factory, properties); + } + + public synchronized void loadCacheManager(CacheManagerFactory factory, Map properties) + { + requireNonNull(factory, "cacheManagerFactory is null"); + log.info("-- Loading cache manager %s --", factory.getName()); + + checkState(cacheManager == null, "cacheManager is already loaded"); + + revocableMemoryContext = newRootAggregatedMemoryContext( + createReservationHandler(bytes -> { + // do not allocate more memory if it would exceed revoking threshold + if (memoryRevokingNeeded(bytes)) { + // schedule memory revoke to free up some space for new splits to be cached + scheduleMemoryRevoke(); + return false; + } + + return memoryPool.tryReserveRevocable(bytes); + }, memoryPool::freeRevocable), 0) + .newLocalMemoryContext("CacheManager"); + CacheManagerContext context = new CacheManagerContext() + { + @Override + public MemoryAllocator revocableMemoryAllocator() + { + return revocableMemoryContext::trySetBytes; + } + + @Override + public BlockEncodingSerde blockEncodingSerde() + { + return blockEncodingSerde; + } + }; + CacheManager cacheManager; + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(factory.getClass().getClassLoader())) { + cacheManager = factory.create(properties, context); + } + this.cacheManager = cacheManager; + + // revoke cache memory when revoking target is reached + memoryPool.addListener(pool -> { + if (memoryRevokingNeeded(0)) { + scheduleMemoryRevoke(); + } + }); + + log.info("-- Loaded cache manager %s --", factory.getName()); + } + + public CacheManager getCacheManager() + { + CacheManager cacheManager = this.cacheManager; + if (cacheManager == null) { + throw new TrinoException(CACHE_MANAGER_NOT_CONFIGURED, "Cache manager must be configured for cache capabilities to be fully functional"); + } + return cacheManager; + } + + public void flushCache() + { + getFutureValue(executor.submit(() -> { + long bytesToRevoke = memoryPool.getMaxBytes() - memoryPool.getFreeBytes(); + if (bytesToRevoke > 0) { + cacheManager.revokeMemory(bytesToRevoke); + } + })); + } + + @Managed + public long getRevocableBytes() + { + if (revocableMemoryContext == null) { + return 0; + } + return revocableMemoryContext.getBytes(); + } + + @Managed + @Nested + public Distribution getDistributionSizeRevokedMemory() + { + return sizeOfRevokedMemoryDistribution; + } + + @Managed + public int getNonEmptyRevokeCount() + { + return nonEmptyRevokeCount.get(); + } + + private static Map loadProperties(File configFile) + { + try { + return new HashMap<>(loadPropertiesFrom(configFile.getPath())); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to read configuration file: " + configFile, e); + } + } + + private void scheduleMemoryRevoke() + { + // allow at most one revoke request to be scheduled + if (revokeRequested.getAndSet(true)) { + return; + } + executor.submit(() -> { + revokeRequested.set(false); + do { + if (!revokeMemory()) { + return; + } + } + while (memoryRevokingNeeded(0)); + }); + } + + private boolean revokeMemory() + { + long bytesToRevoke = -memoryPool.getFreeBytes() + (long) (memoryPool.getMaxBytes() * (1.0 - revokingTarget)); + if (bytesToRevoke <= 0) { + return false; + } + + long revokedBytes; + try (TimeStat.BlockTimer ignore = cacheStats.recordRevokeMemoryTime()) { + revokedBytes = cacheManager.revokeMemory(bytesToRevoke); + } + + if (revokedBytes > 0) { + sizeOfRevokedMemoryDistribution.add(revokedBytes); + nonEmptyRevokeCount.incrementAndGet(); + return true; + } + + return false; + } + + private boolean memoryRevokingNeeded(long additionalRevocableBytes) + { + return memoryPool.getFreeBytes() - additionalRevocableBytes < memoryPool.getMaxBytes() * (1.0 - revokingThreshold); + } + + private static MemoryReservationHandler createReservationHandler(Function tryReserveHandler, Consumer freeHandler) + { + return new MemoryReservationHandler() + { + @Override + public ListenableFuture reserveMemory(String allocationTag, long delta) + { + throw new IllegalStateException(); + } + + @Override + public boolean tryReserveMemory(String allocationTag, long delta) + { + if (delta == 0) { + // empty allocation should always succeed + return true; + } + + if (delta > 0) { + return tryReserveHandler.apply(delta); + } + + freeHandler.accept(-delta); + return true; + } + }; + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheMetadata.java b/core/trino-main/src/main/java/io/trino/cache/CacheMetadata.java new file mode 100644 index 000000000000..f58284a63676 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheMetadata.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.inject.Inject; +import io.trino.Session; +import io.trino.connector.CatalogServiceProvider; +import io.trino.metadata.TableHandle; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.ColumnHandle; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class CacheMetadata +{ + private final CatalogServiceProvider> cacheMetadataProvider; + + @Inject + public CacheMetadata(CatalogServiceProvider> cacheMetadataProvider) + { + this.cacheMetadataProvider = requireNonNull(cacheMetadataProvider, "cacheMetadataProvider is null"); + } + + public Optional getCacheTableId(Session session, TableHandle tableHandle) + { + CatalogHandle catalogHandle = tableHandle.catalogHandle(); + Optional service = cacheMetadataProvider.getService(catalogHandle); + + return service.flatMap(cacheMetadata -> cacheMetadata.getCacheTableId(tableHandle.connectorHandle())); + } + + public Optional getCacheColumnId(Session session, TableHandle tableHandle, ColumnHandle columnHandle) + { + CatalogHandle catalogHandle = tableHandle.catalogHandle(); + Optional service = cacheMetadataProvider.getService(catalogHandle); + + return service.flatMap(cacheMetadata -> cacheMetadata.getCacheColumnId(tableHandle.connectorHandle(), columnHandle)); + } + + public TableHandle getCanonicalTableHandle(Session session, TableHandle tableHandle) + { + CatalogHandle catalogHandle = tableHandle.catalogHandle(); + Optional service = cacheMetadataProvider.getService(catalogHandle); + return service + .map(connectorCacheMetadata -> new TableHandle( + tableHandle.catalogHandle(), + connectorCacheMetadata.getCanonicalTableHandle(tableHandle.connectorHandle()), + tableHandle.transaction())) + .orElse(tableHandle); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheMetrics.java b/core/trino-main/src/main/java/io/trino/cache/CacheMetrics.java new file mode 100644 index 000000000000..7ebca7a1de5d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheMetrics.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import java.util.concurrent.atomic.AtomicInteger; + +public class CacheMetrics +{ + /** + * Counts number of splits not cached due to excessive split data size. + */ + private final AtomicInteger splitNotCachedCount = new AtomicInteger(); + private final AtomicInteger splitCachedCount = new AtomicInteger(); + + public int getSplitNotCachedCount() + { + return splitNotCachedCount.get(); + } + + public int getSplitCachedCount() + { + return splitCachedCount.get(); + } + + public void incrementSplitsNotCached() + { + splitNotCachedCount.incrementAndGet(); + } + + public void incrementSplitsCached() + { + splitCachedCount.incrementAndGet(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheSplitSource.java b/core/trino-main/src/main/java/io/trino/cache/CacheSplitSource.java new file mode 100644 index 000000000000..99023a045345 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheSplitSource.java @@ -0,0 +1,254 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.node.NodeInfo; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.ConnectorSplitManager; +import io.trino.split.SplitSource; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterators.cycle; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static com.google.common.util.concurrent.Futures.transformAsync; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.trino.spi.cache.PlanSignature.canonicalizePlanSignature; +import static java.util.Collections.shuffle; +import static java.util.Objects.requireNonNull; + +/** + * Assigns addresses provided by {@link CacheManager} to splits that + * are to be cached. + */ +public class CacheSplitSource + implements SplitSource +{ + private final ConnectorSplitManager splitManager; + private final SplitSource delegate; + private final Function> addressProvider; + private final String canonicalSignature; + private final Map> splitQueuePerWorker = new ConcurrentHashMap<>(); + private final SplitAdmissionController splitAdmissionController; + private final int minSplitBatchSize; + private final Executor executor; + private final AtomicBoolean isLastBatchProcessed = new AtomicBoolean(false); + + public CacheSplitSource( + PlanSignature signature, + ConnectorSplitManager splitManager, + SplitSource delegate, + ConnectorAwareAddressProvider connectorAwareAddressProvider, + NodeInfo nodeInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider, + boolean schedulerIncludeCoordinator, + int minSplitBatchSize, + Executor executor) + { + this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.delegate = requireNonNull(delegate, "delegate is null"); + ConsistentHashingAddressProvider consistentHashingAddressProvider = connectorAwareAddressProvider.getAddressProvider(nodeInfo, delegate.getCatalogHandle(), schedulerIncludeCoordinator); + consistentHashingAddressProvider.refreshHashRingIfNeeded(); + this.addressProvider = consistentHashingAddressProvider::getPreferredAddress; + this.canonicalSignature = canonicalizePlanSignature(signature).toString(); + this.splitAdmissionController = requireNonNull(splitAdmissionControllerProvider, "splitAdmissionControllerProvider is null").get(signature); + this.minSplitBatchSize = minSplitBatchSize; + this.executor = requireNonNull(executor, "executor is null"); + } + + @VisibleForTesting + CacheSplitSource( + PlanSignature signature, + ConnectorSplitManager splitManager, + SplitSource delegate, + Function> addressProvider, + SplitAdmissionController splitAdmissionController, + int minSplitBatchSize) + { + this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.delegate = requireNonNull(delegate, "delegate is null"); + this.addressProvider = requireNonNull(addressProvider, "addressProvider is null"); + this.canonicalSignature = canonicalizePlanSignature(signature).toString(); + this.splitAdmissionController = requireNonNull(splitAdmissionController, "splitAdmissionController is null"); + this.minSplitBatchSize = minSplitBatchSize; + // Set the executor to direct executor for testing purposes + this.executor = directExecutor(); + } + + @Override + public CatalogHandle getCatalogHandle() + { + return delegate.getCatalogHandle(); + } + + @Override + public ListenableFuture getNextBatch(int maxSize) + { + return assignAddressesAndGetMoreSplits(ImmutableList.of(), getSplitsFromQueue(maxSize), maxSize); + } + + private ListenableFuture assignAddressesAndGetMoreSplits(List newBatch, List currentBatch, int maxSize) + { + // Assign addresses to splits that are cacheable and don't have preferred addresses. + // Additionally, add splits to the queue if they cannot be scheduled at the moment. + ImmutableList.Builder batchBuilder = ImmutableList.builder(); + batchBuilder.addAll(currentBatch); + int currentSize = currentBatch.size(); + checkState(newBatch.size() <= maxSize - currentSize, "New split batch size exceeds the remaining capacity"); + for (Split split : newBatch) { + Optional splitId = splitManager.getCacheSplitId(split.getConnectorSplit()); + if (splitId.isEmpty()) { + batchBuilder.add(split); + currentSize++; + } + else { + Optional preferredAddress; + if (!split.isRemotelyAccessible()) { + // Choose first address from connector provided worker addresses, so that split is + // scheduled deterministically on the worker node. This is such that we reuse the cached splits + // on the worker nodes. + preferredAddress = Optional.of(split.getAddresses().getFirst()); + } + else { + // Get the preferred address for the split using consistent hashing + preferredAddress = addressProvider.apply(canonicalSignature + splitId.get()); + } + if (preferredAddress.isPresent()) { + Split splitWithPreferredAddress = new Split( + split.getCatalogHandle(), + split.getConnectorSplit(), + splitId, + Optional.of(ImmutableList.of(preferredAddress.get())), + split.isSplitAddressEnforced()); + if (splitAdmissionController.canScheduleSplit(splitId.get(), preferredAddress.get())) { + batchBuilder.add(splitWithPreferredAddress); + currentSize++; + } + else { + splitQueuePerWorker.computeIfAbsent(preferredAddress.get(), _ -> new ConcurrentLinkedQueue<>()) + .add(splitWithPreferredAddress); + } + } + else { + // Skip caching if no preferred address could be located which could be due to no available nodes + batchBuilder.add(split); + currentSize++; + } + } + } + + // If the current batch is not full, try fetching more splits from the queue in case some + // splits became free to be scheduled. + List splitsFromQueue = getSplitsFromQueue(maxSize - currentSize); + batchBuilder.addAll(splitsFromQueue); + currentSize += splitsFromQueue.size(); + + int remainingSize = maxSize - currentSize; + // If the current batch is still not full, fetch more splits from the source + if ((remainingSize > 0 && currentSize < minSplitBatchSize) && !isLastBatchProcessed.get()) { + return transformAsync( + delegate.getNextBatch(remainingSize), + nextBatch -> { + isLastBatchProcessed.set(nextBatch.isLastBatch()); + return assignAddressesAndGetMoreSplits(nextBatch.getSplits(), batchBuilder.build(), maxSize); + }, + executor); + } + + return immediateFuture(createSplitBatch(batchBuilder.build())); + } + + private List getSplitsFromQueue(int maxSize) + { + int currentSize = 0; + ImmutableList.Builder batchBuilder = ImmutableList.builder(); + List>> queues = new ArrayList<>(splitQueuePerWorker.entrySet()); + // randomize queue order to prevent scheduling skewness + shuffle(queues); + + // When there are no more new splits (i.e. isLastBatchProcessed=true), forcefully release queued + // splits in order to avoid increasing of query latency at the cost of potential + // cache rejections. Additionally, if we don't do it, there is a possibility that the + // splits in the queue will never be scheduled (deadlock). For example, during self-join. + boolean forceRelease = isLastBatchProcessed.get() || getSplitQueueSize() > 1_000_000; + + for (Iterator>> iter = cycle(queues); + iter.hasNext() && currentSize < maxSize; ) { + Map.Entry> entry = iter.next(); + HostAddress address = entry.getKey(); + Queue splitQueue = entry.getValue(); + Split split = splitQueue.peek(); + if (split == null + || !(forceRelease + || splitAdmissionController.canScheduleSplit(split.getCacheSplitId().orElseThrow(), address))) { + iter.remove(); + continue; + } + + batchBuilder.add(split); + splitQueue.remove(); + currentSize++; + } + return batchBuilder.build(); + } + + private SplitBatch createSplitBatch(List splits) + { + return new SplitBatch(splits, isLastBatchProcessed.get() && getSplitQueueSize() == 0); + } + + private int getSplitQueueSize() + { + return splitQueuePerWorker.values().stream() + .mapToInt(Queue::size) + .sum(); + } + + @Override + public void close() + { + delegate.close(); + } + + @Override + public boolean isFinished() + { + return delegate.isFinished(); + } + + @Override + public Optional> getTableExecuteSplitsInfo() + { + return delegate.getTableExecuteSplitsInfo(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CacheStats.java b/core/trino-main/src/main/java/io/trino/cache/CacheStats.java new file mode 100644 index 000000000000..19cbe33449d5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CacheStats.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.errorprone.annotations.MustBeClosed; +import io.airlift.stats.CounterStat; +import io.airlift.stats.DistributionStat; +import io.airlift.stats.TimeStat; +import io.airlift.stats.TimeStat.BlockTimer; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +public class CacheStats +{ + private final CounterStat cacheHits = new CounterStat(); + private final CounterStat cacheMiss = new CounterStat(); + private final CounterStat splitRejected = new CounterStat(); + private final CounterStat splitFailoverHappened = new CounterStat(); + private final CounterStat missingSplitId = new CounterStat(); + private final CounterStat predicateTooBig = new CounterStat(); + private final CounterStat splitsTooBig = new CounterStat(); + private final DistributionStat readFromCacheData = new DistributionStat(); + private final DistributionStat cachedData = new DistributionStat(); + private final TimeStat revokeMemoryTime = new TimeStat(); + private final TimeStat cacheLookupTime = new TimeStat(); + + @Managed + @Nested + public CounterStat getCacheHits() + { + return cacheHits; + } + + @Managed + @Nested + public CounterStat getCacheMiss() + { + return cacheMiss; + } + + @Managed + @Nested + public CounterStat getSplitRejected() + { + return splitRejected; + } + + @Managed + @Nested + public CounterStat getSplitFailoverHappened() + { + return splitFailoverHappened; + } + + @Managed + @Nested + public CounterStat getMissingSplitId() + { + return missingSplitId; + } + + @Managed + @Nested + public CounterStat getPredicateTooBig() + { + return predicateTooBig; + } + + @Managed + @Nested + public CounterStat getSplitsTooBig() + { + return splitsTooBig; + } + + @Managed + @Nested + public DistributionStat getReadFromCacheData() + { + return readFromCacheData; + } + + @Managed + @Nested + public DistributionStat getCachedData() + { + return cachedData; + } + + @Managed + @Nested + public TimeStat getRevokeMemoryTime() + { + return revokeMemoryTime; + } + + @Managed + @Nested + public TimeStat getCacheLookupTime() + { + return cacheLookupTime; + } + + public void recordCacheMiss() + { + cacheMiss.update(1); + } + + public void recordCacheHit() + { + cacheHits.update(1); + } + + public void recordSplitRejected() + { + splitRejected.update(1); + } + + public void recordSplitFailoverHappened() + { + splitFailoverHappened.update(1); + } + + public void recordMissingSplitId() + { + missingSplitId.update(1); + } + + public void recordPredicateTooBig() + { + predicateTooBig.update(1); + } + + public void recordSplitsTooBig() + { + splitsTooBig.update(1); + } + + public void recordReadFromCacheData(long bytes) + { + readFromCacheData.add(bytes); + } + + public void recordCacheData(long bytes) + { + cachedData.add(bytes); + } + + @MustBeClosed + public BlockTimer recordRevokeMemoryTime() + { + return revokeMemoryTime.time(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CanonicalAggregation.java b/core/trino-main/src/main/java/io/trino/cache/CanonicalAggregation.java new file mode 100644 index 000000000000..6111342bad75 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CanonicalAggregation.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import io.trino.metadata.ResolvedFunction; +import io.trino.sql.ir.Expression; +import io.trino.sql.planner.Symbol; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record CanonicalAggregation( + ResolvedFunction resolvedFunction, + Optional mask, + List arguments) +{ + public CanonicalAggregation + { + requireNonNull(resolvedFunction, "resolvedFunction is null"); + requireNonNull(mask, "mask is null"); + arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments is null")); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CanonicalSubplan.java b/core/trino-main/src/main/java/io/trino/cache/CanonicalSubplan.java new file mode 100644 index 000000000000..df0924da86a4 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CanonicalSubplan.java @@ -0,0 +1,431 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.trino.metadata.TableHandle; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SortOrder; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.ir.Expression; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.optimizations.SymbolMapper; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.TopNRankingNode.RankingType; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterables.getLast; +import static io.trino.cache.CanonicalSubplanExtractor.columnIdToSymbol; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * This class represents a canonical subplan. Canonical subplan nodes can either be + * leaf nodes ({@link CanonicalSubplan#tableScan} is present) or intermediate nodes + * ({@link CanonicalSubplan#childSubplan} is present). + * Canonical subplan nodes map to a plan subgraph for which a common subplan can be extracted. + * For instance, a canonical subplan can be extracted for "scan -> filter -> project" chain of operators + * or for "aggregation" operator. + * Canonical symbol names are derived from {@link CacheColumnId} (if symbol represents + * {@link ColumnHandle}). {@link CacheColumnId} for complex projections will use canonicalized and formatted + * version of projection expression. + */ +public class CanonicalSubplan +{ + /** + * Keychain of a {@link CanonicalSubplan} tree. Plans that can be adapted + * to produce the same results will have the same keychain. + */ + private final List keyChain; + /** + * {@link PlanNodeId} of original table scan that can be used to identify subquery. + */ + private final PlanNodeId tableScanId; + /** + * Mapped and propagated enforced constraint from original table scan node. + */ + private final TupleDomain enforcedConstraint; + /** + * Reference to {@link PlanNode} from which {@link CanonicalSubplan} was derived. + */ + private final PlanNode originalPlanNode; + /** + * Mapping from {@link CacheColumnId} to original {@link Symbol} for entire subplan + * (including child subplans). + */ + private final BiMap originalSymbolMapping; + + /** + * Group by columns that are part of {@link CanonicalSubplan#assignments}. + */ + private final Optional> groupByColumns; + /** + * Output projections with iteration order matching list of output columns. + * Symbol names are canonicalized as {@link CacheColumnId}. + */ + private final Map assignments; + /** + * Filtering conjuncts. Symbol names are canonicalized as {@link CacheColumnId}. + */ + private final List conjuncts; + /** + * Set of conjuncts that can be pulled up though intermediate nodes as a top level filter node. + */ + private final Set pullableConjuncts; + /** + * List of dynamic filters. + */ + private final List dynamicConjuncts; + /** + * If present then this {@link CanonicalSubplan} was created on top of table scan. + */ + private final Optional tableScan; + /** + * If present then this {@link CanonicalSubplan} has another {@link CanonicalSubplan} as it's direct child. + */ + private final Optional childSubplan; + + private CanonicalSubplan( + Key key, + PlanNodeId tableScanId, + TupleDomain enforcedConstraint, + PlanNode originalPlanNode, + BiMap originalSymbolMapping, + Optional> groupByColumns, + Map assignments, + List conjuncts, + Set pullableConjuncts, + List dynamicConjuncts, + Optional tableScan, + Optional childSubplan) + { + this.tableScanId = requireNonNull(tableScanId, "tableScanId is null"); + this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); + this.originalPlanNode = requireNonNull(originalPlanNode, "originalPlanNode is null"); + this.originalSymbolMapping = ImmutableBiMap.copyOf(requireNonNull(originalSymbolMapping, "originalSymbolMapping is null")); + this.groupByColumns = requireNonNull(groupByColumns, "groupByColumns is null").map(ImmutableSet::copyOf); + this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); + this.conjuncts = ImmutableList.copyOf(requireNonNull(conjuncts, "conjuncts is null")); + this.pullableConjuncts = ImmutableSet.copyOf(requireNonNull(pullableConjuncts, "pullableConjuncts is null")); + this.dynamicConjuncts = ImmutableList.copyOf(requireNonNull(dynamicConjuncts, "dynamicConjuncts is null")); + this.tableScan = requireNonNull(tableScan, "tableScan is null"); + this.childSubplan = requireNonNull(childSubplan, "childSubplan is null"); + checkArgument(tableScan.isPresent() != childSubplan.isPresent(), "Source must be either table scan or child subplan"); + keyChain = ImmutableList.builder() + .addAll(childSubplan + .map(CanonicalSubplan::getKeyChain) + .orElse(ImmutableList.of())) + .add(requireNonNull(key, "key is null")) + .build(); + } + + public SymbolMapper canonicalSymbolMapper() + { + return new SymbolMapper(symbol -> { + CacheColumnId columnId = originalSymbolMapping.inverse().get(symbol); + requireNonNull(columnId, format("No column id for symbol %s", symbol)); + return columnIdToSymbol(columnId, symbol.type()); + }); + } + + public Key getKey() + { + return getLast(keyChain); + } + + public List getKeyChain() + { + return keyChain; + } + + public PlanNodeId getTableScanId() + { + return tableScanId; + } + + public TupleDomain getEnforcedConstraint() + { + return enforcedConstraint; + } + + public PlanNode getOriginalPlanNode() + { + return originalPlanNode; + } + + public BiMap getOriginalSymbolMapping() + { + return originalSymbolMapping; + } + + public Optional> getGroupByColumns() + { + return groupByColumns; + } + + public Map getAssignments() + { + return assignments; + } + + public List getConjuncts() + { + return conjuncts; + } + + public Set getPullableConjuncts() + { + return pullableConjuncts; + } + + public List getDynamicConjuncts() + { + return dynamicConjuncts; + } + + public Optional getTableScan() + { + return tableScan; + } + + public Optional getChildSubplan() + { + return childSubplan; + } + + public static class TableScan + { + /** + * Mapping from {@link CacheColumnId} to {@link ColumnHandle}. + */ + private final Map columnHandles; + /** + * Original table handle. + */ + private final TableHandle table; + /** + * {@link CacheTableId} of scanned table. + */ + private final CacheTableId tableId; + /** + * Whether to use connector provided node partitioning for table scan. + */ + private final boolean useConnectorNodePartitioning; + + public TableScan( + Map columnHandles, + TableHandle table, + CacheTableId tableId, + boolean useConnectorNodePartitioning) + { + this.columnHandles = ImmutableMap.copyOf(requireNonNull(columnHandles, "columnHandles is null")); + this.table = requireNonNull(table, "table is null"); + this.tableId = requireNonNull(tableId, "tableId is null"); + this.useConnectorNodePartitioning = useConnectorNodePartitioning; + } + + public Map getColumnHandles() + { + return columnHandles; + } + + public TableHandle getTable() + { + return table; + } + + public CacheTableId getTableId() + { + return tableId; + } + + public boolean isUseConnectorNodePartitioning() + { + return useConnectorNodePartitioning; + } + } + + public static CanonicalSubplanBuilder builderForTableScan( + Key key, + Map columnHandles, + TableHandle table, + CacheTableId tableId, + TupleDomain enforcedConstraint, + boolean useConnectorNodePartitioning, + PlanNodeId tableScanId) + { + return new CanonicalSubplanBuilder( + key, + tableScanId, + enforcedConstraint, + Optional.of(new TableScan(columnHandles, table, tableId, useConnectorNodePartitioning)), + Optional.empty()); + } + + public static CanonicalSubplanBuilder builderForChildSubplan(Key key, CanonicalSubplan childSubplan) + { + requireNonNull(childSubplan, "childSubplan is null"); + return new CanonicalSubplanBuilder( + key, + childSubplan.getTableScanId(), + childSubplan.getEnforcedConstraint(), + Optional.empty(), + Optional.of(childSubplan)); + } + + public static CanonicalSubplanBuilder builderExtending(CanonicalSubplan subplan) + { + return builderExtending(subplan.getKey(), subplan); + } + + public static CanonicalSubplanBuilder builderExtending(Key key, CanonicalSubplan subplan) + { + requireNonNull(subplan, "subplan is null"); + return new CanonicalSubplanBuilder(key, subplan.getTableScanId(), subplan.getEnforcedConstraint(), subplan.getTableScan(), subplan.getChildSubplan()) + .conjuncts(subplan.getConjuncts()) + .dynamicConjuncts(subplan.getDynamicConjuncts()); + } + + public static class CanonicalSubplanBuilder + { + private final Key key; + private final PlanNodeId tableScanId; + private final TupleDomain enforcedConstraint; + private final Optional tableScan; + private final Optional childSubplan; + private PlanNode originalPlanNode; + private BiMap originalSymbolMapping; + private Optional> groupByColumns = Optional.empty(); + private Map assignments; + private List conjuncts = ImmutableList.of(); + private Set pullableConjuncts; + private List dynamicConjuncts = ImmutableList.of(); + + private CanonicalSubplanBuilder( + Key key, + PlanNodeId tableScanId, + TupleDomain enforcedConstraint, + Optional tableScan, + Optional childSubplan) + { + this.key = requireNonNull(key, "key is null"); + this.tableScanId = requireNonNull(tableScanId, "tableScanId is null"); + this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); + this.tableScan = requireNonNull(tableScan, "tableScan is null"); + this.childSubplan = requireNonNull(childSubplan, "childSubplan is null"); + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder originalPlanNode(PlanNode originalPlanNode) + { + this.originalPlanNode = requireNonNull(originalPlanNode, "originalPlanNode is null"); + return this; + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder originalSymbolMapping(BiMap originalSymbolMapping) + { + this.originalSymbolMapping = requireNonNull(originalSymbolMapping, "originalSymbolMapping is null"); + return this; + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder groupByColumns(Set groupByColumns) + { + this.groupByColumns = Optional.of(requireNonNull(groupByColumns, "groupByColumns is null")); + return this; + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder assignments(Map assignments) + { + this.assignments = requireNonNull(assignments, "assignments is null"); + return this; + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder conjuncts(List conjuncts) + { + this.conjuncts = requireNonNull(conjuncts, "conjuncts is null"); + return this; + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder pullableConjuncts(Set pullableConjuncts) + { + this.pullableConjuncts = requireNonNull(pullableConjuncts, "pullableConjuncts is null"); + return this; + } + + @CanIgnoreReturnValue + public CanonicalSubplanBuilder dynamicConjuncts(List dynamicConjuncts) + { + this.dynamicConjuncts = requireNonNull(dynamicConjuncts, "dynamicConjuncts is null"); + return this; + } + + public CanonicalSubplan build() + { + return new CanonicalSubplan( + key, + tableScanId, + enforcedConstraint, + originalPlanNode, + originalSymbolMapping, + groupByColumns, + assignments, + conjuncts, + pullableConjuncts, + dynamicConjuncts, + tableScan, + childSubplan); + } + } + + public record ScanFilterProjectKey(CacheTableId tableId, Set requiredConjuncts) + implements Key {} + + public record FilterProjectKey(Set requiredConjuncts) + implements Key {} + + public record AggregationKey(Set groupByColumns, Set nonPullableConjuncts) + implements Key {} + + public record TopNKey(List orderBy, Map orderings, long count, Set nonPullableConjuncts) + implements Key {} + + public record TopNRankingKey( + List partitionBy, + List orderBy, + Map orderings, + RankingType rankingType, + int maxRankingPerPartition, + Set nonPullableConjuncts) + implements Key {} + + public interface Key {} +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CanonicalSubplanExtractor.java b/core/trino-main/src/main/java/io/trino/cache/CanonicalSubplanExtractor.java new file mode 100644 index 000000000000..5df3264ca662 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CanonicalSubplanExtractor.java @@ -0,0 +1,645 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.graph.Traverser; +import io.trino.Session; +import io.trino.cache.CanonicalSubplan.AggregationKey; +import io.trino.cache.CanonicalSubplan.CanonicalSubplanBuilder; +import io.trino.cache.CanonicalSubplan.FilterProjectKey; +import io.trino.cache.CanonicalSubplan.Key; +import io.trino.cache.CanonicalSubplan.ScanFilterProjectKey; +import io.trino.cache.CanonicalSubplan.TopNKey; +import io.trino.metadata.TableHandle; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SortOrder; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.ExpressionFormatter; +import io.trino.sql.ir.Lambda; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.DeterminismEvaluator; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.SymbolsExtractor; +import io.trino.sql.planner.optimizations.SymbolMapper; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.ChooseAlternativeNode; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanVisitor; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TopNNode; +import io.trino.sql.planner.plan.TopNRankingNode; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Predicates.instanceOf; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Streams.stream; +import static io.trino.cache.CanonicalSubplan.TopNRankingKey; +import static io.trino.sql.DynamicFilters.extractDynamicFilters; +import static io.trino.sql.DynamicFilters.isDynamicFilter; +import static io.trino.sql.ir.ExpressionFormatter.formatExpression; +import static io.trino.sql.ir.IrExpressions.mayFail; +import static io.trino.sql.ir.IrUtils.extractConjuncts; +import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; +import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; +import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.partitioningBy; + +public final class CanonicalSubplanExtractor +{ + private CanonicalSubplanExtractor() {} + + /** + * Extracts a list of {@link CanonicalSubplan} for a given plan. + */ + public static List extractCanonicalSubplans(PlannerContext plannerContext, Session session, PlanNode root) + { + ImmutableList.Builder canonicalSubplans = ImmutableList.builder(); + root.accept(new Visitor(plannerContext, session, canonicalSubplans), null).ifPresent(canonicalSubplans::add); + return canonicalSubplans.build(); + } + + public static CacheColumnId canonicalSymbolToColumnId(Symbol symbol) + { + requireNonNull(symbol, "symbol is null"); + return canonicalExpressionToColumnId(symbol.toSymbolReference()); + } + + public static CacheColumnId canonicalAggregationToColumnId(CanonicalAggregation aggregation) + { + StringBuilder builder = new StringBuilder(); + builder.append("aggregation ") + .append(aggregation.resolvedFunction().name().toString()) + .append('(') + .append(aggregation.arguments().stream() + .map(ExpressionFormatter::formatExpression) + .collect(joining(", "))) + .append(')'); + aggregation.mask().ifPresent(mask -> { + builder.append(" FILTER (WHERE ").append(formatExpression(mask.toSymbolReference())).append(')'); + }); + return new CacheColumnId("(" + builder + ")"); + } + + public static CacheColumnId canonicalExpressionToColumnId(Expression expression) + { + requireNonNull(expression, "expression is null"); + if (expression instanceof Reference symbolReference) { + // symbol -> column id translation should be reversible via columnIdToSymbol method + return new CacheColumnId(symbolReference.name()); + } + + // Make CacheColumnIds for complex expressions always wrapped in '()' so they are distinguishable from + // CacheColumnIds derived from connectors. + return new CacheColumnId("(" + formatExpression(expression) + ")"); + } + + public static Symbol columnIdToSymbol(CacheColumnId columnId, Type type) + { + requireNonNull(columnId, "columnId is null"); + return new Symbol(type, columnId.toString()); + } + + private static class Visitor + extends PlanVisitor, Void> + { + private final PlannerContext plannerContext; + private final Session session; + private final ImmutableList.Builder canonicalSubplans; + + public Visitor(PlannerContext plannerContext, Session session, ImmutableList.Builder canonicalSubplans) + { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.session = requireNonNull(session, "session is null"); + this.canonicalSubplans = requireNonNull(canonicalSubplans, "canonicalSubplans is null"); + } + + @Override + protected Optional visitPlan(PlanNode node, Void context) + { + node.getSources().forEach(this::canonicalizeRecursively); + return Optional.empty(); + } + + @Override + public Optional visitChooseAlternativeNode(ChooseAlternativeNode node, Void context) + { + // do not canonicalize plans that already contain alternative + return Optional.empty(); + } + + @Override + public Optional visitAggregation(AggregationNode node, Void context) + { + PlanNode source = node.getSource(); + // only subset of aggregations is supported + if (!(node.getGroupingSetCount() == 1 + && node.getPreGroupedSymbols().isEmpty() + && node.getStep() == PARTIAL + && node.getGroupIdSymbol().isEmpty() + && node.getHashSymbol().isEmpty())) { + canonicalizeRecursively(source); + return Optional.empty(); + } + + // only subset of aggregation functions are supported + boolean allSupportedAggregations = node.getAggregations().values().stream().allMatch(aggregation -> + // only symbol arguments are supported (no lambdas yet) + aggregation.getArguments().stream().allMatch(argument -> argument instanceof Reference) + && !aggregation.isDistinct() + && aggregation.getFilter().isEmpty() + && aggregation.getOrderingScheme().isEmpty() + && aggregation.getResolvedFunction().deterministic()); + + if (!allSupportedAggregations) { + canonicalizeRecursively(source); + return Optional.empty(); + } + + // always add non-aggregated canonical subplan so that it can be matched against other + // non-aggregated subqueries + Optional subplanOptional = canonicalizeRecursively(source); + if (subplanOptional.isEmpty()) { + return Optional.empty(); + } + + // evaluate mapping from subplan symbols to canonical expressions + CanonicalSubplan subplan = subplanOptional.get(); + SymbolMapper canonicalSymbolMapper = subplan.canonicalSymbolMapper(); + BiMap originalSymbolMapping = subplan.getOriginalSymbolMapping(); + Map assignments = new LinkedHashMap<>(); + + // canonicalize grouping columns + ImmutableSet.Builder groupByColumnsBuilder = ImmutableSet.builder(); + for (Symbol groupingKey : node.getGroupingKeys()) { + CacheColumnId columnId = requireNonNull(originalSymbolMapping.inverse().get(groupingKey)); + groupByColumnsBuilder.add(columnId); + if (assignments.put(columnId, CacheExpression.ofProjection(columnIdToSymbol(columnId, groupingKey.type()).toSymbolReference())) != null) { + // duplicated column ids are not supported + return Optional.empty(); + } + } + + // canonicalize aggregation functions + ImmutableBiMap.Builder symbolMappingBuilder = ImmutableBiMap.builder() + .putAll(originalSymbolMapping); + for (Map.Entry entry : node.getAggregations().entrySet()) { + Symbol symbol = entry.getKey(); + Aggregation aggregation = entry.getValue(); + CanonicalAggregation canonicalAggregation = new CanonicalAggregation( + aggregation.getResolvedFunction(), + aggregation.getMask().map(canonicalSymbolMapper::map), + aggregation.getArguments().stream() + .map(canonicalSymbolMapper::map) + .collect(toImmutableList())); + CacheColumnId columnId = canonicalAggregationToColumnId(canonicalAggregation); + if (assignments.put(columnId, CacheExpression.ofAggregation(canonicalAggregation)) != null) { + // duplicated column ids are not supported + return Optional.empty(); + } + if (originalSymbolMapping.containsKey(columnId)) { + // might happen if function call is projected by user explicitly + return Optional.empty(); + } + symbolMappingBuilder.put(columnId, symbol); + } + + // conjuncts that only contain group by symbols are pullable + BiMap symbolMapping = symbolMappingBuilder.buildOrThrow(); + Set groupByColumns = groupByColumnsBuilder.build(); + Set groupBySymbols = groupByColumns.stream() + .map(id -> columnIdToSymbol(id, symbolMapping.get(id).type())) + .collect(toImmutableSet()); + Map> conjuncts = subplan.getPullableConjuncts().stream() + .collect(partitioningBy(expression -> groupBySymbols.containsAll(SymbolsExtractor.extractAll(expression)))); + Set pullableConjuncts = ImmutableSet.copyOf(conjuncts.get(true)); + Set nonPullableConjuncts = ImmutableSet.copyOf(conjuncts.get(false)); + + // validate order of assignments with aggregation output columns + verify(ImmutableList.copyOf(assignments.keySet()) + .equals(node.getOutputSymbols().stream() + .map(symbol -> requireNonNull(symbolMapping.inverse().get(symbol))) + .collect(toImmutableList())), + "Assignments order doesn't match aggregation output symbols order"); + + return Optional.of(CanonicalSubplan.builderForChildSubplan(new AggregationKey(groupByColumns, nonPullableConjuncts), subplan) + .originalPlanNode(node) + .originalSymbolMapping(symbolMapping) + .groupByColumns(groupByColumns) + .assignments(assignments) + .pullableConjuncts(pullableConjuncts) + .build()); + } + + @Override + public Optional visitTopNRanking(TopNRankingNode node, Void context) + { + PlanNode source = node.getSource(); + + if (!node.isPartial() || node.getHashSymbol().isPresent() || node.getSpecification().orderingScheme().isEmpty()) { + canonicalizeRecursively(source); + return Optional.empty(); + } + + Optional subplanOptional = canonicalizeRecursively(source); + if (subplanOptional.isEmpty()) { + return Optional.empty(); + } + + CanonicalSubplan subplan = subplanOptional.get(); + BiMap originalSymbolMapping = subplan.getOriginalSymbolMapping(); + + // Sorting partition columns increases hit ratio and does not affect output rows + List partitionBy = node.getPartitionBy() + .stream().map(partitionKey -> originalSymbolMapping.inverse().get(partitionKey)) + .sorted(Comparator.comparing(CacheColumnId::toString)) + .collect(toImmutableList()); + + Optional> orderings = canonicalizeOrderingScheme(node.getOrderingScheme(), originalSymbolMapping); + return orderings.map(orderBy -> CanonicalSubplan.builderForChildSubplan( + new TopNRankingKey( + partitionBy, + ImmutableList.copyOf(orderBy.keySet()), + orderBy, + node.getRankingType(), + node.getMaxRankingPerPartition(), + ImmutableSet.copyOf(subplan.getPullableConjuncts())), + subplanOptional.get()) + .originalPlanNode(node) + .originalSymbolMapping(originalSymbolMapping) + .assignments(subplan.getAssignments()) + .pullableConjuncts(ImmutableSet.of()) + .build()); + } + + @Override + public Optional visitTopN(TopNNode node, Void context) + { + PlanNode source = node.getSource(); + + if (node.getStep() != TopNNode.Step.PARTIAL) { + canonicalizeRecursively(source); + return Optional.empty(); + } + + Optional subplanOptional = canonicalizeRecursively(source); + if (subplanOptional.isEmpty()) { + return Optional.empty(); + } + + CanonicalSubplan subplan = subplanOptional.get(); + BiMap originalSymbolMapping = subplan.getOriginalSymbolMapping(); + Optional> orderings = canonicalizeOrderingScheme(node.getOrderingScheme(), originalSymbolMapping); + + return orderings.map(orderBy -> CanonicalSubplan.builderForChildSubplan( + new TopNKey( + ImmutableList.copyOf(orderBy.keySet()), + orderBy, + node.getCount(), + ImmutableSet.copyOf(subplan.getPullableConjuncts())), + subplanOptional.get()) + .originalPlanNode(node) + .originalSymbolMapping(originalSymbolMapping) + .assignments(subplan.getAssignments()) + .pullableConjuncts(ImmutableSet.of()) + .build()); + } + + @Override + public Optional visitProject(ProjectNode node, Void context) + { + PlanNode source = node.getSource(); + + if (containsLambdaExpression(node)) { + // lambda expressions are not supported + canonicalizeRecursively(source); + return Optional.empty(); + } + + if (!node.getAssignments().getExpressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) { + canonicalizeRecursively(source); + return Optional.empty(); + } + + Optional subplanOptional; + boolean extendSubplan; + if (source instanceof FilterNode || source instanceof TableScanNode) { + // subplans consisting of scan <- filter <- project can be represented as one CanonicalSubplan object + subplanOptional = source.accept(this, null); + extendSubplan = true; + } + else { + subplanOptional = canonicalizeRecursively(source); + extendSubplan = false; + } + + if (subplanOptional.isEmpty()) { + return Optional.empty(); + } + + CanonicalSubplan subplan = subplanOptional.get(); + SymbolMapper canonicalSymbolMapper = subplan.canonicalSymbolMapper(); + // canonicalize projection assignments + Map assignments = new LinkedHashMap<>(); + ImmutableBiMap.Builder symbolMappingBuilder = ImmutableBiMap.builder() + .putAll(subplan.getOriginalSymbolMapping()); + for (Symbol symbol : node.getOutputSymbols()) { + // use formatted canonical expression as column id for non-identity projections + Expression canonicalExpression = canonicalSymbolMapper.map(node.getAssignments().get(symbol)); + CacheColumnId columnId = canonicalExpressionToColumnId(canonicalExpression); + if (assignments.put(columnId, CacheExpression.ofProjection(canonicalExpression)) != null) { + // duplicated column ids are not supported + if (extendSubplan) { + canonicalSubplans.add(subplan); + } + return Optional.empty(); + } + // columnId -> symbol could be "identity" and already added by table scan canonicalization + Symbol originalSymbol = subplan.getOriginalSymbolMapping().get(columnId); + if (originalSymbol == null) { + symbolMappingBuilder.put(columnId, symbol); + } + else if (!originalSymbol.equals(symbol)) { + // aliasing of column id to multiple symbols is not supported + if (extendSubplan) { + canonicalSubplans.add(subplan); + } + return Optional.empty(); + } + } + + // Unsafe expressions that could throw an error should be evaluated only for the rows from + // original subquery. Therefore, common subplan predicate must match original subplan predicate. + // If common subplan predicate is wider, then unsafe expressions could fail even though + // evaluation of the original subplan would be successful. + boolean safeProjections = node.getAssignments().getExpressions().stream().noneMatch(expression -> mayFail(plannerContext, expression)); + Set requiredConjuncts = !safeProjections ? subplan.getPullableConjuncts() : ImmutableSet.of(); + + CanonicalSubplanBuilder builder = extendSubplan ? + CanonicalSubplan.builderExtending(setRequiredConjuncts(subplan.getKey(), requiredConjuncts), subplan) : + CanonicalSubplan.builderForChildSubplan(new FilterProjectKey(requiredConjuncts), subplan); + return Optional.of(builder + .originalPlanNode(node) + .originalSymbolMapping(symbolMappingBuilder.buildOrThrow()) + .assignments(assignments) + // all symbols (and thus conjuncts) are pullable through projection + .pullableConjuncts(subplan.getPullableConjuncts()) + .build()); + } + + private Key setRequiredConjuncts(Key key, Set requiredConjuncts) + { + switch (key) { + case ScanFilterProjectKey scanFilterProjectKey -> { + checkArgument(scanFilterProjectKey.requiredConjuncts().isEmpty()); + return new ScanFilterProjectKey(scanFilterProjectKey.tableId(), requiredConjuncts); + } + case FilterProjectKey filterProjectKey -> { + checkArgument(filterProjectKey.requiredConjuncts().isEmpty()); + return new FilterProjectKey(requiredConjuncts); + } + default -> throw new IllegalStateException("Unsupported key type: " + key); + } + } + + @Override + public Optional visitFilter(FilterNode node, Void context) + { + PlanNode source = node.getSource(); + + if (containsLambdaExpression(node)) { + // lambda expressions are not supported + canonicalizeFilterSource(node); + return Optional.empty(); + } + + if (!isDeterministic(node.getPredicate())) { + canonicalizeFilterSource(node); + return Optional.empty(); + } + + Optional subplanOptional; + boolean extendSubplan; + if (source instanceof TableScanNode) { + // subplans consisting of scan <- filter <- project can be represented as one CanonicalSubplan object + subplanOptional = source.accept(this, null); + extendSubplan = true; + } + else { + subplanOptional = canonicalizeRecursively(source); + extendSubplan = false; + } + + if (subplanOptional.isEmpty()) { + return Optional.empty(); + } + + CanonicalSubplan subplan = subplanOptional.get(); + + // extract dynamic and static conjuncts + SymbolMapper canonicalSymbolMapper = subplan.canonicalSymbolMapper(); + ImmutableList.Builder conjuncts = ImmutableList.builder(); + ImmutableList.Builder dynamicConjuncts = ImmutableList.builder(); + for (Expression expression : extractConjuncts(node.getPredicate())) { + if (isDynamicFilter(expression)) { + dynamicConjuncts.add(canonicalSymbolMapper.map(expression)); + } + else { + conjuncts.add(canonicalSymbolMapper.map(expression)); + } + } + + CanonicalSubplanBuilder builder = extendSubplan ? + CanonicalSubplan.builderExtending(subplan) : + CanonicalSubplan.builderForChildSubplan(new FilterProjectKey(ImmutableSet.of()), subplan); + return Optional.of(builder + .originalPlanNode(node) + .originalSymbolMapping(subplan.getOriginalSymbolMapping()) + // assignments from subplan are preserved through filtering + .assignments(subplan.getAssignments()) + .conjuncts(conjuncts.build()) + .dynamicConjuncts(dynamicConjuncts.build()) + // all symbols (and thus conjuncts) are projected through filter node + .pullableConjuncts(ImmutableSet.builder() + .addAll(subplan.getPullableConjuncts()) + .addAll(conjuncts.build()) + .build()) + .build()); + } + + private void canonicalizeFilterSource(FilterNode node) + { + if (!containsDynamicFilter(node.getPredicate())) { + // Filter source must be a table scan if filter predicate contains dynamic filter. + // Hence, such filter nodes cannot be canonicalized separately from the table scan. + // Otherwise, there is a possibility that "load from cache" alternative is created below the filter node. + canonicalizeRecursively(node.getSource()); + } + } + + private boolean containsDynamicFilter(Expression expression) + { + return !extractDynamicFilters(expression).getDynamicConjuncts().isEmpty(); + } + + private boolean containsLambdaExpression(PlanNode node) + { + return extractExpressions(node).stream().anyMatch(this::containsLambdaExpression); + } + + private boolean containsLambdaExpression(Expression expression) + { + return stream(Traverser.forTree(Expression::children).depthFirstPreOrder(expression)) + .anyMatch(instanceOf(Lambda.class)); + } + + @Override + public Optional visitTableScan(TableScanNode node, Void context) + { + if (node.isUpdateTarget()) { + // inserts are not supported + return Optional.empty(); + } + + if (node.isUseConnectorNodePartitioning()) { + // TODO: add support for node partitioning + return Optional.empty(); + } + + CacheMetadata cacheMetadata = plannerContext.getCacheMetadata(); + TableHandle canonicalTableHandle = cacheMetadata.getCanonicalTableHandle(session, node.getTable()); + Optional tableId = cacheMetadata.getCacheTableId(session, canonicalTableHandle) + // prepend catalog id + .map(id -> new CacheTableId(node.getTable().catalogHandle().getId() + ":" + id)); + if (tableId.isEmpty()) { + return Optional.empty(); + } + + // canonicalize output symbols using column ids + ImmutableBiMap.Builder symbolMappingBuilder = ImmutableBiMap.builder(); + Map columnHandles = new LinkedHashMap<>(); + for (Symbol outputSymbol : node.getOutputSymbols()) { + ColumnHandle columnHandle = node.getAssignments().get(outputSymbol); + Optional columnId = cacheMetadata.getCacheColumnId(session, node.getTable(), columnHandle) + // Make connector ids always wrapped in '[]' so they are distinguishable from + // CacheColumnIds derived from complex expressions. + .map(id -> new CacheColumnId("[" + id + "]")); + if (columnId.isEmpty()) { + return Optional.empty(); + } + symbolMappingBuilder.put(columnId.get(), outputSymbol); + if (columnHandles.put(columnId.get(), columnHandle) != null) { + // duplicated column handles are not supported + return Optional.empty(); + } + } + BiMap symbolMapping = symbolMappingBuilder.build(); + + // pass-through canonical output symbols + Map assignments = columnHandles.keySet().stream().collect(toImmutableMap( + identity(), + id -> CacheExpression.ofProjection(columnIdToSymbol(id, symbolMapping.get(id).type()).toSymbolReference()))); + + return Optional.of(CanonicalSubplan.builderForTableScan( + new ScanFilterProjectKey(tableId.get(), ImmutableSet.of()), + columnHandles, + canonicalTableHandle, + tableId.get(), + canonicalizeEnforcedConstraint(node), + node.isUseConnectorNodePartitioning(), + node.getId()) + .originalPlanNode(node) + .originalSymbolMapping(symbolMapping) + .assignments(assignments) + .pullableConjuncts(ImmutableSet.of()) + .build()); + } + + private TupleDomain canonicalizeEnforcedConstraint(TableScanNode node) + { + // table predicate might contain all pushed down predicates to connector whereas TableScanNode#enforcedConstraints are pruned by visibility as output + TupleDomain tablePredicate = plannerContext.getMetadata().getTableProperties(session, node.getTable()).getPredicate(); + if (tablePredicate.isNone()) { + return TupleDomain.none(); + } + if (tablePredicate.isAll()) { + return TupleDomain.all(); + } + + Map domains = tablePredicate.getDomains().get(); + HashMap result = new LinkedHashMap<>(domains.size()); + for (Map.Entry entry : domains.entrySet()) { + Optional columnId = plannerContext.getCacheMetadata().getCacheColumnId(session, node.getTable(), entry.getKey()); + if (columnId.isEmpty()) { + return TupleDomain.all(); + } + + Domain domain = entry.getValue(); + checkState(result.put(columnId.get(), domain) == null || result.get(columnId.get()).equals(domain), + format("Columns with same ids should have same domains: %s maps to %s and %s", entry.getKey(), entry.getValue(), domain)); + } + return TupleDomain.withColumnDomains(result); + } + + private Optional> canonicalizeOrderingScheme(OrderingScheme orderingScheme, BiMap originalSymbolMapping) + { + Map orderings = new LinkedHashMap<>(); + for (Symbol orderKey : orderingScheme.orderBy()) { + CacheColumnId columnId = requireNonNull(originalSymbolMapping.inverse().get(orderKey)); + if (orderings.put(columnId, orderingScheme.ordering(orderKey)) != null) { + // duplicated column ids are not supported + return Optional.empty(); + } + } + return Optional.of(ImmutableMap.copyOf(orderings)); + } + + private Optional canonicalizeRecursively(PlanNode node) + { + Optional subplan = node.accept(this, null); + subplan.ifPresent(canonicalSubplans::add); + return subplan; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CommonPlanAdaptation.java b/core/trino-main/src/main/java/io/trino/cache/CommonPlanAdaptation.java new file mode 100644 index 000000000000..43654a8ea905 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CommonPlanAdaptation.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.ir.Expression; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.ChooseAlternativeNode.FilteredTableScan; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * This class provides a common subplan (shared between different subplans in a query) and a way + * to adapt it to original plan. + */ +public class CommonPlanAdaptation +{ + /** + * Common subplan (shared between different subplans in a query) + */ + private final PlanNode commonSubplan; + /** + * Signature of common subplan. + */ + private final PlanSignatureWithPredicate commonSubplanSignature; + /** + * Common subplan {@link FilteredTableScan}. + */ + private final FilteredTableScan commonSubplanFilteredTableScan; + /** + * Dynamic filter disjuncts from all common subplans. + */ + private final Expression commonDynamicFilterDisjuncts; + /** + * Mapping from {@link CacheColumnId} to {@link ColumnHandle}. + */ + private final Map commonColumnHandles; + /** + * Optional predicate that needs to be applied in order to adapt common subplan to + * original plan. + */ + private final Optional adaptationPredicate; + /** + * Optional projections that need to applied in order to adapt common subplan + * to original plan. + */ + private final Optional adaptationAssignments; + /** + * Mapping between {@link CacheColumnId} and symbols. + */ + private final Map columnIdMapping; + /** + * Adaptation conjuncts with symbol names canonicalized as {@link CacheColumnId}. + */ + private final List canonicalAdaptationConjuncts; + + public CommonPlanAdaptation( + PlanNode commonSubplan, + PlanSignatureWithPredicate commonSubplanSignature, + CommonPlanAdaptation childAdaptation, + Optional adaptationPredicate, + Optional adaptationAssignments, + Map columnIdMapping, + List canonicalAdaptationConjuncts) + { + this( + commonSubplan, + commonSubplanSignature, + childAdaptation.getCommonSubplanFilteredTableScan(), + childAdaptation.getCommonDynamicFilterDisjuncts(), + childAdaptation.getCommonColumnHandles(), + adaptationPredicate, + adaptationAssignments, + columnIdMapping, + canonicalAdaptationConjuncts); + } + + public CommonPlanAdaptation( + PlanNode commonSubplan, + PlanSignatureWithPredicate commonSubplanSignature, + FilteredTableScan commonSubplanFilteredTableScan, + Expression commonDynamicFilterDisjuncts, + Map commonColumnHandles, + Optional adaptationPredicate, + Optional adaptationAssignments, + Map columnIdMapping, + List canonicalAdaptationConjuncts) + { + this.commonSubplan = requireNonNull(commonSubplan, "commonSubplan is null"); + this.commonSubplanSignature = requireNonNull(commonSubplanSignature, "commonSubplanSignature is null"); + this.commonSubplanFilteredTableScan = requireNonNull(commonSubplanFilteredTableScan, "commonSubplanFilteredTableScan is null"); + this.commonDynamicFilterDisjuncts = requireNonNull(commonDynamicFilterDisjuncts, "commonDynamicFilterDisjuncts is null"); + this.commonColumnHandles = requireNonNull(commonColumnHandles, "commonColumnHandles is null"); + this.adaptationPredicate = requireNonNull(adaptationPredicate, "adaptationPredicate is null"); + this.adaptationAssignments = requireNonNull(adaptationAssignments, "adaptationAssignments is null"); + this.columnIdMapping = ImmutableMap.copyOf(requireNonNull(columnIdMapping, "columnIdMapping is null")); + this.canonicalAdaptationConjuncts = ImmutableList.copyOf(requireNonNull(canonicalAdaptationConjuncts, "canonicalAdaptationConjuncts is null")); + } + + public PlanNode adaptCommonSubplan(PlanNode commonSubplan, PlanNodeIdAllocator idAllocator) + { + checkArgument(this.commonSubplan.getOutputSymbols().equals(commonSubplan.getOutputSymbols())); + PlanNode adaptedPlan = commonSubplan; + if (adaptationPredicate.isPresent()) { + adaptedPlan = new FilterNode( + idAllocator.getNextId(), + adaptedPlan, + adaptationPredicate.get()); + } + if (adaptationAssignments.isPresent()) { + adaptedPlan = new ProjectNode( + idAllocator.getNextId(), + adaptedPlan, + adaptationAssignments.get()); + } + return adaptedPlan; + } + + public PlanNode getCommonSubplan() + { + return commonSubplan; + } + + public PlanSignatureWithPredicate getCommonSubplanSignature() + { + return commonSubplanSignature; + } + + public FilteredTableScan getCommonSubplanFilteredTableScan() + { + return commonSubplanFilteredTableScan; + } + + public Expression getCommonDynamicFilterDisjuncts() + { + return commonDynamicFilterDisjuncts; + } + + public Map getCommonColumnHandles() + { + return commonColumnHandles; + } + + public Map getColumnIdMapping() + { + return columnIdMapping; + } + + public List getCanonicalAdaptationConjuncts() + { + return canonicalAdaptationConjuncts; + } + + public record PlanSignatureWithPredicate(PlanSignature signature, TupleDomain predicate) + { + @JsonCreator + public PlanSignatureWithPredicate(PlanSignature signature, TupleDomain predicate) + { + this.signature = requireNonNull(signature, "signature is null"); + this.predicate = requireNonNull(predicate, "predicate is null"); + } + + @JsonProperty + @Override + public PlanSignature signature() + { + return signature; + } + + @JsonProperty + @Override + public TupleDomain predicate() + { + return predicate; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/CommonSubqueriesExtractor.java b/core/trino-main/src/main/java/io/trino/cache/CommonSubqueriesExtractor.java new file mode 100644 index 000000000000..92ee3f566f9d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/CommonSubqueriesExtractor.java @@ -0,0 +1,941 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Ordering; +import com.google.common.collect.Sets; +import com.google.common.collect.Streams; +import io.trino.Session; +import io.trino.cache.CacheController.CacheCandidate; +import io.trino.cache.CanonicalSubplan.AggregationKey; +import io.trino.cache.CanonicalSubplan.FilterProjectKey; +import io.trino.cache.CanonicalSubplan.Key; +import io.trino.cache.CanonicalSubplan.ScanFilterProjectKey; +import io.trino.cache.CanonicalSubplan.TableScan; +import io.trino.cache.CanonicalSubplan.TopNKey; +import io.trino.cache.CanonicalSubplan.TopNRankingKey; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.cost.PlanNodeStatsEstimate; +import io.trino.metadata.ResolvedFunction; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SortOrder; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; +import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.IrUtils; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.DomainTranslator; +import io.trino.sql.planner.DomainTranslator.ExtractionResult; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.SymbolsExtractor; +import io.trino.sql.planner.optimizations.SymbolMapper; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.ChooseAlternativeNode.FilteredTableScan; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TopNNode; +import io.trino.sql.planner.plan.TopNRankingNode; +import io.trino.sql.planner.plan.ValuesNode; + +import java.util.Collection; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.Streams.forEachPair; +import static com.google.common.collect.Streams.zip; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalExpressionToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalSymbolToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.columnIdToSymbol; +import static io.trino.cache.CanonicalSubplanExtractor.extractCanonicalSubplans; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.ExpressionFormatter.formatExpression; +import static io.trino.sql.ir.IrUtils.and; +import static io.trino.sql.ir.IrUtils.combineConjuncts; +import static io.trino.sql.ir.IrUtils.combineDisjuncts; +import static io.trino.sql.ir.IrUtils.extractConjuncts; +import static io.trino.sql.ir.IrUtils.or; +import static io.trino.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates; +import static io.trino.sql.planner.iterative.rule.NormalizeOrExpressionRewriter.normalizeOrExpression; +import static io.trino.sql.planner.iterative.rule.PushPredicateIntoTableScan.pushFilterIntoTableScan; +import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; +import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; +import static java.lang.String.format; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; +import static java.util.function.Predicate.not; + +/** + * Identifies common subqueries and provides adaptation to original query plan. Result of common + * subquery evaluation is cached with {@link CacheManager}. Therefore, IO and computations are + * performed only once and are reused within query execution. + *

+ * The general idea is that if there are two subqueries, e.g: + * {@code subquery1: table_scan(table) <- filter(col1 = 1) <- projection(y := col2 + 1)} + * {@code subquery2: table_scan(table) <- filter(col1 = 2) <- projection(z := col2 * 2)} + *

+ * Then such subqueries can be transformed into: + * {@code subquery1: table_scan(table) <- filter(col1 = 1 OR col1 = 2) <- projection(y := col2 + 1, z := col2 * 2) + * <- filter(col1 = 1) <- projection(y := y)} + * {@code subquery2: table_scan(table) <- filter(col1 = 1 OR col1 = 2) <- projection(y := col2 + 1, z := col2 * 2) + * <- filter(col1 = 2) <- projection(z := z)} + *

+ * {@code where: table_scan(table) <- filter(col1 = 1 OR col1 = 2) <- projection(y := col2 + 1, z := col2 * 2)} + * is a common subquery for which the results can be cached and evaluated only once. + */ +public final class CommonSubqueriesExtractor +{ + private final CacheController cacheController; + private final PlannerContext plannerContext; + private final Session session; + private final PlanNodeIdAllocator idAllocator; + private final SymbolAllocator symbolAllocator; + private final PlanNode root; + private final DomainTranslator domainTranslator; + + public static Map extractCommonSubqueries( + CacheController cacheController, + PlannerContext plannerContext, + Session session, + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator, + PlanNode root) + { + return new CommonSubqueriesExtractor(cacheController, plannerContext, session, idAllocator, symbolAllocator, root) + .extractCommonSubqueries(); + } + + public CommonSubqueriesExtractor( + CacheController cacheController, + PlannerContext plannerContext, + Session session, + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator, + PlanNode root) + { + this.cacheController = requireNonNull(cacheController, "cacheController is null"); + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.session = requireNonNull(session, "session is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); + this.root = requireNonNull(root, "root is null"); + this.domainTranslator = new DomainTranslator(plannerContext.getMetadata()); + } + + public Map extractCommonSubqueries() + { + ImmutableMap.Builder planAdaptations = ImmutableMap.builder(); + List cacheCandidates = cacheController.getCachingCandidates( + session, + extractCanonicalSubplans(plannerContext, session, root)); + + // extract common subplan adaptations + Set processedSubplans = new HashSet<>(); + for (CacheCandidate cacheCandidate : cacheCandidates) { + List subplans = cacheCandidate.subplans().stream() + // skip subqueries for which common subplan was already extracted + .filter(subplan -> !processedSubplans.contains(subplan.getTableScanId())) + .collect(toImmutableList()); + + if (subplans.size() < cacheCandidate.minSubplans()) { + // skip if not enough subplans + continue; + } + + subplans.forEach(subplan -> processedSubplans.add(subplan.getTableScanId())); + List adaptations = adaptSubplans(subplans); + checkState(adaptations.size() == subplans.size()); + forEachPair(subplans.stream(), adaptations.stream(), (subplan, adaptation) -> + planAdaptations.put(subplan.getOriginalPlanNode(), adaptation)); + } + return planAdaptations.buildOrThrow(); + } + + private List adaptSubplans(List subplans) + { + checkArgument(!subplans.isEmpty()); + checkArgument(subplans.stream().map(CanonicalSubplan::getKeyChain).distinct().count() == 1, "All subplans should have the same keychain"); + + Key key = subplans.get(0).getKey(); + if (key instanceof ScanFilterProjectKey) { + return adaptScanFilterProject(subplans); + } + else { + List childAdaptations = adaptSubplans(subplans.stream() + .map(subplan -> subplan.getChildSubplan().orElseThrow()) + .collect(toImmutableList())); + return adaptSubplans(key, childAdaptations, subplans); + } + } + + private List adaptSubplans(Key key, List childAdaptations, List subplans) + { + checkArgument(childAdaptations.size() == subplans.size()); + if (key instanceof FilterProjectKey) { + return adaptFilterProject(childAdaptations, subplans); + } + else if (key instanceof AggregationKey) { + return adaptAggregation(childAdaptations, subplans); + } + if (key instanceof TopNKey) { + return adaptTopN(childAdaptations, subplans); + } + if (key instanceof TopNRankingKey) { + return adaptTopNRanking(childAdaptations, subplans); + } + else { + throw new UnsupportedOperationException(format("Unsupported key: %s", key)); + } + } + + private List adaptScanFilterProject(List subplans) + { + checkArgument(subplans.stream().allMatch(subplan -> subplan.getGroupByColumns().isEmpty()), "Group by columns are not allowed"); + checkArgument(subplans.stream().map(subplan -> subplan.getTableScan().orElseThrow().getTableId()).distinct().count() == 1, "All subplans should have the same table id"); + CacheTableId tableId = subplans.get(0).getTableScan().orElseThrow().getTableId(); + + Expression commonPredicate = extractCommonPredicate(subplans); + Set intersectingConjuncts = extractIntersectingConjuncts(subplans); + Map commonProjections = extractCommonProjections(subplans, commonPredicate, intersectingConjuncts, Optional.empty(), Optional.empty()); + Map commonColumnHandles = extractCommonColumnHandles(subplans); + Map commonColumnIds = extractCommonColumnIds(subplans); + Expression commonDynamicFilterDisjuncts = extractCommonDynamicFilterDisjuncts(subplans); + PlanSignatureWithPredicate planSignature = computePlanSignature( + commonColumnIds, + tableId, + commonPredicate, + commonProjections.keySet().stream() + .collect(toImmutableList()), + commonColumnHandles.keySet()); + + return subplans.stream() + .map(subplan -> { + Map columnIdMapping = createSubplanColumnIdMapping(subplan, commonColumnIds, Optional.empty()); + SymbolMapper symbolMapper = createSymbolMapper(columnIdMapping); + SubplanFilter commonSubplanFilter = createSubplanFilter( + subplan, + commonPredicate, + createSubplanTableScan(subplan.getTableScan().orElseThrow(), commonColumnHandles, columnIdMapping), + symbolMapper); + PlanNode commonSubplan = createSubplanProjection(commonSubplanFilter.subplan(), commonProjections, columnIdMapping, symbolMapper); + List adaptationConjuncts = createAdaptationConjuncts(subplan, commonPredicate, intersectingConjuncts, Optional.empty()); + return new CommonPlanAdaptation( + commonSubplan, + planSignature, + new FilteredTableScan(commonSubplanFilter.tableScan(), commonSubplanFilter.predicate()), + symbolMapper.map(commonDynamicFilterDisjuncts), + commonColumnHandles, + createAdaptationPredicate(adaptationConjuncts, symbolMapper), + createAdaptationAssignments(commonSubplan, subplan, columnIdMapping), + columnIdMapping, + adaptationConjuncts); + }) + .collect(toImmutableList()); + } + + private List adaptFilterProject(List childAdaptations, List subplans) + { + checkArgument(subplans.stream().allMatch(subplan -> subplan.getGroupByColumns().isEmpty()), "Group by columns are not allowed"); + checkArgument(subplans.stream().allMatch(subplan -> subplan.getDynamicConjuncts().isEmpty()), "Dynamic filters are only allowed above table scan"); + + Expression commonPredicate = extractCommonPredicate(subplans); + Set intersectingConjuncts = extractIntersectingConjuncts(subplans); + Map commonProjections = extractCommonProjections(subplans, commonPredicate, intersectingConjuncts, Optional.empty(), Optional.of(childAdaptations)); + Map commonColumnIds = extractCommonColumnIds(subplans); + PlanSignatureWithPredicate planSignature = computePlanSignature( + commonColumnIds, + filterProjectKey(childAdaptations.get(0).getCommonSubplanSignature().signature().getKey()), + childAdaptations.get(0), + commonPredicate, + commonProjections.keySet().stream() + .collect(toImmutableList()), + Optional.empty()); + + return zip(subplans.stream(), childAdaptations.stream(), + (subplan, childAdaptation) -> { + Map columnIdMapping = createSubplanColumnIdMapping(subplan, commonColumnIds, Optional.of(childAdaptation)); + SymbolMapper symbolMapper = createSymbolMapper(columnIdMapping); + PlanNode commonSubplan = createSubplanProjection( + createSubplanFilter(commonPredicate, childAdaptation.getCommonSubplan(), symbolMapper), + commonProjections, + columnIdMapping, + symbolMapper); + List adaptationConjuncts = createAdaptationConjuncts(subplan, commonPredicate, intersectingConjuncts, Optional.of(childAdaptation)); + return new CommonPlanAdaptation( + commonSubplan, + planSignature, + childAdaptation, + createAdaptationPredicate(adaptationConjuncts, symbolMapper), + createAdaptationAssignments(commonSubplan, subplan, columnIdMapping), + columnIdMapping, + adaptationConjuncts); + }).collect(toImmutableList()); + } + + private List adaptTopNRanking(List childAdaptations, List subplans) + { + CanonicalSubplan canonicalSubplan = subplans.get(0); + TopNRankingKey topNRankingKey = (TopNRankingKey) canonicalSubplan.getKey(); + + Map commonColumnIds = extractCommonColumnIds(subplans); + Map commonProjections = extractCommonProjections(subplans, TRUE, ImmutableSet.of(), Optional.of(ImmutableSet.of()), Optional.of(childAdaptations)); + PlanSignatureWithPredicate planSignature = computePlanSignature( + commonColumnIds, + topNRankingKey(childAdaptations.get(0).getCommonSubplanSignature().signature().getKey(), topNRankingKey.partitionBy(), topNRankingKey.orderings(), topNRankingKey.rankingType(), topNRankingKey.maxRankingPerPartition()), + childAdaptations.get(0), + TRUE, + commonProjections.keySet().stream() + .collect(toImmutableList()), + Optional.empty()); + + return zip(subplans.stream(), childAdaptations.stream(), + (subplan, childAdaptation) -> { + Map columnIdMapping = createSubplanColumnIdMapping(subplan, commonColumnIds, Optional.of(childAdaptation)); + TopNRankingNode originalPlanNode = (TopNRankingNode) subplan.getOriginalPlanNode(); + TopNRankingKey key = (TopNRankingKey) subplan.getKey(); + // recreate specification that matches common partition column order + DataOrganizationSpecification specification = new DataOrganizationSpecification( + key.partitionBy().stream() + .map(columnIdMapping::get) + .collect(toImmutableList()), + Optional.of(originalPlanNode.getOrderingScheme())); + PlanNode commonSubplan = new TopNRankingNode( + idAllocator.getNextId(), + childAdaptation.getCommonSubplan(), + specification, + originalPlanNode.getRankingType(), + originalPlanNode.getRankingSymbol(), + originalPlanNode.getMaxRankingPerPartition(), + true, + Optional.empty()); + List adaptationConjuncts = createAdaptationConjuncts(subplan, TRUE, ImmutableSet.of(), Optional.of(childAdaptation)); + return new CommonPlanAdaptation( + commonSubplan, + planSignature, + childAdaptation, + createAdaptationPredicate(adaptationConjuncts, createSymbolMapper(columnIdMapping)), + createAdaptationAssignments(commonSubplan, subplan, columnIdMapping), + columnIdMapping, + adaptationConjuncts); + }).collect(toImmutableList()); + } + + private List adaptTopN(List childAdaptations, List subplans) + { + CanonicalSubplan canonicalSubplan = subplans.get(0); + TopNKey topNKey = (TopNKey) canonicalSubplan.getKey(); + + Map commonColumnIds = extractCommonColumnIds(subplans); + Map commonProjections = extractCommonProjections(subplans, TRUE, ImmutableSet.of(), Optional.of(ImmutableSet.of()), Optional.of(childAdaptations)); + PlanSignatureWithPredicate planSignature = computePlanSignature( + commonColumnIds, + topNKey(childAdaptations.get(0).getCommonSubplanSignature().signature().getKey(), topNKey.orderings(), topNKey.count()), + childAdaptations.get(0), + TRUE, + commonProjections.keySet().stream() + .collect(toImmutableList()), + Optional.empty()); + return zip(subplans.stream(), childAdaptations.stream(), + (subplan, childAdaptation) -> { + Map columnIdMapping = createSubplanColumnIdMapping(subplan, commonColumnIds, Optional.of(childAdaptation)); + TopNNode originalPlanNode = (TopNNode) subplan.getOriginalPlanNode(); + PlanNode commonSubplan = new TopNNode( + idAllocator.getNextId(), + childAdaptation.getCommonSubplan(), + originalPlanNode.getCount(), + originalPlanNode.getOrderingScheme(), + TopNNode.Step.PARTIAL); + List adaptationConjuncts = createAdaptationConjuncts(subplan, TRUE, ImmutableSet.of(), Optional.of(childAdaptation)); + return new CommonPlanAdaptation( + commonSubplan, + planSignature, + childAdaptation, + createAdaptationPredicate(adaptationConjuncts, createSymbolMapper(columnIdMapping)), + createAdaptationAssignments(commonSubplan, subplan, columnIdMapping), + columnIdMapping, + adaptationConjuncts); + }).collect(toImmutableList()); + } + + private List adaptAggregation(List childAdaptations, List subplans) + { + checkArgument(subplans.stream().allMatch(subplan -> subplan.getConjuncts().isEmpty()), "Conjuncts are not allowed in aggregation canonical subplan"); + checkArgument(subplans.get(0).getGroupByColumns().isPresent(), "Group by columns are not present"); + checkArgument(subplans.stream().map(CanonicalSubplan::getGroupByColumns).distinct().count() == 1, "Group by columns must be the same for all subplans"); + + Set groupByColumns = subplans.get(0).getGroupByColumns().orElseThrow(); + Map commonProjections = extractCommonProjections(subplans, TRUE, ImmutableSet.of(), Optional.of(groupByColumns), Optional.of(childAdaptations)); + Map commonColumnIds = extractCommonColumnIds(subplans); + PlanSignatureWithPredicate planSignature = computePlanSignature( + commonColumnIds, + aggregationKey(childAdaptations.get(0).getCommonSubplanSignature().signature().getKey()), + childAdaptations.get(0), + TRUE, + commonProjections.keySet().stream() + .collect(toImmutableList()), + Optional.of(groupByColumns)); + + return zip(subplans.stream(), childAdaptations.stream(), + (subplan, childAdaptation) -> { + Map columnIdMapping = createSubplanColumnIdMapping(subplan, commonColumnIds, Optional.of(childAdaptation)); + PlanNode commonSubplan = createSubplanAggregation( + childAdaptation.getCommonSubplan(), + commonProjections, + groupByColumns, + columnIdMapping); + List adaptationConjuncts = createAdaptationConjuncts(subplan, TRUE, ImmutableSet.of(), Optional.of(childAdaptation)); + return new CommonPlanAdaptation( + commonSubplan, + planSignature, + childAdaptation, + createAdaptationPredicate(adaptationConjuncts, createSymbolMapper(columnIdMapping)), + createAdaptationAssignments(commonSubplan, subplan, columnIdMapping), + columnIdMapping, + adaptationConjuncts); + }).collect(toImmutableList()); + } + + private Expression extractCommonPredicate(List subplans) + { + // When two similar subqueries have different predicates, e.g: subquery1: col = 1, subquery2: col = 2 + // then common subquery must have predicate "col = 1 OR col = 2". Narrowing adaptation predicate is then + // created for each subquery on top of common subquery. + return normalizeOrExpression( + extractCommonPredicates( + or(subplans.stream() + .map(subplan -> and( + subplan.getConjuncts())) + .collect(toImmutableList())))); + } + + private static Set extractIntersectingConjuncts(List subplans) + { + return subplans.stream() + .map(subplan -> (Set) ImmutableSet.copyOf(subplan.getConjuncts())) + .reduce(Sets::intersection) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + } + + private static Map extractCommonProjections( + List subplans, + Expression commonPredicate, + Set intersectingConjuncts, + Optional> pullupColumns, + Optional> childAdaptations) + { + // Extract common projections. Common (cached) subquery must contain projections from all subqueries. + // Pruning adaptation projection is then created for each subquery on top of common subplan. + Map commonProjections = subplans.stream() + .flatMap(subplan -> subplan.getAssignments().entrySet().stream()) + .distinct() + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, CommonSubqueriesExtractor::getSmallerExpression)); + // Common subquery must propagate all symbols used in adaptation predicates. + Map propagatedSymbols = childAdaptations + // Append adaptation conjuncts from child adaptations (if present) + .map(adaptations -> zip(subplans.stream(), adaptations.stream(), CommonSubqueriesExtractor::appendAdaptationConjuncts)) + .orElse(subplans.stream().map(CanonicalSubplan::getConjuncts)) + .filter(conjuncts -> isAdaptationPredicateNeeded(conjuncts, commonPredicate)) + .flatMap(Collection::stream) + // Use only conjuncts that are not enforced by intersecting predicate + .filter(conjunct -> !intersectingConjuncts.contains(conjunct)) + .map(SymbolsExtractor::extractAll) + .flatMap(Collection::stream) + .distinct() + .collect(toImmutableMap(CanonicalSubplanExtractor::canonicalSymbolToColumnId, symbol -> CacheExpression.ofProjection(symbol.toSymbolReference()))); + pullupColumns.ifPresent(columns -> checkState(columns.containsAll(propagatedSymbols.keySet()), "pullup columns don't contain all propagated symbols")); + return Streams.concat(commonProjections.entrySet().stream(), propagatedSymbols.entrySet().stream()) + .distinct() + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, CommonSubqueriesExtractor::getSmallerExpression)); + } + + private static CacheExpression getSmallerExpression(CacheExpression first, CacheExpression second) + { + checkArgument(first.projection().isPresent() == second.projection().isPresent(), "One of the projection expressions is missing. first expression: %s, second expression: %s", first.projection(), second.projection()); + if (first.aggregation().isPresent()) { + // Supported aggregations can only have canonical symbol arguments, hence there should be no evaluation ambiguity. + checkArgument(first.aggregation().equals(second.aggregation()), "Aggregation expressions are not the same. first expression: %s, second expression: %s", first.aggregation(), second.aggregation()); + return first; + } + // Prefer smaller expression trees. + if (IrUtils.preOrder(first.projection().get()).count() <= IrUtils.preOrder(second.projection().get()).count()) { + return first; + } + return second; + } + + private static Map extractCommonColumnHandles(List subplans) + { + // Common subquery must select column handles from all subqueries. + return subplans.stream() + .flatMap(subplan -> subplan.getTableScan().orElseThrow().getColumnHandles().entrySet().stream()) + .distinct() + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private static Map extractCommonColumnIds(List subplans) + { + Map commonColumnIds = new LinkedHashMap<>(); + subplans.stream() + .flatMap(subplan -> subplan.getOriginalSymbolMapping().entrySet().stream()) + .forEach(entry -> commonColumnIds.putIfAbsent(entry.getKey(), entry.getValue())); + return commonColumnIds; + } + + private Expression extractCommonDynamicFilterDisjuncts(List subplans) + { + return combineDisjuncts( + subplans.stream() + .map(subplan -> combineConjuncts(subplan.getDynamicConjuncts())) + .collect(toImmutableList())); + } + + private PlanSignatureWithPredicate computePlanSignature( + Map commonColumnIds, + CacheTableId tableId, + Expression predicate, + List projections, + Set scanColumnIds) + { + return computePlanSignature( + commonColumnIds, + scanFilterProjectKey(tableId), + TupleDomain.all(), + predicate, + projections, + Optional.empty(), + scanColumnIds); + } + + private PlanSignatureWithPredicate computePlanSignature( + Map commonColumnIds, + SignatureKey signatureKey, + CommonPlanAdaptation childAdaptation, + Expression predicate, + List projections, + Optional> groupByColumns) + { + return computePlanSignature( + commonColumnIds, + // Append child group by columns to signature key (if present) + combine(signatureKey, childAdaptation.getCommonSubplanSignature().signature().getGroupByColumns().map(columns -> "groupByColumns=" + columns).orElse("")), + childAdaptation.getCommonSubplanSignature().predicate(), + predicate, + projections, + groupByColumns, + childAdaptation.getCommonColumnHandles().keySet()); + } + + private PlanSignatureWithPredicate computePlanSignature( + Map commonColumnIds, + SignatureKey signatureKey, + TupleDomain tupleDomain, + Expression predicate, + List projections, + Optional> groupByColumns, + Set scanColumnIds) + { + Set projectionSet = ImmutableSet.copyOf(projections); + checkArgument(groupByColumns.isEmpty() || projectionSet.containsAll(groupByColumns.get())); + + // Order group by columns by name + Optional> orderedGroupByColumns = groupByColumns.map( + Ordering.from(comparing(CacheColumnId::toString))::immutableSortedCopy); + List projectionColumnsTypes = projections.stream() + .map(cacheColumnId -> commonColumnIds.get(cacheColumnId).type()) + .collect(toImmutableList()); + if (tupleDomain.isAll() && predicate.equals(TRUE)) { + return new PlanSignatureWithPredicate( + new PlanSignature( + signatureKey, + orderedGroupByColumns, + projections, + projectionColumnsTypes), + TupleDomain.all()); + } + + ExtractionResult extractionResult = DomainTranslator.getExtractionResult( + plannerContext, + session, + predicate); + // Only domains for projected columns can be part of signature predicate + TupleDomain extractedTupleDomain = extractionResult.getTupleDomain() + .transformKeys(CanonicalSubplanExtractor::canonicalSymbolToColumnId) + .intersect(tupleDomain); + Set retainedColumnIds = ImmutableSet.builder() + // retain projected and table scan domains for per split simplification and pruning on worker node + .addAll(projectionSet) + .addAll(scanColumnIds) + .build(); + TupleDomain retainedTupleDomain = extractedTupleDomain + .filter((columnId, domain) -> retainedColumnIds.contains(columnId)); + // Remaining expression and non-projected domains must be part of signature key + TupleDomain remainingTupleDomain = extractedTupleDomain + .filter((columnId, domain) -> !retainedColumnIds.contains(columnId)); + if (!remainingTupleDomain.isAll() || !extractionResult.getRemainingExpression().equals(TRUE)) { + Expression remainingDomainExpression = domainTranslator.toPredicate( + remainingTupleDomain.transformKeys(id -> columnIdToSymbol(id, commonColumnIds.get(id).type()))); + signatureKey = combine( + signatureKey, + "filters=" + formatExpression(combineConjuncts( + // Order remaining expressions alphabetically to improve signature generalisation + Stream.of(remainingDomainExpression, extractionResult.getRemainingExpression()) + .map(IrUtils::extractConjuncts) + .flatMap(Collection::stream) + .sorted(comparing(Expression::toString)) + .collect(toImmutableList())))); + } + + return new PlanSignatureWithPredicate( + new PlanSignature( + signatureKey, + orderedGroupByColumns, + projections, + projectionColumnsTypes), + retainedTupleDomain); + } + + private Map createSubplanColumnIdMapping(CanonicalSubplan subplan, Map commonColumnIds, Optional childAdaptation) + { + // Propagate column id<->symbol mappings from child adaptation + Map columnIdMapping = new LinkedHashMap<>(childAdaptation + .map(CommonPlanAdaptation::getColumnIdMapping) + .orElse(ImmutableMap.of())); + // Propagate original symbol names only if they don't override symbols from child adaptation since + // child adaptation might use different symbols for column ids than original subplan. + subplan.getOriginalSymbolMapping().forEach(columnIdMapping::putIfAbsent); + // Create new symbols for column ids that were not used in original subplan, but are part of common subquery now + commonColumnIds + .forEach((key, value) -> columnIdMapping.computeIfAbsent(key, ignored -> symbolAllocator.newSymbol(value))); + return ImmutableMap.copyOf(columnIdMapping); + } + + private SymbolMapper createSymbolMapper(Map columnIdMapping) + { + return new SymbolMapper(symbol -> requireNonNull(columnIdMapping.get(canonicalSymbolToColumnId(symbol)))); + } + + private PlanNode createSubplanAggregation( + PlanNode subplan, + Map projections, + Set groupByColumns, + Map columnIdMapping) + { + ImmutableList.Builder groupByColumnSymbols = ImmutableList.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + + for (Map.Entry entry : projections.entrySet()) { + CacheColumnId id = entry.getKey(); + if (groupByColumns.contains(entry.getKey())) { + groupByColumnSymbols.add(columnIdMapping.get(id)); + } + else { + CanonicalAggregation aggregationCall = entry.getValue().aggregation().orElseThrow(); + + // Resolve filter expression in terms of subplan symbols + Optional mask = aggregationCall.mask() + .map(filter -> requireNonNull(columnIdMapping.get(canonicalExpressionToColumnId(filter.toSymbolReference())))); + + // Resolve arguments in terms of subplan symbols + ResolvedFunction resolvedFunction = aggregationCall.resolvedFunction(); + List arguments = aggregationCall.arguments().stream() + .peek(argument -> checkState(argument instanceof Reference)) + .map(argument -> columnIdMapping.get(canonicalExpressionToColumnId(argument)).toSymbolReference()) + .collect(toImmutableList()); + + // Re-create aggregation using subquery specific symbols + aggregations.put( + columnIdMapping.get(id), + new Aggregation( + resolvedFunction, + arguments, + false, + Optional.empty(), + Optional.empty(), + mask)); + } + } + + AggregationNode aggregation = new AggregationNode( + idAllocator.getNextId(), + subplan, + aggregations.buildOrThrow(), + singleGroupingSet(groupByColumnSymbols.build()), + ImmutableList.of(), + PARTIAL, + Optional.empty(), + Optional.empty()); + List expectedSymbols = projections.keySet().stream() + .map(columnIdMapping::get) + .collect(toImmutableList()); + checkState(aggregation.getOutputSymbols().equals(expectedSymbols), "Aggregation symbols (%s) don't match expected symbols (%s)", aggregation.getOutputSymbols(), expectedSymbols); + return aggregation; + } + + private TableScanNode createSubplanTableScan( + TableScan tableScan, + Map columnHandles, + Map columnIdMapping) + { + return new TableScanNode( + idAllocator.getNextId(), + // use original table handle as it contains information about + // split enumeration (e.g. enforced partition or bucket filter) for + // a given subquery + tableScan.getTable(), + // Remap column ids into specific subquery symbols + columnHandles.keySet().stream() + .map(columnIdMapping::get) + .collect(toImmutableList()), + columnHandles.entrySet().stream() + .collect(toImmutableMap(entry -> columnIdMapping.get(entry.getKey()), Map.Entry::getValue)), + // Enforced constraint is not important at this stage of planning + TupleDomain.all(), + // Stats are not important at this stage of planning + Optional.empty(), + false, + Optional.of(tableScan.isUseConnectorNodePartitioning())); + } + + private SubplanFilter createSubplanFilter( + CanonicalSubplan subplan, + Expression predicate, + TableScanNode tableScan, + SymbolMapper symbolMapper) + { + if (predicate.equals(TRUE) && subplan.getDynamicConjuncts().isEmpty()) { + return new SubplanFilter(tableScan, Optional.empty(), tableScan); + } + + Expression predicateWithDynamicFilters = + // Subquery specific dynamic filters need to be added back to subplan. + // Actual dynamic filter domains are accounted for in PlanSignature on worker nodes. + symbolMapper.map(combineConjuncts( + predicate, + and(subplan.getDynamicConjuncts()))); + FilterNode filterNode = new FilterNode( + idAllocator.getNextId(), + tableScan, + predicateWithDynamicFilters); + + // Try to push down predicates to table scan + Optional rewritten = pushFilterIntoTableScan( + filterNode, + tableScan, + false, + session, + plannerContext, + node -> PlanNodeStatsEstimate.unknown()); + + // If ValuesNode was returned as a result of pushing down predicates we fall back + // to filterNode to avoid introducing significant changes in plan. Changing node from TableScan to ValuesNode + // potentially interfere with partitioning - note that this step is executed after planning. + rewritten = rewritten.filter(not(ValuesNode.class::isInstance)); + + if (rewritten.isPresent()) { + PlanNode node = rewritten.get(); + if (node instanceof FilterNode rewrittenFilterNode) { + checkState(rewrittenFilterNode.getSource() instanceof TableScanNode, "Expected filter source to be TableScanNode"); + return new SubplanFilter(node, Optional.of(rewrittenFilterNode.getPredicate()), (TableScanNode) rewrittenFilterNode.getSource()); + } + checkState(node instanceof TableScanNode, "Expected rewritten node to be TableScanNode"); + return new SubplanFilter(node, Optional.empty(), (TableScanNode) node); + } + + return new SubplanFilter(filterNode, Optional.of(predicateWithDynamicFilters), tableScan); + } + + private PlanNode createSubplanFilter(Expression predicate, PlanNode source, SymbolMapper symbolMapper) + { + if (predicate.equals(TRUE)) { + return source; + } + return new FilterNode(idAllocator.getNextId(), source, symbolMapper.map(predicate)); + } + + private record SubplanFilter(PlanNode subplan, Optional predicate, TableScanNode tableScan) {} + + private PlanNode createSubplanProjection( + PlanNode subplan, + Map projections, + Map columnIdMapping, + SymbolMapper symbolMapper) + { + return createSubplanAssignments( + subplan, + projections.entrySet().stream() + .collect(toImmutableMap( + entry -> columnIdMapping.get(entry.getKey()), + entry -> symbolMapper.map(entry.getValue().projection().orElseThrow())))) + .map(assignments -> (PlanNode) new ProjectNode(idAllocator.getNextId(), subplan, assignments)) + .orElse(subplan); + } + + private static Optional createAdaptationAssignments( + PlanNode subplan, + CanonicalSubplan canonicalSubplan, + Map columnIdMapping) + { + // Prune and order common subquery output in order to match original subquery. + Map projections = canonicalSubplan.getAssignments().keySet().stream() + .collect(toImmutableMap( + // Use original output symbols for adaptation projection. + id -> canonicalSubplan.getOriginalSymbolMapping().get(id), + id -> columnIdMapping.get(id).toSymbolReference())); + return createSubplanAssignments(subplan, projections); + } + + private static Optional createSubplanAssignments(PlanNode subplan, Map projections) + { + Assignments assignments = Assignments.copyOf(projections); + + // cache is sensitive to output symbols order + if (subplan.getOutputSymbols().equals(assignments.getOutputs())) { + return Optional.empty(); + } + + return Optional.of(assignments); + } + + private static List createAdaptationConjuncts( + CanonicalSubplan subplan, + Expression commonPredicate, + Set intersectingConjuncts, + Optional childAdaptation) + { + List conjuncts = childAdaptation + .map(adaptation -> appendAdaptationConjuncts(subplan, adaptation)) + .orElse(subplan.getConjuncts()); + + if (!isAdaptationPredicateNeeded(conjuncts, commonPredicate)) { + return ImmutableList.of(); + } + + return conjuncts.stream() + // Use only conjuncts that are not enforced by common predicate + .filter(conjunct -> !intersectingConjuncts.contains(conjunct)) + .collect(toImmutableList()); + } + + private static Optional createAdaptationPredicate(List conjuncts, SymbolMapper symbolMapper) + { + if (conjuncts.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(symbolMapper.map(and(conjuncts))); + } + + private static List appendAdaptationConjuncts(CanonicalSubplan subplan, CommonPlanAdaptation childAdaptation) + { + return ImmutableList.builder() + .addAll(subplan.getConjuncts()) + .addAll(childAdaptation.getCanonicalAdaptationConjuncts()) + .build(); + } + + private static boolean isAdaptationPredicateNeeded(List conjuncts, Expression commonPredicate) + { + Set commonConjuncts = ImmutableSet.copyOf(extractConjuncts(commonPredicate)); + return !conjuncts.isEmpty() && !ImmutableSet.copyOf(conjuncts).equals(commonConjuncts); + } + + @VisibleForTesting + public static SignatureKey scanFilterProjectKey(CacheTableId tableId) + { + return new SignatureKey(toStringHelper("ScanFilterProject") + .add("tableId", tableId) + .toString()); + } + + @VisibleForTesting + public static SignatureKey filterProjectKey(SignatureKey childKey) + { + return new SignatureKey(toStringHelper("FilterProject") + .add("childKey", childKey) + .toString()); + } + + @VisibleForTesting + public static SignatureKey aggregationKey(SignatureKey childKey) + { + return new SignatureKey(toStringHelper("Aggregation") + .add("childKey", childKey) + .toString()); + } + + @VisibleForTesting + public static SignatureKey topNRankingKey( + SignatureKey childKey, + List partitionBy, + Map orderBy, + TopNRankingNode.RankingType rankingType, + int maxRankingPerPartition) + { + return new SignatureKey(toStringHelper("TopNRanking") + .add("childKey", childKey) + .add("partitionBy", partitionBy) + .add("orderBy", orderBy) + .add("rankingType", rankingType) + .add("maxRankingPerPartition", maxRankingPerPartition) + .toString()); + } + + @VisibleForTesting + public static SignatureKey topNKey(SignatureKey childKey, Map orderBy, long count) + { + return new SignatureKey(toStringHelper("TopN") + .add("childKey", childKey) + .add("orderBy", orderBy) + .add("count", count) + .toString()); + } + + @VisibleForTesting + public static SignatureKey combine(SignatureKey key, String tail) + { + if (tail.isEmpty()) { + return key; + } + + return new SignatureKey(key + ":" + tail); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/ConnectorAwareAddressProvider.java b/core/trino-main/src/main/java/io/trino/cache/ConnectorAwareAddressProvider.java new file mode 100644 index 000000000000..37b92f07b691 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/ConnectorAwareAddressProvider.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.cache.CacheBuilder; +import com.google.inject.Inject; +import io.airlift.node.NodeInfo; +import io.trino.connector.ConnectorAwareNodeManager; +import io.trino.metadata.InternalNodeManager; +import io.trino.spi.connector.CatalogHandle; + +import static io.trino.cache.CacheUtils.uncheckedCacheGet; +import static io.trino.cache.SafeCaches.buildNonEvictableCache; +import static java.util.Objects.requireNonNull; + +public class ConnectorAwareAddressProvider +{ + private final NonEvictableCache catalogAddressProvider; + private final InternalNodeManager nodeManager; + + @Inject + public ConnectorAwareAddressProvider(InternalNodeManager nodeManager) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.catalogAddressProvider = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000)); + } + + public ConsistentHashingAddressProvider getAddressProvider(NodeInfo nodeInfo, CatalogHandle catalogHandle, boolean schedulerIncludeCoordinator) + { + return uncheckedCacheGet( + catalogAddressProvider, + catalogHandle, + () -> new ConsistentHashingAddressProvider(new ConnectorAwareNodeManager(nodeManager, nodeInfo.getEnvironment(), catalogHandle, schedulerIncludeCoordinator))); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/ConsistentHashingAddressProvider.java b/core/trino-main/src/main/java/io/trino/cache/ConsistentHashingAddressProvider.java new file mode 100644 index 000000000000..bdc05d14c1c2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/ConsistentHashingAddressProvider.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Sets; +import io.airlift.log.Logger; +import io.trino.spi.HostAddress; +import io.trino.spi.Node; +import io.trino.spi.NodeManager; +import org.ishugaliy.allgood.consistent.hash.ConsistentHash; +import org.ishugaliy.allgood.consistent.hash.HashRing; +import org.ishugaliy.allgood.consistent.hash.hasher.DefaultHasher; + +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.units.Duration.nanosSince; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class ConsistentHashingAddressProvider +{ + private static final Logger log = Logger.get(ConsistentHashingAddressProvider.class); + private static final long WORKER_NODES_CACHE_TIMEOUT_SECS = 5; + + private final ConsistentHash consistentHashRing = HashRing.newBuilder() + .hasher(DefaultHasher.METRO_HASH) + .build(); + private final NodeManager nodeManager; + private volatile long lastRefreshTime; + + public ConsistentHashingAddressProvider(NodeManager nodeManager) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + refreshHashRing(); + } + + public Optional getPreferredAddress(String key) + { + refreshHashRingIfNeeded(); + return consistentHashRing.locate(key) + .map(TrinoNode::getHostAndPort); + } + + public void refreshHashRingIfNeeded() + { + if (nanosSince(lastRefreshTime).getValue(SECONDS) > WORKER_NODES_CACHE_TIMEOUT_SECS) { + /// Double lock checking pattern to reduce lock contention + synchronized (this) { + if (nanosSince(lastRefreshTime).getValue(SECONDS) > WORKER_NODES_CACHE_TIMEOUT_SECS) { + refreshHashRing(); + } + } + } + } + + @VisibleForTesting + synchronized void refreshHashRing() + { + try { + Set trinoNodes = nodeManager.getWorkerNodes().stream() + .map(TrinoNode::of) + .collect(toImmutableSet()); + lastRefreshTime = System.nanoTime(); + Set hashRingNodes = consistentHashRing.getNodes(); + Set removedNodes = Sets.difference(hashRingNodes, trinoNodes); + Set newNodes = Sets.difference(trinoNodes, hashRingNodes); + if (!newNodes.isEmpty()) { + consistentHashRing.addAll(newNodes); + } + if (!removedNodes.isEmpty()) { + removedNodes.forEach(consistentHashRing::remove); + } + } + catch (Exception e) { + log.error(e, "Error refreshing hash ring"); + } + } + + private record TrinoNode(String nodeIdentifier, HostAddress hostAndPort) + implements org.ishugaliy.allgood.consistent.hash.node.Node + { + public static TrinoNode of(Node node) + { + return new TrinoNode(node.getNodeIdentifier(), node.getHostAndPort()); + } + + public HostAddress getHostAndPort() + { + return hostAndPort; + } + + @Override + public String getKey() + { + return nodeIdentifier; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/LoadCachedDataOperator.java b/core/trino-main/src/main/java/io/trino/cache/LoadCachedDataOperator.java new file mode 100644 index 000000000000..7957a78b0a8c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/LoadCachedDataOperator.java @@ -0,0 +1,187 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import io.trino.memory.context.LocalMemoryContext; +import io.trino.metadata.Split; +import io.trino.operator.DriverContext; +import io.trino.operator.OperatorContext; +import io.trino.operator.OperatorFactory; +import io.trino.operator.SourceOperator; +import io.trino.operator.SourceOperatorFactory; +import io.trino.spi.Page; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.sql.planner.plan.PlanNodeId; +import jakarta.annotation.Nullable; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class LoadCachedDataOperator + implements SourceOperator +{ + public static class LoadCachedDataOperatorFactory + implements SourceOperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final PlanNodeId sourceId; + private boolean closed; + + public LoadCachedDataOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + PlanNodeId sourceId) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public SourceOperator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LoadCachedDataOperator.class.getSimpleName()); + return new LoadCachedDataOperator(operatorContext, sourceId); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new LoadCachedDataOperatorFactory(operatorId, planNodeId, sourceId); + } + } + + private final OperatorContext operatorContext; + private final PlanNodeId sourceId; + private final CacheStats cacheStats; + private final LocalMemoryContext memoryContext; + + @Nullable + private ConnectorPageSource pageSource; + + private LoadCachedDataOperator(OperatorContext operatorContext, PlanNodeId sourceId) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.memoryContext = operatorContext.newLocalUserMemoryContext(LoadCachedDataOperator.class.getSimpleName()); + CacheDriverContext cacheContext = operatorContext.getDriverContext().getCacheDriverContext() + .orElseThrow(() -> new IllegalArgumentException("Cache context is not present")); + this.cacheStats = cacheContext.cacheStats(); + this.pageSource = cacheContext + .pageSource() + .orElseThrow(() -> new IllegalArgumentException("Cache page sink is not present")); + memoryContext.setBytes(pageSource.getMemoryUsage()); + operatorContext.setLatestMetrics(cacheContext.metrics()); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public void addSplit(Split split) + { + // noop + } + + @Override + public void noMoreSplits() + { + // noop + } + + @Override + public boolean needsInput() + { + return false; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(getClass().getName() + " cannot take input"); + } + + @Override + public Page getOutput() + { + if (pageSource == null) { + return null; + } + + Page page = pageSource.getNextPage(); + if (page == null) { + return null; + } + + cacheStats.recordReadFromCacheData(page.getSizeInBytes()); + operatorContext.recordProcessedInput(page.getSizeInBytes(), page.getPositionCount()); + memoryContext.setBytes(pageSource.getMemoryUsage()); + return page.getLoadedPage(); + } + + @Override + public boolean isFinished() + { + return pageSource == null || pageSource.isFinished(); + } + + @Override + public void close() + throws Exception + { + finish(); + } + + @Override + public void finish() + { + try { + if (pageSource != null) { + pageSource.close(); + operatorContext.setLatestConnectorMetrics(pageSource.getMetrics()); + pageSource = null; + memoryContext.close(); + } + } + catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/MinSeparationSplitAdmissionController.java b/core/trino-main/src/main/java/io/trino/cache/MinSeparationSplitAdmissionController.java new file mode 100644 index 000000000000..5559c212ecbb --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/MinSeparationSplitAdmissionController.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.cache.CacheSplitId; +import it.unimi.dsi.fastutil.ints.Int2LongArrayMap; +import it.unimi.dsi.fastutil.ints.Int2LongMap; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; + +/** + * A simple split admission controller that ensures that a worker processes a minimum number of splits + * with distinct CacheSplitID before scheduling the next batch of splits which can contain splits with the + * same CacheSplitID. + */ +public class MinSeparationSplitAdmissionController + implements SplitAdmissionController +{ + private static final int MAX_CACHE_SPLITS = 1_000_000; + private static final long PENDING_SPLIT_MARKER = Long.MAX_VALUE; + private static final long NOT_SCHEDULED_SPLIT_MARKER = -1; + + private final int minSplitSeparation; + + @GuardedBy("this") + private final Map workerInfos = new HashMap<>(); + + public MinSeparationSplitAdmissionController(int minSplitSeparation) + { + verify(minSplitSeparation > 0, "minSplitSeparation must be greater than 0"); + this.minSplitSeparation = minSplitSeparation; + } + + @Override + public synchronized boolean canScheduleSplit(CacheSplitId splitId, HostAddress address) + { + WorkerInfo workerInfo = workerInfos.computeIfAbsent(address, _ -> new WorkerInfo()); + int cacheSplitKey = getCacheSplitKey(splitId); + long splitSequenceId = workerInfo.scheduledSplits.get(cacheSplitKey); + if (splitSequenceId == NOT_SCHEDULED_SPLIT_MARKER) { + // We use PENDING_SPLIT_MARKER to indicate that a split with a given CacheSplitId has been added + // to the queue for scheduling. However, it has not been scheduled yet. This marker prevents returning + // splits with the same CacheSplitId from concurrently executing CacheSplitSources. + workerInfo.scheduledSplits.put(cacheSplitKey, PENDING_SPLIT_MARKER); + return true; + } + + // Count-based heuristic + return workerInfo.processedSplitCount - splitSequenceId >= minSplitSeparation; + } + + @Override + public synchronized void splitsScheduled(List splits) + { + for (Split split : splits) { + Optional cacheSplitId = split.getCacheSplitId(); + List addresses = split.getAddresses(); + // We only care about splits that are cacheable and have preferred addresses (worker) + if (cacheSplitId.isPresent()) { + HostAddress address = getOnlyElement(addresses); + WorkerInfo workerInfo = workerInfos.computeIfAbsent(address, _ -> new WorkerInfo()); + int cacheSplitKey = getCacheSplitKey(cacheSplitId.get()); + // Do not update split sequence id if the split was already executed. This way subsequent splits with same + // split id won't have to hold execution of the split + long splitSequenceId = workerInfo.scheduledSplits.get(cacheSplitKey); + checkState( + splitSequenceId != NOT_SCHEDULED_SPLIT_MARKER, + "Expected sequence ID for a split: %s", + cacheSplitId.get()); + if (splitSequenceId == PENDING_SPLIT_MARKER) { + workerInfo.scheduledSplits.put(cacheSplitKey, workerInfo.processedSplitCount++); + } + } + } + } + + private static int getCacheSplitKey(CacheSplitId splitId) + { + // Use the hash code of the split id as the key to cap memory usage + return splitId.hashCode() % MAX_CACHE_SPLITS; + } + + private static final class WorkerInfo + { + private final Int2LongMap scheduledSplits; + private long processedSplitCount; + + public WorkerInfo() + { + scheduledSplits = new Int2LongArrayMap(); + // Set the default return value to NOT_SCHEDULED_SPLIT_MARKER to indicate that the split + // has not been scheduled yet. + scheduledSplits.defaultReturnValue(NOT_SCHEDULED_SPLIT_MARKER); + processedSplitCount = 0; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/SplitAdmissionController.java b/core/trino-main/src/main/java/io/trino/cache/SplitAdmissionController.java new file mode 100644 index 000000000000..419f4d86dfec --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/SplitAdmissionController.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.cache.CacheSplitId; + +import java.util.List; + +public interface SplitAdmissionController +{ + /** + * Determines if a split can be scheduled on a worker. + * + * @param splitId the split to schedule + * @param address the worker to schedule the split on + * @return true if the split can be schedule; false otherwise + */ + boolean canScheduleSplit(CacheSplitId splitId, HostAddress address); + + /** + * Notifies the manager that a list of splits have been scheduled. + */ + void splitsScheduled(List splits); +} diff --git a/core/trino-main/src/main/java/io/trino/cache/SplitAdmissionControllerProvider.java b/core/trino-main/src/main/java/io/trino/cache/SplitAdmissionControllerProvider.java new file mode 100644 index 000000000000..e4a996ce721b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/SplitAdmissionControllerProvider.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; +import io.trino.sql.planner.plan.ChooseAlternativeNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SystemSessionProperties.getCacheMinWorkerSplitSeparation; +import static io.trino.cache.CacheCommonSubqueries.getLoadCachedDataPlanNode; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.counting; +import static java.util.stream.Collectors.groupingBy; + +public class SplitAdmissionControllerProvider +{ + public static final SplitAdmissionController NOOP = new SplitAdmissionController() + { + @Override + public boolean canScheduleSplit(CacheSplitId splitId, HostAddress address) + { + return true; + } + + @Override + public void splitsScheduled(List splits) {} + }; + + private final Map planSignatures; + private final Map splitSchedulerManagers; + + public SplitAdmissionControllerProvider(List fragments, Session session) + { + requireNonNull(fragments, "fragments is null"); + requireNonNull(session, "session is null"); + + int cacheMinWorkerSplitSeparation = getCacheMinWorkerSplitSeparation(session); + if (cacheMinWorkerSplitSeparation > 0) { + planSignatures = fragments.stream() + .flatMap(fragment -> PlanNodeSearcher.searchFrom(fragment.getRoot()) + .where(CacheCommonSubqueries::isCacheChooseAlternativeNode) + .findAll() + .stream() + .map(ChooseAlternativeNode.class::cast)) + .collect(toImmutableMap(PlanNode::getId, node -> getLoadCachedDataPlanNode(node).getPlanSignature().signature())); + splitSchedulerManagers = planSignatures.values().stream() + .collect(groupingBy(identity(), counting())) + .entrySet().stream() + // only create admission controller for table scans with repeating signatures + .filter(entry -> entry.getValue() > 1) + .collect(toImmutableMap(Map.Entry::getKey, _ -> new MinSeparationSplitAdmissionController(cacheMinWorkerSplitSeparation))); + } + else { + planSignatures = ImmutableMap.of(); + splitSchedulerManagers = ImmutableMap.of(); + } + } + + public SplitAdmissionController get(PlanNodeId nodeId) + { + PlanSignature planSignature = planSignatures.get(nodeId); + if (planSignature == null) { + return NOOP; + } + return get(planSignature); + } + + public SplitAdmissionController get(PlanSignature signature) + { + return splitSchedulerManagers.getOrDefault(signature, NOOP); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cache/StaticDynamicFilter.java b/core/trino-main/src/main/java/io/trino/cache/StaticDynamicFilter.java new file mode 100644 index 000000000000..594fadb3d72c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cache/StaticDynamicFilter.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.predicate.TupleDomain; + +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.anyOf; + +/** + * Implementation of dynamic filter that is not awaitable. + */ +public class StaticDynamicFilter + implements DynamicFilter +{ + private final Set columnsCovered; + private final boolean isComplete; + private final TupleDomain tupleDomain; + private volatile int hashCode; + + /** + * Creates a {@code Supplier} that caches the {@link StaticDynamicFilter}. + * Cache is reset whenever any underlying dynamic filter gets updated. + */ + public static Supplier createStaticDynamicFilterSupplier(List disjunctiveDynamicFilters) + { + AtomicReference> dynamicFilterCache = new AtomicReference<>(); + resetDynamicFilterCache(dynamicFilterCache, disjunctiveDynamicFilters); + return () -> { + AtomicReference dynamicFilterReference = requireNonNull(dynamicFilterCache.get()); + StaticDynamicFilter dynamicFilter = dynamicFilterReference.get(); + if (dynamicFilter == null) { + dynamicFilter = createStaticDynamicFilter(disjunctiveDynamicFilters); + if (!dynamicFilterReference.compareAndSet(null, dynamicFilter)) { + return dynamicFilterReference.get(); + } + } + return dynamicFilter; + }; + } + + private static void resetDynamicFilterCache( + AtomicReference> dynamicFilterCache, + List disjunctiveDynamicFilters) + { + if (areAwaitable(disjunctiveDynamicFilters)) { + // reset dynamic filter cache whenever any underlying dynamic filter gets updated + whenAnyUpdates(disjunctiveDynamicFilters).addListener(() -> resetDynamicFilterCache(dynamicFilterCache, disjunctiveDynamicFilters), directExecutor()); + } + dynamicFilterCache.set(new AtomicReference<>()); + } + + private static boolean areAwaitable(List disjunctiveDynamicFilters) + { + return disjunctiveDynamicFilters.stream().anyMatch(DynamicFilter::isAwaitable); + } + + private static ListenableFuture whenAnyUpdates(List disjunctiveDynamicFilters) + { + return toListenableFuture(anyOf(disjunctiveDynamicFilters.stream() + .filter(DynamicFilter::isAwaitable) + .map(DynamicFilter::isBlocked) + .toArray(CompletableFuture[]::new))); + } + + public static StaticDynamicFilter createStaticDynamicFilter(List disjunctiveDynamicFilters) + { + requireNonNull(disjunctiveDynamicFilters, "disjunctiveDynamicFilters is null"); + checkArgument(!disjunctiveDynamicFilters.isEmpty()); + return new StaticDynamicFilter( + disjunctiveDynamicFilters.stream() + .flatMap(filter -> filter.getColumnsCovered().stream()) + .collect(toImmutableSet()), + // isComplete needs to be called before getCurrentPredicate + disjunctiveDynamicFilters.stream().allMatch(DynamicFilter::isComplete), + TupleDomain.columnWiseUnion(disjunctiveDynamicFilters.stream() + .map(DynamicFilter::getCurrentPredicate) + .collect(toImmutableList()))); + } + + private StaticDynamicFilter(Set columnsCovered, boolean isComplete, TupleDomain tupleDomain) + { + this.columnsCovered = requireNonNull(columnsCovered, "columnsCovered is null"); + this.isComplete = isComplete; + this.tupleDomain = requireNonNull(tupleDomain, "tupleDomain is null"); + } + + @Override + public Set getColumnsCovered() + { + return columnsCovered; + } + + @Override + public CompletableFuture isBlocked() + { + return NOT_BLOCKED; + } + + @Override + public boolean isComplete() + { + return isComplete; + } + + @Override + public boolean isAwaitable() + { + return false; + } + + @Override + public TupleDomain getCurrentPredicate() + { + return tupleDomain; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + StaticDynamicFilter that = (StaticDynamicFilter) o; + return isComplete == that.isComplete + && Objects.equals(columnsCovered, that.columnsCovered) + && Objects.equals(tupleDomain, that.tupleDomain); + } + + @Override + public int hashCode() + { + if (hashCode == 0) { + hashCode = Objects.hash(columnsCovered, isComplete, tupleDomain); + } + return hashCode; + } +} diff --git a/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java b/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java index a6bc0e1c60c9..e1f5b91a613c 100644 --- a/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java +++ b/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java @@ -31,6 +31,7 @@ import io.trino.metadata.TablePropertyManager; import io.trino.metadata.ViewPropertyManager; import io.trino.security.AccessControlManager; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorIndexProvider; import io.trino.spi.connector.ConnectorNodePartitioningProvider; @@ -58,6 +59,13 @@ public static CatalogServiceProvider createSplitManagerPr return new ConnectorCatalogServiceProvider<>("split manager", connectorServicesProvider, connector -> connector.getSplitManager().orElse(null)); } + @Provides + @Singleton + public static CatalogServiceProvider> createCacheMetadata(ConnectorServicesProvider connectorServicesProvider) + { + return new ConnectorCatalogServiceProvider<>("cache metadata", connectorServicesProvider, ConnectorServices::getCacheMetadata); + } + @Provides @Singleton public static CatalogServiceProvider createPageSourceProviderFactory(ConnectorServicesProvider connectorServicesProvider) diff --git a/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java b/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java index 898f4b60c49b..3d53bd90586c 100644 --- a/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java +++ b/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java @@ -22,6 +22,7 @@ import io.trino.metadata.CatalogProcedures; import io.trino.metadata.CatalogTableFunctions; import io.trino.metadata.CatalogTableProcedures; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.Connector; @@ -78,6 +79,7 @@ public class ConnectorServices private final Optional functionProvider; private final CatalogTableFunctions tableFunctions; private final Optional splitManager; + private final Optional cacheMetadata; private final Optional pageSourceProviderFactory; private final Optional pageSinkProvider; private final Optional indexProvider; @@ -129,6 +131,14 @@ public ConnectorServices(Tracer tracer, CatalogHandle catalogHandle, Connector c } this.splitManager = Optional.ofNullable(splitManager); + ConnectorCacheMetadata cacheMetadata = null; + try { + cacheMetadata = connector.getCacheMetadata(); + } + catch (UnsupportedOperationException ignored) { + } + this.cacheMetadata = Optional.ofNullable(cacheMetadata); + ConnectorPageSourceProviderFactory connectorPageSourceProviderFactory = null; try { connectorPageSourceProviderFactory = connector.getPageSourceProviderFactory(); @@ -267,6 +277,11 @@ public Optional getSplitManager() return splitManager; } + public Optional getCacheMetadata() + { + return cacheMetadata; + } + public Optional getPageSourceProviderFactory() { return pageSourceProviderFactory; diff --git a/core/trino-main/src/main/java/io/trino/cost/ChooseAlternativeRule.java b/core/trino-main/src/main/java/io/trino/cost/ChooseAlternativeRule.java new file mode 100644 index 000000000000..8931382bf07e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/ChooseAlternativeRule.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cost; + +import io.trino.matching.Pattern; +import io.trino.sql.planner.plan.ChooseAlternativeNode; + +import java.util.Optional; + +import static io.trino.sql.planner.plan.Patterns.chooseAlternative; + +public class ChooseAlternativeRule + extends SimpleStatsRule +{ + private static final Pattern PATTERN = chooseAlternative(); + + public ChooseAlternativeRule(StatsNormalizer normalizer) + { + super(normalizer); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + protected Optional doCalculate(ChooseAlternativeNode node, StatsCalculator.Context context) + { + // All alternatives describe the same dataset, therefore it would be wasteful to calculate stats for each alternative. + // Instead, we calculate only for the first alternative which is considered the default alternative. + return Optional.of(context.statsProvider().getStats(node.getSources().get(0))); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cost/CostCalculatorUsingExchanges.java b/core/trino-main/src/main/java/io/trino/cost/CostCalculatorUsingExchanges.java index b2802118c746..1385a30b63f4 100644 --- a/core/trino-main/src/main/java/io/trino/cost/CostCalculatorUsingExchanges.java +++ b/core/trino-main/src/main/java/io/trino/cost/CostCalculatorUsingExchanges.java @@ -22,6 +22,7 @@ import io.trino.sql.planner.iterative.GroupReference; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AssignUniqueId; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; @@ -108,6 +109,15 @@ public PlanCostEstimate visitGroupReference(GroupReference node, Void context) throw new UnsupportedOperationException(); } + @Override + public PlanCostEstimate visitChooseAlternativeNode(ChooseAlternativeNode node, Void context) + { + // It's unknown which alternatives will get executed. The first alternative should be + // the most pessimistic one and therefore probably incurs the maximal cost. + // Note that there is no local cost of ChooseAlternativeNode, because there is no actual execution for it. + return sourcesCosts.getCost(node.getSources().get(0)); + } + @Override public PlanCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java index 35ae05a43890..467e082cbd40 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java @@ -81,6 +81,7 @@ public List> get() rules.add(new RowNumberStatsRule(normalizer)); rules.add(new SampleStatsRule(normalizer)); rules.add(new SortStatsRule()); + rules.add(new ChooseAlternativeRule(normalizer)); rules.add(new DynamicFilterSourceStatsRule()); rules.add(new RemoteSourceStatsRule(normalizer)); rules.add(new TopNRankingStatsRule(normalizer)); diff --git a/core/trino-main/src/main/java/io/trino/execution/DistributionSnapshot.java b/core/trino-main/src/main/java/io/trino/execution/DistributionSnapshot.java index 5323c5f08a54..4820b88b7a8a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DistributionSnapshot.java +++ b/core/trino-main/src/main/java/io/trino/execution/DistributionSnapshot.java @@ -50,6 +50,7 @@ public static OperatorStats pruneOperatorStats(OperatorStats operatorStats) return new OperatorStats( operatorStats.getStageId(), operatorStats.getPipelineId(), + operatorStats.getAlternativeId(), operatorStats.getOperatorId(), operatorStats.getPlanNodeId(), operatorStats.getOperatorType(), diff --git a/core/trino-main/src/main/java/io/trino/execution/MemoryTrackingRemoteTaskFactory.java b/core/trino-main/src/main/java/io/trino/execution/MemoryTrackingRemoteTaskFactory.java index 8c645315a9b3..266e99affdba 100644 --- a/core/trino-main/src/main/java/io/trino/execution/MemoryTrackingRemoteTaskFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/MemoryTrackingRemoteTaskFactory.java @@ -17,6 +17,7 @@ import io.airlift.units.DataSize; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.OutputBuffers; @@ -56,7 +57,8 @@ public RemoteTask createRemoteTask( PartitionedSplitCountTracker partitionedSplitCountTracker, Set outboundDynamicFilterIds, Optional estimatedMemory, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { RemoteTask task = remoteTaskFactory.createRemoteTask( session, @@ -70,7 +72,8 @@ public RemoteTask createRemoteTask( partitionedSplitCountTracker, outboundDynamicFilterIds, estimatedMemory, - summarizeTaskInfo); + summarizeTaskInfo, + splitAdmissionControllerProvider); task.addStateChangeListener(new UpdatePeakMemory(stateMachine)); return task; diff --git a/core/trino-main/src/main/java/io/trino/execution/RemoteTaskFactory.java b/core/trino-main/src/main/java/io/trino/execution/RemoteTaskFactory.java index 4b232c9e377d..3ea4abd1f612 100644 --- a/core/trino-main/src/main/java/io/trino/execution/RemoteTaskFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/RemoteTaskFactory.java @@ -17,6 +17,7 @@ import io.airlift.units.DataSize; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; import io.trino.execution.buffer.OutputBuffers; import io.trino.metadata.InternalNode; @@ -42,5 +43,6 @@ RemoteTask createRemoteTask( PartitionedSplitCountTracker partitionedSplitCountTracker, Set outboundDynamicFilterIds, Optional estimatedMemory, - boolean summarizeTaskInfo); + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider); } diff --git a/core/trino-main/src/main/java/io/trino/execution/SplitRunner.java b/core/trino-main/src/main/java/io/trino/execution/SplitRunner.java index 5b12171230ec..83550f49c59d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SplitRunner.java +++ b/core/trino-main/src/main/java/io/trino/execution/SplitRunner.java @@ -16,8 +16,10 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; +import io.trino.spi.cache.CacheSplitId; import java.io.Closeable; +import java.util.Optional; public interface SplitRunner extends Closeable @@ -32,6 +34,8 @@ public interface SplitRunner String getInfo(); + Optional getCacheSplitId(); + @Override void close(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java index e9fc952953dd..ac047fe637b7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -21,6 +21,7 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.scheduler.SplitSchedulerStats; @@ -62,6 +63,7 @@ public final class SqlStage private final RemoteTaskFactory remoteTaskFactory; private final NodeTaskMap nodeTaskMap; private final boolean summarizeTaskInfo; + private final SplitAdmissionControllerProvider splitAdmissionControllerProvider; private final Set outboundDynamicFilterIds; @@ -84,7 +86,8 @@ public static SqlStage createSqlStage( Executor stateMachineExecutor, Tracer tracer, Span schedulerSpan, - SplitSchedulerStats schedulerStats) + SplitSchedulerStats schedulerStats, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { requireNonNull(stageId, "stageId is null"); requireNonNull(fragment, "fragment is null"); @@ -111,7 +114,8 @@ public static SqlStage createSqlStage( stateMachine, remoteTaskFactory, nodeTaskMap, - summarizeTaskInfo); + summarizeTaskInfo, + splitAdmissionControllerProvider); sqlStage.initialize(); return sqlStage; } @@ -121,13 +125,15 @@ private SqlStage( StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { this.session = requireNonNull(session, "session is null"); this.stateMachine = stateMachine; this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.summarizeTaskInfo = summarizeTaskInfo; + this.splitAdmissionControllerProvider = requireNonNull(splitAdmissionControllerProvider, "splitAdmissionControllerProvider is null"); this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment()); } @@ -268,7 +274,8 @@ public synchronized Optional createTask( nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), outboundDynamicFilterIds, estimatedMemory, - summarizeTaskInfo); + summarizeTaskInfo, + splitAdmissionControllerProvider); noMoreSplits.forEach(task::noMoreSplits); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index d1295bcbabcb..885896b40403 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -42,6 +42,7 @@ import io.trino.operator.TaskContext; import io.trino.spi.SplitWeight; import io.trino.spi.TrinoException; +import io.trino.spi.cache.CacheSplitId; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.tracing.TrinoAttributes; @@ -664,7 +665,7 @@ public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit } Driver driver; try { - driver = driverFactory.createDriver(driverContext); + driver = driverFactory.createDriver(driverContext, Optional.ofNullable(partitionedSplit)); Span.fromContext(Context.current()).addEvent("driver-created"); } catch (Throwable t) { @@ -896,6 +897,14 @@ public String getInfo() return (partitionedSplit == null) ? "" : formatSplitInfo(partitionedSplit.getSplit()); } + @Override + public Optional getCacheSplitId() + { + return Optional.ofNullable(partitionedSplit) + .map(ScheduledSplit::getSplit) + .flatMap(Split::getCacheSplitId); + } + @Override public void close() { diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java index 3c801f9bfe56..2eb384602e66 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java @@ -35,7 +35,7 @@ import io.trino.sql.planner.plan.PlanNodeId; import io.trino.tracing.TrinoAttributes; import io.trino.util.Failures; -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import org.joda.time.DateTime; import java.util.ArrayList; @@ -49,7 +49,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.LongFunction; +import java.util.function.Function; import java.util.function.Supplier; import static com.google.common.base.MoreObjects.toStringHelper; @@ -706,23 +706,21 @@ public BasicStageInfo getBasicStageInfo(Supplier> taskInfosSu private static List combineTaskOperatorSummaries(List taskInfos, int maxTaskOperatorSummaries) { - // Group each unique pipelineId + operatorId combination into lists - Long2ObjectOpenHashMap> pipelineAndOperatorToStats = new Long2ObjectOpenHashMap<>(maxTaskOperatorSummaries); + // Group each unique operatorId + pipelineId + alternativeId combination into lists + Object2ObjectOpenHashMap> pipelineAndOperatorToStats = new Object2ObjectOpenHashMap<>(maxTaskOperatorSummaries); // Expect to have one operator stats entry for each taskInfo int taskInfoCount = taskInfos.size(); - LongFunction> statsListCreator = key -> new ArrayList<>(taskInfoCount); + Function> statsListCreator = key -> new ArrayList<>(taskInfoCount); for (TaskInfo taskInfo : taskInfos) { for (PipelineStats pipeline : taskInfo.stats().getPipelines()) { - // Place the pipelineId in the high bits of the combinedKey mask - long pipelineKeyMask = Integer.toUnsignedLong(pipeline.getPipelineId()) << 32; for (OperatorStats operator : pipeline.getOperatorSummaries()) { - // Place the operatorId into the low bits of the combined key - long combinedKey = pipelineKeyMask | Integer.toUnsignedLong(operator.getOperatorId()); + // Place operatorId, pipelineId and alternativeId in the combined key + String combinedKey = pipeline.getPipelineId() + "." + operator.getOperatorId() + "." + operator.getAlternativeId(); pipelineAndOperatorToStats.computeIfAbsent(combinedKey, statsListCreator).add(operator); } } } - // Merge the list of operator stats from each pipelineId + operatorId into a single entry + // Merge the list of operator stats from each operatorId + pipelineId + alternativeId into a single entry ImmutableList.Builder operatorStatsBuilder = ImmutableList.builderWithExpectedSize(pipelineAndOperatorToStats.size()); for (List operators : pipelineAndOperatorToStats.values()) { OperatorStats stats = operators.get(0); diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/PrioritizedSplitRunner.java b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/PrioritizedSplitRunner.java index 32acd8e96847..889cadf5f700 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/PrioritizedSplitRunner.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/PrioritizedSplitRunner.java @@ -25,8 +25,10 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.trino.execution.SplitRunner; +import io.trino.spi.cache.CacheSplitId; import io.trino.tracing.TrinoAttributes; +import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -285,6 +287,11 @@ public String getInfo() processCalls.get()); } + public Optional getCacheSplitId() + { + return split.getCacheSplitId(); + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java index 7d7d31a50b4a..5e9d25c1b8e0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java @@ -39,18 +39,23 @@ import io.trino.spi.TrinoException; import io.trino.spi.VersionEmbedder; import io.trino.tracing.TrinoAttributes; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; import jakarta.annotation.PostConstruct; import jakarta.annotation.PreDestroy; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.OptionalInt; +import java.util.Queue; import java.util.Set; import java.util.SortedSet; import java.util.concurrent.ConcurrentHashMap; @@ -133,6 +138,14 @@ public class TimeSharingTaskExecutor */ private final Map> blockedSplits = new ConcurrentHashMap<>(); + /** + * CacheSplitIds for splits that are currently registered with the task executor. We use the hash code + * of the split id as the key to speed up the lookup during split scheduling. This is especially needed + * when there are a large number of super short splits since they are cached. + */ + @GuardedBy("this") + private final IntSet runningCacheSplitIds = new IntOpenHashSet(); + private final AtomicLongArray completedTasksPerLevel = new AtomicLongArray(5); private final AtomicLongArray completedSplitsPerLevel = new AtomicLongArray(5); @@ -329,6 +342,12 @@ private boolean doRemoveTask(TimeSharingTaskHandle taskHandle) intermediateSplits.removeAll(splits); blockedSplits.keySet().removeAll(splits); waitingSplits.removeAll(splits); + // Remove running cache split ids + splits.stream() + .map(PrioritizedSplitRunner::getCacheSplitId) + .filter(Optional::isPresent) + .map(Optional::get) + .forEach(id -> runningCacheSplitIds.remove(id.hashCode())); recordLeafSplitsSize(); } @@ -415,6 +434,7 @@ private void splitFinished(PrioritizedSplitRunner split) completedSplitsPerLevel.incrementAndGet(split.getPriority().getLevel()); synchronized (this) { allSplits.remove(split); + split.getCacheSplitId().ifPresent(id -> runningCacheSplitIds.remove(id.hashCode())); long wallNanos = System.nanoTime() - split.getCreatedNanos(); splitWallTime.add(Duration.succinctNanos(wallNanos)); @@ -445,43 +465,94 @@ private void splitFinished(PrioritizedSplitRunner split) } private synchronized void scheduleTaskIfNecessary(TimeSharingTaskHandle taskHandle) + { + int scheduledSplits = scheduleTaskIfNecessary(0, taskHandle, runningCacheSplitIds); + // If we have less than the minimum number of drivers running per task, force start some splits even if + // they are not cached yet. This is such that system always have some drivers running to make progress. + scheduleTaskIfNecessary(scheduledSplits, taskHandle, IntSet.of()); + recordLeafSplitsSize(); + } + + private synchronized int scheduleTaskIfNecessary( + int scheduledSplits, + TimeSharingTaskHandle taskHandle, + IntSet runningCacheSplitIds) { // if task has less than the minimum guaranteed splits running, // immediately schedule new splits for this task. This assures // that a task gets its fair amount of consideration (you have to // have splits to be considered for running on a thread). - int splitsToSchedule = min(guaranteedNumberOfDriversPerTask, taskHandle.getMaxDriversPerTask().orElse(Integer.MAX_VALUE)) - taskHandle.getRunningLeafSplits(); - for (int i = 0; i < splitsToSchedule; ++i) { + int splitsToSchedule = min(guaranteedNumberOfDriversPerTask, + taskHandle.getMaxDriversPerTask().orElse(Integer.MAX_VALUE)) - taskHandle.getRunningLeafSplits(); + + Queue unscheduledSplits = new ArrayDeque<>(); + while (scheduledSplits < splitsToSchedule) { PrioritizedSplitRunner split = taskHandle.pollNextSplit(); if (split == null) { // no more splits to schedule - return; + break; } - startSplit(split); - splitQueuedTime.add(Duration.nanosSince(split.getCreatedNanos())); + // Do not start the split if there's already a split with the same cache id running. This is done to + // improve cache utilization. + if (split.getCacheSplitId().isPresent() + && runningCacheSplitIds.contains(split.getCacheSplitId().get().hashCode())) { + unscheduledSplits.add(split); + continue; + } + + startLeafSplit(split); + scheduledSplits++; } - recordLeafSplitsSize(); + // put back unscheduled splits + for (PrioritizedSplitRunner split : unscheduledSplits) { + taskHandle.enqueueSplit(split); + } + + return scheduledSplits; } private synchronized void addNewEntrants() { // Ignore intermediate splits when checking minimumNumberOfDrivers. - // Otherwise with (for example) minimumNumberOfDrivers = 100, 200 intermediate splits + // Otherwise, with (for example) minimumNumberOfDrivers = 100, 200 intermediate splits // and 100 leaf splits, depending on order of appearing splits, number of // simultaneously running splits may vary. If leaf splits start first, there will // be 300 running splits. If intermediate splits start first, there will be only // 200 running splits. - int running = allSplits.size() - intermediateSplits.size(); - for (int i = 0; i < minimumNumberOfDrivers - running; i++) { + int runningSplits = allSplits.size() - intermediateSplits.size(); + + runningSplits = addNewEntrants(runningSplits, runningCacheSplitIds); + // If we have less than the minimum number of drivers running, force start some splits even if + // they are not cached yet. This is such that system always have some drivers running to make progress. + addNewEntrants(runningSplits, IntSet.of()); + } + + private synchronized int addNewEntrants(int runningSplits, IntSet runningCacheSplitIds) + { + Queue unscheduledSplits = new ArrayDeque<>(); + while (runningSplits < minimumNumberOfDrivers) { PrioritizedSplitRunner split = pollNextSplitWorker(); if (split == null) { break; } - splitQueuedTime.add(Duration.nanosSince(split.getCreatedNanos())); - startSplit(split); + // Do not start the split if there's already a split with the same cache id running. This is done to + // improve cache utilization. + if (split.getCacheSplitId().isPresent() + && runningCacheSplitIds.contains(split.getCacheSplitId().get().hashCode())) { + unscheduledSplits.add(split); + continue; + } + + startLeafSplit(split); + runningSplits++; } + // put back unscheduled splits + for (PrioritizedSplitRunner split : unscheduledSplits) { + split.getTaskHandle().enqueueSplit(split); + } + return runningSplits; } private synchronized void startIntermediateSplit(PrioritizedSplitRunner split) @@ -490,9 +561,17 @@ private synchronized void startIntermediateSplit(PrioritizedSplitRunner split) intermediateSplits.add(split); } + private synchronized void startLeafSplit(PrioritizedSplitRunner split) + { + split.getTaskHandle().splitStarted(split); + startSplit(split); + splitQueuedTime.add(Duration.nanosSince(split.getCreatedNanos())); + } + private synchronized void startSplit(PrioritizedSplitRunner split) { allSplits.add(split); + split.getCacheSplitId().ifPresent(splitId -> runningCacheSplitIds.add(splitId.hashCode())); waitingSplits.offer(split); } diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskHandle.java b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskHandle.java index b6b86bb3e89d..d278c598cdd8 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskHandle.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskHandle.java @@ -173,11 +173,12 @@ public synchronized PrioritizedSplitRunner pollNextSplit() return null; } - PrioritizedSplitRunner split = queuedLeafSplits.poll(); - if (split != null) { - runningLeafSplits.add(split); - } - return split; + return queuedLeafSplits.poll(); + } + + public synchronized void splitStarted(PrioritizedSplitRunner split) + { + runningLeafSplits.add(split); } public synchronized void splitComplete(PrioritizedSplitRunner split) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java index 7f119afc89cc..9d8048b22c8c 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Sets; +import com.google.common.graph.Traverser; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; @@ -30,6 +31,7 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.exchange.DirectExchangeInput; import io.trino.execution.BasicStageInfo; import io.trino.execution.BasicStageStats; @@ -98,6 +100,7 @@ import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Stream; +import java.util.stream.StreamSupport; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -183,6 +186,8 @@ public class PipelinedQueryScheduler private final double retryDelayScaleFactor; private final Span schedulerSpan; + private final SplitAdmissionControllerProvider splitAdmissionControllerProvider; + @GuardedBy("this") private boolean started; @@ -229,6 +234,7 @@ public PipelinedQueryScheduler( .setAttribute(TrinoAttributes.QUERY_ID, queryStateMachine.getQueryId().toString()) .startSpan(); + splitAdmissionControllerProvider = createSplitAdmissionControllerProvider(queryStateMachine.getSession(), plan); stageManager = StageManager.create( queryStateMachine, metadata, @@ -238,7 +244,8 @@ public PipelinedQueryScheduler( schedulerSpan, schedulerStats, plan, - summarizeTaskInfo); + summarizeTaskInfo, + splitAdmissionControllerProvider); coordinatorStagesScheduler = CoordinatorStagesScheduler.create( queryStateMachine, @@ -257,6 +264,17 @@ public PipelinedQueryScheduler( retryDelayScaleFactor = getRetryDelayScaleFactor(queryStateMachine.getSession()); } + private static SplitAdmissionControllerProvider createSplitAdmissionControllerProvider( + Session session, + SubPlan planTree) + { + Iterable iterable = Traverser.forTree(SubPlan::getChildren).breadthFirst(planTree); + List planFragments = StreamSupport.stream(iterable.spliterator(), false) + .map(SubPlan::getFragment) + .collect(toImmutableList()); + return new SplitAdmissionControllerProvider(planFragments, session); + } + @Override public synchronized void start() { @@ -331,6 +349,7 @@ private synchronized Optional createDistributedStage splitBatchSize, dynamicFilterService, tableExecuteContextManager, + splitAdmissionControllerProvider, retryPolicy, attempt); } @@ -861,6 +880,7 @@ public static DistributedStagesScheduler create( int splitBatchSize, DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, + SplitAdmissionControllerProvider splitAdmissionControllerProvider, RetryPolicy retryPolicy, int attempt) { @@ -938,7 +958,8 @@ public static DistributedStagesScheduler create( splitBatchSize, dynamicFilterService, executor, - tableExecuteContextManager); + tableExecuteContextManager, + splitAdmissionControllerProvider); stageSchedulers.put(stageExecution.getStageId(), scheduler); } @@ -1044,14 +1065,15 @@ private static StageScheduler createStageScheduler( int splitBatchSize, DynamicFilterService dynamicFilterService, ScheduledExecutorService executor, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { Session session = queryStateMachine.getSession(); Span stageSpan = stageExecution.getStageSpan(); PlanFragment fragment = stageExecution.getFragment(); PartitioningHandle partitioningHandle = fragment.getPartitioning(); Optional partitionCount = fragment.getPartitionCount(); - Map splitSources = splitSourceFactory.createSplitSources(session, stageSpan, fragment); + Map splitSources = splitSourceFactory.createSplitSources(session, stageSpan, fragment, splitAdmissionControllerProvider); if (!splitSources.isEmpty()) { queryStateMachine.addStateChangeListener(new StateChangeListener<>() { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java index 0034c2713629..dfbfd709ac56 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java @@ -20,6 +20,7 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.BasicStageInfo; import io.trino.execution.BasicStageStats; import io.trino.execution.NodeTaskMap; @@ -72,7 +73,8 @@ static StageManager create( Span schedulerSpan, SplitSchedulerStats schedulerStats, SubPlan planTree, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { Session session = queryStateMachine.getSession(); ImmutableMap.Builder stages = ImmutableMap.builder(); @@ -95,7 +97,8 @@ static StageManager create( queryStateMachine.getStateMachineExecutor(), tracer, schedulerSpan, - schedulerStats); + schedulerStats, + splitAdmissionControllerProvider); StageId stageId = stage.getStageId(); stages.put(stageId, stage); stagesInTopologicalOrder.add(stage); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java index ed1113a939c2..ea42961f1124 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java @@ -42,6 +42,7 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.cost.RuntimeInfoProvider; import io.trino.cost.StaticRuntimeInfoProvider; import io.trino.exchange.ExchangeContextInstance; @@ -741,6 +742,8 @@ private static class Scheduler private boolean queryOutputSet; + private final SplitAdmissionControllerProvider splitAdmissionControllerProvider; + public Scheduler( QueryStateMachine queryStateMachine, Metadata metadata, @@ -821,6 +824,9 @@ public Scheduler( } planInTopologicalOrder = sortPlanInTopologicalOrder(plan); + splitAdmissionControllerProvider = new SplitAdmissionControllerProvider( + planInTopologicalOrder.stream().map(SubPlan::getFragment).collect(toImmutableList()), + queryStateMachine.getSession()); noEventsStopwatch.start(); } @@ -1414,7 +1420,8 @@ private void createStageExecution( queryStateMachine.getStateMachineExecutor(), tracer, schedulerSpan, - schedulerStats); + schedulerStats, + splitAdmissionControllerProvider); closer.register(stage::abort); stageRegistry.add(stage); stage.addFinalStageInfoListener(_ -> queryStateMachine.updateQueryInfo(Optional.ofNullable(stageRegistry.getStageInfo()))); @@ -1452,7 +1459,8 @@ private void createStageExecution( sourceExchanges, partitioningSchemeFactory.get(fragment.getPartitioning(), fragment.getPartitionCount()), stage::recordGetSplitTime, - outputDataSizeEstimates.buildOrThrow())); + outputDataSizeEstimates.buildOrThrow(), + splitAdmissionControllerProvider)); FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get( fragment.getOutputPartitioningScheme().getPartitioning().getHandle(), diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenTaskSourceFactory.java index 452893ce6279..801097a18d7a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenTaskSourceFactory.java @@ -18,6 +18,7 @@ import com.google.inject.Inject; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.ForQueryExecution; import io.trino.execution.QueryManagerConfig; import io.trino.execution.TableExecuteContextManager; @@ -117,7 +118,8 @@ public EventDrivenTaskSource create( Map sourceExchanges, FaultTolerantPartitioningScheme sourcePartitioningScheme, LongConsumer getSplitTimeRecorder, - Map outputDataSizeEstimates) + Map outputDataSizeEstimates, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { ImmutableSetMultimap.Builder remoteSources = ImmutableSetMultimap.builder(); for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) { @@ -132,7 +134,7 @@ public EventDrivenTaskSource create( tableExecuteContextManager, sourceExchanges, remoteSources.build(), - () -> splitSourceFactory.createSplitSources(session, stageSpan, fragment), + () -> splitSourceFactory.createSplitSources(session, stageSpan, fragment, splitAdmissionControllerProvider), createSplitAssigner( session, fragment, diff --git a/core/trino-main/src/main/java/io/trino/metadata/Split.java b/core/trino-main/src/main/java/io/trino/metadata/Split.java index 79851dde5630..4e39e6f17b5e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Split.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Split.java @@ -19,15 +19,19 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.HostAddress; import io.trino.spi.SplitWeight; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorSplit; import java.util.List; import java.util.Map; +import java.util.Optional; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.slice.SizeOf.estimatedSizeOf; import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; import static java.util.Objects.requireNonNull; public final class Split @@ -36,14 +40,40 @@ public final class Split private final CatalogHandle catalogHandle; private final ConnectorSplit connectorSplit; + private final Optional cacheSplitId; + private final Optional> addresses; + /** + * true if the split is executed on its preferred node (from its {@link ConnectorSplit#getAddresses()}. + */ + private final boolean splitAddressEnforced; + + public Split(CatalogHandle catalogHandle, ConnectorSplit connectorSplit) + { + this(catalogHandle, connectorSplit, Optional.empty(), Optional.empty(), false); + } @JsonCreator public Split( @JsonProperty("catalogHandle") CatalogHandle catalogHandle, - @JsonProperty("connectorSplit") ConnectorSplit connectorSplit) + @JsonProperty("connectorSplit") ConnectorSplit connectorSplit, + @JsonProperty("cacheSplitId") Optional cacheSplitId, + @JsonProperty("splitAddressEnforced") boolean splitAddressEnforced) + { + this(catalogHandle, connectorSplit, cacheSplitId, Optional.empty(), splitAddressEnforced); + } + + public Split( + CatalogHandle catalogHandle, + ConnectorSplit connectorSplit, + Optional cacheSplitId, + Optional> addresses, + boolean splitAddressEnforced) { this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); this.connectorSplit = requireNonNull(connectorSplit, "connectorSplit is null"); + this.cacheSplitId = requireNonNull(cacheSplitId, "cacheSplitId is null"); + this.addresses = requireNonNull(addresses, "addresses is null"); + this.splitAddressEnforced = splitAddressEnforced; } @JsonProperty @@ -58,33 +88,57 @@ public ConnectorSplit getConnectorSplit() return connectorSplit; } + @JsonProperty + public Optional getCacheSplitId() + { + return cacheSplitId; + } + @JsonIgnore public Map getInfo() { return firstNonNull(connectorSplit.getSplitInfo(), ImmutableMap.of()); } + // do not serialize addresses as they are not needed on workers + @JsonIgnore public List getAddresses() { - return connectorSplit.getAddresses(); + return addresses.orElse(connectorSplit.getAddresses()); } + // do not serialize remotelyAccessible flag as it is not needed on workers + @JsonIgnore public boolean isRemotelyAccessible() { return connectorSplit.isRemotelyAccessible(); } + @JsonProperty + public boolean isSplitAddressEnforced() + { + return splitAddressEnforced; + } + public SplitWeight getSplitWeight() { return connectorSplit.getSplitWeight(); } + public Split withSplitAddressEnforced(boolean splitAddressEnforced) + { + return new Split(this.catalogHandle, this.connectorSplit, this.cacheSplitId, this.addresses, splitAddressEnforced); + } + @Override public String toString() { return toStringHelper(this) .add("catalogHandle", catalogHandle) .add("connectorSplit", connectorSplit) + .add("cacheSplitId", cacheSplitId) + .add("addresses", addresses) + .add("splitAddressEnforced", splitAddressEnforced) .toString(); } @@ -92,6 +146,9 @@ public long getRetainedSizeInBytes() { return INSTANCE_SIZE + catalogHandle.getRetainedSizeInBytes() - + connectorSplit.getRetainedSizeInBytes(); + + connectorSplit.getRetainedSizeInBytes() + + sizeOf(cacheSplitId, CacheSplitId::getRetainedSizeInBytes) + + sizeOf(addresses, value -> estimatedSizeOf(value, HostAddress::getRetainedSizeInBytes)) + + sizeOf(splitAddressEnforced); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java index e087c609bb39..e7402b57aa6a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java @@ -20,6 +20,7 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; +import io.trino.cache.CacheDriverContext; import io.trino.execution.TaskId; import io.trino.memory.QueryContextVisitor; import io.trino.memory.context.MemoryTrackingContext; @@ -37,6 +38,7 @@ import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getFirst; import static com.google.common.collect.Iterables.getLast; @@ -53,6 +55,7 @@ public class DriverContext { private final PipelineContext pipelineContext; + private final AtomicReference alternativeId = new AtomicReference<>(); private final Executor notificationExecutor; private final ScheduledExecutorService yieldExecutor; private final ScheduledExecutorService timeoutExecutor; @@ -80,6 +83,7 @@ public class DriverContext private final List operatorContexts = new CopyOnWriteArrayList<>(); private final long splitWeight; + private final AtomicReference> cacheDriverContext = new AtomicReference<>(Optional.empty()); public DriverContext( PipelineContext pipelineContext, @@ -138,6 +142,17 @@ public PipelineContext getPipelineContext() return pipelineContext; } + public void setAlternativeId(int alternativeId) + { + checkState(this.alternativeId.get() == null, "alternativeId is already set"); + this.alternativeId.set(alternativeId); + } + + public int getAlternativeId() + { + return Optional.ofNullable(alternativeId.get()).orElse(0); + } + public Session getSession() { return pipelineContext.getSession(); @@ -445,6 +460,18 @@ public List acceptChildren(QueryContextVisitor visitor, C contex .collect(toList()); } + public Optional getCacheDriverContext() + { + return cacheDriverContext.get(); + } + + public void setCacheDriverContext(CacheDriverContext cacheDriverContext) + { + if (!this.cacheDriverContext.compareAndSet(Optional.empty(), Optional.of(cacheDriverContext))) { + throw new IllegalStateException("CacheDriverContext is already set"); + } + } + public ScheduledExecutorService getYieldExecutor() { return yieldExecutor; diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java index 3456c5800097..1b65974436d2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java @@ -13,146 +13,34 @@ */ package io.trino.operator; -import com.google.common.collect.ImmutableList; -import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.trino.execution.ScheduledSplit; import io.trino.sql.planner.plan.PlanNodeId; -import jakarta.annotation.Nullable; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import java.util.OptionalInt; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.Objects.requireNonNull; - -public class DriverFactory +public interface DriverFactory { - private final int pipelineId; - private final boolean inputDriver; - private final boolean outputDriver; - private final Optional sourceId; - private final OptionalInt driverInstances; + int getPipelineId(); + + boolean isInputDriver(); - // must synchronize between createDriver() and noMoreDrivers(), but isNoMoreDrivers() is safe without synchronizing - @GuardedBy("this") - private volatile boolean noMoreDrivers; - private volatile List operatorFactories; + boolean isOutputDriver(); - public DriverFactory(int pipelineId, boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances) - { - this.pipelineId = pipelineId; - this.inputDriver = inputDriver; - this.outputDriver = outputDriver; - this.operatorFactories = ImmutableList.copyOf(requireNonNull(operatorFactories, "operatorFactories is null")); - checkArgument(!operatorFactories.isEmpty(), "There must be at least one operator"); - this.driverInstances = requireNonNull(driverInstances, "driverInstances is null"); + OptionalInt getDriverInstances(); - List sourceIds = operatorFactories.stream() - .filter(SourceOperatorFactory.class::isInstance) - .map(SourceOperatorFactory.class::cast) - .map(SourceOperatorFactory::getSourceId) - .collect(toImmutableList()); - checkArgument(sourceIds.size() <= 1, "Expected at most one source operator in driver factory, but found %s", sourceIds); - this.sourceId = sourceIds.isEmpty() ? Optional.empty() : Optional.of(sourceIds.get(0)); - } + Driver createDriver(DriverContext driverContext, Optional split); - public int getPipelineId() - { - return pipelineId; - } + void noMoreDrivers(); - public boolean isInputDriver() - { - return inputDriver; - } + boolean isNoMoreDrivers(); - public boolean isOutputDriver() - { - return outputDriver; - } + void localPlannerComplete(); /** * return the sourceId of this DriverFactory. * A DriverFactory doesn't always have source node. * For example, ValuesNode is not a source node. */ - public Optional getSourceId() - { - return sourceId; - } - - public OptionalInt getDriverInstances() - { - return driverInstances; - } - - @Nullable - public List getOperatorFactories() - { - return operatorFactories; - } - - public Driver createDriver(DriverContext driverContext) - { - requireNonNull(driverContext, "driverContext is null"); - List operators = new ArrayList<>(operatorFactories.size()); - try { - synchronized (this) { - // must check noMoreDrivers after acquiring the lock - checkState(!noMoreDrivers, "noMoreDrivers is already set"); - for (OperatorFactory operatorFactory : operatorFactories) { - Operator operator = operatorFactory.createOperator(driverContext); - operators.add(operator); - } - } - // Driver creation can continue without holding the lock - return Driver.createDriver(driverContext, operators); - } - catch (Throwable failure) { - for (Operator operator : operators) { - try { - operator.close(); - } - catch (Throwable closeFailure) { - if (failure != closeFailure) { - failure.addSuppressed(closeFailure); - } - } - } - for (OperatorContext operatorContext : driverContext.getOperatorContexts()) { - try { - operatorContext.destroy(); - } - catch (Throwable destroyFailure) { - if (failure != destroyFailure) { - failure.addSuppressed(destroyFailure); - } - } - } - driverContext.failed(failure); - throw failure; - } - } - - public synchronized void noMoreDrivers() - { - if (noMoreDrivers) { - return; - } - for (OperatorFactory operatorFactory : operatorFactories) { - operatorFactory.noMoreOperators(); - } - operatorFactories = null; - noMoreDrivers = true; - } - - // no need to synchronize when just checking the boolean flag - @SuppressWarnings("GuardedBy") - public boolean isNoMoreDrivers() - { - return noMoreDrivers; - } + Optional getSourceId(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorContext.java b/core/trino-main/src/main/java/io/trino/operator/OperatorContext.java index 6b268c5849a0..b1a63e714bfb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/OperatorContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorContext.java @@ -243,7 +243,7 @@ public void setLatestConnectorMetrics(Metrics metrics) public void setPipelineOperatorMetrics(Metrics metrics) { - getDriverContext().getPipelineContext().setPipelineOperatorMetrics(operatorId, metrics); + getDriverContext().getPipelineContext().setPipelineOperatorMetrics(operatorId, driverContext.getAlternativeId(), metrics); } Optional> getFinishedFuture() @@ -530,6 +530,7 @@ public OperatorStats getOperatorStats() return new OperatorStats( driverContext.getTaskId().getStageId().getId(), driverContext.getPipelineContext().getPipelineId(), + driverContext.getAlternativeId(), operatorId, planNodeId, operatorType, diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorDriverFactory.java b/core/trino-main/src/main/java/io/trino/operator/OperatorDriverFactory.java new file mode 100644 index 000000000000..205e31047344 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorDriverFactory.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.trino.execution.ScheduledSplit; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * {@link DriverFactory} that has predefined list of {@link OperatorFactory}ies that does not depend on a particular split. + */ +public class OperatorDriverFactory + implements DriverFactory +{ + private final int pipelineId; + private final boolean inputDriver; + private final boolean outputDriver; + private final Optional sourceId; + private final OptionalInt driverInstances; + + // must synchronize between createDriver() and noMoreDrivers(), but isNoMoreDrivers() is safe without synchronizing + @GuardedBy("this") + private volatile boolean noMoreDrivers; + private volatile List operatorFactories; + + public OperatorDriverFactory(int pipelineId, boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances) + { + this.pipelineId = pipelineId; + this.inputDriver = inputDriver; + this.outputDriver = outputDriver; + this.operatorFactories = ImmutableList.copyOf(requireNonNull(operatorFactories, "operatorFactories is null")); + checkArgument(!operatorFactories.isEmpty(), "There must be at least one operator"); + this.driverInstances = requireNonNull(driverInstances, "driverInstances is null"); + + List sourceIds = operatorFactories.stream() + .filter(SourceOperatorFactory.class::isInstance) + .map(SourceOperatorFactory.class::cast) + .map(SourceOperatorFactory::getSourceId) + .collect(toImmutableList()); + checkArgument(sourceIds.size() <= 1, "Expected at most one source operator in driver factory, but found %s", sourceIds); + this.sourceId = sourceIds.isEmpty() ? Optional.empty() : Optional.of(sourceIds.get(0)); + } + + @Override + public int getPipelineId() + { + return pipelineId; + } + + @Override + public boolean isInputDriver() + { + return inputDriver; + } + + @Override + public boolean isOutputDriver() + { + return outputDriver; + } + + /** + * return the sourceId of this DriverFactory. + * A DriverFactory doesn't always have source node. + * For example, ValuesNode is not a source node. + */ + @Override + public Optional getSourceId() + { + return sourceId; + } + + @Override + public OptionalInt getDriverInstances() + { + return driverInstances; + } + + @Override + public Driver createDriver(DriverContext driverContext, Optional split) + { + requireNonNull(driverContext, "driverContext is null"); + List operators = new ArrayList<>(operatorFactories.size()); + try { + synchronized (this) { + // must check noMoreDrivers after acquiring the lock + checkState(!noMoreDrivers, "noMoreDrivers is already set"); + for (OperatorFactory operatorFactory : operatorFactories) { + Operator operator = operatorFactory.createOperator(driverContext); + operators.add(operator); + } + } + // Driver creation can continue without holding the lock + return Driver.createDriver(driverContext, operators); + } + catch (Throwable failure) { + for (Operator operator : operators) { + try { + operator.close(); + } + catch (Throwable closeFailure) { + if (failure != closeFailure) { + failure.addSuppressed(closeFailure); + } + } + } + for (OperatorContext operatorContext : driverContext.getOperatorContexts()) { + try { + operatorContext.destroy(); + } + catch (Throwable destroyFailure) { + if (failure != destroyFailure) { + failure.addSuppressed(destroyFailure); + } + } + } + driverContext.failed(failure); + throw failure; + } + } + + @Override + public synchronized void noMoreDrivers() + { + if (noMoreDrivers) { + return; + } + for (OperatorFactory operatorFactory : operatorFactories) { + operatorFactory.noMoreOperators(); + } + operatorFactories = null; + noMoreDrivers = true; + } + + @Override + // no need to synchronize when just checking the boolean flag + @SuppressWarnings("GuardedBy") + public boolean isNoMoreDrivers() + { + return noMoreDrivers; + } + + @Override + public void localPlannerComplete() + { + operatorFactories + .stream() + .filter(LocalPlannerAware.class::isInstance) + .map(LocalPlannerAware.class::cast) + .forEach(LocalPlannerAware::localPlannerComplete); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorStats.java b/core/trino-main/src/main/java/io/trino/operator/OperatorStats.java index a0ec50f4e85d..9b9d111a6eeb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/OperatorStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorStats.java @@ -38,6 +38,7 @@ public class OperatorStats { private final int stageId; private final int pipelineId; + private final int alternativeId; private final int operatorId; private final PlanNodeId planNodeId; private final String operatorType; @@ -93,6 +94,7 @@ public class OperatorStats public OperatorStats( @JsonProperty("stageId") int stageId, @JsonProperty("pipelineId") int pipelineId, + @JsonProperty("alternativeId") int alternativeId, @JsonProperty("operatorId") int operatorId, @JsonProperty("planNodeId") PlanNodeId planNodeId, @JsonProperty("operatorType") String operatorType, @@ -146,6 +148,7 @@ public OperatorStats( { this.stageId = stageId; this.pipelineId = pipelineId; + this.alternativeId = alternativeId; checkArgument(operatorId >= 0, "operatorId is negative"); this.operatorId = operatorId; @@ -214,6 +217,12 @@ public int getPipelineId() return pipelineId; } + @JsonProperty + public int getAlternativeId() + { + return alternativeId; + } + @JsonProperty public int getOperatorId() { @@ -513,8 +522,9 @@ private OperatorStats add(Iterable operators, Optional p ImmutableList.Builder operatorInfos = ImmutableList.builder(); for (OperatorStats operator : operators) { checkArgument(operator.getOperatorId() == operatorId, "Expected operatorId to be %s but was %s", operatorId, operator.getOperatorId()); + checkArgument(operator.getPipelineId() == pipelineId, "Expected pipelineId to be %s but was %s", pipelineId, operator.getPipelineId()); + checkArgument(operator.getAlternativeId() == alternativeId, "Expected alternativeId to be %s but was %s", alternativeId, operator.getAlternativeId()); checkArgument(operator.getOperatorType().equals(operatorType), "Expected operatorType to be %s but was %s", operatorType, operator.getOperatorType()); - totalDrivers += operator.totalDrivers; addInputCalls += operator.getAddInputCalls(); @@ -572,6 +582,7 @@ private OperatorStats add(Iterable operators, Optional p return new OperatorStats( stageId, pipelineId, + alternativeId, operatorId, planNodeId, operatorType, @@ -651,6 +662,7 @@ public OperatorStats summarize() return new OperatorStats( stageId, pipelineId, + alternativeId, operatorId, planNodeId, operatorType, @@ -696,6 +708,7 @@ public OperatorStats withPipelineMetrics(Metrics pipelineMetrics) return new OperatorStats( stageId, pipelineId, + alternativeId, operatorId, planNodeId, operatorType, diff --git a/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java b/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java index 9bc61bc88708..c3bc21c03c04 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java @@ -103,9 +103,9 @@ public class PipelineContext private final AtomicLong physicalWrittenDataSize = new AtomicLong(); - private final ConcurrentMap operatorSummaries = new ConcurrentHashMap<>(); + private final ConcurrentMap operatorSummaries = new ConcurrentHashMap<>(); // pre-merged metrics which are shared among instances of given operator within pipeline - private final ConcurrentMap pipelineOperatorMetrics = new ConcurrentHashMap<>(); + private final ConcurrentMap pipelineOperatorMetrics = new ConcurrentHashMap<>(); private final MemoryTrackingContext pipelineMemoryContext; @@ -182,9 +182,9 @@ public void splitsAdded(int count, long weightSum) } } - public void setPipelineOperatorMetrics(int operatorId, Metrics metrics) + public void setPipelineOperatorMetrics(int operatorId, int alternativeId, Metrics metrics) { - pipelineOperatorMetrics.put(operatorId, metrics); + pipelineOperatorMetrics.put(new AlternativeOperatorId(operatorId, alternativeId), metrics); } public void driverFinished(DriverContext driverContext) @@ -216,8 +216,9 @@ public void driverFinished(DriverContext driverContext) // merge the operator stats into the operator summary List operators = driverStats.getOperatorStats(); for (OperatorStats operator : operators) { - Metrics pipelineLevelMetrics = pipelineOperatorMetrics.getOrDefault(operator.getOperatorId(), Metrics.EMPTY); - operatorSummaries.merge(operator.getOperatorId(), operator, (first, second) -> first.addFillingPipelineMetrics(second, pipelineLevelMetrics)); + AlternativeOperatorId alternativeOperatorId = new AlternativeOperatorId(operator.getOperatorId(), operator.getAlternativeId()); + Metrics pipelineLevelMetrics = pipelineOperatorMetrics.getOrDefault(alternativeOperatorId, Metrics.EMPTY); + operatorSummaries.merge(alternativeOperatorId, operator, (first, second) -> first.addFillingPipelineMetrics(second, pipelineLevelMetrics)); } physicalInputDataSize.update(driverStats.getPhysicalInputDataSize().toBytes()); @@ -412,9 +413,9 @@ public PipelineStats getPipelineStats() boolean hasUnfinishedDrivers = false; boolean unfinishedDriversFullyBlocked = true; - TreeMap operatorSummaries = new TreeMap<>(this.operatorSummaries); + TreeMap operatorSummaries = new TreeMap<>(this.operatorSummaries); // Expect the same number of operators as existing summaries, with one operator per driver context in the resulting multimap - ListMultimap runningOperators = ArrayListMultimap.create(operatorSummaries.size(), driverContexts.size()); + ListMultimap runningOperators = ArrayListMultimap.create(operatorSummaries.size(), driverContexts.size()); ImmutableList.Builder drivers = ImmutableList.builderWithExpectedSize(driverContexts.size()); for (DriverContext driverContext : driverContexts) { DriverStats driverStats = driverContext.getDriverStats(); @@ -435,7 +436,7 @@ public PipelineStats getPipelineStats() totalBlockedTime += driverStats.getTotalBlockedTime().roundTo(NANOSECONDS); for (OperatorStats operatorStats : driverStats.getOperatorStats()) { - runningOperators.put(operatorStats.getOperatorId(), operatorStats); + runningOperators.put(new AlternativeOperatorId(operatorStats.getOperatorId(), operatorStats.getAlternativeId()), operatorStats); } physicalInputDataSize += driverStats.getPhysicalInputDataSize().toBytes(); @@ -462,7 +463,7 @@ public PipelineStats getPipelineStats() } // Computes the combined stats from existing completed operators and those still running - BiFunction combineOperatorStats = (operatorId, current) -> { + BiFunction combineOperatorStats = (operatorId, current) -> { List runningStats = runningOperators.get(operatorId); if (runningStats.isEmpty()) { return current; @@ -482,7 +483,7 @@ else if (pipelineLevelMetrics != Metrics.EMPTY) { return combined; } }; - for (Integer operatorId : runningOperators.keySet()) { + for (AlternativeOperatorId operatorId : runningOperators.keySet()) { operatorSummaries.compute(operatorId, combineOperatorStats); } @@ -671,4 +672,18 @@ public PipelineStatus build() return new PipelineStatus(queuedDrivers, runningDrivers, blockedDrivers, queuedPartitionedSplits, queuedPartitionedSplitsWeight, runningPartitionedSplits, runningPartitionedSplitsWeight); } } + + private record AlternativeOperatorId(int operatorId, int alternativeId) + implements Comparable + { + @Override + public int compareTo(AlternativeOperatorId o) + { + if (alternativeId != o.alternativeId) { + return alternativeId - o.alternativeId; + } + + return operatorId - o.operatorId; + } + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java b/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java index ba154d044e0b..59126ef5d409 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java @@ -57,6 +57,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.cache.CacheDriverContext.getDynamicFilter; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.operator.PageUtils.recordMaterializedBytes; import static io.trino.operator.WorkProcessor.TransformationState.finished; @@ -500,7 +501,7 @@ public WorkProcessorSourceOperator create( pageProcessor.apply(dynamicFilter), table, columns, - dynamicFilter, + getDynamicFilter(operatorContext, dynamicFilter), types, minOutputPageSize, minOutputPageRowCount); diff --git a/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java b/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java index 277a71af1dba..0d6bc3de903a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java @@ -41,6 +41,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.toListenableFuture; +import static io.trino.cache.CacheDriverContext.getDynamicFilter; import static java.util.Objects.requireNonNull; public class TableScanOperator @@ -93,7 +94,7 @@ public SourceOperator createOperator(DriverContext driverContext) pageSourceProvider, table, columns, - dynamicFilter); + getDynamicFilter(operatorContext, dynamicFilter)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/index/IndexBuildDriverFactoryProvider.java b/core/trino-main/src/main/java/io/trino/operator/index/IndexBuildDriverFactoryProvider.java index 221c7fe27b75..39ba4a682af4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/index/IndexBuildDriverFactoryProvider.java +++ b/core/trino-main/src/main/java/io/trino/operator/index/IndexBuildDriverFactoryProvider.java @@ -14,7 +14,7 @@ package io.trino.operator.index; import com.google.common.collect.ImmutableList; -import io.trino.operator.DriverFactory; +import io.trino.operator.OperatorDriverFactory; import io.trino.operator.OperatorFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; @@ -73,10 +73,10 @@ public List getOutputTypes() return outputTypes; } - public DriverFactory createSnapshot(int pipelineId, IndexSnapshotBuilder indexSnapshotBuilder) + public OperatorDriverFactory createSnapshot(int pipelineId, IndexSnapshotBuilder indexSnapshotBuilder) { checkArgument(indexSnapshotBuilder.getOutputTypes().equals(outputTypes)); - return new DriverFactory( + return new OperatorDriverFactory( pipelineId, inputDriver, false, @@ -87,7 +87,7 @@ public DriverFactory createSnapshot(int pipelineId, IndexSnapshotBuilder indexSn OptionalInt.empty()); } - public DriverFactory createStreaming(PageBuffer pageBuffer, Page indexKeyTuple) + public OperatorDriverFactory createStreaming(PageBuffer pageBuffer, Page indexKeyTuple) { ImmutableList.Builder operatorFactories = ImmutableList.builder() .addAll(coreOperatorFactories); @@ -99,6 +99,6 @@ public DriverFactory createStreaming(PageBuffer pageBuffer, Page indexKeyTuple) operatorFactories.add(new PageBufferOperatorFactory(outputOperatorId, planNodeId, pageBuffer, "IndexBuilder")); - return new DriverFactory(pipelineId, inputDriver, false, operatorFactories.build(), OptionalInt.empty()); + return new OperatorDriverFactory(pipelineId, inputDriver, false, operatorFactories.build(), OptionalInt.empty()); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java b/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java index 3a7bb3e5e997..22481748a334 100644 --- a/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java +++ b/core/trino-main/src/main/java/io/trino/operator/index/IndexLoader.java @@ -24,8 +24,8 @@ import io.trino.execution.SplitAssignment; import io.trino.metadata.Split; import io.trino.operator.Driver; -import io.trino.operator.DriverFactory; import io.trino.operator.FlatHashStrategyCompiler; +import io.trino.operator.OperatorDriverFactory; import io.trino.operator.PagesIndex; import io.trino.operator.PipelineContext; import io.trino.operator.TaskContext; @@ -42,6 +42,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.BlockingQueue; @@ -234,8 +235,8 @@ private IndexedData streamIndexDataForSingleKey(UpdateRequest updateRequest) Page indexKeyTuple = updateRequest.getPage().getRegion(0, 1); PageBuffer pageBuffer = new PageBuffer(100); - DriverFactory driverFactory = indexBuildDriverFactoryProvider.createStreaming(pageBuffer, indexKeyTuple); - Driver driver = driverFactory.createDriver(pipelineContext.addDriverContext()); + OperatorDriverFactory driverFactory = indexBuildDriverFactoryProvider.createStreaming(pageBuffer, indexKeyTuple); + Driver driver = driverFactory.createDriver(pipelineContext.addDriverContext(), Optional.empty()); PageRecordSet pageRecordSet = new PageRecordSet(keyTypes, indexKeyTuple); PlanNodeId planNodeId = driverFactory.getSourceId().get(); @@ -271,7 +272,7 @@ private synchronized void initializeStateIfNecessary() @NotThreadSafe private static class IndexSnapshotLoader { - private final DriverFactory driverFactory; + private final OperatorDriverFactory driverFactory; private final PipelineContext pipelineContext; private final Set lookupSourceInputChannels; private final Set allInputChannels; @@ -329,7 +330,7 @@ public boolean load(List requests) UnloadedIndexKeyRecordSet recordSetForLookupSource = new UnloadedIndexKeyRecordSet(pipelineContext.getSession(), indexSnapshotReference.get(), lookupSourceInputChannels, indexTypes, requests, hashStrategyCompiler); // Drive index lookup to produce the output (landing in indexSnapshotBuilder) - try (Driver driver = driverFactory.createDriver(pipelineContext.addDriverContext())) { + try (Driver driver = driverFactory.createDriver(pipelineContext.addDriverContext(), Optional.empty())) { PlanNodeId sourcePlanNodeId = driverFactory.getSourceId().get(); ScheduledSplit split = new ScheduledSplit(0, sourcePlanNodeId, new Split(INDEX_CATALOG_HANDLE, new IndexSplit(recordSetForLookupSource))); driver.updateSplitAssignment(new SplitAssignment(sourcePlanNodeId, ImmutableSet.of(split), true)); diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index d24691b1313a..90366469ddd1 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -52,6 +52,7 @@ import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.DynamicFilterSourceNode; import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.SemiJoinNode; import org.roaringbitmap.RoaringBitmap; @@ -95,6 +96,7 @@ import static io.trino.spi.predicate.Domain.union; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.DynamicFilters.extractSourceSymbols; +import static io.trino.sql.ir.IrUtils.extractDisjuncts; import static io.trino.sql.planner.DomainCoercer.applySaturatedCasts; import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; @@ -418,10 +420,23 @@ public static Set getOutboundDynamicFilters(PlanFragment plan) { // dynamic filters which are consumed by the given stage but produced by a different stage return ImmutableSet.copyOf(difference( - getConsumedDynamicFilters(plan.getRoot()), + union(getConsumedDynamicFilters(plan.getRoot()), getCacheDynamicFilters(plan.getRoot())), getProducedDynamicFilters(plan.getRoot()))); } + @VisibleForTesting + static Set getCacheDynamicFilters(PlanNode planNode) + { + return PlanNodeSearcher.searchFrom(planNode) + .whereIsInstanceOfAny(LoadCachedDataPlanNode.class) + .findAll().stream() + .map(LoadCachedDataPlanNode.class::cast) + .flatMap(node -> extractDisjuncts(node.getDynamicFilterDisjuncts()).stream()) + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .map(DynamicFilters.Descriptor::getId) + .collect(toImmutableSet()); + } + @VisibleForTesting Optional getSummary(QueryId queryId, DynamicFilterId filterId) { diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java index f7d4deb214f9..5183c59808fa 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java @@ -24,6 +24,7 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; import io.trino.execution.LocationFactory; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; @@ -146,7 +147,8 @@ public HttpRemoteTask createRemoteTask( PartitionedSplitCountTracker partitionedSplitCountTracker, Set outboundDynamicFilterIds, Optional estimatedMemory, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { return new HttpRemoteTask( session, @@ -177,6 +179,7 @@ public HttpRemoteTask createRemoteTask( stats, dynamicFilterService, outboundDynamicFilterIds, - estimatedMemory); + estimatedMemory, + splitAdmissionControllerProvider); } } diff --git a/core/trino-main/src/main/java/io/trino/server/PluginManager.java b/core/trino-main/src/main/java/io/trino/server/PluginManager.java index cf7cf280c65a..a55096978d45 100644 --- a/core/trino-main/src/main/java/io/trino/server/PluginManager.java +++ b/core/trino-main/src/main/java/io/trino/server/PluginManager.java @@ -17,6 +17,7 @@ import com.google.errorprone.annotations.ThreadSafe; import com.google.inject.Inject; import io.airlift.log.Logger; +import io.trino.cache.CacheManagerRegistry; import io.trino.connector.CatalogFactory; import io.trino.connector.CatalogStoreManager; import io.trino.eventlistener.EventListenerManager; @@ -36,6 +37,7 @@ import io.trino.server.security.PasswordAuthenticatorManager; import io.trino.spi.Plugin; import io.trino.spi.block.BlockEncoding; +import io.trino.spi.cache.CacheManagerFactory; import io.trino.spi.catalog.CatalogStoreFactory; import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.ConnectorFactory; @@ -92,6 +94,7 @@ public class PluginManager private final GroupProviderManager groupProviderManager; private final ExchangeManagerRegistry exchangeManagerRegistry; private final SpoolingManagerRegistry spoolingManagerRegistry; + private final CacheManagerRegistry cacheManagerRegistry; private final SessionPropertyDefaults sessionPropertyDefaults; private final TypeRegistry typeRegistry; private final BlockEncodingManager blockEncodingManager; @@ -116,7 +119,8 @@ public PluginManager( BlockEncodingManager blockEncodingManager, HandleResolver handleResolver, ExchangeManagerRegistry exchangeManagerRegistry, - SpoolingManagerRegistry spoolingManagerRegistry) + SpoolingManagerRegistry spoolingManagerRegistry, + CacheManagerRegistry cacheManagerRegistry) { this.pluginsProvider = requireNonNull(pluginsProvider, "pluginsProvider is null"); this.catalogStoreManager = requireNonNull(catalogStoreManager, "catalogStoreManager is null"); @@ -135,6 +139,7 @@ public PluginManager( this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.spoolingManagerRegistry = requireNonNull(spoolingManagerRegistry, "spoolingManagerRegistry is null"); + this.cacheManagerRegistry = requireNonNull(cacheManagerRegistry, "cacheManagerRegistry is null"); } @Override @@ -277,6 +282,10 @@ private void installPluginInternal(Plugin plugin) log.info("Registering spooling manager %s", spoolingManagerFactory.getName()); spoolingManagerRegistry.addSpoolingManagerFactory(spoolingManagerFactory); } + for (CacheManagerFactory cacheManagerFactory : plugin.getCacheManagerFactories()) { + log.info("Registering cache manager %s", cacheManagerFactory.getName()); + cacheManagerRegistry.addCacheManagerFactory(cacheManagerFactory); + } } public static PluginClassLoader createClassLoader(String pluginName, List urls) diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index f5ba8fbb1c55..939f7355e774 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -39,6 +39,8 @@ import io.airlift.openmetrics.JmxOpenMetricsModule; import io.airlift.tracing.TracingModule; import io.airlift.units.Duration; +import io.trino.cache.CacheManagerModule; +import io.trino.cache.CacheManagerRegistry; import io.trino.client.NodeVersion; import io.trino.connector.CatalogManagerConfig; import io.trino.connector.CatalogManagerConfig.CatalogMangerKind; @@ -124,6 +126,7 @@ private void doStart(String trinoVersion) new AccessControlModule(), new EventListenerModule(), new ExchangeManagerModule(), + new CacheManagerModule(), new CoordinatorDiscoveryModule(), new CatalogManagerModule(), new TransactionManagerModule(), @@ -177,6 +180,7 @@ private void doStart(String trinoVersion) injector.getInstance(GroupProviderManager.class).loadConfiguredGroupProvider(); injector.getInstance(ExchangeManagerRegistry.class).loadExchangeManager(); injector.getInstance(SpoolingManagerRegistry.class).loadSpoolingManager(); + injector.getInstance(CacheManagerRegistry.class).loadCacheManager(); injector.getInstance(CertificateAuthenticatorManager.class).loadCertificateAuthenticator(); injector.getInstance(Key.get(new TypeLiteral>() {})) .ifPresent(HeaderAuthenticatorManager::loadHeaderAuthenticator); diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index b3ebe209e70b..914ceecea27a 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -33,6 +33,7 @@ import io.trino.SystemSessionProperties; import io.trino.SystemSessionPropertiesProvider; import io.trino.block.BlockJsonSerde; +import io.trino.cache.CacheMetadata; import io.trino.client.NodeVersion; import io.trino.connector.system.SystemConnectorModule; import io.trino.dispatcher.DispatchManager; @@ -108,6 +109,7 @@ import io.trino.spi.VersionEmbedder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeOperators; @@ -301,6 +303,7 @@ protected void setup(Binder binder) binder.bind(MultilevelSplitQueue.class).in(Scopes.SINGLETON); newExporter(binder).export(MultilevelSplitQueue.class).withGeneratedName(); + jsonCodecBinder(binder).bindJsonCodec(TupleDomain.class); binder.bind(LocalExecutionPlanner.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(CompilerConfig.class); binder.bind(ExpressionCompiler.class).in(Scopes.SINGLETON); @@ -423,6 +426,9 @@ protected void setup(Binder binder) // split manager binder.bind(SplitManager.class).in(Scopes.SINGLETON); + // cache metadata + binder.bind(CacheMetadata.class).in(Scopes.SINGLETON); + // node partitioning manager binder.bind(NodePartitioningManager.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index 13eb1e7e63c0..cacfd3de6409 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -39,6 +39,7 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.DynamicFiltersCollector; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; import io.trino.execution.ExecutionFailureInfo; @@ -64,6 +65,7 @@ import io.trino.server.DynamicFilterService; import io.trino.server.FailTaskRequest; import io.trino.server.TaskUpdateRequest; +import io.trino.spi.HostAddress; import io.trino.spi.SplitWeight; import io.trino.spi.TrinoException; import io.trino.spi.TrinoTransportException; @@ -142,6 +144,7 @@ public final class HttpRemoteTask private final Session session; private final Span stageSpan; private final String nodeId; + private final HostAddress nodeAddress; private final AtomicBoolean speculative; private final PlanFragment planFragment; @@ -209,6 +212,8 @@ public final class HttpRemoteTask private final long requestSizeHeadroomInBytes; private final boolean adaptiveUpdateRequestSizeEnabled; + private final SplitAdmissionControllerProvider splitAdmissionControllerProvider; + public HttpRemoteTask( Session session, Span stageSpan, @@ -238,7 +243,8 @@ public HttpRemoteTask( RemoteTaskStats stats, DynamicFilterService dynamicFilterService, Set outboundDynamicFilterIds, - Optional estimatedMemory) + Optional estimatedMemory, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { requireNonNull(session, "session is null"); requireNonNull(stageSpan, "stageSpan is null"); @@ -262,6 +268,7 @@ public HttpRemoteTask( this.session = session; this.stageSpan = stageSpan; this.nodeId = node.getNodeIdentifier(); + this.nodeAddress = node.getHostAndPort(); this.speculative = new AtomicBoolean(speculative); this.planFragment = planFragment; this.outputBuffers.set(outputBuffers); @@ -281,7 +288,9 @@ public HttpRemoteTask( this.stats = stats; for (Entry entry : initialSplits.entries()) { - ScheduledSplit scheduledSplit = new ScheduledSplit(nextSplitId.getAndIncrement(), entry.getKey(), entry.getValue()); + Split split = entry.getValue(); + split = split.withSplitAddressEnforced(split.getAddresses().contains(nodeAddress)); + ScheduledSplit scheduledSplit = new ScheduledSplit(nextSplitId.getAndIncrement(), entry.getKey(), split); pendingSplits.put(entry.getKey(), scheduledSplit); } maxUnacknowledgedSplits = getMaxUnacknowledgedSplitsPerTask(session); @@ -389,6 +398,7 @@ public HttpRemoteTask( outboundDynamicFilterIds, outboundDynamicFiltersCollector::updateDomains); + this.splitAdmissionControllerProvider = requireNonNull(splitAdmissionControllerProvider, "splitAdmissionControllerProvider is null"); partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); } @@ -452,6 +462,7 @@ public synchronized void addSplits(Multimap splitsBySource) int added = 0; long addedWeight = 0; for (Split split : splits) { + split = split.withSplitAddressEnforced(split.getAddresses().contains(nodeAddress)); if (pendingSplits.put(sourceId, new ScheduledSplit(nextSplitId.getAndIncrement(), sourceId, split))) { if (isPartitionedSource) { added++; @@ -464,6 +475,9 @@ public synchronized void addSplits(Multimap splitsBySource) pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, addedWeight); partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); } + // Notify that splits have been scheduled. This is needed such that no two same splits are scheduled on + // the same worker at the same time thus, to effectively utilize the cache. + splitAdmissionControllerProvider.get(sourceId).splitsScheduled(ImmutableList.copyOf(splits)); needsUpdate = true; } updateSplitQueueSpace(); diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index 9d0f010c604e..dc2246fcbfdb 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -44,6 +44,9 @@ import io.opentelemetry.sdk.trace.SpanProcessor; import io.trino.Session; import io.trino.SystemSessionPropertiesProvider; +import io.trino.cache.CacheManagerModule; +import io.trino.cache.CacheManagerRegistry; +import io.trino.cache.CacheMetadata; import io.trino.connector.CatalogManagerConfig.CatalogMangerKind; import io.trino.connector.CatalogManagerModule; import io.trino.connector.CatalogStoreManager; @@ -187,6 +190,7 @@ public static Builder builder() private final TestingHttpServer server; private final TransactionManager transactionManager; private final TablePropertyManager tablePropertyManager; + private final CacheMetadata cacheMetadata; private final PlannerContext plannerContext; private final QueryExplainer queryExplainer; private final SessionPropertyManager sessionPropertyManager; @@ -214,6 +218,7 @@ public static Builder builder() private final FailureInjector failureInjector; private final ExchangeManagerRegistry exchangeManagerRegistry; private final SpoolingManagerRegistry spoolingManagerRegistry; + private final CacheManagerRegistry cacheManagerRegistry; public static class TestShutdownAction implements ShutdownAction @@ -307,6 +312,7 @@ private TestingTrinoServer( .add(new CatalogManagerModule()) .add(new TransactionManagerModule()) .add(new ServerMainModule(VERSION)) + .add(new CacheManagerModule()) .add(new TestingWarningCollectorModule()) .add(binder -> { binder.bind(EventListenerConfig.class).in(Scopes.SINGLETON); @@ -388,6 +394,7 @@ private TestingTrinoServer( server = injector.getInstance(TestingHttpServer.class); transactionManager = injector.getInstance(TransactionManager.class); tablePropertyManager = injector.getInstance(TablePropertyManager.class); + cacheMetadata = injector.getInstance(CacheMetadata.class); globalFunctionCatalog = injector.getInstance(GlobalFunctionCatalog.class); plannerContext = injector.getInstance(PlannerContext.class); accessControl = injector.getInstance(TestingAccessControlManager.class); @@ -428,6 +435,8 @@ private TestingTrinoServer( failureInjector = injector.getInstance(FailureInjector.class); exchangeManagerRegistry = injector.getInstance(ExchangeManagerRegistry.class); spoolingManagerRegistry = injector.getInstance(SpoolingManagerRegistry.class); + cacheManagerRegistry = injector.getInstance(CacheManagerRegistry.class); + cacheManagerRegistry.loadCacheManager(); systemAccessControlConfiguration.ifPresentOrElse( configuration -> { @@ -524,6 +533,11 @@ public void loadSpoolingManager(String name, Map properties) spoolingManagerRegistry.loadSpoolingManager(name, properties); } + public CacheManagerRegistry getCacheManagerRegistry() + { + return cacheManagerRegistry; + } + /** * Add the event listeners from connectors. Connector event listeners are * only supported for statically loaded catalogs, and this doesn't match up @@ -579,6 +593,11 @@ public TablePropertyManager getTablePropertyManager() return tablePropertyManager; } + public CacheMetadata getCacheMetadata() + { + return cacheMetadata; + } + public PlannerContext getPlannerContext() { return plannerContext; @@ -677,11 +696,20 @@ public ShutdownAction getShutdownAction() } public Connector getConnector(String catalogName) + { + return getConnector(getCatalogHandle(catalogName)); + } + + public CatalogHandle getCatalogHandle(String catalogName) { checkState(coordinator, "not a coordinator"); - CatalogHandle catalogHandle = catalogManager.orElseThrow().getCatalog(new CatalogName(catalogName)) + return catalogManager.orElseThrow().getCatalog(new CatalogName(catalogName)) .orElseThrow(() -> new IllegalArgumentException("Catalog '%s' not found".formatted(catalogName))) .getCatalogHandle(); + } + + public Connector getConnector(CatalogHandle catalogHandle) + { return injector.getInstance(ConnectorServicesProvider.class) .getConnectorServices(catalogHandle) .getConnector(); diff --git a/core/trino-main/src/main/java/io/trino/split/PageSourceManager.java b/core/trino-main/src/main/java/io/trino/split/PageSourceManager.java index 8be210388fd2..d15fe7781fdb 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSourceManager.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSourceManager.java @@ -13,6 +13,7 @@ */ package io.trino.split; +import com.google.common.annotations.VisibleForTesting; import com.google.inject.Inject; import io.trino.Session; import io.trino.connector.CatalogServiceProvider; @@ -23,6 +24,7 @@ import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorPageSourceProvider; import io.trino.spi.connector.ConnectorPageSourceProviderFactory; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.EmptyPageSource; import io.trino.spi.predicate.TupleDomain; @@ -51,10 +53,11 @@ public PageSourceProvider createPageSourceProvider(CatalogHandle catalogHandle) return new PageSourceProviderInstance(provider.createPageSourceProvider()); } - private record PageSourceProviderInstance(ConnectorPageSourceProvider pageSourceProvider) + @VisibleForTesting + public record PageSourceProviderInstance(ConnectorPageSourceProvider pageSourceProvider) implements PageSourceProvider { - private PageSourceProviderInstance + public PageSourceProviderInstance { requireNonNull(pageSourceProvider, "pageSourceProvider is null"); } @@ -84,5 +87,33 @@ public ConnectorPageSource createPageSource(Session session, columns, dynamicFilter); } + + @Override + public TupleDomain getUnenforcedPredicate( + Session session, + Split split, + TableHandle table, + TupleDomain dynamicFilter) + { + checkArgument(split.getCatalogHandle().equals(table.catalogHandle()), "mismatched split and table"); + + CatalogHandle catalogHandle = split.getCatalogHandle(); + ConnectorSession connectorSession = session.toConnectorSession(catalogHandle); + return pageSourceProvider.getUnenforcedPredicate(connectorSession, split.getConnectorSplit(), table.connectorHandle(), dynamicFilter); + } + + @Override + public TupleDomain prunePredicate( + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + checkArgument(split.getCatalogHandle().equals(table.catalogHandle()), "mismatched split and table"); + + CatalogHandle catalogHandle = split.getCatalogHandle(); + ConnectorSession connectorSession = session.toConnectorSession(catalogHandle); + return pageSourceProvider.prunePredicate(connectorSession, split.getConnectorSplit(), table.connectorHandle(), predicate); + } } } diff --git a/core/trino-main/src/main/java/io/trino/split/PageSourceProvider.java b/core/trino-main/src/main/java/io/trino/split/PageSourceProvider.java index 5a103103c2ed..a2584fd9d3a7 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSourceProvider.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSourceProvider.java @@ -19,6 +19,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.predicate.TupleDomain; import java.util.List; @@ -30,4 +31,22 @@ ConnectorPageSource createPageSource( TableHandle table, List columns, DynamicFilter dynamicFilter); + + default TupleDomain getUnenforcedPredicate( + Session session, + Split split, + TableHandle table, + TupleDomain dynamicFilter) + { + throw new UnsupportedOperationException(); + } + + default TupleDomain prunePredicate( + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + throw new UnsupportedOperationException(); + } } diff --git a/core/trino-main/src/main/java/io/trino/split/RecordPageSourceProvider.java b/core/trino-main/src/main/java/io/trino/split/RecordPageSourceProvider.java index 1a905b23db5e..65809a9f41b6 100644 --- a/core/trino-main/src/main/java/io/trino/split/RecordPageSourceProvider.java +++ b/core/trino-main/src/main/java/io/trino/split/RecordPageSourceProvider.java @@ -23,6 +23,7 @@ import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.RecordPageSource; +import io.trino.spi.predicate.TupleDomain; import java.util.List; @@ -49,4 +50,26 @@ public ConnectorPageSource createPageSource( { return new RecordPageSource(recordSetProvider.getRecordSet(transaction, session, split, table, columns)); } + + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain dynamicFilter) + { + // record page source doesn't support dynamic predicates + return TupleDomain.all(); + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain predicate) + { + // record page source doesn't support pruning of predicates + return predicate; + } } diff --git a/core/trino-main/src/main/java/io/trino/split/SplitManager.java b/core/trino-main/src/main/java/io/trino/split/SplitManager.java index 12d9e1491d4c..0eb2580a909c 100644 --- a/core/trino-main/src/main/java/io/trino/split/SplitManager.java +++ b/core/trino-main/src/main/java/io/trino/split/SplitManager.java @@ -15,14 +15,19 @@ import com.google.inject.Inject; import io.airlift.concurrent.BoundedExecutor; +import io.airlift.node.NodeInfo; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.trino.Session; +import io.trino.cache.CacheSplitSource; +import io.trino.cache.ConnectorAwareAddressProvider; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.connector.CatalogServiceProvider; import io.trino.execution.QueryManagerConfig; import io.trino.metadata.TableFunctionHandle; import io.trino.metadata.TableHandle; +import io.trino.spi.cache.PlanSignature; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitManager; @@ -132,6 +137,34 @@ public SplitSource getSplits(Session session, Span parentSpan, TableFunctionHand return new TracingSplitSource(splitSource, tracer, Optional.of(span), "split-buffer"); } + public ConnectorSplitManager getConnectorSplitManager(TableHandle tableHandle) + { + return splitManagerProvider.getService(tableHandle.catalogHandle()); + } + + public CacheSplitSource getCacheSplitSource( + PlanSignature signature, + TableHandle tableHandle, + SplitSource delegate, + ConnectorAwareAddressProvider connectorAwareAddressProvider, + NodeInfo nodeInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider, + boolean schedulerIncludeCoordinator, + int minScheduleSplitBatchSize) + { + return new CacheSplitSource( + signature, + getConnectorSplitManager(tableHandle), + delegate, + connectorAwareAddressProvider, + nodeInfo, + splitAdmissionControllerProvider, + schedulerIncludeCoordinator, + minScheduleSplitBatchSize, + // Use the same executor as the one used by BufferingSplitSource + executor); + } + private Span splitSourceSpan(Span parentSpan, CatalogHandle catalogHandle) { return tracer.spanBuilder("split-source") diff --git a/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java b/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java index acd256228ad2..f68461f52cd7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java +++ b/core/trino-main/src/main/java/io/trino/sql/PlannerContext.java @@ -15,6 +15,7 @@ import com.google.inject.Inject; import io.opentelemetry.api.trace.Tracer; +import io.trino.cache.CacheMetadata; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionManager; import io.trino.metadata.FunctionResolver; @@ -38,6 +39,7 @@ public class PlannerContext // throughout the analyzer and planner, so it is easy to create // circular dependencies, just create a junk drawer of services. private final Metadata metadata; + private final CacheMetadata cacheMetadata; private final TypeOperators typeOperators; private final BlockEncodingSerde blockEncodingSerde; private final TypeManager typeManager; @@ -47,6 +49,7 @@ public class PlannerContext @Inject public PlannerContext(Metadata metadata, + CacheMetadata cacheMetadata, TypeOperators typeOperators, BlockEncodingSerde blockEncodingSerde, TypeManager typeManager, @@ -55,6 +58,7 @@ public PlannerContext(Metadata metadata, Tracer tracer) { this.metadata = requireNonNull(metadata, "metadata is null"); + this.cacheMetadata = requireNonNull(cacheMetadata, "cacheMetadata is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); @@ -68,6 +72,11 @@ public Metadata getMetadata() return metadata; } + public CacheMetadata getCacheMetadata() + { + return cacheMetadata; + } + public TypeOperators getTypeOperators() { return typeOperators; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFiltersCollector.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFiltersCollector.java index 1e31c6009cc8..a865010d91cb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFiltersCollector.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFiltersCollector.java @@ -61,14 +61,7 @@ public LocalDynamicFiltersCollector(Session session) // Called during JoinNode planning (no need to be synchronized as local planning is single threaded) public void register(Set filterIds) { - filterIds.forEach(filterId -> verify( - futures.put(filterId, SettableFuture.create()) == null, - "LocalDynamicFiltersCollector: duplicate filter %s", filterId)); - } - - public Set getRegisteredDynamicFilterIds() - { - return futures.keySet(); + filterIds.forEach(filterId -> futures.putIfAbsent(filterId, SettableFuture.create())); } // Used during execution (after build-side dynamic filter collection is over). diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 2ec18558dd32..1237ffe94ca7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -29,11 +29,19 @@ import com.google.common.collect.SetMultimap; import com.google.common.primitives.Ints; import com.google.inject.Inject; +import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.cache.CacheDataOperator; +import io.trino.cache.CacheDriverFactory; +import io.trino.cache.CacheManagerRegistry; +import io.trino.cache.CacheStats; +import io.trino.cache.CommonPlanAdaptation; +import io.trino.cache.LoadCachedDataOperator.LoadCachedDataOperatorFactory; import io.trino.cache.NonEvictableCache; +import io.trino.cache.StaticDynamicFilter; import io.trino.client.NodeVersion; import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.DynamicFilterConfig; @@ -67,11 +75,11 @@ import io.trino.operator.JoinOperatorType; import io.trino.operator.LeafTableFunctionOperator.LeafTableFunctionOperatorFactory; import io.trino.operator.LimitOperator.LimitOperatorFactory; -import io.trino.operator.LocalPlannerAware; import io.trino.operator.MarkDistinctOperator.MarkDistinctOperatorFactory; import io.trino.operator.MergeOperator.MergeOperatorFactory; import io.trino.operator.MergeProcessorOperator; import io.trino.operator.MergeWriterOperator.MergeWriterOperatorFactory; +import io.trino.operator.OperatorDriverFactory; import io.trino.operator.OperatorFactory; import io.trino.operator.OrderByOperator.OrderByOperatorFactory; import io.trino.operator.OutputFactory; @@ -163,6 +171,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.SqlRow; +import io.trino.spi.cache.CacheColumnId; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorIndex; import io.trino.spi.connector.ConnectorSession; @@ -181,6 +190,7 @@ import io.trino.spi.function.table.TableFunctionProcessorProvider; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; +import io.trino.spi.predicate.TupleDomain; import io.trino.spi.spool.SpoolingManager; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; @@ -213,6 +223,8 @@ import io.trino.sql.planner.plan.AggregationNode.Step; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.CacheDataPlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -226,6 +238,7 @@ import io.trino.sql.planner.plan.IndexSourceNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.MergeProcessorNode; import io.trino.sql.planner.plan.MergeWriterNode; @@ -309,9 +322,9 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; -import static com.google.common.collect.Sets.difference; import static io.trino.SystemSessionProperties.getAdaptivePartialAggregationUniqueRowsRatioThreshold; import static io.trino.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit; +import static io.trino.SystemSessionProperties.getCacheMaxSplitSize; import static io.trino.SystemSessionProperties.getDynamicRowFilterSelectivityThreshold; import static io.trino.SystemSessionProperties.getExchangeCompressionCodec; import static io.trino.SystemSessionProperties.getFilterAndProjectMinOutputPageRowCount; @@ -328,8 +341,12 @@ import static io.trino.SystemSessionProperties.isEnableLargeDynamicFilters; import static io.trino.SystemSessionProperties.isForceSpillingOperator; import static io.trino.SystemSessionProperties.isSpillEnabled; +import static io.trino.cache.CacheCommonSubqueries.getLoadCachedDataPlanNode; +import static io.trino.cache.CacheCommonSubqueries.isCacheChooseAlternativeNode; import static io.trino.cache.CacheUtils.uncheckedCacheGet; import static io.trino.cache.SafeCaches.buildNonEvictableCache; +import static io.trino.cache.StaticDynamicFilter.createStaticDynamicFilter; +import static io.trino.cache.StaticDynamicFilter.createStaticDynamicFilterSupplier; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; @@ -370,6 +387,7 @@ import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.ir.IrUtils.combineConjuncts; +import static io.trino.sql.ir.IrUtils.extractDisjuncts; import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.trino.sql.planner.SortExpressionExtractor.extractSortExpression; @@ -417,6 +435,9 @@ public class LocalExecutionPlanner private final Metadata metadata; private final Optional explainAnalyzeContext; private final PageSourceManager pageSourceManager; + private final CacheManagerRegistry cacheManagerRegistry; + private final JsonCodec tupleDomainCodec; + private final CacheStats cacheStats; private final IndexManager indexManager; private final NodePartitioningManager nodePartitioningManager; private final PageSinkManager pageSinkManager; @@ -475,6 +496,9 @@ public LocalExecutionPlanner( PlannerContext plannerContext, Optional explainAnalyzeContext, PageSourceManager pageSourceManager, + CacheManagerRegistry cacheManagerRegistry, + JsonCodec tupleDomainCodec, + CacheStats cacheStats, IndexManager indexManager, NodePartitioningManager nodePartitioningManager, PageSinkManager pageSinkManager, @@ -506,6 +530,9 @@ public LocalExecutionPlanner( this.metadata = plannerContext.getMetadata(); this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); + this.cacheManagerRegistry = requireNonNull(cacheManagerRegistry, "cacheManagerRegistry is null"); + this.tupleDomainCodec = requireNonNull(tupleDomainCodec, "tupleDomainCodec is null"); + this.cacheStats = requireNonNull(cacheStats, "cacheStats is null"); this.indexManager = requireNonNull(indexManager, "indexManager is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.directExchangeClientSupplier = directExchangeClientSupplier; @@ -680,21 +707,15 @@ public LocalExecutionPlan plan( outputTypes, pagePreprocessor, new PagesSerdeFactory(plannerContext.getBlockEncodingSerde(), getExchangeCompressionCodec(session))), - physicalOperation), - context); + physicalOperation)); // notify operator factories that planning has completed - context.getDriverFactories().stream() - .map(DriverFactory::getOperatorFactories) - .flatMap(List::stream) - .filter(LocalPlannerAware.class::isInstance) - .map(LocalPlannerAware.class::cast) - .forEach(LocalPlannerAware::localPlannerComplete); + context.getDriverFactories().forEach(DriverFactory::localPlannerComplete); return new LocalExecutionPlan(context.getDriverFactories(), partitionedSourceOrder); } - private static class LocalExecutionPlanContext + private class LocalExecutionPlanContext { private final TaskContext taskContext; private final List driverFactories; @@ -706,6 +727,8 @@ private static class LocalExecutionPlanContext private int nextOperatorId; private boolean inputDriver = true; private OptionalInt driverInstanceCount = OptionalInt.empty(); + private Optional cacheContext = Optional.empty(); + private Optional alternativeSourceId = Optional.empty(); public LocalExecutionPlanContext(TaskContext taskContext) { @@ -728,13 +751,50 @@ private LocalExecutionPlanContext( this.nextPipelineId = nextPipelineId; } - public void addDriverFactory(boolean outputDriver, PhysicalOperation physicalOperation, LocalExecutionPlanContext context) + public void addDriverFactory(boolean outputDriver, PhysicalOperation physicalOperation) { - boolean inputDriver = context.isInputDriver(); - OptionalInt driverInstances = context.getDriverInstanceCount(); - List operatorFactories = physicalOperation.getOperatorFactories(); + List operatorFactories = physicalOperation.getPipelineTail(); addLookupOuterDrivers(outputDriver, operatorFactories); - addDriverFactory(inputDriver, outputDriver, operatorFactories, driverInstances); + if (physicalOperation.getPipelineHeadAlternatives().isEmpty()) { + addDriverFactory(inputDriver, outputDriver, operatorFactories, driverInstanceCount); + } + else { + // we have alternatives, we need to extend them to the end of the pipeline and create CacheDriverFactory + List commonOperators = physicalOperation.getPipelineTail().stream() + .map(SharedOperatorFactory::new) + .collect(toImmutableList()); + int pipelineId = getNextPipelineId(); + List alternatives = physicalOperation.getPipelineHeadAlternatives().stream() + .map(alternative -> new OperatorDriverFactory( + pipelineId, + inputDriver, + outputDriver, + ImmutableList.builder() + .addAll(alternative) + .addAll(commonOperators) + .build(), + driverInstanceCount)) + .collect(toImmutableList()); + driverFactories.add(cacheContext + .map(cacheContext -> new CacheDriverFactory( + pipelineId, + inputDriver, + outputDriver, + driverInstanceCount, + alternativeSourceId.orElseThrow(), + taskContext.getSession(), + pageSourceManager.createPageSourceProvider(cacheContext.getOriginalTableHandle().catalogHandle()), + cacheManagerRegistry, + tupleDomainCodec, + cacheContext.getOriginalTableHandle(), + cacheContext.getPlanSignature(), + cacheContext.getCommonColumnHandles(), + cacheContext.getCommonDynamicFilterSupplier(), + cacheContext.getOriginalDynamicFilterSupplier(), + ImmutableList.copyOf(alternatives), + cacheStats)) + .orElseThrow(() -> new IllegalStateException("Cache context is not set"))); + } } private void addLookupOuterDrivers(boolean isOutputDriver, List operatorFactories) @@ -765,7 +825,7 @@ private void addLookupOuterDrivers(boolean isOutputDriver, List private void addDriverFactory(boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances) { - driverFactories.add(new DriverFactory(getNextPipelineId(), inputDriver, outputDriver, operatorFactories, driverInstances)); + driverFactories.add(new OperatorDriverFactory(getNextPipelineId(), inputDriver, outputDriver, operatorFactories, driverInstances)); } private List getDriverFactories() @@ -793,10 +853,7 @@ private void registerCoordinatorDynamicFilters(List d Set consumedFilterIds = dynamicFilters.stream() .map(DynamicFilters.Descriptor::getId) .collect(toImmutableSet()); - LocalDynamicFiltersCollector dynamicFiltersCollector = getDynamicFiltersCollector(); - // Don't repeat registration of node-local filters or those already registered by another scan (e.g. co-located joins) - dynamicFiltersCollector.register( - difference(consumedFilterIds, dynamicFiltersCollector.getRegisteredDynamicFilterIds())); + getDynamicFiltersCollector().register(consumedFilterIds); } private TaskContext getTaskContext() @@ -853,6 +910,71 @@ public void setDriverInstanceCount(int driverInstanceCount) } this.driverInstanceCount = OptionalInt.of(driverInstanceCount); } + + public void setCacheContext(CacheContext cacheContext) + { + checkState(this.cacheContext.isEmpty(), "cacheContext is already set"); + this.cacheContext = Optional.of(requireNonNull(cacheContext, "cacheContext is null")); + } + + public Optional getAlternativeSourceId() + { + return alternativeSourceId; + } + + public void setAlternativeSourceId(PlanNodeId alternativeSourceId) + { + checkState(this.alternativeSourceId.isEmpty()); + this.alternativeSourceId = Optional.of(requireNonNull(alternativeSourceId, "alternativeSourceId is null")); + } + } + + private static class CacheContext + { + private final TableHandle originalTableHandle; + private final CommonPlanAdaptation.PlanSignatureWithPredicate planSignature; + private final Map commonColumnHandles; + private final Supplier commonDynamicFilterSupplier; + private final Supplier originalDynamicFilterSupplier; + + public CacheContext( + TableHandle originalTableHandle, + LoadCachedDataPlanNode loadCacheData, + Supplier commonDynamicFilterSupplier, + Supplier originalDynamicFilterSupplier) + { + requireNonNull(loadCacheData, "loadCacheData is null"); + this.originalTableHandle = requireNonNull(originalTableHandle, "originalTableHandle is null"); + this.planSignature = loadCacheData.getPlanSignature(); + this.commonColumnHandles = loadCacheData.getCommonColumnHandles(); + this.commonDynamicFilterSupplier = requireNonNull(commonDynamicFilterSupplier, "commonDynamicFilterSupplier is null"); + this.originalDynamicFilterSupplier = requireNonNull(originalDynamicFilterSupplier, "originalDynamicFilterSupplier is null"); + } + + public TableHandle getOriginalTableHandle() + { + return originalTableHandle; + } + + public CommonPlanAdaptation.PlanSignatureWithPredicate getPlanSignature() + { + return planSignature; + } + + public Map getCommonColumnHandles() + { + return commonColumnHandles; + } + + public Supplier getCommonDynamicFilterSupplier() + { + return commonDynamicFilterSupplier; + } + + public Supplier getOriginalDynamicFilterSupplier() + { + return originalDynamicFilterSupplier; + } } private static class IndexSourceContext @@ -1735,6 +1857,70 @@ private Supplier prepareArgumentProjection(Expression argument, return pageFunctionCompiler.compileProjection(rowExpression, Optional.empty()); } + @Override + public PhysicalOperation visitChooseAlternativeNode(ChooseAlternativeNode node, LocalExecutionPlanContext context) + { + checkArgument(isCacheChooseAlternativeNode(node)); + context.setAlternativeSourceId(node.getId()); + // when splits are cached dynamic filter needs to be static during split processing + LoadCachedDataPlanNode loadCachedData = getLoadCachedDataPlanNode(node); + TableScanNode commonTableScan = node.getOriginalTableScan().tableScanNode(); + List commonDynamicFilters = extractDisjuncts(loadCachedData.getDynamicFilterDisjuncts()).stream() + .map(predicate -> getDynamicFilter(commonTableScan, predicate, context)) + .collect(toImmutableList()); + Supplier commonDynamicFilterSupplier = createStaticDynamicFilterSupplier(commonDynamicFilters); + Supplier originalDynamicFilterSupplier = node.getOriginalTableScan().filterPredicate() + .map(predicate -> getDynamicFilter(commonTableScan, predicate, context)) + .map(dynamicFilter -> createStaticDynamicFilterSupplier(ImmutableList.of(dynamicFilter))) + .orElse(() -> createStaticDynamicFilter(ImmutableList.of(DynamicFilter.EMPTY))); + context.setCacheContext(new CacheContext( + node.getOriginalTableScan().tableHandle(), + loadCachedData, + commonDynamicFilterSupplier, + originalDynamicFilterSupplier)); + + ImmutableList.Builder alternatives = ImmutableList.builder(); + Map outputLayout = null; + for (PlanNode alternative : node.getSources()) { + PhysicalOperation alternativeOperation = alternative.accept(this, context); + if (outputLayout == null) { + // we need an output layout, we may as well take it from the first alternative. + // this is consistent with ChooseAlternativeNode.getOutputSymbols + outputLayout = alternativeOperation.getLayout(); + alternatives.add(alternativeOperation); + } + else { + checkArgument( + outputLayout.equals(alternativeOperation.getLayout()), + "All alternatives should have the same layout but %s != %s", + outputLayout, + alternativeOperation.getLayout()); + // we don't need channel reordering if layout matches exactly + alternatives.add(alternativeOperation); + } + } + + return new PhysicalOperation(outputLayout, alternatives.build()); + } + + @Override + public PhysicalOperation visitCacheDataPlanNode(CacheDataPlanNode node, LocalExecutionPlanContext context) + { + PhysicalOperation source = node.getSource().accept(this, context); + return new PhysicalOperation( + new CacheDataOperator.CacheDataOperatorFactory(context.getNextOperatorId(), node.getId(), getCacheMaxSplitSize(session).toBytes()), + source.getLayout(), + source); + } + + @Override + public PhysicalOperation visitLoadCachedDataPlanNode(LoadCachedDataPlanNode node, LocalExecutionPlanContext context) + { + return new PhysicalOperation( + new LoadCachedDataOperatorFactory(context.getNextOperatorId(), node.getId(), context.getAlternativeSourceId().orElseThrow()), + makeLayout(node)); + } + @Override public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecutionPlanContext context) { @@ -2148,7 +2334,7 @@ else if (sourceNode instanceof SampleNode sampleNode) { SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, - sourceNode.getId(), + context.getAlternativeSourceId().orElse(sourceNode.getId()), pageSourceManager, cursorProcessor, pageProcessor, @@ -2204,7 +2390,14 @@ private PhysicalOperation visitTableScan(PlanNodeId planNodeId, TableScanNode no } DynamicFilter dynamicFilter = getDynamicFilter(node, filterExpression, context); - OperatorFactory operatorFactory = new TableScanOperatorFactory(context.getNextOperatorId(), planNodeId, node.getId(), pageSourceManager, node.getTable(), columns, dynamicFilter); + OperatorFactory operatorFactory = new TableScanOperatorFactory( + context.getNextOperatorId(), + planNodeId, + context.getAlternativeSourceId().orElse(node.getId()), + pageSourceManager, + node.getTable(), + columns, + dynamicFilter); return new PhysicalOperation(operatorFactory, makeLayout(node)); } @@ -2746,10 +2939,9 @@ private PhysicalOperation createNestedLoopJoin(JoinNode node, Set outputMappings = ImmutableMap.builder(); @@ -2878,10 +3070,9 @@ private PagesSpatialIndexFactory createPagesSpatialIndexFactory( 10_000, pagesIndexFactory); - context.addDriverFactory( + buildContext.addDriverFactory( false, - new PhysicalOperation(builderOperatorFactory, buildSource), - buildContext); + new PhysicalOperation(builderOperatorFactory, buildSource)); return builderOperatorFactory.getPagesSpatialIndexFactory(); } @@ -3020,10 +3211,9 @@ private PhysicalOperation createLookupJoin( // is reduced (e.g. by plan rule) with respect to default task concurrency taskConcurrency / partitionCount)); - context.addDriverFactory( + buildContext.addDriverFactory( false, - new PhysicalOperation(hashBuilderOperatorFactory, buildSource), - buildContext); + new PhysicalOperation(hashBuilderOperatorFactory, buildSource)); JoinOperatorType joinType = JoinOperatorType.ofJoinNodeType(node.getType(), outputSingleMatch, waitForBuild); operator = spillingJoin( @@ -3071,10 +3261,9 @@ private PhysicalOperation createLookupJoin( // is reduced (e.g. by plan rule) with respect to default task concurrency taskConcurrency / partitionCount)); - context.addDriverFactory( + buildContext.addDriverFactory( false, - new PhysicalOperation(hashBuilderOperatorFactory, buildSource), - buildContext); + new PhysicalOperation(hashBuilderOperatorFactory, buildSource)); JoinOperatorType joinType = JoinOperatorType.ofJoinNodeType(node.getType(), outputSingleMatch, waitForBuild); operator = join( @@ -3310,10 +3499,9 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont joinCompiler, typeOperators); SetSupplier setProvider = setBuilderOperatorFactory.getSetProvider(); - context.addDriverFactory( + buildContext.addDriverFactory( false, - new PhysicalOperation(setBuilderOperatorFactory, buildSource), - buildContext); + new PhysicalOperation(setBuilderOperatorFactory, buildSource)); // Source channels are always laid out first, followed by the boolean output symbol Map outputMappings = ImmutableMap.builder() @@ -3736,7 +3924,7 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan List expectedLayout = getOnlyElement(node.getInputs()); Function pagePreprocessor = enforceLoadedLayoutProcessor(expectedLayout, source.getLayout()); - context.addDriverFactory( + subContext.addDriverFactory( false, new PhysicalOperation( new LocalExchangeSinkOperatorFactory( @@ -3744,8 +3932,7 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan subContext.getNextOperatorId(), node.getId(), pagePreprocessor), - source), - subContext); + source)); // the main driver is not an input... the exchange sources are the input for the plan context.setInputDriver(false); @@ -3818,7 +4005,7 @@ else if (context.getDriverInstanceCount().isPresent()) { List expectedLayout = node.getInputs().get(i); Function pagePreprocessor = enforceLoadedLayoutProcessor(expectedLayout, source.getLayout()); - context.addDriverFactory( + subContext.addDriverFactory( false, new PhysicalOperation( new LocalExchangeSinkOperatorFactory( @@ -3826,8 +4013,7 @@ else if (context.getDriverInstanceCount().isPresent()) { subContext.getNextOperatorId(), node.getId(), pagePreprocessor), - source), - subContext); + source)); } // the main driver is not an input... the exchange sources are the input for the plan @@ -4321,7 +4507,8 @@ private static Set getConsumedDynamicFilterIds(PlanNode node) */ private static class PhysicalOperation { - private final List operatorFactories; + private final List pipelineTail; + private final List> pipelineHeadAlternatives; private final Map layout; private final List types; @@ -4340,20 +4527,42 @@ public PhysicalOperation(OperatorFactory outputOperatorFactory, PhysicalOperatio this(outputOperatorFactory, ImmutableMap.of(), Optional.of(requireNonNull(source, "source is null"))); } + public PhysicalOperation( + Map layout, + List pipelineHeadAlternatives) + { + this( + layout, + ImmutableList.of(), + pipelineHeadAlternatives.stream() + .map(PhysicalOperation::getPipelineTail) + .collect(toImmutableList())); + } + private PhysicalOperation( OperatorFactory operatorFactory, Map layout, Optional source) { - requireNonNull(operatorFactory, "operatorFactory is null"); + this( + layout, + ImmutableList.builder() + .addAll(source.map(PhysicalOperation::getPipelineTail).orElse(ImmutableList.of())) + .add(operatorFactory) + .build(), + source.map(operation -> operation.pipelineHeadAlternatives).orElse(ImmutableList.of())); + } + + private PhysicalOperation( + Map layout, + List pipelineTail, + List> pipelineHeadAlternatives) + { requireNonNull(layout, "layout is null"); - requireNonNull(source, "source is null"); this.types = toTypes(layout); - this.operatorFactories = ImmutableList.builder() - .addAll(source.map(PhysicalOperation::getOperatorFactories).orElse(ImmutableList.of())) - .add(operatorFactory) - .build(); + this.pipelineTail = requireNonNull(pipelineTail, "pipelineEnd is null"); + this.pipelineHeadAlternatives = requireNonNull(pipelineHeadAlternatives, "pipelineStartAlternatives is null"); this.layout = ImmutableMap.copyOf(layout); } @@ -4390,7 +4599,18 @@ public Map getLayout() private List getOperatorFactories() { - return operatorFactories; + checkArgument(pipelineHeadAlternatives.isEmpty()); + return pipelineTail; + } + + private List getPipelineTail() + { + return pipelineTail; + } + + private List> getPipelineHeadAlternatives() + { + return pipelineHeadAlternatives; } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index cdea3859c6e9..53318d3dbb1b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -22,6 +22,8 @@ import io.opentelemetry.api.trace.SpanBuilder; import io.opentelemetry.context.Context; import io.trino.Session; +import io.trino.cache.CacheCommonSubqueries; +import io.trino.cache.CacheController; import io.trino.cost.CachingCostProvider; import io.trino.cost.CachingStatsProvider; import io.trino.cost.CachingTableStatsProvider; @@ -131,12 +133,14 @@ import static com.google.common.collect.Streams.zip; import static io.trino.SystemSessionProperties.getMaxWriterTaskCount; import static io.trino.SystemSessionProperties.getRetryPolicy; +import static io.trino.SystemSessionProperties.isCacheEnabled; import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries; import static io.trino.SystemSessionProperties.isUsePreferredWritePartitioning; import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_SYMBOL; import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND; import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED; @@ -187,6 +191,8 @@ public enum Stage private final StatisticsAggregationPlanner statisticsAggregationPlanner; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; + private final boolean cacheEnabled; + private final CacheCommonSubqueries cacheCommonSubqueries; private final WarningCollector warningCollector; private final PlanOptimizersStatsCollector planOptimizersStatsCollector; private final CachingTableStatsProvider tableStatsProvider; @@ -229,6 +235,13 @@ public LogicalPlanner( this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, plannerContext, session); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.cacheEnabled = isCacheEnabled(session); + this.cacheCommonSubqueries = new CacheCommonSubqueries( + new CacheController(), + plannerContext, + session, + idAllocator, + symbolAllocator); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null"); this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null"); @@ -281,6 +294,20 @@ public Plan plan(Analysis analysis, Stage stage, boolean collectPlanStatistics) } } + if (cacheEnabled) { + try (var ignored = scopedSpan(plannerContext.getTracer(), "cache-subqueries")) { + root = cacheCommonSubqueries.cacheSubqueries(root); + if (stage.ordinal() >= OPTIMIZED_AND_VALIDATED.ordinal()) { + try (var span = scopedSpan(plannerContext.getTracer(), "validate-alternatives")) { + planSanityChecker.validatePlanWithAlternatives(root, session, plannerContext, warningCollector); + } + } + } + catch (Throwable t) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "SUBQUERY CACHE: planning exception", t); + } + } + TableStatsProvider collectTableStatsProvider; if (collectPlanStatistics) { collectTableStatsProvider = tableStatsProvider; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java index 079f64482063..32277aeb2a8a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java @@ -18,6 +18,8 @@ import io.trino.sql.planner.optimizations.UnaliasSymbolReferences; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ApplyNode; +import io.trino.sql.planner.plan.CacheDataPlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.DynamicFilterSourceNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; @@ -27,6 +29,7 @@ import io.trino.sql.planner.plan.IntersectNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.OffsetNode; import io.trino.sql.planner.plan.PatternRecognitionNode; import io.trino.sql.planner.plan.PlanNode; @@ -216,6 +219,25 @@ public PlanNode visitPatternRecognition(PatternRecognitionNode node, RewriteCont node.getVariableDefinitions()); } + @Override + public PlanNode visitChooseAlternativeNode(ChooseAlternativeNode node, RewriteContext context) + { + List alternatives = node.getSources().stream().map(context::rewrite).collect(toImmutableList()); + return new ChooseAlternativeNode(idAllocator.getNextId(), alternatives, node.getOriginalTableScan()); + } + + @Override + public PlanNode visitCacheDataPlanNode(CacheDataPlanNode node, RewriteContext context) + { + return new CacheDataPlanNode(idAllocator.getNextId(), context.rewrite(node.getSource())); + } + + @Override + public PlanNode visitLoadCachedDataPlanNode(LoadCachedDataPlanNode node, RewriteContext context) + { + return new LoadCachedDataPlanNode(idAllocator.getNextId(), node.getPlanSignature(), node.getDynamicFilterDisjuncts(), node.getCommonColumnHandles(), node.getOutputSymbols()); + } + @Override public PlanNode visitUnion(UnionNode node, RewriteContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index cd4e3cdc69d0..9a4f4dfa911c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -36,6 +36,8 @@ import io.trino.spi.function.FunctionId; import io.trino.sql.planner.optimizations.PlanNodeSearcher; import io.trino.sql.planner.plan.AdaptivePlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode.FilteredTableScan; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; import io.trino.sql.planner.plan.MergeWriterNode; @@ -435,6 +437,22 @@ public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, Rew return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitChooseAlternativeNode(ChooseAlternativeNode node, RewriteContext context) + { + // All alternatives and original table scan should have the same partitioning + TableScanNode scan = node.getOriginalTableScan().tableScanNode(); + PartitioningHandle partitioning = metadata.getTableProperties(session, scan.getTable()) + .getTablePartitioning() + .filter(value -> scan.isUseConnectorNodePartitioning()) + .map(TablePartitioning::partitioningHandle) + .orElse(SOURCE_DISTRIBUTION); + context.get().addSourceDistribution(node.getId(), partitioning, metadata, session); + + // stop the process in order not to add the underlying TableScanNodes as well + return node; + } + @Override public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, RewriteContext context) { @@ -819,6 +837,17 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) // and new partitioning is compatible with previous one node.getUseConnectorNodePartitioning()); } + + @Override + public PlanNode visitChooseAlternativeNode(ChooseAlternativeNode node, RewriteContext context) + { + List newAlternatives = node.getSources().stream() + .map(alternative -> context.defaultRewrite(alternative, context.get())) + .toList(); + TableScanNode newTableScan = (TableScanNode) context.rewrite(node.getOriginalTableScan().tableScanNode()); + FilteredTableScan newFilteredTableScan = new FilteredTableScan(newTableScan, node.getOriginalTableScan().filterPredicate()); + return new ChooseAlternativeNode(node.getId(), newAlternatives, newFilteredTableScan); + } } private static final class ExchangeNodeToRemoteSourceRewriter diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SchedulingOrderVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/SchedulingOrderVisitor.java index 2a4275c9fe3e..a3d13b0e85b8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SchedulingOrderVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SchedulingOrderVisitor.java @@ -15,6 +15,7 @@ package io.trino.sql.planner; import com.google.common.collect.ImmutableList; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.TableFunctionProcessorNode; @@ -46,6 +47,13 @@ public Visitor(Consumer schedulingOrder) this.schedulingOrder = requireNonNull(schedulingOrder, "schedulingOrder is null"); } + @Override + public Void visitChooseAlternativeNode(ChooseAlternativeNode node, Void context) + { + schedulingOrder.accept(node.getId()); + return null; + } + @Override public Void visitTableScan(TableScanNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SharedOperatorFactory.java b/core/trino-main/src/main/java/io/trino/sql/planner/SharedOperatorFactory.java new file mode 100644 index 000000000000..0a32e496f361 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SharedOperatorFactory.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import io.trino.operator.DriverContext; +import io.trino.operator.LocalPlannerAware; +import io.trino.operator.Operator; +import io.trino.operator.OperatorFactory; + +import static java.util.Objects.requireNonNull; + +/** + * OperatorFactory that can be reused between DriverFactory instances that share the same lifecycle. + * DriverFactory instances are reused by different sub-plan alternatives for the same pipeline. + * This class makes sure noMoreOperators is called only once for the delegate + * as some OperatorFactory implementations fail if noMoreOperators is called twice. + */ +public class SharedOperatorFactory + implements OperatorFactory, LocalPlannerAware +{ + private final OperatorFactory delegate; + private boolean noMoreOperators; + + public SharedOperatorFactory(OperatorFactory delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + return delegate.createOperator(driverContext); + } + + @Override + public synchronized void noMoreOperators() + { + if (noMoreOperators) { + return; + } + delegate.noMoreOperators(); + noMoreOperators = true; + } + + @Override + public OperatorFactory duplicate() + { + return new SharedOperatorFactory(delegate.duplicate()); + } + + @Override + public void localPlannerComplete() + { + if (delegate instanceof LocalPlannerAware localPlannerAware) { + localPlannerAware.localPlannerComplete(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java index 99a041cfc13a..c36da4ceab09 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SplitSourceFactory.java @@ -17,8 +17,14 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import io.airlift.log.Logger; +import io.airlift.node.NodeInfo; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.CacheSplitSource; +import io.trino.cache.ConnectorAwareAddressProvider; +import io.trino.cache.SplitAdmissionControllerProvider; +import io.trino.execution.QueryManagerConfig; +import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.metadata.TableHandle; import io.trino.server.DynamicFilterService; import io.trino.spi.connector.ColumnHandle; @@ -34,6 +40,7 @@ import io.trino.sql.planner.plan.AdaptivePlanNode; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AssignUniqueId; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterSourceNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; @@ -44,6 +51,7 @@ import io.trino.sql.planner.plan.IndexJoinNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.MergeProcessorNode; import io.trino.sql.planner.plan.MergeWriterNode; @@ -80,7 +88,10 @@ import java.util.Map; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.cache.CacheCommonSubqueries.getLoadCachedDataPlanNode; +import static io.trino.cache.CacheCommonSubqueries.isCacheChooseAlternativeNode; import static io.trino.spi.connector.Constraint.alwaysTrue; import static io.trino.spi.connector.DynamicFilter.EMPTY; import static io.trino.sql.ir.IrUtils.filterConjuncts; @@ -93,22 +104,41 @@ public class SplitSourceFactory private final SplitManager splitManager; private final PlannerContext plannerContext; private final DynamicFilterService dynamicFilterService; + private final ConnectorAwareAddressProvider connectorAwareAddressProvider; + private final NodeInfo nodeInfo; + private final boolean schedulerIncludeCoordinator; + private final int minScheduleSplitBatchSize; @Inject - public SplitSourceFactory(SplitManager splitManager, PlannerContext plannerContext, DynamicFilterService dynamicFilterService) + public SplitSourceFactory( + SplitManager splitManager, + PlannerContext plannerContext, + DynamicFilterService dynamicFilterService, + ConnectorAwareAddressProvider connectorAwareAddressProvider, + NodeInfo nodeInfo, + NodeSchedulerConfig nodeSchedulerConfig, + QueryManagerConfig queryManagerConfig) { this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.plannerContext = requireNonNull(plannerContext, "metadata is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.connectorAwareAddressProvider = requireNonNull(connectorAwareAddressProvider, "connectorAwareAddressProvider is null"); + this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null"); + this.schedulerIncludeCoordinator = requireNonNull(nodeSchedulerConfig, "nodeSchedulerConfig is null").isIncludeCoordinator(); + this.minScheduleSplitBatchSize = requireNonNull(queryManagerConfig, "queryManagerConfig is null").getMinScheduleSplitBatchSize(); } - public Map createSplitSources(Session session, Span stageSpan, PlanFragment fragment) + public Map createSplitSources( + Session session, + Span stageSpan, + PlanFragment fragment, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { ImmutableList.Builder allSplitSources = ImmutableList.builder(); try { // get splits for this fragment, this is lazy so split assignments aren't actually calculated here return fragment.getRoot().accept( - new Visitor(session, stageSpan, allSplitSources), + new Visitor(session, stageSpan, allSplitSources, splitAdmissionControllerProvider), null); } catch (Throwable t) { @@ -133,15 +163,18 @@ private final class Visitor private final Session session; private final Span stageSpan; private final ImmutableList.Builder splitSources; + private final SplitAdmissionControllerProvider splitAdmissionControllerProvider; private Visitor( Session session, Span stageSpan, - ImmutableList.Builder allSplitSources) + ImmutableList.Builder allSplitSources, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { this.session = session; this.stageSpan = stageSpan; this.splitSources = allSplitSources; + this.splitAdmissionControllerProvider = splitAdmissionControllerProvider; } @Override @@ -324,6 +357,32 @@ public Map visitTableFunctionProcessor(TableFunctionPro return node.getSource().orElseThrow().accept(this, context); } + @Override + public Map visitChooseAlternativeNode(ChooseAlternativeNode node, Void context) + { + checkArgument(isCacheChooseAlternativeNode(node)); + + TableHandle originalTableHandle = node.getOriginalTableScan().tableHandle(); + SplitSource splitSource = createSplitSource( + originalTableHandle, + node.getOriginalTableScan().assignments(), + node.getOriginalTableScan().filterPredicate()); + LoadCachedDataPlanNode loadCachedDataNode = getLoadCachedDataPlanNode(node); + CacheSplitSource cacheSplitSource = splitManager.getCacheSplitSource( + loadCachedDataNode.getPlanSignature().signature(), + originalTableHandle, + splitSource, + connectorAwareAddressProvider, + nodeInfo, + splitAdmissionControllerProvider, + schedulerIncludeCoordinator, + minScheduleSplitBatchSize); + + splitSources.add(cacheSplitSource); + + return ImmutableMap.of(node.getId(), cacheSplitSource); + } + @Override public Map visitRowNumber(RowNumberNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CacheDataPlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CacheDataPlanNode.java new file mode 100644 index 000000000000..2ef78050935b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CacheDataPlanNode.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.errorprone.annotations.Immutable; +import io.trino.sql.planner.Symbol; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +@Immutable +public class CacheDataPlanNode + extends PlanNode +{ + private final PlanNode source; + + @JsonCreator + public CacheDataPlanNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source) + { + super(id); + this.source = requireNonNull(source, "source is null"); + } + + @Override + public List getOutputSymbols() + { + return source.getOutputSymbols(); + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitCacheDataPlanNode(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new CacheDataPlanNode(getId(), Iterables.getOnlyElement(newChildren)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ChooseAlternativeNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ChooseAlternativeNode.java new file mode 100644 index 000000000000..f2d9179d1b09 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ChooseAlternativeNode.java @@ -0,0 +1,151 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.Immutable; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.ColumnHandle; +import io.trino.sql.ir.Expression; +import io.trino.sql.planner.Symbol; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +@Immutable +public class ChooseAlternativeNode + extends PlanNode +{ + private final List alternatives; + + private final FilteredTableScan originalTableScan; + + @JsonCreator + public ChooseAlternativeNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("sources") List alternatives, + @JsonProperty("originalTableScan") FilteredTableScan originalTableScan) + { + super(id); + + requireNonNull(alternatives, "alternatives is null"); + checkArgument(alternatives.size() > 1, "Expected at least two alternative"); + checkArgument(sameOutputSymbols(alternatives), "All alternatives should have the same output symbols"); + this.alternatives = ImmutableList.copyOf(alternatives); + + this.originalTableScan = requireNonNull(originalTableScan, "originalTableScan is null"); + } + + private boolean sameOutputSymbols(List alternatives) + { + List outputSymbols = alternatives.get(0).getOutputSymbols(); + for (int i = 1; i < alternatives.size(); i++) { + if (!outputSymbols.equals(alternatives.get(i).getOutputSymbols())) { + return false; + } + } + return true; + } + + @JsonProperty + public FilteredTableScan getOriginalTableScan() + { + return originalTableScan; + } + + @JsonProperty + @Override + public List getSources() + { + return alternatives; + } + + @Override + public List getOutputSymbols() + { + // all alternatives must have the same output symbols but can theoretically differ on the order. + // any order would work here, so we pick the order from the first alternative. + // this is consistent with LocalExecutionPlanner.Visitor.visitChooseAlternativeNode + return alternatives.get(0).getOutputSymbols(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + checkArgument(newChildren.size() == alternatives.size(), "expected newChildren to contain %s nodes", alternatives.size()); + return new ChooseAlternativeNode(getId(), newChildren, originalTableScan); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitChooseAlternativeNode(this, context); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ChooseAlternativeNode chooseAlternativeNode = (ChooseAlternativeNode) o; + return Objects.equals(alternatives, chooseAlternativeNode.alternatives) && + Objects.equals(originalTableScan, chooseAlternativeNode.originalTableScan); + } + + @Override + public int hashCode() + { + return Objects.hash(alternatives, originalTableScan); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("alternatives", alternatives) + .add("originalTableScan", originalTableScan) + .toString(); + } + + public record FilteredTableScan(TableScanNode tableScanNode, Optional filterPredicate) + { + public FilteredTableScan + { + requireNonNull(tableScanNode, "tableScanNode is null"); + requireNonNull(filterPredicate, "filterPredicate is null"); + } + + public TableHandle tableHandle() + { + return tableScanNode.getTable(); + } + + public Map assignments() + { + return tableScanNode.getAssignments(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/LoadCachedDataPlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/LoadCachedDataPlanNode.java new file mode 100644 index 000000000000..ab89dbda20a6 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/LoadCachedDataPlanNode.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.Immutable; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.connector.ColumnHandle; +import io.trino.sql.ir.Expression; +import io.trino.sql.planner.Symbol; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +@Immutable +public class LoadCachedDataPlanNode + extends PlanNode +{ + private final PlanSignatureWithPredicate planSignature; + /** + * Dynamic filter disjuncts from all common subplans. + */ + private final Expression dynamicFilterDisjuncts; + private final Map commonColumnHandles; + private final List outputSymbols; + + @JsonCreator + public LoadCachedDataPlanNode( + @JsonProperty PlanNodeId id, + @JsonProperty PlanSignatureWithPredicate planSignature, + @JsonProperty Expression dynamicFilterDisjuncts, + @JsonProperty Map commonColumnHandles, + @JsonProperty List outputSymbols) + { + super(id); + this.planSignature = requireNonNull(planSignature, "planSignature is null"); + this.dynamicFilterDisjuncts = requireNonNull(dynamicFilterDisjuncts, "dynamicFilterDisjuncts is null"); + this.commonColumnHandles = requireNonNull(commonColumnHandles, "commonColumnHandles is null"); + this.outputSymbols = requireNonNull(outputSymbols, "outputSymbols is null"); + } + + @JsonProperty + public PlanSignatureWithPredicate getPlanSignature() + { + return planSignature; + } + + @JsonProperty + public Expression getDynamicFilterDisjuncts() + { + return dynamicFilterDisjuncts; + } + + @JsonProperty + public Map getCommonColumnHandles() + { + return commonColumnHandles; + } + + @Override + @JsonProperty + public List getOutputSymbols() + { + return outputSymbols; + } + + @Override + public List getSources() + { + return ImmutableList.of(); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitLoadCachedDataPlanNode(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new LoadCachedDataPlanNode( + getId(), + planSignature, + dynamicFilterDisjuncts, + commonColumnHandles, + outputSymbols); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java index e4ed765f9dd6..be1f93129129 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java @@ -243,6 +243,11 @@ public static Pattern except() return typeOf(ExceptNode.class); } + public static Pattern chooseAlternative() + { + return typeOf(ChooseAlternativeNode.class); + } + public static Pattern remoteSourceNode() { return typeOf(RemoteSourceNode.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index 258837d805d7..cca00f91d5b9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -62,6 +62,9 @@ @JsonSubTypes.Type(value = TableDeleteNode.class, name = "tableDelete"), @JsonSubTypes.Type(value = TableExecuteNode.class, name = "tableExecute"), @JsonSubTypes.Type(value = TableFinishNode.class, name = "tableCommit"), + @JsonSubTypes.Type(value = ChooseAlternativeNode.class, name = "chooseAlternative"), + @JsonSubTypes.Type(value = CacheDataPlanNode.class, name = "cacheData"), + @JsonSubTypes.Type(value = LoadCachedDataPlanNode.class, name = "loadCachedData"), @JsonSubTypes.Type(value = TableFunctionNode.class, name = "tableFunction"), @JsonSubTypes.Type(value = TableFunctionProcessorNode.class, name = "tableFunctionProcessor"), @JsonSubTypes.Type(value = TableScanNode.class, name = "tableScan"), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 45ab769d0f3c..1eceee68d961 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -249,6 +249,21 @@ public R visitPatternRecognition(PatternRecognitionNode node, C context) return visitPlan(node, context); } + public R visitChooseAlternativeNode(ChooseAlternativeNode node, C context) + { + return visitPlan(node, context); + } + + public R visitCacheDataPlanNode(CacheDataPlanNode node, C context) + { + return visitPlan(node, context); + } + + public R visitLoadCachedDataPlanNode(LoadCachedDataPlanNode node, C context) + { + return visitPlan(node, context); + } + public R visitTableFunction(TableFunctionNode node, C context) { return visitPlan(node, context); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 589e97e7fa7d..b7dacb350d53 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -73,6 +73,8 @@ import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.CacheDataPlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterId; @@ -89,6 +91,7 @@ import io.trino.sql.planner.plan.IntersectNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.MergeProcessorNode; import io.trino.sql.planner.plan.MergeWriterNode; @@ -177,6 +180,7 @@ import static io.trino.sql.planner.planprinter.TextRenderer.formatPositions; import static io.trino.sql.planner.planprinter.TextRenderer.indentString; import static java.lang.Math.abs; +import static java.lang.Math.min; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; @@ -197,6 +201,7 @@ public class PlanPrinter private final Map dynamicFilterDomainStats; private final ValuePrinter valuePrinter; private final Anonymizer anonymizer; + private final boolean verbose; // NOTE: do NOT add Metadata or Session to this class. The plan printer must be usable outside of a transaction. @VisibleForTesting @@ -207,7 +212,8 @@ public class PlanPrinter ValuePrinter valuePrinter, StatsAndCosts estimatedStatsAndCosts, Optional> stats, - Anonymizer anonymizer) + Anonymizer anonymizer, + boolean verbose) { requireNonNull(planRoot, "planRoot is null"); requireNonNull(tableInfoSupplier, "tableInfoSupplier is null"); @@ -221,6 +227,7 @@ public class PlanPrinter this.dynamicFilterDomainStats = ImmutableMap.copyOf(dynamicFilterDomainStats); this.valuePrinter = valuePrinter; this.anonymizer = anonymizer; + this.verbose = verbose; Optional totalScheduledTime = stats.map(s -> new Duration(s.values().stream() .mapToLong(planNode -> planNode.getPlanNodeScheduledTime().toMillis()) @@ -267,7 +274,8 @@ public static String jsonFragmentPlan(PlanNode root, Metadata metadata, Function valuePrinter, StatsAndCosts.empty(), Optional.empty(), - new NoOpAnonymizer()) + new NoOpAnonymizer(), + false) .toJson(); } @@ -287,7 +295,8 @@ public static String jsonLogicalPlan( valuePrinter, estimatedStatsAndCosts, Optional.empty(), - new NoOpAnonymizer()) + new NoOpAnonymizer(), + false) .toJson(); } @@ -345,7 +354,8 @@ private static String jsonDistributedPlan( valuePrinter, planFragment.getStatsAndCosts(), Optional.empty(), - anonymizer) + anonymizer, + false) .toJsonRenderedNode())); return DISTRIBUTED_PLAN_CODEC.toJson(anonymizedPlan); } @@ -383,7 +393,8 @@ public static String textLogicalPlan( valuePrinter, estimatedStatsAndCosts, Optional.empty(), - new NoOpAnonymizer()) + new NoOpAnonymizer(), + verbose) .toText(verbose, level)); return builder.toString(); } @@ -589,7 +600,8 @@ private static String formatFragment( valuePrinter, fragment.getStatsAndCosts(), planNodeStats, - anonymizer).toText(verbose, 1)) + anonymizer, + verbose).toText(verbose, 1)) .append("\n"); return builder.toString(); @@ -1940,6 +1952,45 @@ public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Context return processChildren(node, new Context(context.isInitialPlan())); } + @Override + public Void visitChooseAlternativeNode(ChooseAlternativeNode node, Context context) + { + List alternatives = node.getSources(); + addNode(node, "ChooseAlternativeNode", ImmutableMap.of("alternativesCount", String.valueOf(alternatives.size())), context); + Context childContext = new Context(context.isInitialPlan()); + if (stats.isEmpty()) { + // print no more than 10 alternatives for EXPLAIN ANALYZE + for (int i = 0; i < min(alternatives.size(), 10); i++) { + PlanNode child = alternatives.get(i); + child.accept(this, childContext); + } + } + else { + for (PlanNode child : alternatives) { + Optional childStats = stats.map(s -> s.get(child.getId())); + // print alternative if it was used or verbose + if (verbose || childStats.isPresent()) { + child.accept(this, childContext); + } + } + } + return null; + } + + @Override + public Void visitLoadCachedDataPlanNode(LoadCachedDataPlanNode node, Context context) + { + addNode(node, "LoadCachedData", context); + return null; + } + + @Override + public Void visitCacheDataPlanNode(CacheDataPlanNode node, Context context) + { + addNode(node, "CacheData", context); + return processChildren(node, new Context(context.isInitialPlan())); + } + @Override protected Void visitPlan(PlanNode node, Context context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java index e1f5b8fdca11..d42ade570ff5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java @@ -55,6 +55,18 @@ public PlanSanityChecker(boolean forceSingleNode) new DynamicFiltersChecker(), new TableScanValidator(), new TableExecuteStructureValidator()) + .putAll( + Stage.AFTER_ALTERNATIVES_PLANNING, + new ValidateDependenciesChecker(), + new NoDuplicatePlanNodeIdsChecker(), + new TypeValidator(), + new VerifyOnlyOneOutputNode(), + new VerifyNoFilteredAggregations(), + new VerifyUseConnectorNodePartitioningSet(), + new ValidateScaledWritersUsage(), + new DynamicFiltersChecker(), + new TableScanValidator(), + new TableExecuteStructureValidator()) .putAll( Stage.AFTER_ADAPTIVE_PLANNING, new ValidateDependenciesChecker(), @@ -87,6 +99,15 @@ public void validateIntermediatePlan( validate(Stage.INTERMEDIATE, planNode, session, plannerContext, warningCollector); } + public void validatePlanWithAlternatives( + PlanNode planNode, + Session session, + PlannerContext plannerContext, + WarningCollector warningCollector) + { + validate(Stage.AFTER_ALTERNATIVES_PLANNING, planNode, session, plannerContext, warningCollector); + } + public void validateAdaptivePlan( PlanNode planNode, Session session, @@ -137,6 +158,6 @@ void validate( private enum Stage { - INTERMEDIATE, FINAL, AFTER_ADAPTIVE_PLANNING + INTERMEDIATE, FINAL, AFTER_ALTERNATIVES_PLANNING, AFTER_ADAPTIVE_PLANNING } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index e78e9b88611f..d3ab0127ee46 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -25,6 +25,8 @@ import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; +import io.trino.sql.planner.plan.CacheDataPlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.DistinctLimitNode; import io.trino.sql.planner.plan.DynamicFilterSourceNode; @@ -39,6 +41,7 @@ import io.trino.sql.planner.plan.IntersectNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.MergeProcessorNode; import io.trino.sql.planner.plan.MergeWriterNode; @@ -344,6 +347,31 @@ public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundSymbols) + { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); // visit child + } + + return null; + } + + @Override + public Void visitLoadCachedDataPlanNode(LoadCachedDataPlanNode node, Set boundSymbols) + { + return null; + } + + @Override + public Void visitCacheDataPlanNode(CacheDataPlanNode node, Set boundSymbols) + { + node.getSource().accept(this, boundSymbols); // visit child + + return null; + } + @Override public Void visitWindow(WindowNode node, Set boundSymbols) { diff --git a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java index bfce4411861b..101d467c2a14 100644 --- a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java +++ b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java @@ -18,6 +18,9 @@ import com.google.common.collect.ImmutableSet; import com.google.common.io.Closer; import io.airlift.configuration.secrets.SecretsResolver; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.airlift.json.ObjectMapperProvider; import io.airlift.node.NodeInfo; import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; @@ -25,6 +28,11 @@ import io.trino.FeaturesConfig; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.block.BlockJsonSerde; +import io.trino.cache.CacheConfig; +import io.trino.cache.CacheManagerRegistry; +import io.trino.cache.CacheMetadata; +import io.trino.cache.CacheStats; import io.trino.client.NodeVersion; import io.trino.connector.CatalogFactory; import io.trino.connector.CatalogServiceProviderModule; @@ -138,11 +146,15 @@ import io.trino.spi.PageIndexerFactory; import io.trino.spi.PageSorter; import io.trino.spi.Plugin; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockEncodingSerde; import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.connector.ConnectorName; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeOperators; import io.trino.spiller.GenericSpillerFactory; @@ -226,6 +238,7 @@ import static io.opentelemetry.api.OpenTelemetry.noop; import static io.trino.connector.CatalogServiceProviderModule.createAccessControlProvider; import static io.trino.connector.CatalogServiceProviderModule.createAnalyzePropertyManager; +import static io.trino.connector.CatalogServiceProviderModule.createCacheMetadata; import static io.trino.connector.CatalogServiceProviderModule.createColumnPropertyManager; import static io.trino.connector.CatalogServiceProviderModule.createFunctionProvider; import static io.trino.connector.CatalogServiceProviderModule.createIndexProvider; @@ -303,6 +316,9 @@ public class PlanTester private final PluginManager pluginManager; private final ExchangeManagerRegistry exchangeManagerRegistry; private final SpoolingManagerRegistry spoolingManagerRegistry; + private final CacheManagerRegistry cacheManagerRegistry; + private final JsonCodec tupleDomainCodec; + private final TaskManagerConfig taskManagerConfig; private final OptimizerConfig optimizerConfig; private final StatementAnalyzerFactory statementAnalyzerFactory; @@ -310,15 +326,20 @@ public class PlanTester public static PlanTester create(Session defaultSession) { - return new PlanTester(defaultSession, 1); + return create(defaultSession, 1); } public static PlanTester create(Session defaultSession, int nodeCountForStats) { - return new PlanTester(defaultSession, nodeCountForStats); + return create(defaultSession, nodeCountForStats, new CacheConfig()); + } + + public static PlanTester create(Session defaultSession, int nodeCountForStats, CacheConfig cacheConfig) + { + return new PlanTester(defaultSession, nodeCountForStats, cacheConfig); } - private PlanTester(Session defaultSession, int nodeCountForStats) + private PlanTester(Session defaultSession, int nodeCountForStats, CacheConfig cacheConfig) { requireNonNull(defaultSession, "defaultSession is null"); @@ -397,7 +418,7 @@ private PlanTester(Session defaultSession, int nodeCountForStats) this.pageSinkManager = new PageSinkManager(createPageSinkProvider(catalogManager)); this.indexManager = new IndexManager(createIndexProvider(catalogManager)); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, nodeSchedulerConfig, new NodeTaskMap(finalizerService))); - this.sessionPropertyManager = createSessionPropertyManager(catalogManager, taskManagerConfig, optimizerConfig); + this.sessionPropertyManager = createSessionPropertyManager(catalogManager, taskManagerConfig, cacheConfig, optimizerConfig); this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler, typeOperators, createNodePartitioningProvider(catalogManager)); TableProceduresRegistry tableProceduresRegistry = new TableProceduresRegistry(createTableProceduresProvider(catalogManager)); FunctionManager functionManager = new FunctionManager(createFunctionProvider(catalogManager), globalFunctionCatalog, languageFunctionManager); @@ -416,7 +437,8 @@ private PlanTester(Session defaultSession, int nodeCountForStats) new JsonValueFunction(functionManager, metadata, typeManager), new JsonQueryFunction(functionManager, metadata, typeManager))); - this.plannerContext = new PlannerContext(metadata, typeOperators, blockEncodingSerde, typeManager, functionManager, languageFunctionManager, tracer); + CacheMetadata cacheMetadata = new CacheMetadata(createCacheMetadata(catalogManager)); + this.plannerContext = new PlannerContext(metadata, cacheMetadata, typeOperators, blockEncodingSerde, typeManager, functionManager, languageFunctionManager, tracer); this.pageFunctionCompiler = new PageFunctionCompiler(functionManager, 0); this.filterCompiler = new ColumnarFilterCompiler(functionManager, 0); this.expressionCompiler = new ExpressionCompiler(functionManager, pageFunctionCompiler, filterCompiler); @@ -458,6 +480,8 @@ private PlanTester(Session defaultSession, int nodeCountForStats) exchangeManagerRegistry = new ExchangeManagerRegistry(noop(), noopTracer(), secretsResolver); spoolingManagerRegistry = new SpoolingManagerRegistry(new ServerConfig(), new SpoolingEnabledConfig(), noop(), noopTracer()); + cacheManagerRegistry = new CacheManagerRegistry(cacheConfig, new LocalMemoryManager(new NodeMemoryConfig()), plannerContext.getBlockEncodingSerde(), new CacheStats()); + tupleDomainCodec = getTupleDomainJsonCodec(blockEncodingSerde, typeManager); this.pluginManager = new PluginManager( (loader, createClassLoader) -> {}, Optional.empty(), @@ -475,7 +499,8 @@ private PlanTester(Session defaultSession, int nodeCountForStats) blockEncodingManager, new HandleResolver(), exchangeManagerRegistry, - spoolingManagerRegistry); + spoolingManagerRegistry, + cacheManagerRegistry); catalogManager.registerGlobalSystemConnector(globalSystemConnector); languageFunctionManager.setPlannerContext(plannerContext); @@ -511,9 +536,21 @@ private PlanTester(Session defaultSession, int nodeCountForStats) defaultSession.getQueryDataEncoding()); } + public static JsonCodec getTupleDomainJsonCodec(BlockEncodingSerde blockEncodingSerde, TypeManager typeManager) + { + ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); + objectMapperProvider.setJsonDeserializers(ImmutableMap.of( + Block.class, new BlockJsonSerde.Deserializer(blockEncodingSerde), + Type.class, new TypeDeserializer(typeManager))); + objectMapperProvider.setJsonSerializers(ImmutableMap.of( + Block.class, new BlockJsonSerde.Serializer(blockEncodingSerde))); + return new JsonCodecFactory(objectMapperProvider).jsonCodec(TupleDomain.class); + } + private static SessionPropertyManager createSessionPropertyManager( ConnectorServicesProvider connectorServicesProvider, TaskManagerConfig taskManagerConfig, + CacheConfig cacheConfig, OptimizerConfig optimizerConfig) { SystemSessionProperties sessionProperties = new SystemSessionProperties( @@ -524,6 +561,7 @@ private static SessionPropertyManager createSessionPropertyManager( optimizerConfig, new NodeMemoryConfig(), new DynamicFilterConfig(), + cacheConfig, new NodeSchedulerConfig()); return CatalogServiceProviderModule.createSessionPropertyManager(ImmutableSet.of(sessionProperties), connectorServicesProvider); } @@ -735,6 +773,9 @@ private List createDrivers(Session session, @Language("SQL") String sql) plannerContext, Optional.empty(), pageSourceManager, + cacheManagerRegistry, + tupleDomainCodec, + new CacheStats(), indexManager, nodePartitioningManager, pageSinkManager, @@ -803,7 +844,7 @@ private List createDrivers(Session session, @Language("SQL") String sql) } else { DriverContext driverContext = taskContext.addPipelineContext(driverFactory.getPipelineId(), driverFactory.isInputDriver(), driverFactory.isOutputDriver(), false).addDriverContext(); - Driver driver = driverFactory.createDriver(driverContext); + Driver driver = driverFactory.createDriver(driverContext, Optional.empty()); drivers.add(driver); } } @@ -817,7 +858,7 @@ private List createDrivers(Session session, @Language("SQL") String sql) boolean partitioned = partitionedSources.contains(driverFactory.getSourceId().orElseThrow()); for (ScheduledSplit split : splitAssignment.getSplits()) { DriverContext driverContext = taskContext.addPipelineContext(driverFactory.getPipelineId(), driverFactory.isInputDriver(), driverFactory.isOutputDriver(), partitioned).addDriverContext(); - Driver driver = driverFactory.createDriver(driverContext); + Driver driver = driverFactory.createDriver(driverContext, Optional.of(split)); driver.updateSplitAssignment(new SplitAssignment(split.getPlanNodeId(), ImmutableSet.of(split), true)); drivers.add(driver); } diff --git a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java index d3e92ee9ef59..2fc12765ea78 100644 --- a/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/QueryRunner.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableMap; import io.opentelemetry.sdk.trace.data.SpanData; import io.trino.Session; +import io.trino.cache.CacheMetadata; import io.trino.cost.StatsCalculator; import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.metadata.FunctionBundle; @@ -57,6 +58,8 @@ public interface QueryRunner TransactionManager getTransactionManager(); + CacheMetadata getCacheMetadata(); + PlannerContext getPlannerContext(); QueryExplainer getQueryExplainer(); diff --git a/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java index 1bbeda0f0a2b..9aa911236474 100644 --- a/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java @@ -20,6 +20,7 @@ import io.opentelemetry.sdk.trace.data.SpanData; import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; import io.trino.Session; +import io.trino.cache.CacheMetadata; import io.trino.cost.StatsCalculator; import io.trino.execution.FailureInjector.InjectedFailureType; import io.trino.execution.warnings.WarningCollector; @@ -189,6 +190,12 @@ public PlannerContext getPlannerContext() return server.getPlannerContext(); } + @Override + public CacheMetadata getCacheMetadata() + { + return server.getCacheMetadata(); + } + @Override public QueryExplainer getQueryExplainer() { diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingHandles.java b/core/trino-main/src/main/java/io/trino/testing/TestingHandles.java index 31ca8cfe32aa..a045ea83af45 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingHandles.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingHandles.java @@ -17,6 +17,7 @@ import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.CatalogHandle.CatalogVersion; +import io.trino.spi.connector.SchemaTableName; import io.trino.testing.TestingMetadata.TestingTableHandle; import static io.trino.spi.connector.CatalogHandle.createRootCatalogHandle; @@ -37,4 +38,9 @@ public static CatalogHandle createTestCatalogHandle(String catalogName) { return createRootCatalogHandle(new CatalogName(catalogName), TEST_CATALOG_VERSION); } + + public static TableHandle createTestTableHandle(SchemaTableName schemaTableName) + { + return new TableHandle(TEST_CATALOG_HANDLE, new TestingTableHandle(schemaTableName), TestingTransactionHandle.create()); + } } diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheCommonSubqueries.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheCommonSubqueries.java new file mode 100644 index 000000000000..fbd1f5680ab1 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheCommonSubqueries.java @@ -0,0 +1,473 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.plugin.tpch.TpchColumnHandle; +import io.trino.plugin.tpch.TpchConnectorFactory; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.function.OperatorType; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.DynamicFilters; +import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.ValuesNode; +import io.trino.testing.PlanTester; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Predicate; + +import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalAggregationToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalExpressionToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.columnIdToSymbol; +import static io.trino.cache.CommonSubqueriesExtractor.aggregationKey; +import static io.trino.cache.CommonSubqueriesExtractor.combine; +import static io.trino.cache.CommonSubqueriesExtractor.filterProjectKey; +import static io.trino.cache.CommonSubqueriesExtractor.scanFilterProjectKey; +import static io.trino.spi.predicate.Range.greaterThan; +import static io.trino.spi.predicate.Range.lessThan; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.Logical.Operator.OR; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.cacheDataPlanNode; +import static io.trino.sql.planner.assertions.PlanMatchPattern.chooseAlternativeNode; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.loadCachedDataPlanNode; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.trino.sql.planner.assertions.PlanMatchPattern.symbol; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; +import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; +import static io.trino.testing.TestingSession.testSessionBuilder; + +public class TestCacheCommonSubqueries + extends BasePlanTest +{ + private static final Session TEST_SESSION = testSessionBuilder() + .setCatalog(TEST_CATALOG_NAME) + .setSchema("tiny") + // disable so join order is not changed in tests + .setSystemProperty(JOIN_REORDERING_STRATEGY, "none") + .build(); + private static final CacheColumnId NATIONKEY_COLUMN_ID = new CacheColumnId("[nationkey:bigint]"); + private static final CacheColumnId REGIONKEY_COLUMN_ID = new CacheColumnId("[regionkey:bigint]"); + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction MULTIPLY_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(BIGINT, BIGINT)); + + private String testCatalogId; + + @Override + protected PlanTester createPlanTester() + { + PlanTester planTester = PlanTester.create( + TEST_SESSION, + 1, + new CacheConfig() + .setEnabled(true) + .setCacheCommonSubqueriesEnabled(true)); + + planTester.createCatalog(planTester.getDefaultSession().getCatalog().get(), + new TpchConnectorFactory(1, false), + ImmutableMap.of()); + testCatalogId = planTester.getCatalogHandle(TEST_SESSION.getCatalog().orElseThrow()).getId(); + + return planTester; + } + + @Test + public void testCacheCommonSubqueries() + { + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(testCatalogId + ":tiny:nation:0.01")), + Optional.empty(), + ImmutableList.of(REGIONKEY_COLUMN_ID, NATIONKEY_COLUMN_ID), + ImmutableList.of(BIGINT, BIGINT)), + TupleDomain.withColumnDomains(ImmutableMap.of( + NATIONKEY_COLUMN_ID, Domain.create(ValueSet.ofRanges(lessThan(BIGINT, 5L), greaterThan(BIGINT, 10L)), false)))); + Map columnHandles = ImmutableMap.of( + NATIONKEY_COLUMN_ID, new TpchColumnHandle("nationkey", BIGINT), + REGIONKEY_COLUMN_ID, new TpchColumnHandle("regionkey", BIGINT)); + assertPlan(""" + SELECT * FROM + (SELECT regionkey FROM nation WHERE nationkey > 10) + UNION ALL + (SELECT regionkey FROM nation WHERE nationkey < 5) + """, + anyTree(exchange(LOCAL, + chooseAlternativeNode( + // original subplan + strictProject(ImmutableMap.of("REGIONKEY_A", expression(new Reference(BIGINT, "REGIONKEY_A"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY_A"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY_A", "nationkey", "REGIONKEY_A", "regionkey")))), + // store data in cache alternative + strictProject(ImmutableMap.of("REGIONKEY_A", expression(new Reference(BIGINT, "REGIONKEY_A"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY_A"), new Constant(BIGINT, 10L)), + cacheDataPlanNode( + strictProject(ImmutableMap.of("REGIONKEY_A", expression(new Reference(BIGINT, "REGIONKEY_A")), "NATIONKEY_A", expression(new Reference(BIGINT, "NATIONKEY_A"))), + filter( + new Logical(OR, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY_A"), new Constant(BIGINT, 10L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "NATIONKEY_A"), new Constant(BIGINT, 5L)))), + tableScan("nation", ImmutableMap.of("NATIONKEY_A", "nationkey", "REGIONKEY_A", "regionkey"))))))), + // load data from cache alternative + strictProject(ImmutableMap.of("REGIONKEY_A", expression(new Reference(BIGINT, "REGIONKEY_A"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY_A"), new Constant(BIGINT, 10L)), + loadCachedDataPlanNode(signature, columnHandles, "REGIONKEY_A", "NATIONKEY_A")))), + chooseAlternativeNode( + // original subplan + strictProject(ImmutableMap.of("REGIONKEY_B", expression(new Reference(BIGINT, "REGIONKEY_B"))), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "NATIONKEY_B"), new Constant(BIGINT, 5L)), + tableScan("nation", ImmutableMap.of("NATIONKEY_B", "nationkey", "REGIONKEY_B", "regionkey")))), + // store data in cache alternative + strictProject(ImmutableMap.of("REGIONKEY_B", expression(new Reference(BIGINT, "REGIONKEY_B"))), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "NATIONKEY_B"), new Constant(BIGINT, 5L)), + cacheDataPlanNode( + strictProject(ImmutableMap.of("REGIONKEY_B", expression(new Reference(BIGINT, "REGIONKEY_B")), "NATIONKEY_B", expression(new Reference(BIGINT, "NATIONKEY_B"))), + filter( + new Logical(OR, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY_B"), new Constant(BIGINT, 10L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "NATIONKEY_B"), new Constant(BIGINT, 5L)))), + tableScan("nation", ImmutableMap.of("NATIONKEY_B", "nationkey", "REGIONKEY_B", "regionkey"))))))), + // load data from cache alternative + strictProject(ImmutableMap.of("REGIONKEY_B", expression(new Reference(BIGINT, "REGIONKEY_B"))), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "NATIONKEY_B"), new Constant(BIGINT, 5L)), + loadCachedDataPlanNode(signature, columnHandles, "REGIONKEY_B", "NATIONKEY_B"))))))); + } + + @Test + public void testJoinQuery() + { + List cacheColumnIds = ImmutableList.of(NATIONKEY_COLUMN_ID, REGIONKEY_COLUMN_ID); + List cacheColumnTypes = ImmutableList.of(BIGINT, BIGINT); + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(testCatalogId + ":tiny:nation:0.01")), + Optional.empty(), + cacheColumnIds, + cacheColumnTypes), + TupleDomain.all()); + Predicate isNationKeyDynamicFilter = node -> DynamicFilters.getDescriptor(node.getPredicate()) + .map(descriptor -> descriptor.getInput().equals(new Reference(BIGINT, "nationkey"))) + .orElse(false); + assertPlan(""" + SELECT * FROM + (SELECT nationkey FROM nation) + JOIN + (SELECT regionkey FROM nation) + ON nationkey = regionkey + """, + anyTree(node(JoinNode.class, + chooseAlternativeNode( + // original subplan + filter(TRUE, // for DF on nationkey + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey"))) + .with(FilterNode.class, isNationKeyDynamicFilter), + // store data in cache alternative + strictProject(ImmutableMap.of("NATIONKEY", expression(new Reference(BIGINT, "NATIONKEY"))), + cacheDataPlanNode( + filter(TRUE, // for DF on nationkey + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey"))) + .with(FilterNode.class, isNationKeyDynamicFilter))), + // load data from cache alternative + strictProject(ImmutableMap.of("NATIONKEY", expression(new Reference(BIGINT, "NATIONKEY"))), + loadCachedDataPlanNode( + signature, + dfDisjuncts -> dfDisjuncts.size() == 1, + "NATIONKEY", "REGIONKEY"))), + anyTree( + chooseAlternativeNode( + // original subplan + tableScan("nation", ImmutableMap.of("REGIONKEY", "regionkey")), + // store data in cache alternative + strictProject(ImmutableMap.of("REGIONKEY", expression(new Reference(BIGINT, "REGIONKEY"))), + cacheDataPlanNode( + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey")))), + // load data from cache alternative + strictProject(ImmutableMap.of("REGIONKEY", expression(new Reference(BIGINT, "REGIONKEY"))), + loadCachedDataPlanNode(signature, "NATIONKEY", "REGIONKEY"))))))); + } + + @Test + public void testJoinQueryWithCommonDynamicFilters() + { + List cacheColumnIds = ImmutableList.of(NATIONKEY_COLUMN_ID, REGIONKEY_COLUMN_ID); + List cacheColumnTypes = ImmutableList.of(BIGINT, BIGINT); + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature( + combine( + scanFilterProjectKey(new CacheTableId(testCatalogId + ":tiny:nation:0.01")), + "filters=((\"[nationkey:bigint]\" BETWEEN bigint '0' AND bigint '1') OR (\"[regionkey:bigint]\" BETWEEN bigint '0' AND bigint '1'))"), + Optional.empty(), + cacheColumnIds, + cacheColumnTypes), + TupleDomain.all()); + Map columnHandles = ImmutableMap.of( + NATIONKEY_COLUMN_ID, new TpchColumnHandle("nationkey", BIGINT), + REGIONKEY_COLUMN_ID, new TpchColumnHandle("regionkey", BIGINT)); + assertPlan(""" + (SELECT nationkey FROM nation n JOIN (SELECT * FROM (VALUES 0, 1) t(a)) t ON n.nationkey = t.a) + UNION ALL + (SELECT regionkey FROM nation n JOIN (SELECT * FROM (VALUES 0, 1) t(a)) t ON n.regionkey = t.a) + """, + anyTree(exchange(LOCAL, + node(JoinNode.class, + chooseAlternativeNode( + anyTree(tableScan("nation")), + anyTree(cacheDataPlanNode( + anyTree(tableScan("nation")))), + anyTree(loadCachedDataPlanNode(signature, columnHandles, dfDisjuncts -> dfDisjuncts.size() == 2, "NATIONKEY", "REGIONKEY"))), + anyTree(node(ValuesNode.class))), + node(JoinNode.class, + chooseAlternativeNode( + anyTree(tableScan("nation")), + anyTree(cacheDataPlanNode( + anyTree(tableScan("nation")))), + anyTree(loadCachedDataPlanNode(signature, columnHandles, dfDisjuncts -> dfDisjuncts.size() == 2, "NATIONKEY", "REGIONKEY"))), + anyTree(node(ValuesNode.class)))))); + } + + @Test + public void testCommonProjectionOnDifferentLevels() + { + Expression mul2 = new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L))); + Expression mul4 = new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 4L))); + Expression aMul2Mul2 = new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "MUL2_A"), new Reference(BIGINT, "MUL2_A"))); + Expression aMul4Mul4 = new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "MUL4_A"), new Reference(BIGINT, "MUL4_A"))); + Expression bMul2Mul2 = new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "MUL2_B_NEW"), new Reference(BIGINT, "MUL2_B_NEW"))); + Expression bMul4Mul4 = new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "MUL4_B"), new Reference(BIGINT, "MUL4_B"))); + + Reference canonicalMul2 = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))), BIGINT).toSymbolReference(); + Reference canonicalMul4 = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 4L)))), BIGINT).toSymbolReference(); + Expression canonicalMul2Mul2 = new Call(MULTIPLY_BIGINT, ImmutableList.of(canonicalMul2, canonicalMul2)); + Expression canonicalMul4Mul4 = new Call(MULTIPLY_BIGINT, ImmutableList.of(canonicalMul4, canonicalMul4)); + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature( + filterProjectKey(scanFilterProjectKey(new CacheTableId(testCatalogId + ":tiny:nation:0.01"))), + Optional.empty(), + ImmutableList.of( + canonicalExpressionToColumnId(canonicalMul2), + canonicalExpressionToColumnId(canonicalMul2Mul2), + canonicalExpressionToColumnId(canonicalMul4Mul4)), + ImmutableList.of(BIGINT, BIGINT, BIGINT)), + TupleDomain.all()); + + // Make sure that plan with ambiguous projections is correctly planned and validated + assertPlan(""" + SELECT nationkey_mul, nationkey_mul * nationkey_mul FROM (SELECT nationkey * 2 AS nationkey_mul FROM nation) + UNION ALL + SELECT nationkey * 2, nationkey_mul * nationkey_mul FROM (SELECT nationkey, nationkey * 4 AS nationkey_mul FROM nation)""", + anyTree(exchange(LOCAL, + chooseAlternativeNode( + strictProject(ImmutableMap.of( + "MUL2_A", expression(new Reference(BIGINT, "MUL2_A")), + "MUL2_MUL2_A", expression(aMul2Mul2)), + strictProject(ImmutableMap.of("MUL2_A", expression(mul2)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))), + strictProject(ImmutableMap.of( + "MUL2_A", expression(new Reference(BIGINT, "MUL2_A")), + "MUL2_MUL2_A", expression(new Reference(BIGINT, "MUL2_MUL2_A"))), + cacheDataPlanNode( + strictProject(ImmutableMap.of( + "MUL2_A", expression(new Reference(BIGINT, "MUL2_A")), + "MUL2_MUL2_A", expression(aMul2Mul2), + "MUL4_MUL4_A", expression(aMul4Mul4)), + strictProject(ImmutableMap.of( + "MUL2_A", expression(mul2), + "MUL4_A", expression(mul4), + "NATIONKEY", expression(new Reference(BIGINT, "NATIONKEY"))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))))), + strictProject(ImmutableMap.of( + "MUL2_A", expression(new Reference(BIGINT, "MUL2_A")), + "MUL2_MUL2_A", expression(new Reference(BIGINT, "MUL2_MUL2_A"))), + loadCachedDataPlanNode( + signature, + "MUL2_A", "MUL2_MUL2_A", "MUL4_MUL4_A"))), + chooseAlternativeNode( + strictProject(ImmutableMap.of( + "MUL2_B", expression(mul2), + "MUL4_MUL4_B", expression(bMul4Mul4)), + strictProject(ImmutableMap.of( + "NATIONKEY", expression(new Reference(BIGINT, "NATIONKEY")), + "MUL4_B", expression(mul4)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))), + strictProject(ImmutableMap.of( + "MUL2_B", expression(new Reference(BIGINT, "MUL2_B_NEW")), + "MUL4_MUL4_B", expression(new Reference(BIGINT, "MUL4_MUL4_B"))), + cacheDataPlanNode( + strictProject(ImmutableMap.of( + "MUL2_B_NEW", expression(new Reference(BIGINT, "MUL2_B_NEW")), + "MUL2_MUL2_B", expression(bMul2Mul2), + "MUL4_MUL4_B", expression(bMul4Mul4)), + strictProject(ImmutableMap.of( + "MUL2_B_NEW", expression(mul2), + "MUL4_B", expression(mul4), + "NATIONKEY", expression(new Reference(BIGINT, "NATIONKEY"))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))))), + strictProject(ImmutableMap.of( + "MUL2_B", expression(new Reference(BIGINT, "MUL2_B_NEW")), + "MUL4_MUL4_B", expression(new Reference(BIGINT, "MUL4_MUL4_B"))), + loadCachedDataPlanNode( + signature, + "MUL2_B_NEW", "MUL2_MUL2_B", "MUL4_MUL4_B")))))); + } + + @Test + public void testAggregationQuery() + { + Reference nationkey = new Reference(BIGINT, "[nationkey:bigint]"); + CanonicalAggregation max = canonicalAggregation("max", nationkey); + CanonicalAggregation sum = canonicalAggregation("sum", nationkey); + CanonicalAggregation avg = canonicalAggregation("avg", nationkey); + List cacheColumnIds = ImmutableList.of(REGIONKEY_COLUMN_ID, canonicalAggregationToColumnId(sum), canonicalAggregationToColumnId(max), canonicalAggregationToColumnId(avg)); + List cacheColumnTypes = ImmutableList.of(BIGINT, BIGINT, BIGINT, RowType.anonymousRow(DOUBLE, BIGINT)); + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(testCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of(REGIONKEY_COLUMN_ID)), + cacheColumnIds, + cacheColumnTypes), + TupleDomain.all()); + assertPlan(""" + SELECT sum(nationkey), max(nationkey) FROM nation GROUP BY regionkey + UNION ALL + SELECT avg(nationkey), sum(nationkey) FROM nation GROUP BY regionkey""", + anyTree(anyTree(aggregation( + singleGroupingSet("REGIONKEY_A"), + ImmutableMap.of( + Optional.of("SUM_A"), aggregationFunction("sum", false, ImmutableList.of(symbol("SUM_PARTIAL_A"))), + Optional.of("MAX_A"), aggregationFunction("max", false, ImmutableList.of(symbol("MAX_PARTIAL_A")))), + Optional.empty(), + FINAL, + anyTree( + chooseAlternativeNode( + // original subplan + aggregation( + singleGroupingSet("REGIONKEY_A"), + ImmutableMap.of( + Optional.of("MAX_PARTIAL_A"), aggregationFunction("max", false, ImmutableList.of(symbol("NATIONKEY_A"))), + Optional.of("SUM_PARTIAL_A"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY_A")))), + Optional.empty(), + PARTIAL, + tableScan("nation", ImmutableMap.of("NATIONKEY_A", "nationkey", "REGIONKEY_A", "regionkey"))), + // store data in cache alternative + strictProject(ImmutableMap.of( + "REGIONKEY_A", expression(new Reference(BIGINT, "REGIONKEY_A")), + "MAX_PARTIAL_A", expression(new Reference(BIGINT, "MAX_PARTIAL_A")), + "SUM_PARTIAL_A", expression(new Reference(BIGINT, "SUM_PARTIAL_A"))), + cacheDataPlanNode( + aggregation( + singleGroupingSet("REGIONKEY_A"), + ImmutableMap.of( + Optional.of("MAX_PARTIAL_A"), aggregationFunction("max", false, ImmutableList.of(symbol("NATIONKEY_A"))), + Optional.of("SUM_PARTIAL_A"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY_A"))), + Optional.of("AVG_PARTIAL_A"), aggregationFunction("avg", false, ImmutableList.of(symbol("NATIONKEY_A")))), + Optional.empty(), + PARTIAL, + tableScan("nation", ImmutableMap.of("NATIONKEY_A", "nationkey", "REGIONKEY_A", "regionkey"))))), + // load data from cache alternative + strictProject(ImmutableMap.of( + "REGIONKEY_A", expression(new Reference(BIGINT, "REGIONKEY_A")), + "MAX_PARTIAL_A", expression(new Reference(BIGINT, "MAX_PARTIAL_A")), + "SUM_PARTIAL_A", expression(new Reference(BIGINT, "SUM_PARTIAL_A"))), + loadCachedDataPlanNode(signature, "REGIONKEY_A", "SUM_PARTIAL_A", "MAX_PARTIAL_A", "AVG_PARTIAL_A")))))), + anyTree(aggregation( + singleGroupingSet("REGIONKEY_B"), + ImmutableMap.of( + Optional.of("AVG_B"), aggregationFunction("avg", false, ImmutableList.of(symbol("AVG_PARTIAL_B"))), + Optional.of("SUM_B"), aggregationFunction("sum", false, ImmutableList.of(symbol("SUM_PARTIAL_B")))), + Optional.empty(), + FINAL, + anyTree( + chooseAlternativeNode( + // original subplan + aggregation( + singleGroupingSet("REGIONKEY_B"), + ImmutableMap.of( + Optional.of("SUM_PARTIAL_B"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY_B"))), + Optional.of("AVG_PARTIAL_B"), aggregationFunction("avg", false, ImmutableList.of(symbol("NATIONKEY_B")))), + Optional.empty(), + PARTIAL, + tableScan("nation", ImmutableMap.of("NATIONKEY_B", "nationkey", "REGIONKEY_B", "regionkey"))), + // store data in cache alternative + strictProject(ImmutableMap.of( + "REGIONKEY_B", expression(new Reference(BIGINT, "REGIONKEY_B")), + "SUM_PARTIAL_B", expression(new Reference(BIGINT, "SUM_PARTIAL_B")), + "AVG_PARTIAL_B", expression(new Reference(DOUBLE, "AVG_PARTIAL_B"))), + cacheDataPlanNode( + aggregation( + singleGroupingSet("REGIONKEY_B"), + ImmutableMap.of( + Optional.of("MAX_PARTIAL_B"), aggregationFunction("max", false, ImmutableList.of(symbol("NATIONKEY_B"))), + Optional.of("SUM_PARTIAL_B"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY_B"))), + Optional.of("AVG_PARTIAL_B"), aggregationFunction("avg", false, ImmutableList.of(symbol("NATIONKEY_B")))), + Optional.empty(), + PARTIAL, + tableScan("nation", ImmutableMap.of("NATIONKEY_B", "nationkey", "REGIONKEY_B", "regionkey"))))), + // load data from cache alternative + strictProject(ImmutableMap.of( + "REGIONKEY_B", expression(new Reference(BIGINT, "REGIONKEY_B")), + "SUM_PARTIAL_B", expression(new Reference(BIGINT, "SUM_PARTIAL_B")), + "AVG_PARTIAL_B", expression(new Reference(DOUBLE, "AVG_PARTIAL_B"))), + loadCachedDataPlanNode(signature, "REGIONKEY_B", "SUM_PARTIAL_B", "MAX_PARTIAL_B", "AVG_PARTIAL_B")))))))); + } + + private CanonicalAggregation canonicalAggregation(String name, Expression input) + { + ResolvedFunction resolvedFunction = getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction(name, TypeSignatureProvider.fromTypes(input.type())); + return new CanonicalAggregation(resolvedFunction, Optional.empty(), ImmutableList.of(input)); + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheConfig.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheConfig.java new file mode 100644 index 000000000000..7c938923b8b4 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheConfig.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableMap; +import io.airlift.units.DataSize; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestCacheConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(CacheConfig.class) + .setEnabled(false) + .setRevokingThreshold(0.9) + .setRevokingTarget(0.7) + .setCacheCommonSubqueriesEnabled(true) + .setCacheAggregationsEnabled(true) + .setCacheProjectionsEnabled(true) + .setMaxSplitSize(DataSize.of(256, DataSize.Unit.MEGABYTE)) + .setCacheMinWorkerSplitSeparation(500)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder() + .put("cache.enabled", "true") + .put("cache.revoking-threshold", "0.6") + .put("cache.revoking-target", "0.5") + .put("cache.common-subqueries.enabled", "false") + .put("cache.aggregations.enabled", "false") + .put("cache.projections.enabled", "false") + .put("cache.max-split-size", "64MB") + .put("cache.min-worker-split-separation", "10000") + .buildOrThrow(); + + CacheConfig expected = new CacheConfig() + .setEnabled(true) + .setRevokingThreshold(0.6) + .setRevokingTarget(0.5) + .setCacheAggregationsEnabled(false) + .setCacheProjectionsEnabled(false) + .setCacheCommonSubqueriesEnabled(false) + .setMaxSplitSize(DataSize.of(64, DataSize.Unit.MEGABYTE)) + .setCacheMinWorkerSplitSeparation(10000); + assertFullMapping(properties, expected); + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheController.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheController.java new file mode 100644 index 000000000000..847419e81ecc --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheController.java @@ -0,0 +1,331 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.cache.CacheController.CacheCandidate; +import io.trino.cache.CanonicalSubplan.AggregationKey; +import io.trino.cache.CanonicalSubplan.ScanFilterProjectKey; +import io.trino.cache.CanonicalSubplan.TopNRankingKey; +import io.trino.metadata.TableHandle; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.catalog.CatalogName; +import io.trino.spi.connector.CatalogHandle.CatalogVersion; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.connector.SortOrder; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.ValuesNode; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static io.trino.SystemSessionProperties.CACHE_AGGREGATIONS_ENABLED; +import static io.trino.SystemSessionProperties.CACHE_COMMON_SUBQUERIES_ENABLED; +import static io.trino.SystemSessionProperties.CACHE_PROJECTIONS_ENABLED; +import static io.trino.spi.connector.CatalogHandle.createRootCatalogHandle; +import static io.trino.spi.predicate.Domain.multipleValues; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.sql.planner.plan.TopNRankingNode.RankingType; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCacheController +{ + private static final PlanNodeId PLAN_NODE_ID = new PlanNodeId("id"); + private static final CacheTableId TABLE_ID = new CacheTableId("table"); + private static final CacheColumnId COLUMN_A = new CacheColumnId("A"); + private static final CacheColumnId COLUMN_B = new CacheColumnId("B"); + public static final TableHandle TABLE_HANDLE = new TableHandle(createRootCatalogHandle(new CatalogName("catalog"), new CatalogVersion("version")), new ConnectorTableHandle() {}, new ConnectorTransactionHandle() {}); + + @Test + public void testCacheController() + { + CanonicalSubplan firstGroupByAB = createCanonicalAggregationSubplan(ImmutableSet.of(COLUMN_A, COLUMN_B)); + CanonicalSubplan secondGroupByAB = createCanonicalAggregationSubplan(ImmutableSet.of(COLUMN_A, COLUMN_B)); + CanonicalSubplan groupByA = createCanonicalAggregationSubplan(ImmutableSet.of(COLUMN_A)); + CanonicalSubplan firstProjection = createCanonicalTableScanSubplan(); + CanonicalSubplan secondProjection = createCanonicalTableScanSubplan(); + CanonicalSubplan topN = createCanonicalTopNSubplan(ImmutableMap.of(COLUMN_A, SortOrder.ASC_NULLS_FIRST), 10); + CanonicalSubplan topNRanking = createCanonicalTopNRankingSubplan(ImmutableList.of(COLUMN_B), ImmutableMap.of(COLUMN_A, SortOrder.ASC_NULLS_FIRST), RankingType.ROW_NUMBER, 10); + List subplans = ImmutableList.of(secondProjection, firstProjection, groupByA, secondGroupByAB, firstGroupByAB); + + CacheController cacheController = new CacheController(); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)) + .containsExactly( + // common aggregations are first + new CacheCandidate(ImmutableList.of(secondGroupByAB, firstGroupByAB), 2), + // then common projections + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + // then single aggregations + new CacheCandidate(ImmutableList.of(groupByA), 1), + new CacheCandidate(ImmutableList.of(secondGroupByAB), 1), + new CacheCandidate(ImmutableList.of(firstGroupByAB), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, false), subplans)) + .containsExactly( + new CacheCandidate(ImmutableList.of(secondGroupByAB, firstGroupByAB), 2), + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2)); + + assertThat(cacheController.getCachingCandidates(cacheProperties(false, true, false), subplans)) + .containsExactly( + new CacheCandidate(ImmutableList.of(groupByA), 1), + new CacheCandidate(ImmutableList.of(secondGroupByAB), 1), + new CacheCandidate(ImmutableList.of(firstGroupByAB), 1)); + + assertThat(cacheController.getCachingCandidates(cacheProperties(false, false, true), subplans)) + .containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + subplans = ImmutableList.of(secondProjection, firstProjection, topN); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)) + .containsExactly( + // common projections are first + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + // then single topN + new CacheCandidate(ImmutableList.of(topN), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, false), subplans)) + .containsExactly(new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2)); + assertThat(cacheController.getCachingCandidates(cacheProperties(false, true, false), subplans)) + .containsExactly(new CacheCandidate(ImmutableList.of(topN), 1)); + assertThat(cacheController.getCachingCandidates(cacheProperties(false, false, true), subplans)) + .containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + subplans = ImmutableList.of(secondProjection, firstProjection, topNRanking); + + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)) + .containsExactly( + // common projections are first + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + // then single topNRanking + new CacheCandidate(ImmutableList.of(topNRanking), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, false), subplans)) + .containsExactly(new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2)); + assertThat(cacheController.getCachingCandidates(cacheProperties(false, true, false), subplans)) + .containsExactly(new CacheCandidate(ImmutableList.of(topNRanking), 1)); + assertThat(cacheController.getCachingCandidates(cacheProperties(false, false, true), subplans)) + .containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + } + + @Test + public void testExcludingCommonSubqueriesPlansWithTableEnforcedConstraint() + { + PlanNodeId firstId = new PlanNodeId("first"); + PlanNodeId secondId = new PlanNodeId("second"); + PlanNodeId thirdId = new PlanNodeId("third"); + CanonicalSubplan firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.all()); + CanonicalSubplan secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.all()); + CanonicalSubplan topNRanking = createCanonicalTopNRankingSubplan(ImmutableList.of(COLUMN_B), ImmutableMap.of(COLUMN_A, SortOrder.ASC_NULLS_FIRST), RankingType.ROW_NUMBER, 10); + List subplans = ImmutableList.of(secondProjection, firstProjection, topNRanking); + + CacheController cacheController = new CacheController(); + // full intersection with Tuple.all + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + new CacheCandidate(ImmutableList.of(topNRanking), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + // intersection between firstProjection and secondProjection via 3L value + firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(1L, 2L, 3L))))); + secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(3L, 4L, 5L))))); + subplans = ImmutableList.of(secondProjection, firstProjection, topNRanking); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection), 2), + new CacheCandidate(ImmutableList.of(topNRanking), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + // full exclude by Tuple.none + firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.none()); + secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.none()); + subplans = ImmutableList.of(secondProjection, firstProjection, topNRanking); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(topNRanking), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + // no intersection exclude between firstProjection and secondProjection + firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(1L, 2L, 3L))))); + secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(7L))))); + subplans = ImmutableList.of(secondProjection, firstProjection, topNRanking); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, true, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(topNRanking), 1), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + // intersection between 3. plans via value 3L + firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(1L, 2L, 3L))))); + CanonicalSubplan thirdProjection = createCanonicalTableScanSubplan(thirdId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(3L, 4L, 5L))))); + secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(3L, 4L, 5L))))); + subplans = ImmutableList.of(secondProjection, firstProjection, thirdProjection); + + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection, firstProjection, thirdProjection), 2), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1), + new CacheCandidate(ImmutableList.of(thirdProjection), 1)); + + // intersection between firstProjection and thirdProjection, but not with secondProjection + firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(1L, 2L, 3L))))); + thirdProjection = createCanonicalTableScanSubplan(thirdId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(3L, 4L, 5L))))); + secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(7L))))); + subplans = ImmutableList.of(thirdProjection, firstProjection, secondProjection); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(thirdProjection, firstProjection), 2), + // then single projections + new CacheCandidate(ImmutableList.of(thirdProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1), + new CacheCandidate(ImmutableList.of(secondProjection), 1)); + + // similar case as above, first element in subplans does not intersect with rest + subplans = ImmutableList.of(secondProjection, thirdProjection, firstProjection); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(thirdProjection, firstProjection), 2), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(thirdProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1)); + + // split common subplans by intersection into two commonSubplans + firstProjection = createCanonicalTableScanSubplan(firstId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(1L, 2L, 3L))))); + thirdProjection = createCanonicalTableScanSubplan(thirdId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(3L, 4L, 5L))))); + CanonicalSubplan forthProjection = createCanonicalTableScanSubplan(thirdId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(7L, 8L, 9L))))); + secondProjection = createCanonicalTableScanSubplan(secondId, TupleDomain.withColumnDomains(ImmutableMap.of(new CacheColumnId("column1"), multipleValues(INTEGER, ImmutableList.of(7L, 0L))))); + subplans = ImmutableList.of(secondProjection, thirdProjection, firstProjection, forthProjection); + assertThat(cacheController.getCachingCandidates(cacheProperties(true, false, true), subplans)).containsExactly( + new CacheCandidate(ImmutableList.of(secondProjection, forthProjection), 2), + new CacheCandidate(ImmutableList.of(thirdProjection, firstProjection), 2), + // then single projections + new CacheCandidate(ImmutableList.of(secondProjection), 1), + new CacheCandidate(ImmutableList.of(thirdProjection), 1), + new CacheCandidate(ImmutableList.of(firstProjection), 1), + new CacheCandidate(ImmutableList.of(forthProjection), 1)); + } + + private CanonicalSubplan createCanonicalAggregationSubplan(Set groupByColumns) + { + CanonicalSubplan tableScanPlan = createCanonicalTableScanSubplan(); + + return CanonicalSubplan.builderForChildSubplan(new AggregationKey(groupByColumns, ImmutableSet.of()), tableScanPlan) + .originalPlanNode(new ValuesNode(PLAN_NODE_ID, 0)) + .originalSymbolMapping(ImmutableBiMap.of()) + .assignments(ImmutableMap.of()) + .pullableConjuncts(ImmutableSet.of()) + .groupByColumns(groupByColumns) + .build(); + } + + private CanonicalSubplan createCanonicalTopNRankingSubplan(List partitionBy, Map orderBy, RankingType rankingType, int maxRankingPerPartition) + { + CanonicalSubplan tableScanPlan = createCanonicalTableScanSubplan(); + + return CanonicalSubplan.builderForChildSubplan(new TopNRankingKey(partitionBy, orderBy.keySet().stream().toList(), orderBy, rankingType, maxRankingPerPartition, ImmutableSet.of()), tableScanPlan) + .originalPlanNode(new ValuesNode(PLAN_NODE_ID, 0)) + .originalSymbolMapping(ImmutableBiMap.of()) + .assignments(ImmutableMap.of()) + .pullableConjuncts(ImmutableSet.of()) + .build(); + } + + private CanonicalSubplan createCanonicalTopNSubplan(Map orderBy, long count) + { + CanonicalSubplan tableScanPlan = createCanonicalTableScanSubplan(); + + return CanonicalSubplan.builderForChildSubplan(new CanonicalSubplan.TopNKey(orderBy.keySet().stream().toList(), orderBy, count, ImmutableSet.of()), tableScanPlan) + .originalPlanNode(new ValuesNode(PLAN_NODE_ID, 0)) + .originalSymbolMapping(ImmutableBiMap.of()) + .assignments(ImmutableMap.of()) + .pullableConjuncts(ImmutableSet.of()) + .build(); + } + + private static CanonicalSubplan createCanonicalTableScanSubplan() + { + return createCanonicalTableScanSubplan(PLAN_NODE_ID, TupleDomain.all()); + } + + private static CanonicalSubplan createCanonicalTableScanSubplan(PlanNodeId planNodeId, TupleDomain enforcedConstraint) + { + return CanonicalSubplan.builderForTableScan( + new ScanFilterProjectKey(TABLE_ID, ImmutableSet.of()), + ImmutableMap.of(), + TABLE_HANDLE, + TABLE_ID, + enforcedConstraint, + false, + planNodeId) + .originalPlanNode(new ValuesNode(PLAN_NODE_ID, 0)) + .originalSymbolMapping(ImmutableBiMap.of()) + .assignments(ImmutableMap.of()) + .pullableConjuncts(ImmutableSet.of()) + .build(); + } + + private Session cacheProperties(boolean cacheSubqueries, boolean cacheAggregations, boolean cacheProjections) + { + return testSessionBuilder() + .setSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, Boolean.toString(cacheSubqueries)) + .setSystemProperty(CACHE_AGGREGATIONS_ENABLED, Boolean.toString(cacheAggregations)) + .setSystemProperty(CACHE_PROJECTIONS_ENABLED, Boolean.toString(cacheProjections)) + .build(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheDataOperator.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheDataOperator.java new file mode 100644 index 000000000000..636d42ed371a --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheDataOperator.java @@ -0,0 +1,391 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.json.JsonCodec; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.execution.ScheduledSplit; +import io.trino.memory.LocalMemoryManager; +import io.trino.memory.NodeMemoryConfig; +import io.trino.metadata.BlockEncodingManager; +import io.trino.metadata.InternalBlockEncodingSerde; +import io.trino.metadata.Split; +import io.trino.metadata.TableHandle; +import io.trino.operator.Driver; +import io.trino.operator.DriverContext; +import io.trino.operator.DriverFactory; +import io.trino.operator.Operator; +import io.trino.operator.OperatorContext; +import io.trino.operator.OperatorDriverFactory; +import io.trino.operator.OperatorFactory; +import io.trino.spi.Page; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.FixedPageSource; +import io.trino.spi.metrics.Metrics; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.TestingTypeManager; +import io.trino.spi.type.TypeManager; +import io.trino.split.PageSourceProvider; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.plan.PlanNodeId; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.block.BlockAssertions.createLongSequenceBlock; +import static io.trino.cache.CacheDriverFactory.MIN_PROCESSED_SPLITS; +import static io.trino.cache.CacheDriverFactory.THRASHING_CACHE_THRESHOLD; +import static io.trino.cache.StaticDynamicFilter.createStaticDynamicFilterSupplier; +import static io.trino.operator.PageTestUtils.createPage; +import static io.trino.spi.connector.DynamicFilter.EMPTY; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.testing.PlanTester.getTupleDomainJsonCodec; +import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; +import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.TestingSplit.createRemoteSplit; +import static io.trino.testing.TestingTaskContext.createTaskContext; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestCacheDataOperator +{ + private static final Session TEST_SESSION = testSessionBuilder().build(); + private final PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); + private CacheManagerRegistry registry; + private JsonCodec tupleDomainCodec; + + @BeforeEach + public void setUp() + { + NodeMemoryConfig config = new NodeMemoryConfig() + .setHeapHeadroom(DataSize.of(10, MEGABYTE)) + .setMaxQueryMemoryPerNode(DataSize.of(100, MEGABYTE)); + LocalMemoryManager memoryManager = new LocalMemoryManager(config, DataSize.of(110, MEGABYTE).toBytes()); + CacheConfig cacheConfig = new CacheConfig(); + cacheConfig.setEnabled(true); + registry = new CacheManagerRegistry(cacheConfig, memoryManager, new TestingBlockEncodingSerde(), new CacheStats()); + registry.loadCacheManager(); + TypeManager typeManager = new TestingTypeManager(); + tupleDomainCodec = getTupleDomainJsonCodec(new InternalBlockEncodingSerde(new BlockEncodingManager(), typeManager), typeManager); + } + + @Test + public void testLimitCache() + { + PlanSignature signature = createPlanSignature("sig"); + CacheManager.SplitCache splitCache = registry.getCacheManager().getSplitCache(signature); + PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); + + CacheDataOperator.CacheDataOperatorFactory operatorFactory = new CacheDataOperator.CacheDataOperatorFactory( + 0, + planNodeIdAllocator.getNextId(), + DataSize.of(1024, DataSize.Unit.BYTE).toBytes()); + DriverContext driverContext = createTaskContext(Executors.newSingleThreadExecutor(), Executors.newScheduledThreadPool(1), TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + + CacheMetrics cacheMetrics = new CacheMetrics(); + CacheStats cacheStats = new CacheStats(); + + driverContext.setCacheDriverContext(new CacheDriverContext(Optional.empty(), splitCache.storePages(new CacheSplitId("split1"), TupleDomain.all(), TupleDomain.all()), EMPTY, cacheMetrics, cacheStats, Metrics.EMPTY)); + CacheDataOperator cacheDataOperator = (CacheDataOperator) operatorFactory.createOperator(driverContext); + + // sink was not aborted - there is a space in a cache. The page was passed through and split is going to be cached + Page smallPage = createPage(ImmutableList.of(BIGINT), 1, Optional.empty(), ImmutableList.of(createLongSequenceBlock(0, 10))); + cacheDataOperator.addInput(smallPage); + assertThat(cacheDataOperator.getOutput()).isEqualTo(smallPage); + cacheDataOperator.finish(); + assertThat(cacheMetrics.getSplitNotCachedCount()).isEqualTo(0); + assertThat(cacheMetrics.getSplitCachedCount()).isEqualTo(1); + + // sink was aborted - there is no sufficient space in a cache. The page was passed through but split is not going to be cached + Page bigPage = createPage(ImmutableList.of(BIGINT), 1, Optional.empty(), ImmutableList.of(createLongSequenceBlock(0, 2_000))); + driverContext = createTaskContext(Executors.newSingleThreadExecutor(), Executors.newScheduledThreadPool(1), TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + driverContext.setCacheDriverContext(new CacheDriverContext(Optional.empty(), splitCache.storePages(new CacheSplitId("split2"), TupleDomain.all(), TupleDomain.all()), EMPTY, cacheMetrics, cacheStats, Metrics.EMPTY)); + cacheDataOperator = (CacheDataOperator) operatorFactory.createOperator(driverContext); + + cacheDataOperator.addInput(bigPage); + cacheDataOperator.finish(); + + assertThat(cacheDataOperator.getOutput()).isEqualTo(bigPage); + assertThat(cacheMetrics.getSplitNotCachedCount()).isEqualTo(1); + assertThat(cacheMetrics.getSplitCachedCount()).isEqualTo(1); + } + + @Test + public void testCachingThreshold() + { + PlanSignature signature = createPlanSignature("sig"); + Page bigPage = createPage(ImmutableList.of(BIGINT), 1, Optional.empty(), ImmutableList.of(createLongSequenceBlock(0, 128))); + Page smallPage = createPage(ImmutableList.of(BIGINT), 1, Optional.empty(), ImmutableList.of(createLongSequenceBlock(0, 16))); + AtomicInteger operatorIdAllocator = new AtomicInteger(); + CacheDataOperator.CacheDataOperatorFactory cacheDataOperatorFactory = new CacheDataOperator.CacheDataOperatorFactory( + operatorIdAllocator.incrementAndGet(), + planNodeIdAllocator.getNextId(), + DataSize.of(1024, DataSize.Unit.BYTE).toBytes()); + + PassThroughOperator.PassThroughOperatorFactory passThroughOperatorFactory = + new PassThroughOperator.PassThroughOperatorFactory(operatorIdAllocator.incrementAndGet(), planNodeIdAllocator.getNextId(), () -> smallPage); + + List driverFactories = ImmutableList.of( + prepareDriverFactory(operatorIdAllocator, 2, preparePassThroughOperator(() -> smallPage)), + new OperatorDriverFactory( + operatorIdAllocator.incrementAndGet(), + true, + false, + ImmutableList.of(passThroughOperatorFactory, cacheDataOperatorFactory), + OptionalInt.empty()), + prepareDriverFactory(operatorIdAllocator, 2, preparePassThroughOperator(() -> smallPage))); + + CacheDriverFactory cacheDriverFactory = new CacheDriverFactory( + 0, + true, + true, + OptionalInt.empty(), + new PlanNodeId("test"), + TEST_SESSION, + new TestPageSourceProvider(), + registry, + tupleDomainCodec, + TEST_TABLE_HANDLE, + new PlanSignatureWithPredicate(signature, TupleDomain.all()), + ImmutableMap.of(), + createStaticDynamicFilterSupplier(ImmutableList.of(EMPTY)), + createStaticDynamicFilterSupplier(ImmutableList.of(EMPTY)), + driverFactories, + new CacheStats()); + + // process splits where split's page is small. All splits will be successfully cached + createAndRunDriver(0, MIN_PROCESSED_SPLITS, cacheDriverFactory); + assertThat(cacheDriverFactory.getCacheMetrics().getSplitCachedCount()).isEqualTo(MIN_PROCESSED_SPLITS); + assertThat(cacheDriverFactory.getCacheMetrics().getSplitNotCachedCount()).isEqualTo(0); + + int splitToBeRejectedCount = (int) Math.ceil((MIN_PROCESSED_SPLITS * (1.0f - THRASHING_CACHE_THRESHOLD)) / THRASHING_CACHE_THRESHOLD); + + // try to process splits that cannot be cached because its page sizes exceeds threshold size + // caching is not going to be "disabled" because threshold was not exceeded. + passThroughOperatorFactory.setPageSupplier(() -> bigPage); + createAndRunDriver(MIN_PROCESSED_SPLITS, MIN_PROCESSED_SPLITS + splitToBeRejectedCount, cacheDriverFactory); + assertThat(cacheDriverFactory.getCacheMetrics().getSplitCachedCount()).isEqualTo(MIN_PROCESSED_SPLITS); + assertThat(cacheDriverFactory.getCacheMetrics().getSplitNotCachedCount()).isEqualTo(splitToBeRejectedCount); + + // exceed threshold + CacheSplitId splitId = new CacheSplitId(String.format("split_%d", MIN_PROCESSED_SPLITS + splitToBeRejectedCount)); + DriverContext driverContext = createTaskContext(Executors.newSingleThreadExecutor(), Executors.newScheduledThreadPool(1), TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + + Split split = new Split(TEST_CATALOG_HANDLE, createRemoteSplit(), Optional.of(splitId), true); + try (Driver driver = cacheDriverFactory.createDriver(driverContext, Optional.of(new ScheduledSplit(0, planNodeIdAllocator.getNextId(), split)))) { + assertThat(driver.getDriverContext().getCacheDriverContext()).isEmpty(); + } + assertThat(cacheDriverFactory.getCacheMetrics().getSplitCachedCount()).isEqualTo(MIN_PROCESSED_SPLITS); + assertThat(cacheDriverFactory.getCacheMetrics().getSplitNotCachedCount()).isEqualTo(splitToBeRejectedCount); + } + + private static PlanSignature createPlanSignature(String signature) + { + return new PlanSignature( + new SignatureKey(signature), + Optional.empty(), + ImmutableList.of(new CacheColumnId("id")), + ImmutableList.of(INTEGER)); + } + + private void createAndRunDriver(int start, int end, CacheDriverFactory cacheDriverFactory) + { + for (int i = start; i < end; i++) { + CacheSplitId splitId = new CacheSplitId(String.format("split_%d", i)); + DriverContext driverContext = createTaskContext(Executors.newSingleThreadExecutor(), Executors.newScheduledThreadPool(1), TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + Split split = new Split(TEST_CATALOG_HANDLE, createRemoteSplit(), Optional.of(splitId), true); + try (Driver driver = cacheDriverFactory.createDriver(driverContext, Optional.of(new ScheduledSplit(0, planNodeIdAllocator.getNextId(), split)))) { + driver.process(new Duration(10.0, TimeUnit.SECONDS), 100); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + } + } + } + + private DriverFactory prepareDriverFactory(AtomicInteger operatorIdAllocator, int operatorsCount, Function operatorFactoryProvider) + { + return new OperatorDriverFactory( + operatorIdAllocator.incrementAndGet(), + true, + false, + IntStream.range(0, operatorsCount).mapToObj(i -> operatorFactoryProvider.apply(operatorIdAllocator.incrementAndGet())).toList(), + OptionalInt.empty()); + } + + private Function preparePassThroughOperator(Supplier pageSupplier) + { + return (operatorId) -> new PassThroughOperator.PassThroughOperatorFactory(operatorId, planNodeIdAllocator.getNextId(), pageSupplier); + } + + private static class TestPageSourceProvider + implements PageSourceProvider + { + @Override + public ConnectorPageSource createPageSource( + Session session, + Split split, + TableHandle table, + List columns, + DynamicFilter dynamicFilter) + { + return new FixedPageSource(ImmutableList.of()); + } + + @Override + public TupleDomain getUnenforcedPredicate( + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + return TupleDomain.all(); + } + + @Override + public TupleDomain prunePredicate( + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + return predicate; + } + } + + private static class PassThroughOperator + implements Operator + { + private static class PassThroughOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private Supplier pageSupplier; + + public PassThroughOperatorFactory(int operatorId, PlanNodeId planNodeId, Supplier pageSupplier) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.pageSupplier = requireNonNull(pageSupplier, "pageSupplier is null"); + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + return new PassThroughOperator(driverContext.addOperatorContext(operatorId, planNodeId, PassThroughOperator.class.getSimpleName()), pageSupplier.get()); + } + + @Override + public void noMoreOperators() + { + } + + @Override + public OperatorFactory duplicate() + { + return new PassThroughOperatorFactory(operatorId, planNodeId, pageSupplier); + } + + void setPageSupplier(Supplier pageSupplier) + { + this.pageSupplier = pageSupplier; + } + } + + private final OperatorContext context; + private final Page page; + private boolean finished; + + public PassThroughOperator(OperatorContext context, Page page) + { + this.context = requireNonNull(context, "context is null"); + this.page = requireNonNull(page, "page is null"); + } + + @Override + public OperatorContext getOperatorContext() + { + return context; + } + + @Override + public boolean needsInput() + { + return !finished; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(); + } + + @Override + public Page getOutput() + { + finished = true; + return page; + } + + @Override + public void finish() + { + finished = true; + } + + @Override + public boolean isFinished() + { + return finished; + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheDriverFactory.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheDriverFactory.java new file mode 100644 index 000000000000..fae1bcf1752b --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheDriverFactory.java @@ -0,0 +1,717 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.json.JsonCodec; +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.execution.ScheduledSplit; +import io.trino.memory.LocalMemoryManager; +import io.trino.memory.NodeMemoryConfig; +import io.trino.metadata.BlockEncodingManager; +import io.trino.metadata.InternalBlockEncodingSerde; +import io.trino.metadata.Split; +import io.trino.metadata.TableHandle; +import io.trino.operator.DevNullOperator; +import io.trino.operator.Driver; +import io.trino.operator.DriverContext; +import io.trino.operator.DriverFactory; +import io.trino.operator.OperatorDriverFactory; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheManager.SplitCache; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheManagerFactory; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.EmptyPageSource; +import io.trino.spi.connector.FixedPageSource; +import io.trino.spi.connector.TestingColumnHandle; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.TestingTypeManager; +import io.trino.spi.type.TypeManager; +import io.trino.split.PageSourceProvider; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.testing.TestingTaskContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.LongStream; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.RowPagesBuilder.rowPagesBuilder; +import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_ROW_FILTERING; +import static io.trino.cache.CacheDriverFactory.MAX_UNENFORCED_PREDICATE_VALUE_COUNT; +import static io.trino.cache.CacheDriverFactory.appendRemainingPredicates; +import static io.trino.cache.CacheDriverFactory.getDynamicRowFilteringUnenforcedPredicate; +import static io.trino.cache.StaticDynamicFilter.createStaticDynamicFilter; +import static io.trino.cache.StaticDynamicFilter.createStaticDynamicFilterSupplier; +import static io.trino.plugin.base.cache.CacheUtils.normalizeTupleDomain; +import static io.trino.spi.connector.DynamicFilter.EMPTY; +import static io.trino.spi.predicate.Domain.multipleValues; +import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.testing.PlanTester.getTupleDomainJsonCodec; +import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; +import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.TestingSplit.createRemoteSplit; +import static java.util.function.Function.identity; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestCacheDriverFactory +{ + private static final Session TEST_SESSION = testSessionBuilder().build(); + private static final SignatureKey SIGNATURE_KEY = new SignatureKey("key"); + private static final CacheSplitId SPLIT_ID = new CacheSplitId("split"); + private static final PlanNodeId PLAN_NODE_ID = new PlanNodeId("test"); + + private final PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); + + private TestSplitCache splitCache; + private CacheManagerRegistry registry; + private JsonCodec tupleDomainCodec; + private ScheduledExecutorService scheduledExecutor; + + @BeforeEach + public void setUp() + { + NodeMemoryConfig config = new NodeMemoryConfig() + .setHeapHeadroom(DataSize.of(16, MEGABYTE)) + .setMaxQueryMemoryPerNode(DataSize.of(32, MEGABYTE)); + CacheConfig cacheConfig = new CacheConfig(); + cacheConfig.setEnabled(true); + registry = new CacheManagerRegistry(cacheConfig, new LocalMemoryManager(config, DataSize.of(1024, MEGABYTE).toBytes()), new TestingBlockEncodingSerde(), new CacheStats()); + TestCacheManagerFactory cacheManagerFactory = new TestCacheManagerFactory(); + registry.loadCacheManager(cacheManagerFactory, ImmutableMap.of()); + splitCache = cacheManagerFactory.getCacheManager().getSplitCache(); + TypeManager typeManager = new TestingTypeManager(); + tupleDomainCodec = getTupleDomainJsonCodec(new InternalBlockEncodingSerde(new BlockEncodingManager(), typeManager), typeManager); + scheduledExecutor = Executors.newScheduledThreadPool(1); + } + + @AfterEach + public void tearDown() + { + scheduledExecutor.shutdownNow(); + } + + @Test + public void testCreateDriverForOriginalPlan() + { + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature(SIGNATURE_KEY, Optional.empty(), ImmutableList.of(), ImmutableList.of()), + TupleDomain.all()); + AtomicInteger operatorIdAllocator = new AtomicInteger(); + + // expect driver for original plan because cacheSplit is empty + CacheDriverFactory cacheDriverFactory = createCacheDriverFactory(new TestPageSourceProvider(), signature, operatorIdAllocator); + Driver driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.empty())); + assertThat(driver.getDriverContext().getCacheDriverContext()).isEmpty(); + + // expect driver for original plan because split got scheduled on non-preferred node + cacheDriverFactory = createCacheDriverFactory(new TestPageSourceProvider(), signature, operatorIdAllocator); + driver = cacheDriverFactory.createDriver( + createDriverContext(), + createSplit(Optional.of(new CacheSplitId("split")), false)); + assertThat(driver.getDriverContext().getCacheDriverContext()).isEmpty(); + + // expect driver for original plan because dynamic filter filters data completely + cacheDriverFactory = createCacheDriverFactory(new TestPageSourceProvider(input -> TupleDomain.none(), identity()), signature, operatorIdAllocator); + driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isEmpty(); + + // expect driver for original plan because enforced predicate is pruned to empty tuple domain + cacheDriverFactory = createCacheDriverFactory(new TestPageSourceProvider(identity(), input -> TupleDomain.none()), signature, operatorIdAllocator); + driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isEmpty(); + + // expect driver for original plan because dynamic filter is too big + Domain bigDomain = multipleValues(BIGINT, LongStream.range(0, MAX_UNENFORCED_PREDICATE_VALUE_COUNT + 1) + .boxed() + .collect(toImmutableList())); + cacheDriverFactory = createCacheDriverFactory( + new TestPageSourceProvider(_ -> TupleDomain.withColumnDomains(ImmutableMap.of(new TestingColumnHandle("column"), bigDomain)), identity()), + signature, + operatorIdAllocator); + driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isEmpty(); + } + + @Test + public void testCreateDriverWithSmallerDynamicFilter() + { + CacheColumnId cacheColumnId = new CacheColumnId("cacheColumn"); + ColumnHandle columnHandle = new TestingColumnHandle("column"); + + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature(SIGNATURE_KEY, Optional.empty(), ImmutableList.of(cacheColumnId), ImmutableList.of(BIGINT)), + TupleDomain.withColumnDomains( + ImmutableMap.of(cacheColumnId, multipleValues(BIGINT, LongStream.range(0L, 100L).boxed().toList())))); + DriverFactory driverFactory = createDriverFactory(new AtomicInteger()); + Map columnHandles = ImmutableMap.of(columnHandle, cacheColumnId); + + // use original dynamic filter + CacheDriverFactory cacheDriverFactory = createCacheDriverFactory( + signature, + columnHandles, + createStaticDynamicFilterSupplier(ImmutableList.of(new TestDynamicFilter(TupleDomain.withColumnDomains( + ImmutableMap.of(columnHandle, multipleValues(BIGINT, LongStream.range(0L, 5000L).boxed().toList()))), true))), + createStaticDynamicFilterSupplier(ImmutableList.of(new TestDynamicFilter(TupleDomain.withColumnDomains( + ImmutableMap.of(columnHandle, multipleValues(BIGINT, LongStream.range(0L, 100L).boxed().toList()))), true))), + driverFactory); + + Optional pageSource = Optional.of(new EmptyPageSource()); + splitCache.addExpectedCacheLookup( + Optional.of(SPLIT_ID), + Optional.of(TupleDomain.withColumnDomains( + ImmutableMap.of(cacheColumnId, multipleValues(BIGINT, LongStream.range(0L, 100L).boxed().toList())))), + Optional.of(TupleDomain.withColumnDomains( + ImmutableMap.of(cacheColumnId, multipleValues(BIGINT, LongStream.range(0L, 100L).boxed().toList())))), + pageSource); + Driver driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + assertThat(driver.getDriverContext().getCacheDriverContext().get().pageSource()).isEqualTo(pageSource); + + // use common dynamic filter + cacheDriverFactory = createCacheDriverFactory( + signature, + columnHandles, + createStaticDynamicFilterSupplier(ImmutableList.of(new TestDynamicFilter(TupleDomain.withColumnDomains( + ImmutableMap.of(columnHandle, multipleValues(BIGINT, LongStream.range(0L, 150L).boxed().toList()))), true))), + createStaticDynamicFilterSupplier(ImmutableList.of(new TestDynamicFilter(TupleDomain.withColumnDomains( + ImmutableMap.of(columnHandle, multipleValues(BIGINT, LongStream.range(0L, 100L).boxed().toList()))), true))), + driverFactory); + + splitCache.addExpectedCacheLookup( + Optional.of(SPLIT_ID), + Optional.of(TupleDomain.withColumnDomains( + ImmutableMap.of(cacheColumnId, multipleValues(BIGINT, LongStream.range(0L, 100L).boxed().toList())))), + Optional.of(TupleDomain.withColumnDomains( + ImmutableMap.of(cacheColumnId, multipleValues(BIGINT, LongStream.range(0L, 150L).boxed().toList())))), + pageSource); + driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + assertThat(driver.getDriverContext().getCacheDriverContext().get().pageSource()).isEqualTo(pageSource); + } + + @Test + public void testCreateDriverWhenDynamicFilterWasChanged() + { + CacheColumnId columnId = new CacheColumnId("cacheColumnId"); + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature(SIGNATURE_KEY, Optional.empty(), ImmutableList.of(columnId), ImmutableList.of(BIGINT)), + TupleDomain.all()); + ColumnHandle columnHandle = new TestingColumnHandle("column"); + TupleDomain originalDynamicPredicate = TupleDomain.withColumnDomains(ImmutableMap.of(columnHandle, singleValue(BIGINT, 0L))); + + DriverFactory driverFactory = createDriverFactory(new AtomicInteger()); + TestDynamicFilter commonDynamicFilter = new TestDynamicFilter(TupleDomain.all(), false); + Map columnHandles = ImmutableMap.of(columnHandle, columnId); + CacheDriverFactory cacheDriverFactory = createCacheDriverFactory( + signature, + columnHandles, + createStaticDynamicFilterSupplier(ImmutableList.of(commonDynamicFilter)), + createStaticDynamicFilterSupplier(ImmutableList.of(new TestDynamicFilter(originalDynamicPredicate, true))), + driverFactory); + + // baseSignature should use original dynamic filter because it contains more domains + splitCache.addExpectedCacheLookup(TupleDomain.all(), originalDynamicPredicate.transformKeys(columnHandles::get)); + Driver driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + + // baseSignature should use common dynamic filter as it uses same domains + commonDynamicFilter.setDynamicPredicate( + TupleDomain.withColumnDomains(ImmutableMap.of(columnHandle, multipleValues(BIGINT, ImmutableList.of(0L, 1L)))), + true); + splitCache.addExpectedCacheLookup(TupleDomain.all(), commonDynamicFilter.getCurrentPredicate().transformKeys(columnHandles::get)); + cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + } + + @Test + public void testPrunesAndProjectsPredicates() + { + CacheColumnId projectedScanColumnId = new CacheColumnId("projectedScanColumnId"); + CacheColumnId projectedColumnId = new CacheColumnId("projectedColumnId"); + CacheColumnId nonProjectedScanColumnId = new CacheColumnId("nonProjectedScanColumnId"); + + ColumnHandle projectedScanColumnHandle = new TestingColumnHandle("projectedScanColumnId"); + ColumnHandle nonProjectedScanColumnHandle = new TestingColumnHandle("nonProjectedScanColumnHandle"); + Map columnHandles = ImmutableMap.of(projectedScanColumnId, projectedScanColumnHandle, nonProjectedScanColumnId, nonProjectedScanColumnHandle); + + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature(SIGNATURE_KEY, Optional.empty(), ImmutableList.of(projectedScanColumnId, projectedColumnId), ImmutableList.of(BIGINT, BIGINT)), + TupleDomain.withColumnDomains(ImmutableMap.of( + projectedScanColumnId, singleValue(BIGINT, 100L), + projectedColumnId, singleValue(BIGINT, 110L), + nonProjectedScanColumnId, singleValue(BIGINT, 120L)))); + + StaticDynamicFilter dynamicFilter = createStaticDynamicFilter(ImmutableList.of(new TestDynamicFilter( + TupleDomain.withColumnDomains(ImmutableMap.of( + projectedScanColumnHandle, singleValue(BIGINT, 200L), + nonProjectedScanColumnHandle, singleValue(BIGINT, 220L))), + true))); + + PageSourceProvider pageSourceProvider = new TestPageSourceProvider( + // unenforcedPredicateSupplier + _ -> TupleDomain.withColumnDomains(ImmutableMap.of( + projectedScanColumnHandle, singleValue(BIGINT, 300L), + nonProjectedScanColumnHandle, singleValue(BIGINT, 310L))), + // prunePredicateSupplier + _ -> TupleDomain.withColumnDomains(ImmutableMap.of( + projectedScanColumnHandle, singleValue(BIGINT, 400L), + nonProjectedScanColumnHandle, singleValue(BIGINT, 410L)))); + DriverFactory driverFactory = createDriverFactory(new AtomicInteger()); + CacheDriverFactory cacheDriverFactory = new CacheDriverFactory( + 0, + true, + true, + OptionalInt.empty(), + PLAN_NODE_ID, + Session.builder(TEST_SESSION) + // dynamic row filtering prevents propagation of domain values + .setSystemProperty(ENABLE_DYNAMIC_ROW_FILTERING, "false") + .build(), + pageSourceProvider, + registry, + tupleDomainCodec, + TEST_TABLE_HANDLE, + signature, + columnHandles, + createStaticDynamicFilterSupplier(ImmutableList.of(dynamicFilter)), + createStaticDynamicFilterSupplier(ImmutableList.of(EMPTY)), + ImmutableList.of(driverFactory, driverFactory, driverFactory), + new CacheStats()); + + splitCache.addExpectedCacheLookup( + // cacheId + appendRemainingPredicates( + SPLIT_ID, + Optional.of(tupleDomainCodec.toJson(normalizeTupleDomain(TupleDomain.withColumnDomains(ImmutableMap.of(nonProjectedScanColumnId, singleValue(BIGINT, 410L)))))), + Optional.of(tupleDomainCodec.toJson(normalizeTupleDomain(TupleDomain.withColumnDomains(ImmutableMap.of(nonProjectedScanColumnId, singleValue(BIGINT, 310L))))))), + // predicate + TupleDomain.withColumnDomains(ImmutableMap.of( + projectedColumnId, singleValue(BIGINT, 110L), + projectedScanColumnId, singleValue(BIGINT, 400L))), + // unenforcedPredicate + TupleDomain.withColumnDomains(ImmutableMap.of(projectedScanColumnId, singleValue(BIGINT, 300L)))); + Driver driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + } + + @Test + public void testDynamicRowFilteringUnenforcedPredicate() + { + Split split = new Split(TEST_CATALOG_HANDLE, createRemoteSplit(), Optional.of(SPLIT_ID), true); + CacheColumnId nonDfColumnId = new CacheColumnId("nonDfColumnId"); + CacheColumnId dfColumnId1 = new CacheColumnId("dfColumnId"); + CacheColumnId dfColumnId2 = new CacheColumnId("dfColumnId2"); + + ColumnHandle nonDfColumnHandle = new TestingColumnHandle("nonDfColumnHandle"); + ColumnHandle dfColumnHandle1 = new TestingColumnHandle("dfColumnHandle1"); + ColumnHandle dfColumnHandle2 = new TestingColumnHandle("dfColumnHandle2"); + + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature(SIGNATURE_KEY, Optional.empty(), ImmutableList.of(nonDfColumnId, dfColumnId1, dfColumnId2), ImmutableList.of(BIGINT, BIGINT, BIGINT)), + TupleDomain.all()); + TestPageSourceProvider pageSourceProvider = new TestPageSourceProvider( + (_) -> TupleDomain.withColumnDomains(ImmutableMap.of( + nonDfColumnHandle, singleValue(BIGINT, 10L), + dfColumnHandle1, multipleValues(BIGINT, ImmutableList.of(20L, 22L, 23L)))), + (_) -> TupleDomain.withColumnDomains(ImmutableMap.of( + dfColumnHandle1, multipleValues(BIGINT, ImmutableList.of(20L, 21L))))); + DriverFactory driverFactory = createDriverFactory(new AtomicInteger()); + CacheDriverFactory cacheDriverFactory = createCacheDriverFactory( + pageSourceProvider, + signature, + ImmutableMap.of(nonDfColumnHandle, nonDfColumnId, dfColumnHandle1, dfColumnId1, dfColumnHandle2, dfColumnId2), + createStaticDynamicFilterSupplier(ImmutableList.of(new TestDynamicFilter(TupleDomain.withColumnDomains(ImmutableMap.of( + dfColumnHandle1, multipleValues(BIGINT, ImmutableList.of(20L, 21L, 23L)), + dfColumnHandle2, singleValue(BIGINT, 30L))), true))), + createStaticDynamicFilterSupplier(ImmutableList.of(EMPTY)), + driverFactory); + // signature unenforced predicate consists of pruned DF columns which are intersected with unenforced predicate of delegate page source provider + splitCache.addExpectedCacheLookup( + TupleDomain.withColumnDomains(ImmutableMap.of( + dfColumnId1, multipleValues(BIGINT, ImmutableList.of(20L, 21L)))), + TupleDomain.withColumnDomains(ImmutableMap.of( + nonDfColumnId, singleValue(BIGINT, 10L), + dfColumnId1, singleValue(BIGINT, 20L)))); + Driver driver = cacheDriverFactory.createDriver(createDriverContext(), createSplit(Optional.of(SPLIT_ID))); + assertThat(driver.getDriverContext().getCacheDriverContext()).isPresent(); + + // delegate provider returns TupleDomain.none() + assertThat(getDynamicRowFilteringUnenforcedPredicate( + new TestPageSourceProvider((_) -> TupleDomain.none(), (_) -> TupleDomain.none()), + TEST_SESSION, + split, + TEST_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of(dfColumnHandle1, singleValue(BIGINT, 1L))))) + .isEqualTo(TupleDomain.none()); + + // delegate provider returns TupleDomain.all() + assertThat(getDynamicRowFilteringUnenforcedPredicate( + new TestPageSourceProvider((_) -> TupleDomain.all(), (_) -> TupleDomain.all()), + TEST_SESSION, + split, + TEST_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of(dfColumnHandle1, singleValue(BIGINT, 1L))))) + .isEqualTo(TupleDomain.all()); + } + + private Optional createSplit(Optional cacheSplitId) + { + return createSplit(cacheSplitId, true); + } + + private Optional createSplit(Optional cacheSplitId, boolean splitAddressEnforced) + { + return Optional.of(new ScheduledSplit( + 0, + PLAN_NODE_ID, + new Split(TEST_CATALOG_HANDLE, createRemoteSplit(), cacheSplitId, splitAddressEnforced))); + } + + private static class TestPageSourceProvider + implements PageSourceProvider + { + private final Function, TupleDomain> unenforcedPredicateSupplier; + private final Function, TupleDomain> prunePredicateSupplier; + + public TestPageSourceProvider( + Function, TupleDomain> unenforcedPredicateSupplier, + Function, TupleDomain> prunePredicateSupplier) + { + this.unenforcedPredicateSupplier = unenforcedPredicateSupplier; + this.prunePredicateSupplier = prunePredicateSupplier; + } + + public TestPageSourceProvider() + { + // mimic connector returning compact effective predicate on extra column + this(identity(), identity()); + } + + @Override + public ConnectorPageSource createPageSource( + Session session, + Split split, + TableHandle table, + List columns, + DynamicFilter dynamicFilter) + { + return new FixedPageSource(rowPagesBuilder(BIGINT).build()); + } + + @Override + public TupleDomain getUnenforcedPredicate( + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + return unenforcedPredicateSupplier.apply(predicate); + } + + @Override + public TupleDomain prunePredicate( + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + return prunePredicateSupplier.apply(predicate); + } + } + + private CacheDriverFactory createCacheDriverFactory( + PlanSignatureWithPredicate signature, + Map columnHandles, + Supplier commonDynamicFilterSupplier, + Supplier originalDynamicFilterSupplier, + DriverFactory driverFactory) + { + return createCacheDriverFactory( + new TestPageSourceProvider(), + signature, + columnHandles, + commonDynamicFilterSupplier, + originalDynamicFilterSupplier, + driverFactory); + } + + private CacheDriverFactory createCacheDriverFactory( + TestPageSourceProvider pageSourceProvider, + PlanSignatureWithPredicate signature, + Map columnHandles, + Supplier commonDynamicFilterSupplier, + Supplier originalDynamicFilterSupplier, + DriverFactory driverFactory) + { + return new CacheDriverFactory( + 0, + true, + true, + OptionalInt.empty(), + PLAN_NODE_ID, + TEST_SESSION, + pageSourceProvider, + registry, + tupleDomainCodec, + TEST_TABLE_HANDLE, + signature, + columnHandles.entrySet().stream().collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)), + commonDynamicFilterSupplier, + originalDynamicFilterSupplier, + ImmutableList.of(driverFactory, driverFactory, driverFactory), + new CacheStats()); + } + + private CacheDriverFactory createCacheDriverFactory( + TestPageSourceProvider pageSourceProvider, + PlanSignatureWithPredicate signature, + AtomicInteger operatorIdAllocator) + { + DriverFactory driverFactory = createDriverFactory(operatorIdAllocator); + return new CacheDriverFactory(0, + true, + true, + OptionalInt.empty(), + PLAN_NODE_ID, + TEST_SESSION, + pageSourceProvider, + registry, + tupleDomainCodec, + TEST_TABLE_HANDLE, + signature, + ImmutableMap.of(new CacheColumnId("column"), new TestingColumnHandle("column")), + createStaticDynamicFilterSupplier(ImmutableList.of(EMPTY)), + createStaticDynamicFilterSupplier(ImmutableList.of(EMPTY)), + ImmutableList.of(driverFactory, driverFactory, driverFactory), + new CacheStats()); + } + + private DriverFactory createDriverFactory(AtomicInteger operatorIdAllocator) + { + return new OperatorDriverFactory( + operatorIdAllocator.incrementAndGet(), + true, + false, + ImmutableList.of( + new DevNullOperator.DevNullOperatorFactory(0, planNodeIdAllocator.getNextId()), + new DevNullOperator.DevNullOperatorFactory(1, planNodeIdAllocator.getNextId())), + OptionalInt.empty()); + } + + private DriverContext createDriverContext() + { + return TestingTaskContext.createTaskContext(directExecutor(), scheduledExecutor, TEST_SESSION) + .addPipelineContext(0, true, true, false) + .addDriverContext(); + } + + private static class TestDynamicFilter + implements DynamicFilter + { + private TupleDomain dynamicPredicate; + private boolean complete; + private CompletableFuture blocked = new CompletableFuture<>(); + + public TestDynamicFilter(TupleDomain dynamicPredicate, boolean complete) + { + this.dynamicPredicate = dynamicPredicate; + this.complete = complete; + } + + public void setDynamicPredicate(TupleDomain dynamicPredicate, boolean complete) + { + checkState(!this.complete); + this.dynamicPredicate = dynamicPredicate; + this.complete = complete; + CompletableFuture blocked = this.blocked; + if (!complete) { + this.blocked = new CompletableFuture<>(); + } + blocked.complete(null); + } + + @Override + public Set getColumnsCovered() + { + return ImmutableSet.of(); + } + + @Override + public CompletableFuture isBlocked() + { + return blocked; + } + + @Override + public boolean isComplete() + { + return complete; + } + + @Override + public boolean isAwaitable() + { + return !complete; + } + + @Override + public TupleDomain getCurrentPredicate() + { + return dynamicPredicate; + } + } + + private static class TestCacheManagerFactory + implements CacheManagerFactory + { + private final TestCacheManager cacheManager = new TestCacheManager(); + + public TestCacheManager getCacheManager() + { + return cacheManager; + } + + @Override + public String getName() + { + return "TestCacheManager"; + } + + @Override + public TestCacheManager create(Map config, CacheManagerContext context) + { + return cacheManager; + } + } + + private static class TestCacheManager + implements CacheManager + { + private final TestSplitCache splitCache = new TestSplitCache(); + + public TestSplitCache getSplitCache() + { + return splitCache; + } + + @Override + public TestSplitCache getSplitCache(PlanSignature signature) + { + return splitCache; + } + + @Override + public long revokeMemory(long bytesToRevoke) + { + throw new UnsupportedOperationException(); + } + } + + private static class TestSplitCache + implements SplitCache + { + private final LinkedList expectedCacheLookups = new LinkedList<>(); + + public void addExpectedCacheLookup(TupleDomain predicate, TupleDomain unenforcedPredicate) + { + addExpectedCacheLookup(Optional.empty(), Optional.of(predicate), Optional.of(unenforcedPredicate), Optional.of(new EmptyPageSource())); + } + + public void addExpectedCacheLookup(CacheSplitId cacheSplitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + addExpectedCacheLookup(Optional.of(cacheSplitId), Optional.of(predicate), Optional.of(unenforcedPredicate), Optional.of(new EmptyPageSource())); + } + + public void addExpectedCacheLookup( + Optional expectedCacheSplitId, + Optional> expectedPredicate, + Optional> expectedUnenforcedPredicate, + Optional pageSource) + { + expectedCacheLookups.add(new CacheLookup(expectedCacheSplitId, expectedPredicate, expectedUnenforcedPredicate, pageSource)); + } + + @Override + public Optional loadPages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + assertThat(expectedCacheLookups.size()).isGreaterThan(0); + CacheLookup cacheLookup = expectedCacheLookups.pollFirst(); + cacheLookup.expectedPredicate.ifPresent(expected -> assertThat(predicate).isEqualTo(expected)); + cacheLookup.expectedUnenforcedPredicate.ifPresent(expected -> assertThat(unenforcedPredicate).isEqualTo(expected)); + cacheLookup.expectedCacheSplitId.ifPresent(expected -> assertThat(splitId).isEqualTo(expected)); + return cacheLookup.pageSource; + } + + @Override + public Optional storePages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + assertThat(expectedCacheLookups.size()).isGreaterThan(0); + CacheLookup cacheLookup = expectedCacheLookups.pollFirst(); + cacheLookup.expectedPredicate.ifPresent(expected -> assertThat(predicate).isEqualTo(expected)); + cacheLookup.expectedUnenforcedPredicate.ifPresent(expected -> assertThat(unenforcedPredicate).isEqualTo(expected)); + cacheLookup.expectedCacheSplitId.ifPresent(expected -> assertThat(splitId).isEqualTo(expected)); + return Optional.empty(); + } + + @Override + public void close() {} + + record CacheLookup( + Optional expectedCacheSplitId, + Optional> expectedPredicate, + Optional> expectedUnenforcedPredicate, + Optional pageSource) {} + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheManagerRegistry.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheManagerRegistry.java new file mode 100644 index 000000000000..03bfffa1188f --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheManagerRegistry.java @@ -0,0 +1,158 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.units.DataSize; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.memory.LocalMemoryManager; +import io.trino.memory.NodeMemoryConfig; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheManagerFactory; +import io.trino.spi.cache.MemoryAllocator; +import io.trino.spi.cache.PlanSignature; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.Map; +import java.util.OptionalLong; + +import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestCacheManagerRegistry +{ + private static final TaskId TASK_ID = new TaskId(new StageId("id", 0), 1, 2); + private static final String TEST_CACHE_MANAGER = "test-manager"; + + private LocalMemoryManager memoryManager; + private TestCacheManager cacheManager; + private CacheManagerRegistry registry; + + @BeforeEach + public void setup() + { + NodeMemoryConfig config = new NodeMemoryConfig() + .setHeapHeadroom(DataSize.of(10, MEGABYTE)) + .setMaxQueryMemoryPerNode(DataSize.of(100, MEGABYTE)); + + memoryManager = new LocalMemoryManager(config, DataSize.of(110, MEGABYTE).toBytes()); + registry = new CacheManagerRegistry(new CacheConfig(), memoryManager, newDirectExecutorService(), new TestingBlockEncodingSerde(), new CacheStats()); + registry.addCacheManagerFactory(new TestCacheManagerFactory()); + registry.loadCacheManager(TEST_CACHE_MANAGER, ImmutableMap.of()); + } + + @Test + public void testRevokeMemoryOnListener() + { + assertThat(cacheManager.tryAllocateMemory(DataSize.of(90, MEGABYTE).toBytes())).isTrue(); + + // revoke should not be triggered + assertThat(cacheManager.getBytesToRevoke()).isEmpty(); + + // allocating query memory should trigger cache revoke + ListenableFuture memoryFuture = memoryManager.getMemoryPool().reserve(TASK_ID, "allocation", DataSize.of(80, MEGABYTE).toBytes()); + assertThat(memoryFuture).isNotDone(); + assertThat(cacheManager.getBytesToRevoke()).hasValue(DataSize.of(100, MEGABYTE).toBytes()); + + // freeing memory should unblock memory future + assertThat(cacheManager.tryAllocateMemory(20)).isTrue(); + assertThat(memoryFuture).isDone(); + } + + @Test + public void testRevokeMemoryOnBigAllocation() + { + assertThat(cacheManager.tryAllocateMemory(DataSize.of(90, MEGABYTE).toBytes())).isTrue(); + assertThat(cacheManager.getBytesToRevoke()).isEmpty(); + assertThat(registry.getNonEmptyRevokeCount()).isEqualTo(0); + assertThat(registry.getDistributionSizeRevokedMemory().getCount()).isEqualTo(0); + + assertThat(cacheManager.tryAllocateMemory(DataSize.of(95, MEGABYTE).toBytes())).isFalse(); + assertThat(cacheManager.getBytesToRevoke()).hasValue(DataSize.of(20, MEGABYTE).toBytes()); + assertThat(registry.getNonEmptyRevokeCount()).isEqualTo(1); + assertThat(registry.getDistributionSizeRevokedMemory().getCount()).isEqualTo(1); + assertThat(registry.getDistributionSizeRevokedMemory().getAvg()).isEqualTo(DataSize.of(20, MEGABYTE).toBytes()); + } + + private class TestCacheManagerFactory + implements CacheManagerFactory + { + @Override + public String getName() + { + return TEST_CACHE_MANAGER; + } + + @Override + public CacheManager create(Map config, CacheManagerContext context) + { + requireNonNull(context, "context is null"); + requireNonNull(context.revocableMemoryAllocator(), "revocableMemoryAllocator is null"); + requireNonNull(context.blockEncodingSerde(), "revocableMemoryAllocator is null"); + cacheManager = new TestCacheManager(context.revocableMemoryAllocator()); + return cacheManager; + } + } + + private static class TestCacheManager + implements CacheManager + { + private final MemoryAllocator allocator; + private OptionalLong bytesToRevoke = OptionalLong.empty(); + + private TestCacheManager(MemoryAllocator allocator) + { + this.allocator = requireNonNull(allocator, "allocator is null"); + } + + @Override + public SplitCache getSplitCache(PlanSignature signature) + { + throw new UnsupportedOperationException(); + } + + @Override + public long revokeMemory(long bytesToRevoke) + { + if (this.bytesToRevoke.isPresent()) { + return 0L; + } + this.bytesToRevoke = OptionalLong.of(bytesToRevoke); + return bytesToRevoke; + } + + private boolean tryAllocateMemory(long bytes) + { + return allocator.trySetBytes(bytes); + } + + private OptionalLong getBytesToRevoke() + { + return bytesToRevoke; + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCacheSplitAdmissionController.java b/core/trino-main/src/test/java/io/trino/cache/TestCacheSplitAdmissionController.java new file mode 100644 index 000000000000..b147311af59f --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCacheSplitAdmissionController.java @@ -0,0 +1,279 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import io.trino.client.NodeVersion; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.spi.HostAddress; +import io.trino.spi.Node; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.connector.ConnectorSplitManager; +import io.trino.split.MockSplitSource; +import io.trino.split.SplitSource; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.split.MockSplitSource.Action.FINISH; +import static io.trino.split.SplitSource.SplitBatch; +import static java.lang.Math.abs; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCacheSplitAdmissionController +{ + @Test + public void testMinSplitProcessedPerWorkerThreshold() + throws ExecutionException, InterruptedException + { + int minSplitSeparation = 5; + int availableSplits = 50; + int minSplitBatchSize = 20; + MinSeparationSplitAdmissionController controller = new MinSeparationSplitAdmissionController(minSplitSeparation); + Function> addressProvider = createAddressProvider(2); + SplitSource cacheSplitSourceA = createCacheSplitSource(addressProvider, controller, availableSplits, minSplitBatchSize); + SplitSource cacheSplitSourceB = createCacheSplitSource(addressProvider, controller, availableSplits, minSplitBatchSize); + SplitSource cacheSplitSourceC = createCacheSplitSource(addressProvider, controller, availableSplits, minSplitBatchSize); + + // Batch 1: + SplitBatch batchA1 = cacheSplitSourceA.getNextBatch(10).get(); + assertCacheSplitIds(batchA1, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(1, 3, 5, 7, 9)), + "node1", createCacheSplitIds(IntStream.of(0, 2, 4, 6, 8)))); + assertThat(batchA1.isLastBatch()).isFalse(); + + SplitBatch batchB1 = cacheSplitSourceB.getNextBatch(10).get(); + assertCacheSplitIds(batchB1, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(10, 12, 14, 16, 18)), + "node1", createCacheSplitIds(IntStream.of(11, 13, 15, 17, 19)))); + assertThat(batchB1.isLastBatch()).isFalse(); + + SplitBatch batchC1 = cacheSplitSourceC.getNextBatch(10).get(); + assertCacheSplitIds(batchC1, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(21, 23, 25, 27, 29)), + "node1", createCacheSplitIds(IntStream.of(20, 22, 24, 26, 28)))); + assertThat(batchC1.isLastBatch()).isFalse(); + + controller.splitsScheduled(batchA1.getSplits()); + controller.splitsScheduled(batchB1.getSplits()); + controller.splitsScheduled(batchC1.getSplits()); + + // Batch 2: + SplitBatch batchA2 = cacheSplitSourceA.getNextBatch(20).get(); + // Since we have crossed the gap limit for some splits, we will get those splits from queue + assertCacheSplitIds(batchA2, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(10, 12, 14, 16, 18, 21, 30, 32, 34, 36)), + "node1", createCacheSplitIds(IntStream.of(11, 13, 15, 17, 19, 20, 31, 33, 35, 37)))); + assertThat(batchA2.isLastBatch()).isFalse(); + + SplitBatch batchB2 = cacheSplitSourceB.getNextBatch(19).get(); + // Since we have crossed the gap limit for some splits, we will get those splits from queue + assertCacheSplitIds(batchB2, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(1, 3, 5, 7, 9, 21, 38, 41, 43)), + "node1", createCacheSplitIds(IntStream.of(0, 2, 4, 6, 8, 20, 39, 40, 42, 44)))); + assertThat(batchB2.isLastBatch()).isFalse(); + + SplitBatch batchC2 = cacheSplitSourceC.getNextBatch(20).get(); + // Since we have crossed the gap limit for some splits, we will get those splits from queue + assertCacheSplitIds(batchC2, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(1, 3, 5, 7, 9, 10, 12, 14, 16, 18)), + "node1", createCacheSplitIds(IntStream.of(0, 2, 4, 6, 8, 11, 13, 15, 17, 19)))); + assertThat(batchC2.isLastBatch()).isFalse(); + + controller.splitsScheduled(batchA2.getSplits()); + + // Batch3 + // Since we have scheduled all the splits, we will get the remaining splits from the queue + SplitBatch batchA3 = cacheSplitSourceA.getNextBatch(30).get(); + assertCacheSplitIds(batchA3, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(23, 25, 27, 29, 45, 47, 49, 38, 41, 43)), + "node1", createCacheSplitIds(IntStream.of(22, 24, 26, 28, 46, 48, 39, 40, 42, 44)))); + assertThat(batchA3.isLastBatch()).isTrue(); + + SplitBatch batchB3 = cacheSplitSourceB.getNextBatch(30).get(); + assertCacheSplitIds(batchB3, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(23, 25, 27, 29, 30, 32, 34, 36, 45, 47, 49)), + "node1", createCacheSplitIds(IntStream.of(22, 24, 26, 28, 31, 33, 35, 37, 46, 48)))); + assertThat(batchB3.isLastBatch()).isTrue(); + + SplitBatch batchC3 = cacheSplitSourceC.getNextBatch(30).get(); + assertCacheSplitIds(batchC3, ImmutableMap.of( + "node0", createCacheSplitIds(IntStream.of(30, 32, 34, 36, 38, 41, 43, 45, 47, 49)), + "node1", createCacheSplitIds(IntStream.of(31, 33, 35, 37, 39, 40, 42, 44, 46, 48)))); + assertThat(batchC3.isLastBatch()).isTrue(); + + // Assert that all splits are scheduled + assertThat(batchA1.getSplits().size() + batchA2.getSplits().size() + batchA3.getSplits().size()).isEqualTo(50); + assertThat(batchB1.getSplits().size() + batchB2.getSplits().size() + batchB3.getSplits().size()).isEqualTo(50); + assertThat(batchC1.getSplits().size() + batchC2.getSplits().size() + batchC3.getSplits().size()).isEqualTo(50); + + cacheSplitSourceA.close(); + cacheSplitSourceB.close(); + cacheSplitSourceC.close(); + } + + @Test + public void testSplitsFromQueueAreFairlyDistributedAmongNodes() + throws ExecutionException, InterruptedException + { + int minSplitSeparation = 1; + int availableSplits = 100; + int minSplitBatchSize = 40; + MinSeparationSplitAdmissionController controller = new MinSeparationSplitAdmissionController(minSplitSeparation); + Function> addressProvider = createAddressProvider(3); + SplitSource cacheSplitSourceA = createCacheSplitSource(addressProvider, controller, availableSplits, minSplitBatchSize); + SplitSource cacheSplitSourceB = createCacheSplitSource(addressProvider, controller, availableSplits, minSplitBatchSize); + + // Batch 1: + SplitBatch batchA1 = cacheSplitSourceA.getNextBatch(40).get(); + assertCacheSplitIds(batchA1, createCacheSplitIds(IntStream.range(0, 40))); + assertThat(batchA1.isLastBatch()).isFalse(); + + SplitBatch batchB1 = cacheSplitSourceB.getNextBatch(40).get(); + assertCacheSplitIds(batchB1, createCacheSplitIds(IntStream.range(40, 80))); + assertThat(batchB1.isLastBatch()).isFalse(); + + controller.splitsScheduled(batchA1.getSplits()); + controller.splitsScheduled(batchB1.getSplits()); + + // Batch 2: + // This will be mostly from the queue since the minSplitSeparation is quite low, and we have already + // scheduled 30 splits + + // We will get 5 splits from the queue. However, these 5 splits should belong to diverse nodes such that + // we don't end up scheduling all the splits from a single node to avoid skewness. + SplitBatch batchB2 = cacheSplitSourceB.getNextBatch(5).get(); + // splits are diverse and belong to all available nodes + assertHostAddress(batchB2, ImmutableSet.of("node0", "node1", "node2")); + assertThat(batchB1.isLastBatch()).isFalse(); + + cacheSplitSourceA.close(); + cacheSplitSourceB.close(); + } + + private static List createCacheSplitIds(IntStream splitIds) + { + return splitIds.boxed() + .map(i -> new CacheSplitId("split" + i)) + .collect(toImmutableList()); + } + + private static void assertCacheSplitIds(SplitBatch batch, Map> expected) + { + Map> actual = new HashMap<>(); + for (Split split : batch.getSplits()) { + String address = getOnlyElement(split.getAddresses()).toString(); + actual.computeIfAbsent(address, _ -> new ArrayList<>()).add(split.getCacheSplitId().get()); + } + assertThat(actual).isEqualTo(expected); + } + + private static void assertCacheSplitIds(SplitBatch batch, List expected) + { + List actual = batch.getSplits().stream() + .map(Split::getCacheSplitId) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + assertThat(actual).containsExactlyInAnyOrderElementsOf(expected); + } + + private static void assertHostAddress(SplitBatch batch, Set addresses) + { + Set hostAddresses = batch.getSplits().stream() + .map(Split::getAddresses) + .map(Iterables::getOnlyElement) + .map(HostAddress::toString) + .collect(toImmutableSet()); + assertThat(hostAddresses).containsExactlyInAnyOrderElementsOf(addresses); + } + + private static CacheSplitSource createCacheSplitSource( + Function> addressProvider, + SplitAdmissionController scheduler, + int availableSplits, + int minSplitBatchSize) + { + return new CacheSplitSource( + createPlanSignature("signature1"), + new TestingSplitManager(), + createMockSplitSource(availableSplits), + addressProvider, + scheduler, + minSplitBatchSize); + } + + private static MockSplitSource createMockSplitSource(int availableSplits) + { + MockSplitSource mockSplitSource = new MockSplitSource(); + mockSplitSource.setBatchSize(10); + mockSplitSource.increaseAvailableSplits(availableSplits); + mockSplitSource.atSplitCompletion(FINISH); + return mockSplitSource; + } + + private static Function> createAddressProvider(int numNodes) + { + List nodes = IntStream.range(0, numNodes) + .mapToObj(i -> node("node" + i)) + .collect(toImmutableList()); + return value -> Optional.of(nodes.get(abs(value.hashCode()) % numNodes)) + .map(Node::getHostAndPort); + } + + private static Node node(String nodeName) + { + return new InternalNode(nodeName, URI.create("http://" + nodeName + "/"), NodeVersion.UNKNOWN, false); + } + + private static PlanSignature createPlanSignature(String signature) + { + return new PlanSignature( + new SignatureKey(signature), + Optional.empty(), + ImmutableList.of(), + ImmutableList.of()); + } + + private static class TestingSplitManager + implements ConnectorSplitManager + { + private int splitCount; + + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + return Optional.of(new CacheSplitId("split" + splitCount++)); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCanonicalSubplanExtractor.java b/core/trino-main/src/test/java/io/trino/cache/TestCanonicalSubplanExtractor.java new file mode 100644 index 000000000000..894df366b4e2 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCanonicalSubplanExtractor.java @@ -0,0 +1,1030 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.base.Functions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.cache.CanonicalSubplan.AggregationKey; +import io.trino.cache.CanonicalSubplan.FilterProjectKey; +import io.trino.cache.CanonicalSubplan.Key; +import io.trino.cache.CanonicalSubplan.ScanFilterProjectKey; +import io.trino.cache.CanonicalSubplan.TableScan; +import io.trino.cache.CanonicalSubplan.TopNKey; +import io.trino.cache.CanonicalSubplan.TopNRankingKey; +import io.trino.metadata.AbstractMockMetadata; +import io.trino.metadata.MetadataManager; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableHandle; +import io.trino.metadata.TableProperties; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.plugin.tpch.TpchColumnHandle; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorTableProperties; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.SortOrder; +import io.trino.spi.connector.TestingColumnHandle; +import io.trino.spi.function.OperatorType; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.DynamicFilters; +import io.trino.sql.PlannerContext; +import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Plan; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TestingPlannerContext; +import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DynamicFilterId; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.testing.PlanTester; +import io.trino.testing.TestingHandles; +import io.trino.testing.TestingMetadata; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getLast; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalAggregationToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalExpressionToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.columnIdToSymbol; +import static io.trino.cache.CanonicalSubplanExtractor.extractCanonicalSubplans; +import static io.trino.metadata.TestMetadataManager.createTestMetadataManager; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.IrUtils.and; +import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; +import static io.trino.sql.planner.plan.TopNRankingNode.RankingType.RANK; +import static io.trino.sql.planner.plan.TopNRankingNode.RankingType.ROW_NUMBER; +import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static java.util.Map.entry; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public class TestCanonicalSubplanExtractor + extends BasePlanTest +{ + private static final Session TEST_SESSION = testSessionBuilder().build(); + private static final CacheTableId CACHE_TABLE_ID = new CacheTableId("cache_table_id"); + private static final PlanNodeId SCAN_NODE_ID = new PlanNodeId("scan_id"); + private static final String CATALOG_ID = TEST_TABLE_HANDLE.catalogHandle().getId(); + private static final CacheTableId CATALOG_CACHE_TABLE_ID = new CacheTableId(CATALOG_ID + ":" + CACHE_TABLE_ID); + private static final CacheColumnId CACHE_COL1 = new CacheColumnId("[cache_column1]"); + private static final CacheColumnId CACHE_COL2 = new CacheColumnId("[cache_column2]"); + private static final CacheColumnId REGIONKEY_ID = new CacheColumnId("[regionkey:bigint]"); + private static final CacheColumnId NATIONKEY_ID = new CacheColumnId("[nationkey:bigint]"); + private static final CacheColumnId NAME_ID = new CacheColumnId("[name:varchar(25)]"); + private static final Reference REGIONKEY_REF = new Reference(BIGINT, "[regionkey:bigint]"); + private static final Reference NATIONKEY_REF = new Reference(BIGINT, "[nationkey:bigint]"); + private static final Reference NAME_REF = new Reference(createVarcharType(25), "[name:varchar(25)]"); + private static final Reference CACHE_COL1_REF = new Reference(BIGINT, "[cache_column1]"); + private static final Reference CACHE_COL2_REF = new Reference(BIGINT, "[cache_column2]"); + + private static final PlannerContext PLANNER_CONTEXT = TestingPlannerContext.plannerContextBuilder() + .withMetadata(new MockMetadata()) + .withCacheMetadata(new TestCacheMetadata()) + .build(); + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction MULTIPLY_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + + private PlanBuilder planBuilder; + private String tpchCatalogId; + + public TestCanonicalSubplanExtractor() + { + super(ImmutableMap.of( + // increase task concurrency to get parallel plans + TASK_CONCURRENCY, "4")); + } + + @BeforeAll + public void setup() + { + planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), PLANNER_CONTEXT, TEST_SESSION); + tpchCatalogId = getPlanTester().getCatalogHandle(getPlanTester().getDefaultSession().getCatalog().orElseThrow()).getId(); + } + + @Test + public void testAggregationWithMultipleGroupByColumnsAndPredicate() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT sum(nationkey), sum(nationkey) filter(where nationkey > 10) + FROM nation + WHERE regionkey > BIGINT '10' + GROUP BY name, regionkey * 2 + HAVING name = '0123456789012345689012345' AND sum(nationkey) > BIGINT '10'"""); + assertThat(subplans).hasSize(2); + + CacheTableId tableId = new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"); + Expression nonPullableConjunct = new Comparison(GREATER_THAN, REGIONKEY_REF, new Constant(BIGINT, 10L)); + Expression pullableConjunct = new Comparison(EQUAL, NAME_REF, new Constant(createVarcharType(25), utf8Slice("0123456789012345689012345"))); + CanonicalSubplan nonAggregatedSubplan = subplans.get(0); + ScanFilterProjectKey scanFilterProjectKey = new ScanFilterProjectKey(tableId, ImmutableSet.of(nonPullableConjunct, pullableConjunct)); + assertThat(nonAggregatedSubplan.getKeyChain()).containsExactly(scanFilterProjectKey); + assertThat(nonAggregatedSubplan.getGroupByColumns()).isEmpty(); + assertThat(nonAggregatedSubplan.getConjuncts()).containsExactly(nonPullableConjunct, pullableConjunct); + assertThat(nonAggregatedSubplan.getPullableConjuncts()).containsExactlyElementsOf(nonAggregatedSubplan.getConjuncts()); + assertThat(nonAggregatedSubplan.getDynamicConjuncts()).isEmpty(); + assertThat(nonAggregatedSubplan.getTableScan()).isPresent(); + assertThat(nonAggregatedSubplan.getChildSubplan()).isEmpty(); + CacheColumnId regionKeyGreaterThan10 = canonicalExpressionToColumnId(new Comparison(GREATER_THAN, NATIONKEY_REF, new Constant(BIGINT, 10L))); + CacheColumnId regionKeyMultiplyBy2 = canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(REGIONKEY_REF, new Constant(BIGINT, 2L)))); + assertThat(nonAggregatedSubplan.getAssignments()).containsExactly( + entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF)), + entry(regionKeyGreaterThan10, CacheExpression.ofProjection(new Comparison(GREATER_THAN, NATIONKEY_REF, new Constant(BIGINT, 10L)))), + entry(NAME_ID, CacheExpression.ofProjection(NAME_REF)), + entry(regionKeyMultiplyBy2, CacheExpression.ofProjection(new Call(MULTIPLY_BIGINT, ImmutableList.of(REGIONKEY_REF, new Constant(BIGINT, 2L)))))); + assertThat(nonAggregatedSubplan.getTableScan().get().getColumnHandles()).containsExactly( + entry(NATIONKEY_ID, new TpchColumnHandle("nationkey", BIGINT)), + entry(NAME_ID, new TpchColumnHandle("name", createVarcharType(25))), + entry(REGIONKEY_ID, new TpchColumnHandle("regionkey", BIGINT))); + assertThat(nonAggregatedSubplan.getTableScan().get().getTableId()).isEqualTo(tableId); + + CanonicalAggregation sum = sumNationkey(); + CanonicalAggregation filteredSum = new CanonicalAggregation( + sumBigint(), + Optional.of(columnIdToSymbol(regionKeyGreaterThan10, BOOLEAN)), + List.of(NATIONKEY_REF)); + CanonicalSubplan aggregatedSubplan = subplans.get(1); + assertThat(aggregatedSubplan.getKeyChain()).containsExactly(scanFilterProjectKey, new AggregationKey(aggregatedSubplan.getGroupByColumns().get(), ImmutableSet.of(nonPullableConjunct))); + assertThat(aggregatedSubplan.getConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getPullableConjuncts()).containsExactly(pullableConjunct); + assertThat(aggregatedSubplan.getDynamicConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getChildSubplan()).contains(nonAggregatedSubplan); + assertThat(aggregatedSubplan.getOriginalPlanNode()).isInstanceOf(AggregationNode.class); + assertThat(getGroupByExpressions(aggregatedSubplan)).contains(ImmutableList.of( + NAME_REF, + columnIdToSymbol(regionKeyMultiplyBy2, BIGINT).toSymbolReference())); + assertThat(aggregatedSubplan.getOriginalSymbolMapping()).containsOnlyKeys( + NATIONKEY_ID, + NAME_ID, + REGIONKEY_ID, + canonicalExpressionToColumnId(new Comparison(GREATER_THAN, NATIONKEY_REF, new Constant(BIGINT, 10L))), + canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(REGIONKEY_REF, new Constant(BIGINT, 2L)))), + canonicalAggregationToColumnId(filteredSum), + canonicalAggregationToColumnId(sum)); + assertThat(aggregatedSubplan.getAssignments()).containsExactlyInAnyOrderEntriesOf(ImmutableMap.builder() + .put(NAME_ID, CacheExpression.ofProjection(NAME_REF)) + .put(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(REGIONKEY_REF, new Constant(BIGINT, 2L)))), CacheExpression.ofProjection(columnIdToSymbol(regionKeyMultiplyBy2, BIGINT).toSymbolReference())) + .put(canonicalAggregationToColumnId(filteredSum), CacheExpression.ofAggregation(filteredSum)) + .put(canonicalAggregationToColumnId(sum), CacheExpression.ofAggregation(sum)) + .buildOrThrow()); + } + + @Test + public void testAggregationWithMultipleGroupByColumns() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT sum(nationkey + 1) + FROM nation + GROUP BY name, regionkey"""); + assertThat(subplans).hasSize(2); + + CacheTableId tableId = new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"); + CacheColumnId nationKeyPlusOne = canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(NATIONKEY_REF, new Constant(BIGINT, 1L)))); + CanonicalSubplan nonAggregatedSubplan = subplans.get(0); + assertThat(nonAggregatedSubplan.getKeyChain()).containsExactly(new ScanFilterProjectKey(tableId, ImmutableSet.of())); + assertThat(nonAggregatedSubplan.getGroupByColumns()).isEmpty(); + assertThat(nonAggregatedSubplan.getAssignments()).containsExactly( + entry(nationKeyPlusOne, CacheExpression.ofProjection(new Call(ADD_BIGINT, ImmutableList.of(NATIONKEY_REF, new Constant(BIGINT, 1L))))), + entry(NAME_ID, CacheExpression.ofProjection(NAME_REF)), + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF))); + assertThat(nonAggregatedSubplan.getTableScan()).isPresent(); + assertThat(nonAggregatedSubplan.getChildSubplan()).isEmpty(); + assertThat(nonAggregatedSubplan.getTableScan().get().getColumnHandles()).containsExactly( + entry(NATIONKEY_ID, new TpchColumnHandle("nationkey", BIGINT)), + entry(NAME_ID, new TpchColumnHandle("name", createVarcharType(25))), + entry(REGIONKEY_ID, new TpchColumnHandle("regionkey", BIGINT))); + assertThat(nonAggregatedSubplan.getTableScan().get().getTableId()).isEqualTo(tableId); + + CanonicalAggregation sum = new CanonicalAggregation( + sumBigint(), + Optional.empty(), + List.of(columnIdToSymbol(nationKeyPlusOne, BIGINT).toSymbolReference())); + CanonicalSubplan aggregatedSubplan = subplans.get(1); + assertThat(aggregatedSubplan.getKeyChain()).containsExactly(new ScanFilterProjectKey(tableId, ImmutableSet.of()), new AggregationKey(aggregatedSubplan.getGroupByColumns().get(), ImmutableSet.of())); + assertThat(aggregatedSubplan.getOriginalPlanNode()).isInstanceOf(AggregationNode.class); + assertThat(getGroupByExpressions(aggregatedSubplan)).contains(ImmutableList.of(NAME_REF, REGIONKEY_REF)); + assertThat(aggregatedSubplan.getOriginalSymbolMapping()).containsOnlyKeys( + NATIONKEY_ID, + NAME_ID, + REGIONKEY_ID, + canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(NATIONKEY_REF, new Constant(BIGINT, 1L)))), + canonicalAggregationToColumnId(sum)); + assertThat(aggregatedSubplan.getAssignments()).containsExactly( + entry(NAME_ID, CacheExpression.ofProjection(NAME_REF)), + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF)), + entry(canonicalAggregationToColumnId(sum), CacheExpression.ofAggregation(sum))); + assertThat(aggregatedSubplan.getConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getPullableConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getDynamicConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getTableScan()).isEmpty(); + assertThat(aggregatedSubplan.getChildSubplan()).contains(nonAggregatedSubplan); + } + + @Test + public void testNestedProjections() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT regionkey + FROM (SELECT nationkey * 2 as nationkey_mul, regionkey FROM nation) + WHERE nationkey_mul + nationkey_mul > BIGINT '10' AND regionkey > BIGINT '10'"""); + assertThat(subplans).hasSize(2); + + Expression nationKeyMultiplyBy2 = new Call(MULTIPLY_BIGINT, ImmutableList.of(NATIONKEY_REF, new Constant(BIGINT, 2L))); + Expression regionKeyPredicate = new Comparison(GREATER_THAN, REGIONKEY_REF, new Constant(BIGINT, 10L)); + CacheTableId tableId = new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"); + CanonicalSubplan nestedSubplan = subplans.get(0); + ScanFilterProjectKey scanFilterProjectKey = new ScanFilterProjectKey(tableId, ImmutableSet.of(regionKeyPredicate)); + assertThat(nestedSubplan.getKeyChain()).containsExactly(scanFilterProjectKey); + assertThat(nestedSubplan.getGroupByColumns()).isEmpty(); + assertThat(nestedSubplan.getConjuncts()).containsExactly(regionKeyPredicate); + assertThat(nestedSubplan.getPullableConjuncts()).containsExactly(regionKeyPredicate); + assertThat(nestedSubplan.getDynamicConjuncts()).isEmpty(); + assertThat(nestedSubplan.getTableScan()).isPresent(); + assertThat(nestedSubplan.getChildSubplan()).isEmpty(); + assertThat(nestedSubplan.getAssignments()).containsExactly( + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF)), + entry(canonicalExpressionToColumnId(nationKeyMultiplyBy2), CacheExpression.ofProjection(nationKeyMultiplyBy2))); + assertThat(nestedSubplan.getTableScan().get().getTableId()).isEqualTo(tableId); + + Reference nationKeyMultiplyBy2Reference = columnIdToSymbol(canonicalExpressionToColumnId(nationKeyMultiplyBy2), BIGINT).toSymbolReference(); + Expression nationKeyPredicate = new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(nationKeyMultiplyBy2Reference, nationKeyMultiplyBy2Reference)), new Constant(BIGINT, 10L)); + CanonicalSubplan topSubplan = subplans.get(1); + assertThat(topSubplan.getKeyChain()).containsExactly(scanFilterProjectKey, new FilterProjectKey(ImmutableSet.of())); + assertThat(topSubplan.getConjuncts()).containsExactly(nationKeyPredicate); + assertThat(topSubplan.getPullableConjuncts()).containsExactly(regionKeyPredicate, nationKeyPredicate); + assertThat(topSubplan.getDynamicConjuncts()).isEmpty(); + assertThat(topSubplan.getTableScan()).isEmpty(); + assertThat(topSubplan.getChildSubplan()).contains(nestedSubplan); + assertThat(topSubplan.getAssignments()).containsExactly( + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF))); + } + + @Test + public void testUnsafeProjections() + { + // nationkey * 2 is unsafe expression + assertRequiredConjuncts( + "SELECT nationkey * 2 FROM nation WHERE regionkey > 10", + ScanFilterProjectKey.class, + new Comparison(GREATER_THAN, REGIONKEY_REF, new Constant(BIGINT, 10L))); + // safe expressions + assertRequiredConjuncts( + "SELECT nationkey, 42 FROM nation WHERE regionkey > 10", + ScanFilterProjectKey.class); + // nested projection; nationkey_mul is reference; "nationkey_mul * nationkey_mul > 10" is not pushed to table scan level + // therefore nationkey * nationkey (potentially unsafe) is evaluated for every input row + assertRequiredConjuncts( + "SELECT nationkey_mul FROM (SELECT nationkey * nationkey as nationkey_mul FROM nation) WHERE nationkey_mul * nationkey_mul > 10", + FilterProjectKey.class); + // nationkey * nationkey is unsafe expression + Symbol nationKeyMul = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(NATIONKEY_REF, NATIONKEY_REF))), BIGINT); + Expression nationKeyMulMul = new Call(MULTIPLY_BIGINT, ImmutableList.of(nationKeyMul.toSymbolReference(), nationKeyMul.toSymbolReference())); + assertRequiredConjuncts( + "SELECT nationkey_mul * nationkey_mul FROM (SELECT nationkey * nationkey as nationkey_mul FROM nation) WHERE nationkey_mul * nationkey_mul > 10", + FilterProjectKey.class, + new Comparison(GREATER_THAN, nationKeyMulMul, new Constant(BIGINT, 10L))); + } + + private void assertRequiredConjuncts(@Language("SQL") String query, Class keyType, Expression... expectedConjuncts) + { + List subplans = extractCanonicalSubplansForQuery(query); + assertThat(subplans).isNotEmpty(); + CanonicalSubplan topLevelSubplan = getLast(subplans); + Key topLevelKey = topLevelSubplan.getKey(); + assertThat(topLevelKey).isInstanceOf(keyType); + switch (topLevelKey) { + case ScanFilterProjectKey scanFilterProjectKey -> { + assertThat(scanFilterProjectKey.requiredConjuncts()).containsExactly(expectedConjuncts); + } + case FilterProjectKey filterProjectKey -> { + assertThat(filterProjectKey.requiredConjuncts()).containsExactly(expectedConjuncts); + } + default -> fail("Unexpected key type: " + topLevelKey.getClass()); + } + } + + @Test + public void testBigintAggregation() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT sum(nationkey) + FROM nation + GROUP BY regionkey"""); + assertThat(subplans).hasSize(2); + + CanonicalSubplan nonAggregatedSubplan = subplans.get(0); + assertThat(nonAggregatedSubplan.getGroupByColumns()).isEmpty(); + assertThat(nonAggregatedSubplan.getAssignments()).containsExactly( + entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF)), + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF))); + + CanonicalAggregation sum = sumNationkey(); + CanonicalSubplan aggregatedSubplan = subplans.get(1); + assertThat(aggregatedSubplan.getOriginalPlanNode()).isInstanceOf(AggregationNode.class); + assertThat(getGroupByExpressions(aggregatedSubplan)).contains(ImmutableList.of(REGIONKEY_REF)); + assertThat(aggregatedSubplan.getOriginalSymbolMapping()).containsOnlyKeys( + NATIONKEY_ID, + REGIONKEY_ID, + canonicalAggregationToColumnId(sum)); + assertThat(aggregatedSubplan.getAssignments()).containsExactly( + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF)), + entry(canonicalAggregationToColumnId(sum), CacheExpression.ofAggregation(sum))); + assertThat(aggregatedSubplan.getConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getDynamicConjuncts()).isEmpty(); + } + + @Test + public void testGlobalAggregation() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT sum(nationkey) + FROM nation"""); + assertThat(subplans).hasSize(2); + + CanonicalSubplan nonAggregatedSubplan = subplans.get(0); + assertThat(nonAggregatedSubplan.getGroupByColumns()).isEmpty(); + + CanonicalAggregation sum = sumNationkey(); + CanonicalSubplan aggregatedSubplan = subplans.get(1); + assertThat(aggregatedSubplan.getOriginalPlanNode()).isInstanceOf(AggregationNode.class); + assertThat(aggregatedSubplan.getGroupByColumns()).contains(ImmutableSet.of()); + assertThat(aggregatedSubplan.getOriginalSymbolMapping()).containsOnlyKeys( + NATIONKEY_ID, + canonicalAggregationToColumnId(sum)); + assertThat(aggregatedSubplan.getAssignments()).containsExactly( + entry(canonicalAggregationToColumnId(sum), CacheExpression.ofAggregation(sum))); + assertThat(aggregatedSubplan.getConjuncts()).isEmpty(); + assertThat(aggregatedSubplan.getDynamicConjuncts()).isEmpty(); + } + + @Test + public void testNestedAggregations() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT sum(sum_nationkey) + FROM (SELECT sum(nationkey) sum_nationkey, name + FROM nation + GROUP BY name, regionkey) + GROUP BY name || 'abc'"""); + assertThat(subplans).hasSize(2); + + CanonicalSubplan nonAggregatedSubplan = subplans.get(0); + assertThat(nonAggregatedSubplan.getGroupByColumns()).isEmpty(); + + CanonicalSubplan aggregatedSubplan = subplans.get(1); + assertThat(aggregatedSubplan.getOriginalPlanNode()).isInstanceOf(AggregationNode.class); + assertThat(getGroupByExpressions(aggregatedSubplan)).contains(ImmutableList.of(NAME_REF, REGIONKEY_REF)); + } + + @Test + public void testUnsupportedAggregations() + { + assertUnsupportedAggregation("SELECT array_agg(nationkey order by nationkey) FROM nation"); + assertUnsupportedAggregation("SELECT sum(nationkey) FROM nation GROUP BY ROLLUP (nationkey)"); + assertUnsupportedAggregation("SELECT sum(distinct nationkey), sum(distinct regionkey) FROM nation"); + } + + private void assertUnsupportedAggregation(@Language("SQL") String query) + { + List subplans = extractCanonicalSubplansForQuery(query); + assertThat(subplans).hasSize(1); + assertThat(getOnlyElement(subplans).getGroupByColumns()).isEmpty(); + } + + @Test + public void testTopNRankingRank() + { + List subplans = extractCanonicalSubplansForQuery("SELECT name, regionkey FROM nation ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES", false); + assertThat(subplans).hasSize(2); + CanonicalSubplan scanSubplan = subplans.get(0); + assertThat(scanSubplan.getAssignments()).containsExactly( + entry(NAME_ID, CacheExpression.ofProjection(NAME_REF)), + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF))); + CanonicalSubplan topNSubplan = subplans.get(1); + assertThat(topNSubplan.getChildSubplan().get()).isEqualTo(scanSubplan); + assertThat(topNSubplan.getKey()).isInstanceOf(TopNRankingKey.class); + TopNRankingKey key = (TopNRankingKey) topNSubplan.getKey(); + assertThat(key.partitionBy()).isEqualTo(ImmutableList.of()); + assertThat(key.orderings()).containsExactly(entry(REGIONKEY_ID, SortOrder.ASC_NULLS_LAST)); + assertThat(key.rankingType()).isEqualTo(RANK); + assertThat(key.maxRankingPerPartition()).isEqualTo(6); + } + + @Test + public void testTopNRankingRowNumber() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT * + FROM (SELECT nationkey, ROW_NUMBER () OVER (PARTITION BY name, nationkey ORDER BY regionkey DESC) update_rank FROM nation) AS t + WHERE t.update_rank = 1""", false); + assertThat(subplans).hasSize(2); + CanonicalSubplan scanSubplan = subplans.get(0); + assertThat(scanSubplan.getAssignments()).containsExactly( + entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF)), + entry(NAME_ID, CacheExpression.ofProjection(NAME_REF)), + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF))); + CanonicalSubplan topNSubplan = subplans.get(1); + assertThat(topNSubplan.getChildSubplan().get()).isEqualTo(scanSubplan); + assertThat(topNSubplan.getKey()).isInstanceOf(TopNRankingKey.class); + TopNRankingKey key = (TopNRankingKey) topNSubplan.getKey(); + assertThat(key.partitionBy().stream().toList()).isEqualTo(ImmutableList.of(NAME_ID, NATIONKEY_ID)); + assertThat(key.orderings()).containsExactly(entry(REGIONKEY_ID, SortOrder.DESC_NULLS_LAST)); + assertThat(key.rankingType()).isEqualTo(ROW_NUMBER); + assertThat(key.maxRankingPerPartition()).isEqualTo(1); + } + + @Test + public void testProjectionWithLambdas() + { + List subplans = extractCanonicalSubplansForQuery("SELECT any_match(array[nationkey], x -> x > 5) FROM nation"); + assertThat(subplans).hasSize(1); + assertThat(subplans.get(0).getOriginalPlanNode()).isInstanceOf(TableScanNode.class); + } + + @Test + public void testFilterWithLambdas() + { + List subplans = extractCanonicalSubplansForQuery("SELECT nationkey FROM nation WHERE any_match(array[nationkey], x -> x > 5)"); + assertThat(subplans).hasSize(1); + assertThat(subplans.get(0).getOriginalPlanNode()).isInstanceOf(TableScanNode.class); + } + + @Test + public void testTopN() + { + List subplans = extractCanonicalSubplansForQuery("SELECT nationkey FROM nation ORDER BY name LIMIT 5"); + assertThat(subplans).hasSize(2); + CanonicalSubplan scanSubplan = subplans.get(0); + assertThat(scanSubplan.getAssignments()).containsExactly( + entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF)), + entry(NAME_ID, CacheExpression.ofProjection(NAME_REF))); + CanonicalSubplan topNSubplan = subplans.get(1); + assertThat(topNSubplan.getChildSubplan().get()).isEqualTo(scanSubplan); + assertThat(topNSubplan.getKey()).isInstanceOf(TopNKey.class); + } + + @Test + public void testTopNWithMultipleOrderByColumns() + { + List subplans = extractCanonicalSubplansForQuery("SELECT nationkey FROM nation ORDER BY regionkey, nationkey DESC offset 10 LIMIT 5"); + CanonicalSubplan scanSubplan = subplans.get(0); + assertThat(scanSubplan.getAssignments()).containsExactly( + entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF)), + entry(REGIONKEY_ID, CacheExpression.ofProjection(REGIONKEY_REF))); + CanonicalSubplan topNSubplan = subplans.get(1); + assertThat(topNSubplan.getKey()).isInstanceOf(TopNKey.class); + TopNKey key = (TopNKey) topNSubplan.getKey(); + assertThat(key.orderings()).containsExactly( + entry(REGIONKEY_ID, SortOrder.ASC_NULLS_LAST), + entry(NATIONKEY_ID, SortOrder.DESC_NULLS_LAST)); + assertThat(key.count()).isEqualTo(15); + } + + @Test + public void testTopNWithExpressionInOrderByColumn() + { + List subplans = extractCanonicalSubplansForQuery("SELECT nationkey FROM nation ORDER BY regionkey + 5 offset 10 LIMIT 5"); + CanonicalSubplan scanSubplan = subplans.get(0); + CacheColumnId regionKeyAdded5 = canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(REGIONKEY_REF, new Constant(BIGINT, 5L)))); + assertThat(scanSubplan.getAssignments()).containsExactly( + entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF)), + entry(regionKeyAdded5, CacheExpression.ofProjection(new Call(ADD_BIGINT, ImmutableList.of(REGIONKEY_REF, new Constant(BIGINT, 5L)))))); + CanonicalSubplan topNSubplan = subplans.get(1); + assertThat(topNSubplan.getKey()).isInstanceOf(TopNKey.class); + TopNKey key = (TopNKey) topNSubplan.getKey(); + assertThat(key.orderings()).containsExactly(entry(regionKeyAdded5, SortOrder.ASC_NULLS_LAST)); + assertThat(key.count()).isEqualTo(15); + } + + @Test + public void testNestedTopN() + { + List subplans = extractCanonicalSubplansForQuery("SELECT nationkey FROM (SELECT nationkey, name FROM nation ORDER BY nationkey limit 5) ORDER BY 1 limit 15"); + assertThat(subplans).hasSize(2); + CanonicalSubplan scanSubplan = subplans.get(0); + assertThat(scanSubplan.getAssignments()).containsExactly(entry(NATIONKEY_ID, CacheExpression.ofProjection(NATIONKEY_REF))); + CanonicalSubplan topNSubplan = subplans.get(1); + assertThat(topNSubplan.getKey()).isInstanceOf(TopNKey.class); + TopNKey key = (TopNKey) topNSubplan.getKey(); + assertThat(key.orderings()).containsExactly(entry(NATIONKEY_ID, SortOrder.ASC_NULLS_LAST)); + assertThat(key.count()).isEqualTo(5); + } + + @Test + public void testUnsupportedTopNWithGroupBy() + { + // unsupported final aggregation and exchanges between two topN + List subplans = extractCanonicalSubplansForQuery("SELECT max(nationkey) FROM nation GROUP BY name ORDER BY name LIMIT 1"); + assertThat(subplans).hasSize(2); + assertThat(subplans).noneMatch((subplan) -> subplan.getKey() instanceof TopNKey); + } + + @Test + public void testNondeterministicTopN() + { + List subplans = extractCanonicalSubplansForQuery("SELECT * FROM nation ORDER BY RANDOM() LIMIT 1"); + assertThat(subplans).hasSize(1); + assertThat(getOnlyElement(subplans).getKey()).isNotExactlyInstanceOf(TopNKey.class); + } + + @Test + public void testConjunctionOfNonDeterministicPredicateAndDynamicFilter() + { + List subplans = extractCanonicalSubplansForQuery(""" + SELECT n.comment + FROM + (SELECT * FROM nation WHERE random(regionkey) > 20) n + JOIN + (SELECT * FROM region WHERE random(regionkey) > 15) r + ON + n.regionkey = r.regionkey"""); + assertThat(subplans).hasSize(1); + CanonicalSubplan subplan = getOnlyElement(subplans); + assertThat(subplan.getDynamicConjuncts()).isEmpty(); + assertThat(subplan.getTableScan()).isPresent(); + TableScan scan = subplan.getTableScan().get(); + assertThat(scan.getTableId()).isEqualTo(new CacheTableId(tpchCatalogId + ":tiny:region:0.01")); + } + + private Optional> getGroupByExpressions(CanonicalSubplan subplan) + { + return subplan.getGroupByColumns() + .map(columns -> columns.stream() + .map(column -> requireNonNull(subplan.getAssignments().get(column), "No assignment for column: " + column)) + .map(cacheExpression -> cacheExpression.projection().orElseThrow()) + .collect(toImmutableList())); + } + + @Test + public void testExtractCanonicalScanAndProject() + { + ProjectNode projectNode = createScanAndProjectNode(); + List subplans = extractCanonicalSubplans( + PLANNER_CONTEXT, + TEST_SESSION, + projectNode); + assertThat(subplans).hasSize(1); + + CanonicalSubplan subplan = getOnlyElement(subplans); + assertThat(subplan.getOriginalPlanNode()).isEqualTo(projectNode); + assertThat(subplan.getOriginalSymbolMapping()).containsExactly( + entry(CACHE_COL1, new Symbol(BIGINT, "symbol1")), + entry(CACHE_COL2, new Symbol(BIGINT, "symbol2")), + entry(canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, new Constant(BIGINT, 1L)))), new Symbol(BIGINT, "projection1"))); + assertThat(subplan.getAssignments()).containsExactly( + entry(canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, new Constant(BIGINT, 1L)))), CacheExpression.ofProjection(new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, new Constant(BIGINT, 1L))))), + entry(CACHE_COL2, CacheExpression.ofProjection(CACHE_COL2_REF))); + + assertThat(subplan.getConjuncts()).isEmpty(); + assertThat(subplan.getDynamicConjuncts()).isEmpty(); + TableScan tableScan = subplan.getTableScan().orElseThrow(); + assertThat(tableScan.getColumnHandles()).containsExactly( + entry(CACHE_COL1, new TestingColumnHandle("column1")), + entry(CACHE_COL2, new TestingColumnHandle("column2"))); + assertThat(tableScan.getTableId()).isEqualTo(CATALOG_CACHE_TABLE_ID); + assertThat(tableScan.getTable()).isEqualTo(TEST_TABLE_HANDLE); + assertThat(subplan.getTableScanId()).isEqualTo(SCAN_NODE_ID); + } + + @Test + public void testExtractCanonicalFilterAndProject() + { + ProjectNode projectNode = createFilterAndProjectNode(); + List subplans = extractCanonicalSubplans( + PLANNER_CONTEXT, + TEST_SESSION, + projectNode); + assertThat(subplans).hasSize(1); + + CanonicalSubplan subplan = getOnlyElement(subplans); + assertThat(subplan.getOriginalPlanNode()).isEqualTo(projectNode); + assertThat(subplan.getOriginalSymbolMapping()).containsExactly( + entry(CACHE_COL1, new Symbol(BIGINT, "symbol1")), + entry(CACHE_COL2, new Symbol(BIGINT, "symbol2")), + entry(canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, new Constant(BIGINT, 1L)))), new Symbol(BIGINT, "projection1"))); + assertThat(subplan.getAssignments()).containsExactly( + entry(canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, new Constant(BIGINT, 1L)))), CacheExpression.ofProjection(new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, new Constant(BIGINT, 1L))))), + entry(CACHE_COL2, CacheExpression.ofProjection(CACHE_COL2_REF))); + + assertThat(subplan.getConjuncts()).hasSize(1); + Expression predicate = getOnlyElement(subplan.getConjuncts()); + assertThat(predicate).isEqualTo(new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, CACHE_COL2_REF)), new Constant(BIGINT, 0L))); + + assertThat(subplan.getDynamicConjuncts()).hasSize(1); + Expression dynamicFilterExpression = getOnlyElement(subplan.getDynamicConjuncts()); + assertThat(DynamicFilters.getDescriptor(dynamicFilterExpression)).contains( + new DynamicFilters.Descriptor(new DynamicFilterId("dynamic_filter_id"), CACHE_COL1_REF)); + + TableScan tableScan = subplan.getTableScan().orElseThrow(); + assertThat(tableScan.getColumnHandles()).containsExactly( + entry(CACHE_COL1, new TestingColumnHandle("column1")), + entry(CACHE_COL2, new TestingColumnHandle("column2"))); + assertThat(tableScan.getTableId()).isEqualTo(CATALOG_CACHE_TABLE_ID); + assertThat(tableScan.getTable()).isEqualTo(TEST_TABLE_HANDLE); + assertThat(subplan.getTableScanId()).isEqualTo(SCAN_NODE_ID); + } + + @Test + public void testExtractCanonicalFilter() + { + FilterNode filterNode = createFilterNode(); + List subplans = extractCanonicalSubplans( + PLANNER_CONTEXT, + TEST_SESSION, + filterNode); + assertThat(subplans).hasSize(1); + + CanonicalSubplan subplan = getOnlyElement(subplans); + assertThat(subplan.getOriginalPlanNode()).isEqualTo(filterNode); + assertThat(subplan.getOriginalSymbolMapping()).containsExactly( + entry(CACHE_COL1, new Symbol(BIGINT, "symbol1")), + entry(CACHE_COL2, new Symbol(BIGINT, "symbol2"))); + assertThat(subplan.getAssignments()).containsExactly( + entry(CACHE_COL1, CacheExpression.ofProjection(CACHE_COL1_REF)), + entry(CACHE_COL2, CacheExpression.ofProjection(CACHE_COL2_REF))); + + assertThat(subplan.getConjuncts()).hasSize(1); + Expression predicate = getOnlyElement(subplan.getConjuncts()); + assertThat(predicate).isEqualTo(new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(CACHE_COL1_REF, CACHE_COL2_REF)), new Constant(BIGINT, 0L))); + + assertThat(subplan.getDynamicConjuncts()).hasSize(1); + Expression dynamicFilterExpression = getOnlyElement(subplan.getDynamicConjuncts()); + assertThat(DynamicFilters.getDescriptor(dynamicFilterExpression)).contains( + new DynamicFilters.Descriptor(new DynamicFilterId("dynamic_filter_id"), CACHE_COL1_REF)); + + TableScan tableScan = subplan.getTableScan().orElseThrow(); + assertThat(tableScan.getColumnHandles()).containsExactly( + entry(CACHE_COL1, new TestingColumnHandle("column1")), + entry(CACHE_COL2, new TestingColumnHandle("column2"))); + assertThat(tableScan.getTableId()).isEqualTo(CATALOG_CACHE_TABLE_ID); + assertThat(tableScan.getTable()).isEqualTo(TEST_TABLE_HANDLE); + assertThat(subplan.getTableScanId()).isEqualTo(SCAN_NODE_ID); + } + + @Test + public void testExtractCanonicalTableScan() + { + // no cache id, therefore no canonical plan + TableScanNode tableScanNode = createTableScan(); + assertThat(extractCanonicalSubplans( + TestingPlannerContext.plannerContextBuilder() + .withMetadata(new MockMetadata()) + .withCacheMetadata(new TestCacheMetadata(Optional.empty(), handle -> Optional.of(new CacheColumnId(handle.getName())))) + .build(), + TEST_SESSION, + tableScanNode)) + .isEmpty(); + + // no column id, therefore no canonical plan + assertThat(extractCanonicalSubplans( + TestingPlannerContext.plannerContextBuilder() + .withMetadata(new MockMetadata()) + .withCacheMetadata(new TestCacheMetadata(Optional.of(CACHE_TABLE_ID), handle -> Optional.empty())) + .build(), + TEST_SESSION, + tableScanNode)) + .isEmpty(); + + List subplans = extractCanonicalSubplans( + PLANNER_CONTEXT, + TEST_SESSION, + tableScanNode); + assertThat(subplans).hasSize(1); + + CanonicalSubplan subplan = getOnlyElement(subplans); + assertThat(subplan.getOriginalPlanNode()).isEqualTo(tableScanNode); + assertThat(subplan.getOriginalSymbolMapping()).containsExactly( + entry(CACHE_COL1, new Symbol(BIGINT, "symbol1")), + entry(CACHE_COL2, new Symbol(BIGINT, "symbol2"))); + assertThat(subplan.getAssignments()).containsExactly( + entry(CACHE_COL1, CacheExpression.ofProjection(CACHE_COL1_REF)), + entry(CACHE_COL2, CacheExpression.ofProjection(CACHE_COL2_REF))); + assertThat(subplan.getConjuncts()).isEmpty(); + assertThat(subplan.getDynamicConjuncts()).isEmpty(); + + TableScan tableScan = subplan.getTableScan().orElseThrow(); + assertThat(tableScan.getColumnHandles()).containsExactly( + entry(CACHE_COL1, new TestingColumnHandle("column1")), + entry(CACHE_COL2, new TestingColumnHandle("column2"))); + assertThat(tableScan.getTableId()).isEqualTo(CATALOG_CACHE_TABLE_ID); + assertThat(tableScan.getTable()).isEqualTo(TEST_TABLE_HANDLE); + assertThat(subplan.getTableScanId()).isEqualTo(SCAN_NODE_ID); + } + + @Test + public void testProjectionWithDuplicatedExpressions() + { + assertThatCanonicalSubplanIsForTableScan(new ProjectNode( + new PlanNodeId("project_node"), + createTableScan(), + Assignments.of( + new Symbol(BIGINT, "alias1"), + new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "symbol1"), new Constant(BIGINT, 2L))), + new Symbol(BIGINT, "alias2"), + new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "symbol1"), new Constant(BIGINT, 2L)))))); + } + + @Test + public void testAliasingProjection() + { + assertThatCanonicalSubplanIsForTableScan(new ProjectNode( + new PlanNodeId("project_node"), + createTableScan(), + Assignments.of( + new Symbol(BIGINT, "alias"), + new Reference(BIGINT, "symbol1")))); + assertThatCanonicalSubplanIsForTableScan(new ProjectNode( + new PlanNodeId("project_node"), + createTableScan(), + Assignments.of( + new Symbol(BIGINT, "symbol1"), + new Reference(BIGINT, "symbol1"), + new Symbol(BIGINT, "alias"), + new Reference(BIGINT, "symbol1")))); + assertThatCanonicalSubplanIsForTableScan(new ProjectNode( + new PlanNodeId("project_node"), + createTableScan(), + Assignments.of( + new Symbol(BIGINT, "alias"), + new Reference(BIGINT, "symbol1"), + new Symbol(BIGINT, "symbol1"), + new Reference(BIGINT, "symbol1")))); + } + + private void assertThatCanonicalSubplanIsForTableScan(PlanNode root) + { + List subplans = extractCanonicalSubplans(PLANNER_CONTEXT, TEST_SESSION, root); + assertThat(subplans).hasSize(1); + assertThat(getOnlyElement(subplans).getOriginalPlanNode()).isInstanceOf(TableScanNode.class); + } + + @Test + public void testTableScanWithDuplicatedColumnHandle() + { + Symbol symbol1 = new Symbol(BIGINT, "symbol1"); + Symbol symbol2 = new Symbol(BIGINT, "symbol2"); + TestingColumnHandle columnHandle = new TestingColumnHandle("column1"); + TableScanNode tableScanNode = new TableScanNode( + SCAN_NODE_ID, + TEST_TABLE_HANDLE, + ImmutableList.of(symbol1, symbol2), + ImmutableMap.of(symbol2, columnHandle, symbol1, columnHandle), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + assertThat(extractCanonicalSubplans(PLANNER_CONTEXT, TEST_SESSION, tableScanNode)).isEmpty(); + } + + @Test + public void testTableHandlesCanonization() + { + TableHandle tableHandle1 = TestingHandles.createTestTableHandle(SchemaTableName.schemaTableName("schema", "table1")); + TableHandle tableHandle2 = TestingHandles.createTestTableHandle(SchemaTableName.schemaTableName("schema", "table2")); + + PlanNode root = planBuilder.union(ImmutableListMultimap.of(), ImmutableList.of( + planBuilder.tableScan(tableHandle1, ImmutableList.of(), ImmutableMap.of(), Optional.of(false)), + planBuilder.tableScan(tableHandle2, ImmutableList.of(), ImmutableMap.of(), Optional.of(false)))); + + // TableHandles will be turned into common canonical version + TableHandle canonicalTableHandle = TestingHandles.createTestTableHandle(SchemaTableName.schemaTableName("schema", "common")); + List canonicalTableScans = extractCanonicalSubplans( + TestingPlannerContext.plannerContextBuilder() + .withMetadata(new MockMetadata()) + .withCacheMetadata(new TestCacheMetadata( + handle -> Optional.of(new CacheColumnId(handle.getName())), + (tableHandle) -> canonicalTableHandle, + (tableHandle) -> Optional.of(new CacheTableId(tableHandle.connectorHandle().toString())))) + .build(), + TEST_SESSION, + root).stream() + .map(subplan -> subplan.getTableScan().orElseThrow()) + .collect(toImmutableList()); + List tableIds = canonicalTableScans.stream() + .map(TableScan::getTableId) + .collect(toImmutableList()); + CacheTableId schemaCommonId = new CacheTableId(CATALOG_ID + ":schema.common"); + assertThat(tableIds).isEqualTo(ImmutableList.of(schemaCommonId, schemaCommonId)); + assertThat(canonicalTableScans).allMatch(scan -> scan.getTable().equals(canonicalTableHandle)); + + // TableHandles will not be turned into common canonical version + tableIds = extractCanonicalSubplans( + TestingPlannerContext.plannerContextBuilder() + .withMetadata(new MockMetadata()) + .withCacheMetadata(new TestCacheMetadata( + handle -> Optional.of(new CacheColumnId(handle.getName())), + (tableHandle) -> { + TestingMetadata.TestingTableHandle handle = (TestingMetadata.TestingTableHandle) tableHandle.connectorHandle(); + if (handle.getTableName().getTableName().equals("table1")) { + return TestingHandles.createTestTableHandle(SchemaTableName.schemaTableName("schema", "common1")); + } + else { + return TestingHandles.createTestTableHandle(SchemaTableName.schemaTableName("schema", "common2")); + } + }, + (tableHandle) -> Optional.of(new CacheTableId(tableHandle.connectorHandle().toString())))) + .build(), + TEST_SESSION, + root).stream() + .map(subplan -> subplan.getTableScan().orElseThrow().getTableId()) + .collect(toImmutableList()); + assertThat(tableIds).isEqualTo(ImmutableList.of( + new CacheTableId(CATALOG_ID + ":schema.common1"), + new CacheTableId(CATALOG_ID + ":schema.common2"))); + } + + private List extractCanonicalSubplansForQuery(@Language("SQL") String query) + { + return extractCanonicalSubplansForQuery(query, true); + } + + private List extractCanonicalSubplansForQuery(@Language("SQL") String query, boolean forceSingleNode) + { + Plan plan = plan(query, OPTIMIZED_AND_VALIDATED, forceSingleNode); + PlanTester planTester = getPlanTester(); + return planTester.inTransaction(session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> planTester.getPlannerContext().getMetadata().getCatalogHandle(session, catalog)); + return extractCanonicalSubplans(getPlanTester().getPlannerContext(), session, plan.getRoot()); + }); + } + + private CanonicalAggregation sumNationkey() + { + return new CanonicalAggregation( + sumBigint(), + Optional.empty(), + List.of(NATIONKEY_REF)); + } + + private ResolvedFunction sumBigint() + { + return getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction("sum", TypeSignatureProvider.fromTypes(BIGINT)); + } + + private ProjectNode createScanAndProjectNode() + { + return new ProjectNode( + new PlanNodeId("project_node"), + createTableScan(), + Assignments.of( + new Symbol(BIGINT, "projection1"), + new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "symbol1"), new Constant(BIGINT, 1L))), + new Symbol(BIGINT, "symbol2"), + new Reference(BIGINT, "symbol2"))); + } + + private ProjectNode createFilterAndProjectNode() + { + return new ProjectNode( + new PlanNodeId("project_node"), + createFilterNode(), + Assignments.of( + new Symbol(BIGINT, "projection1"), + new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "symbol1"), new Constant(BIGINT, 1L))), + new Symbol(BIGINT, "symbol2"), + new Reference(BIGINT, "symbol2"))); + } + + private FilterNode createFilterNode() + { + MetadataManager metadataManager = createTestMetadataManager(); + return new FilterNode( + new PlanNodeId("filter_node"), + createTableScan(), + and( + new Comparison(GREATER_THAN, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "symbol1"), new Reference(BIGINT, "symbol2"))), new Constant(BIGINT, 0L)), + createDynamicFilterExpression( + metadataManager, + new DynamicFilterId("dynamic_filter_id"), + BIGINT, + new Reference(BIGINT, "symbol1")))); + } + + private TableScanNode createTableScan() + { + Symbol symbol1 = new Symbol(BIGINT, "symbol1"); + Symbol symbol2 = new Symbol(BIGINT, "symbol2"); + TestingColumnHandle handle1 = new TestingColumnHandle("column1"); + TestingColumnHandle handle2 = new TestingColumnHandle("column2"); + return new TableScanNode( + SCAN_NODE_ID, + TEST_TABLE_HANDLE, + ImmutableList.of(symbol1, symbol2), + ImmutableMap.of(symbol2, handle2, symbol1, handle1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + } + + private static class TestCacheMetadata + extends CacheMetadata + { + private final Function> tableHandleCacheTableIdMapper; + + private final Function> cacheColumnIdMapper; + private final Function canonicalizeTableHande; + + private TestCacheMetadata() + { + this(handle -> Optional.of(new CacheColumnId("cache_" + handle.getName())), Functions.identity(), (any) -> Optional.of(CACHE_TABLE_ID)); + } + + private TestCacheMetadata( + Optional cacheTableId, + Function> cacheColumnIdMapper) + { + this(cacheColumnIdMapper, Function.identity(), (any) -> cacheTableId); + } + + private TestCacheMetadata( + Function> cacheColumnIdMapper, + Function canonicalizeTableHande, + Function> tableHandleCacheTableIdMapper) + { + super(catalogHandle -> Optional.empty()); + this.cacheColumnIdMapper = cacheColumnIdMapper; + this.canonicalizeTableHande = canonicalizeTableHande; + this.tableHandleCacheTableIdMapper = tableHandleCacheTableIdMapper; + } + + @Override + public Optional getCacheTableId(Session session, TableHandle tableHandle) + { + return tableHandleCacheTableIdMapper.apply(tableHandle); + } + + @Override + public Optional getCacheColumnId(Session session, TableHandle tableHandle, ColumnHandle columnHandle) + { + return cacheColumnIdMapper.apply((TestingColumnHandle) columnHandle); + } + + @Override + public TableHandle getCanonicalTableHandle(Session session, TableHandle tableHandle) + { + return canonicalizeTableHande.apply(tableHandle); + } + } + + protected static class MockMetadata + extends AbstractMockMetadata + { + @Override + public TableProperties getTableProperties(Session session, TableHandle handle) + { + return new TableProperties( + handle.catalogHandle(), + handle.transaction(), + new ConnectorTableProperties( + TupleDomain.all(), + Optional.empty(), + Optional.empty(), ImmutableList.of())); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestCommonSubqueriesExtractor.java b/core/trino-main/src/test/java/io/trino/cache/TestCommonSubqueriesExtractor.java new file mode 100644 index 000000000000..74ed6b0c1ba5 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestCommonSubqueriesExtractor.java @@ -0,0 +1,2395 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.connector.MockConnectorColumnHandle; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorTableHandle; +import io.trino.cost.StatsAndCosts; +import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableHandle; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.plugin.tpch.TpchColumnHandle; +import io.trino.plugin.tpch.TpchConnectorFactory; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.connector.ConnectorTableProperties; +import io.trino.spi.connector.ConstraintApplicationResult; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.OperatorType; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.SortedRangeSet; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.DynamicFilters.Descriptor; +import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.ir.Booleans; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Logical; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Plan; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.planner.assertions.PlanAssert; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DynamicFilterId; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TopNNode; +import io.trino.sql.planner.plan.TopNRankingNode; +import io.trino.sql.planner.plan.UnionNode; +import io.trino.sql.tree.SortItem.NullOrdering; +import io.trino.sql.tree.SortItem.Ordering; +import io.trino.testing.PlanTester; +import io.trino.testing.TestingTransactionHandle; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; + +import java.util.AbstractMap.SimpleEntry; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.SystemSessionProperties.CACHE_AGGREGATIONS_ENABLED; +import static io.trino.SystemSessionProperties.CACHE_COMMON_SUBQUERIES_ENABLED; +import static io.trino.SystemSessionProperties.CACHE_PROJECTIONS_ENABLED; +import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalAggregationToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.canonicalExpressionToColumnId; +import static io.trino.cache.CanonicalSubplanExtractor.columnIdToSymbol; +import static io.trino.cache.CommonSubqueriesExtractor.aggregationKey; +import static io.trino.cache.CommonSubqueriesExtractor.combine; +import static io.trino.cache.CommonSubqueriesExtractor.filterProjectKey; +import static io.trino.cache.CommonSubqueriesExtractor.scanFilterProjectKey; +import static io.trino.cache.CommonSubqueriesExtractor.topNKey; +import static io.trino.cache.CommonSubqueriesExtractor.topNRankingKey; +import static io.trino.cost.StatsCalculator.noopStatsCalculator; +import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; +import static io.trino.metadata.FunctionManager.createTestingFunctionManager; +import static io.trino.spi.block.BlockTestUtils.assertBlockEquals; +import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; +import static io.trino.spi.connector.SortOrder.DESC_NULLS_LAST; +import static io.trino.spi.predicate.Range.greaterThan; +import static io.trino.spi.predicate.Range.lessThan; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; +import static io.trino.sql.DynamicFilters.extractDynamicFilters; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; +import static io.trino.sql.ir.ExpressionFormatter.formatExpression; +import static io.trino.sql.ir.IrUtils.and; +import static io.trino.sql.ir.IrUtils.extractDisjuncts; +import static io.trino.sql.ir.Logical.Operator.AND; +import static io.trino.sql.ir.Logical.Operator.OR; +import static io.trino.sql.planner.ExpressionExtractor.extractExpressions; +import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; +import static io.trino.sql.planner.SymbolsExtractor.extractOutputSymbols; +import static io.trino.sql.planner.SymbolsExtractor.extractUnique; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.globalAggregation; +import static io.trino.sql.planner.assertions.PlanMatchPattern.identityProject; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; +import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.trino.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.symbol; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.topN; +import static io.trino.sql.planner.assertions.PlanMatchPattern.topNRanking; +import static io.trino.sql.planner.plan.TopNRankingNode.RankingType.RANK; +import static io.trino.sql.planner.plan.TopNRankingNode.RankingType.ROW_NUMBER; +import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static java.util.Collections.emptyList; +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCommonSubqueriesExtractor + extends BasePlanTest +{ + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction MULTIPLY_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(BIGINT, BIGINT)); + + private static final CacheTableId CACHE_TABLE_ID = new CacheTableId("cache_table_id"); + private static final CacheColumnId REGIONKEY_ID = new CacheColumnId("[regionkey:bigint]"); + private static final CacheColumnId NATIONKEY_ID = new CacheColumnId("[nationkey:bigint]"); + private static final CacheColumnId NAME_ID = new CacheColumnId("[name:varchar(25)]"); + private static final String TEST_SCHEMA = "test_schema"; + private static final String TEST_TABLE = "test_table"; + private static final Session TEST_SESSION = testSessionBuilder() + .setCatalog(TEST_CATALOG_NAME) + .setSchema(TEST_SCHEMA) + .setSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, "true") + .build(); + private static final Session TPCH_SESSION = testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + // prevent CBO from interfering with tests + .setSystemProperty(JOIN_REORDERING_STRATEGY, "none") + .build(); + private static final MockConnectorColumnHandle HANDLE_1 = new MockConnectorColumnHandle("column1", BIGINT); + private static final MockConnectorColumnHandle HANDLE_2 = new MockConnectorColumnHandle("column2", BIGINT); + private static final TupleDomain CONSTRAINT_1 = TupleDomain.withColumnDomains(ImmutableMap.of( + HANDLE_1, + Domain.create(ValueSet.ofRanges( + Range.lessThan(BIGINT, 50L), + Range.greaterThan(BIGINT, 150L)), false))); + private static final TupleDomain CONSTRAINT_2 = TupleDomain.withColumnDomains(ImmutableMap.of( + HANDLE_1, + Domain.create(ValueSet.ofRanges( + Range.lessThan(BIGINT, 20L), + Range.greaterThan(BIGINT, 40L)), false))); + + private static final TupleDomain CONSTRAINT_3 = TupleDomain.withColumnDomains(ImmutableMap.of( + HANDLE_1, + Domain.create(ValueSet.ofRanges( + Range.lessThan(BIGINT, 30L), + Range.greaterThan(BIGINT, 70L)), false))); + private static final SchemaTableName TABLE_NAME = new SchemaTableName(TEST_SCHEMA, TEST_TABLE); + private static final Expression NATIONKEY_EXPRESSION = new Reference(BIGINT, "[nationkey:bigint]"); + + private TableHandle testTableHandle; + private String tpchCatalogId; + + @Override + protected PlanTester createPlanTester() + { + PlanTester planTester = PlanTester.create(TEST_SESSION); + planTester.createCatalog( + TEST_CATALOG_NAME, + MockConnectorFactory.builder() + .withGetColumns(handle -> ImmutableList.of( + new ColumnMetadata("column1", BIGINT), + new ColumnMetadata("column2", BIGINT))) + .withGetCacheTableId(handle -> Optional.of(CACHE_TABLE_ID)) + .withGetCanonicalTableHandle(Function.identity()) + .withGetCacheColumnId(handle -> { + MockConnectorColumnHandle column = (MockConnectorColumnHandle) handle; + return Optional.of(new CacheColumnId("cache_" + column.getName())); + }) + .withApplyFilter((session, tableHandle, constraint) -> { + // predicate is fully subsumed + if (constraint.getSummary().equals(CONSTRAINT_1)) { + return Optional.of(new ConstraintApplicationResult<>(new MockConnectorTableHandle(TABLE_NAME, CONSTRAINT_1, Optional.of(ImmutableList.of(HANDLE_1))), TupleDomain.all(), io.trino.spi.expression.Constant.TRUE, false)); + } + // predicate is rejected + else if (constraint.getSummary().equals(CONSTRAINT_2)) { + return Optional.of(new ConstraintApplicationResult<>(new MockConnectorTableHandle(TABLE_NAME, TupleDomain.all(), Optional.empty()), CONSTRAINT_2, io.trino.spi.expression.Constant.TRUE, false)); + } + // predicate is subsumed opportunistically + else if (constraint.getSummary().equals(CONSTRAINT_3)) { + return Optional.of(new ConstraintApplicationResult<>(new MockConnectorTableHandle(TABLE_NAME, CONSTRAINT_3, Optional.empty()), CONSTRAINT_3, io.trino.spi.expression.Constant.TRUE, false)); + } + return Optional.empty(); + }) + .withGetTableProperties((session, tableHandle) -> { + MockConnectorTableHandle handle = (MockConnectorTableHandle) tableHandle; + if (handle.getConstraint().equals(CONSTRAINT_2)) { + return new ConnectorTableProperties(TupleDomain.none(), Optional.empty(), Optional.empty(), emptyList()); + } + return new ConnectorTableProperties(handle.getConstraint(), Optional.empty(), Optional.empty(), emptyList()); + }) + .build(), + ImmutableMap.of()); + planTester.createCatalog(TPCH_SESSION.getCatalog().get(), + new TpchConnectorFactory(1), + ImmutableMap.of()); + testTableHandle = new TableHandle( + planTester.getCatalogHandle(TEST_CATALOG_NAME), + new MockConnectorTableHandle(TABLE_NAME), + TestingTransactionHandle.create()); + tpchCatalogId = planTester.getCatalogHandle(TPCH_SESSION.getCatalog().get()).getId(); + return planTester; + } + + @Test + public void testCommonDynamicFilters() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT nationkey FROM + ((SELECT nationkey, regionkey FROM nation n JOIN (SELECT * FROM (VALUES 0, 1) t(a)) t ON n.nationkey = t.a) + UNION ALL + (SELECT nationkey, regionkey FROM nation n JOIN (SELECT * FROM (VALUES 0, 1) t(a)) t ON n.regionkey = t.a)) l(nationkey, regionkey) + JOIN (SELECT * FROM (VALUES 0, 1, 2) t(a)) t ON l.nationkey = t.a"""); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(FilterNode.class)); + + CommonPlanAdaptation projectionA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation projectionB = Iterables.get(planAdaptations.values(), 1); + + // extract dynamic filter ids + List dynamicFilterIds = PlanNodeSearcher.searchFrom(commonSubqueries.plan()) + .whereIsInstanceOfAny(JoinNode.class) + .findAll().stream() + .map(JoinNode.class::cast) + .flatMap(join -> join.getDynamicFilters().keySet().stream()) + .collect(toImmutableList()); + DynamicFilterId topId = dynamicFilterIds.get(0); + DynamicFilterId leftId = dynamicFilterIds.get(1); + DynamicFilterId rightId = dynamicFilterIds.get(2); + + List symbols = commonSubqueries.planAdaptations.values().stream() + .map(subplan -> subplan.getCommonSubplanFilteredTableScan().tableScanNode()) + .flatMap(scan -> scan.getOutputSymbols().stream()) + .map(symbol -> formatExpression(symbol.toSymbolReference())) + .collect(toImmutableList()); + String leftNationkey = symbols.get(0); + String leftRegionkey = symbols.get(1); + String rightNationkey = symbols.get(2); + String rightRegionkey = symbols.get(3); + + // assert that common subplan have dynamic filter preserved in both FilterNode and FilteredTableScan + assertThat(extractExpressions(projectionA.getCommonSubplan()).stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .collect(toImmutableList())) + .containsExactly( + new Descriptor(topId, new Reference(BIGINT, leftNationkey)), + new Descriptor(leftId, new Reference(BIGINT, leftNationkey))); + assertThat(projectionA.getCommonSubplanFilteredTableScan().filterPredicate().stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .collect(toImmutableList())) + .containsExactly( + new Descriptor(topId, new Reference(BIGINT, leftNationkey)), + new Descriptor(leftId, new Reference(BIGINT, leftNationkey))); + + assertThat(extractExpressions(projectionB.getCommonSubplan()).stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .collect(toImmutableList())) + .containsExactly( + new Descriptor(topId, new Reference(BIGINT, rightNationkey)), + new Descriptor(rightId, new Reference(BIGINT, rightRegionkey))); + assertThat(projectionB.getCommonSubplanFilteredTableScan().filterPredicate().stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .collect(toImmutableList())) + .containsExactly( + new Descriptor(topId, new Reference(BIGINT, rightNationkey)), + new Descriptor(rightId, new Reference(BIGINT, rightRegionkey))); + + // assert that common dynamic filter is extracted for both subplans + assertThat(extractDisjuncts(projectionA.getCommonDynamicFilterDisjuncts()).stream() + .map(expression -> extractDynamicFilters(expression).getDynamicConjuncts())) + .containsExactly( + ImmutableList.of(new Descriptor(topId, new Reference(BIGINT, leftNationkey)), new Descriptor(leftId, new Reference(BIGINT, leftNationkey))), + ImmutableList.of(new Descriptor(topId, new Reference(BIGINT, leftNationkey)), new Descriptor(rightId, new Reference(BIGINT, leftRegionkey)))); + + assertThat(extractDisjuncts(projectionB.getCommonDynamicFilterDisjuncts()).stream() + .map(expression -> extractDynamicFilters(expression).getDynamicConjuncts())) + .containsExactly( + ImmutableList.of(new Descriptor(topId, new Reference(BIGINT, rightNationkey)), new Descriptor(leftId, new Reference(BIGINT, rightNationkey))), + ImmutableList.of(new Descriptor(topId, new Reference(BIGINT, rightNationkey)), new Descriptor(rightId, new Reference(BIGINT, rightRegionkey)))); + + // verify DF mappings for common dynamic filter + TpchColumnHandle nationkeyHandle = new TpchColumnHandle("nationkey", BIGINT); + TpchColumnHandle regionkeyHandle = new TpchColumnHandle("regionkey", BIGINT); + assertThat(projectionA.getCommonColumnHandles()) + .containsExactly(new SimpleEntry<>(NATIONKEY_ID, nationkeyHandle), new SimpleEntry<>(REGIONKEY_ID, regionkeyHandle)); + assertThat(projectionA.getCommonColumnHandles()).isEqualTo(projectionB.getCommonColumnHandles()); + } + + @Test + public void testUnsafeProjections() + { + // safe projections are matched when predicate is different + assertAdaptationCount(2, """ + SELECT nationkey FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey FROM nation WHERE regionkey > 10 + """); + assertAdaptationCount(2, """ + SELECT nationkey FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey FROM nation + """); + // different unsafe projections are not matched when predicate is different + assertAdaptationCount(0, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey * 2 FROM nation WHERE regionkey > 10 + """); + assertAdaptationCount(0, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey * 2 FROM nation + """); + // same unsafe projections are not matched when predicate is different + assertAdaptationCount(0, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT nationkey * 2 FROM nation WHERE regionkey > 10 + """); + // common subquery for different unsafe projections with same predicate is extracted + assertAdaptationCount(2, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT nationkey * 2 FROM nation WHERE regionkey > 10 + UNION ALL + SELECT regionkey * 2 FROM nation WHERE regionkey > 10 + """); + // different unsafe projections with same predicate are matched + assertAdaptationCount(2, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey * 2 FROM nation WHERE regionkey > 20 + """); + // unsafe projections are not matched with safe projections even if predicates are same + assertAdaptationCount(0, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey FROM nation WHERE regionkey > 20 + """); + // common subquery for safe projections with different predicates is extracted + assertAdaptationCount(2, """ + SELECT nationkey * 2 FROM nation WHERE regionkey > 20 + UNION ALL + SELECT regionkey FROM nation WHERE regionkey > 20 + UNION ALL + SELECT nationkey FROM nation WHERE regionkey > 10 + """); + // unsafe projections are matched with safe projections if there is no predicate + assertAdaptationCount(2, """ + SELECT nationkey * 2 FROM nation + UNION ALL + SELECT regionkey FROM nation + """); + } + + private void assertAdaptationCount(int size, @Language("SQL") String query) + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(query); + assertThat(commonSubqueries.planAdaptations()).hasSize(size); + } + + @Test + public void testCacheWithExcludingEnforcedConstraints() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol subqueryColumn = symbolAllocator.newSymbol("subquery_column", BIGINT); + + BiFunction, PlanNode> getPlan = (Long tableScanId, TupleDomain enforcedConstraint) -> { + PlanNode scanA = new TableScanNode( + new PlanNodeId(tableScanId.toString()), + new TableHandle( + getPlanTester().getCatalogHandle(TEST_CATALOG_NAME), + new MockConnectorTableHandle(TABLE_NAME, enforcedConstraint, Optional.empty()), + TestingTransactionHandle.create()), + ImmutableList.of(subqueryColumn), + ImmutableMap.of(subqueryColumn, HANDLE_1), + enforcedConstraint, + Optional.empty(), + false, + Optional.of(false)); + return new FilterNode( + new PlanNodeId("filter" + tableScanId), + scanA, + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_column"), new Constant(BIGINT, tableScanId))); + }; + PlanNode filter1 = getPlan.apply(1L, TupleDomain.all()); + PlanNode filter2 = getPlan.apply(2L, TupleDomain.all()); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Map planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(filter1, filter2), + ImmutableListMultimap.of(), + ImmutableList.of())); + + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(filter1); + assertThat(planAdaptations).containsKey(filter2); + + // with enforced constraint intersection + filter1 = getPlan.apply(1L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.multipleValues(BIGINT, ImmutableList.of(1L, 2L, 3L))))); + filter2 = getPlan.apply(2L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.multipleValues(BIGINT, ImmutableList.of(3L, 4L, 5L))))); + planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(filter1, filter2), + ImmutableListMultimap.of(), + ImmutableList.of())); + + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(filter1); + assertThat(planAdaptations).containsKey(filter2); + + // no common subplans with excluding enforced constraints + planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of( + getPlan.apply(1L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.singleValue(BIGINT, 1L)))), + getPlan.apply(2L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.singleValue(BIGINT, 2L))))), + ImmutableListMultimap.of(), + ImmutableList.of())); + assertThat(planAdaptations).hasSize(0); + + // 2 groups with intersecting enforced constraint + filter1 = getPlan.apply(1L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.multipleValues(BIGINT, ImmutableList.of(1L, 2L, 3L))))); + filter2 = getPlan.apply(2L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.multipleValues(BIGINT, ImmutableList.of(3L, 4L, 5L))))); + PlanNode filter3 = getPlan.apply(3L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.multipleValues(BIGINT, ImmutableList.of(6L, 7L, 8L))))); + PlanNode filter4 = getPlan.apply(4L, TupleDomain.withColumnDomains(ImmutableMap.of(HANDLE_1, Domain.multipleValues(BIGINT, ImmutableList.of(8L, 9L, 0L))))); + planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(filter1, filter4, filter3, filter2), + ImmutableListMultimap.of(), + ImmutableList.of())); + assertThat(planAdaptations).hasSize(4); + assertThat(planAdaptations).containsKeys(filter1, filter2, filter3, filter4); + PlanMatchPattern commonSubplan = filter( + new Logical(OR, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_column"), new Constant(BIGINT, 1L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_column"), new Constant(BIGINT, 2L)))), + tableScan(TABLE_NAME.getTableName(), ImmutableMap.of("subquery_column", "column1"))); + assertPlan(planAdaptations.get(filter1).getCommonSubplan(), commonSubplan); + assertPlan(planAdaptations.get(filter2).getCommonSubplan(), commonSubplan); + commonSubplan = filter( + new Logical(OR, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_column"), new Constant(BIGINT, 4L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_column"), new Constant(BIGINT, 3L)))), + tableScan(TABLE_NAME.getTableName(), ImmutableMap.of("subquery_column", "column1"))); + assertPlan(planAdaptations.get(filter3).getCommonSubplan(), commonSubplan); + assertPlan(planAdaptations.get(filter4).getCommonSubplan(), commonSubplan); + } + + @Test + public void testCacheTopNRankingRank() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT name, regionkey FROM nation WHERE nationkey > 10 ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES + """, + false, true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(1); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(TopNRankingNode.class)); + CommonPlanAdaptation topNRanking = planAdaptations.values().stream().findFirst().get(); + + PlanMatchPattern commonSubplan = topNRanking(pattern -> pattern.specification( + ImmutableList.of(), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", ASC_NULLS_LAST)) + .rankingType(RANK) + .maxRankingPerPartition(6) + .partial(true), + strictProject(ImmutableMap.of( + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey"))))); + assertTpchPlan(topNRanking.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + // validate no adaptation is required + assertThat(topNRanking.adaptCommonSubplan(topNRanking.getCommonSubplan(), idAllocator)).isEqualTo(topNRanking.getCommonSubplan()); + + List cacheColumnIds = ImmutableList.of(NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(createVarcharType(25), BIGINT); + assertThat(topNRanking.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNRankingKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableList.of(), + ImmutableMap.of(REGIONKEY_ID, ASC_NULLS_LAST), + RANK, 6), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + NATIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false))))); + } + + @Test + public void testCacheTopNRankingRankWithPullableConjuncts() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + (SELECT name, regionkey FROM nation WHERE nationkey > 10 ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES) + UNION ALL + (SELECT name, regionkey FROM nation WHERE nationkey > 10 ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES) + """, + false, true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(TopNRankingNode.class)); + CommonPlanAdaptation topNA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation topNB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = topNRanking(pattern -> pattern.specification( + ImmutableList.of(), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", ASC_NULLS_LAST)) + .rankingType(RANK) + .maxRankingPerPartition(6) + .partial(true), + strictProject(ImmutableMap.of( + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey"))))); + assertTpchPlan(topNA.getCommonSubplan(), commonSubplan); + assertTpchPlan(topNB.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + // validate no adaptation is required + assertThat(topNA.adaptCommonSubplan(topNA.getCommonSubplan(), idAllocator)).isEqualTo(topNA.getCommonSubplan()); + assertThat(topNB.adaptCommonSubplan(topNB.getCommonSubplan(), idAllocator)).isEqualTo(topNB.getCommonSubplan()); + + List cacheColumnIds = ImmutableList.of(NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(createVarcharType(25), BIGINT); + assertThat(topNA.getCommonSubplanSignature()).isEqualTo(topNB.getCommonSubplanSignature()); + assertThat(topNA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNRankingKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableList.of(), + ImmutableMap.of(REGIONKEY_ID, ASC_NULLS_LAST), + RANK, 6), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + NATIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false))))); + } + + @Test + public void testCacheTopNRankingRankWithNonPullableConjuncts() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + (SELECT name, regionkey FROM nation WHERE nationkey > 10 ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES) + UNION ALL + (SELECT name, regionkey FROM nation WHERE nationkey > 11 ORDER BY regionkey FETCH FIRST 6 ROWS WITH TIES) + """, + false, true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(TopNRankingNode.class)); + CommonPlanAdaptation topNA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation topNB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplanA = topNRanking(pattern -> pattern.specification( + ImmutableList.of(), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", ASC_NULLS_LAST)) + .rankingType(RANK) + .maxRankingPerPartition(6) + .partial(true), + strictProject(ImmutableMap.of( + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey"))))); + assertTpchPlan(topNA.getCommonSubplan(), commonSubplanA); + PlanMatchPattern commonSubplanB = topNRanking(pattern -> pattern.specification( + ImmutableList.of(), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", ASC_NULLS_LAST)) + .rankingType(RANK) + .maxRankingPerPartition(6) + .partial(true), + strictProject(ImmutableMap.of( + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY"))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 11L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey"))))); + assertTpchPlan(topNB.getCommonSubplan(), commonSubplanB); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + // validate no adaptation is required + assertThat(topNA.adaptCommonSubplan(topNA.getCommonSubplan(), idAllocator)).isEqualTo(topNA.getCommonSubplan()); + assertThat(topNB.adaptCommonSubplan(topNB.getCommonSubplan(), idAllocator)).isEqualTo(topNB.getCommonSubplan()); + + List cacheColumnIds = ImmutableList.of(NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(createVarcharType(25), BIGINT); + assertThat(topNA.getCommonSubplanSignature()).isNotEqualTo(topNB.getCommonSubplanSignature()); + assertThat(topNA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNRankingKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableList.of(), + ImmutableMap.of(REGIONKEY_ID, ASC_NULLS_LAST), + RANK, 6), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + NATIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false))))); + } + + @Test + public void testCacheTopNRankingRow() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT * + FROM (SELECT nationkey, ROW_NUMBER () OVER (PARTITION BY name, nationkey ORDER BY regionkey DESC) update_rank FROM nation) AS t + WHERE t.update_rank = 1""", + false, true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(1); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(TopNRankingNode.class)); + CommonPlanAdaptation topNRanking = planAdaptations.values().stream().findFirst().get(); + PlanMatchPattern commonSubplan = topNRanking(pattern -> pattern.specification( + ImmutableList.of("NAME", "NATIONKEY"), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", DESC_NULLS_LAST)) + .rankingType(ROW_NUMBER) + .maxRankingPerPartition(1) + .partial(true), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey"))); + assertTpchPlan(topNRanking.getCommonSubplan(), commonSubplan); + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + + // validate no adaptation is required + assertThat(topNRanking.adaptCommonSubplan(topNRanking.getCommonSubplan(), idAllocator)).isEqualTo(topNRanking.getCommonSubplan()); + + List cacheColumnIds = ImmutableList.of(NATIONKEY_ID, NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(BIGINT, createVarcharType(25), BIGINT); + assertThat(topNRanking.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNRankingKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableList.of(NAME_ID, NATIONKEY_ID), + ImmutableMap.of(REGIONKEY_ID, DESC_NULLS_LAST), + ROW_NUMBER, 1), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testCacheTopNRankingRowUnionWithSwappedPartitionBy() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + (SELECT * + FROM (SELECT nationkey, ROW_NUMBER () OVER (PARTITION BY nationkey, name ORDER BY regionkey DESC) update_rank + FROM nation WHERE regionkey < 10) AS t + WHERE t.update_rank = 1) + UNION ALL (SELECT * + FROM (SELECT nationkey, ROW_NUMBER () OVER (PARTITION BY name, nationkey ORDER BY regionkey DESC) update_rank + FROM nation WHERE regionkey < 10) AS t + WHERE t.update_rank = 1)""", + false, true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(TopNRankingNode.class)); + + CommonPlanAdaptation topNA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation topNB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = topNRanking(pattern -> pattern.specification( + ImmutableList.of("NAME", "NATIONKEY"), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", DESC_NULLS_LAST)) + .rankingType(ROW_NUMBER) + .maxRankingPerPartition(1) + .partial(true), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey")))); + assertTpchPlan(topNA.getCommonSubplan(), commonSubplan); + assertTpchPlan(topNB.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(topNA.adaptCommonSubplan(topNA.getCommonSubplan(), idAllocator), commonSubplan); + assertTpchPlan(topNB.adaptCommonSubplan(topNB.getCommonSubplan(), idAllocator), commonSubplan); + List cacheColumnIds = ImmutableList.of(NATIONKEY_ID, NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(BIGINT, createVarcharType(25), BIGINT); + assertThat(topNA.getCommonSubplanSignature()).isEqualTo(topNB.getCommonSubplanSignature()); + assertThat(topNA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNRankingKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableList.of(NAME_ID, NATIONKEY_ID), + ImmutableMap.of(REGIONKEY_ID, DESC_NULLS_LAST), + ROW_NUMBER, 1), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + REGIONKEY_ID, Domain.create(ValueSet.ofRanges(lessThan(BIGINT, 10L)), false))))); + } + + @Test + public void testTopNRankingRowWithWithNonPullableConjuncts() + { + @Language("SQL") String query = """ + (SELECT * + FROM (SELECT nationkey, ROW_NUMBER () OVER (PARTITION BY nationkey ORDER BY regionkey DESC) update_rank + FROM nation WHERE regionkey < 11) AS t + WHERE t.update_rank = 1) + UNION ALL (SELECT * + FROM (SELECT nationkey, ROW_NUMBER () OVER (PARTITION BY nationkey ORDER BY regionkey DESC) update_rank + FROM nation WHERE regionkey < 10) AS t + WHERE t.update_rank = 1)"""; + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(query, true, true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + // common subplans has higher priority than aggregations + assertThat(planAdaptations).noneSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(TopNRankingNode.class)); + commonSubqueries = extractTpchCommonSubqueries(query, false, true, false, false); + planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(TopNRankingNode.class)); + CommonPlanAdaptation topNRankingA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation topNRankingB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplanA = topNRanking(pattern -> pattern.specification( + ImmutableList.of("NATIONKEY"), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", DESC_NULLS_LAST)) + .rankingType(ROW_NUMBER) + .maxRankingPerPartition(1) + .partial(true), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey")))); + PlanMatchPattern commonSubplanB = topNRanking(pattern -> pattern.specification( + ImmutableList.of("NATIONKEY"), + ImmutableList.of("REGIONKEY"), + ImmutableMap.of("REGIONKEY", DESC_NULLS_LAST)) + .rankingType(ROW_NUMBER) + .maxRankingPerPartition(1) + .partial(true), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 11L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey")))); + assertTpchPlan(topNRankingB.getCommonSubplan(), commonSubplanA); + assertTpchPlan(topNRankingA.getCommonSubplan(), commonSubplanB); + assertThat(topNRankingA.getCommonSubplanSignature()).isNotEqualTo(topNRankingB.getCommonSubplanSignature()); + } + + @Test + public void testCacheTopN() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT nationkey FROM nation + WHERE regionkey > 10 and nationkey > 2 + ORDER BY name ASC, regionkey DESC OFFSET 5 LIMIT 5""", + false, true, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(1); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(TopNNode.class)); + CommonPlanAdaptation topN = planAdaptations.values().stream().findFirst().get(); + PlanMatchPattern commonSubplan = topN( + 10, + ImmutableList.of(sort("NAME", Ordering.ASCENDING, NullOrdering.LAST), + sort("REGIONKEY", Ordering.DESCENDING, NullOrdering.LAST)), + TopNNode.Step.PARTIAL, + filter( + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L)))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey")))); + assertTpchPlan(topN.getCommonSubplan(), commonSubplan); + + // validate no adaptation is required + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertThat(topN.adaptCommonSubplan(topN.getCommonSubplan(), idAllocator)).isEqualTo(topN.getCommonSubplan()); + + List cacheColumnIds = ImmutableList.of(NATIONKEY_ID, NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(BIGINT, createVarcharType(25), BIGINT); + assertThat(topN.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableMap.of(NAME_ID, ASC_NULLS_LAST, REGIONKEY_ID, DESC_NULLS_LAST), 10), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + REGIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false), + NATIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 2L)), false))))); + } + + @Test + public void testTopNWithNonPullableConjuncts() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + (SELECT nationkey FROM nation WHERE regionkey > 1 ORDER BY name ASC OFFSET 5 LIMIT 5) + UNION ALL + (SELECT regionkey FROM nation WHERE regionkey > 2 ORDER BY name ASC OFFSET 5 LIMIT 5)""", + true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).noneSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(TopNNode.class)); + } + + @Test + public void testMultipleCacheTopN() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + (SELECT nationkey FROM nation WHERE regionkey < 10 ORDER BY name ASC OFFSET 5 LIMIT 5) + UNION ALL + (SELECT regionkey FROM nation WHERE regionkey < 10 ORDER BY name ASC OFFSET 5 LIMIT 5)""", + true, false, false); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(TopNNode.class)); + + CommonPlanAdaptation topNA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation topNB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = topN( + 10, + ImmutableList.of(sort("NAME", Ordering.ASCENDING, NullOrdering.LAST)), + TopNNode.Step.PARTIAL, + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey")))); + assertTpchPlan(topNA.getCommonSubplan(), commonSubplan); + assertTpchPlan(topNB.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(topNA.adaptCommonSubplan(topNA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "NATIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "NATIONKEY")), + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME"))), + commonSubplan)); + assertTpchPlan(topNB.adaptCommonSubplan(topNB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY")), + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME"))), + commonSubplan)); + assertThat(topNA.getCommonSubplanSignature()).isEqualTo(topNB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(NATIONKEY_ID, NAME_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(BIGINT, createVarcharType(25), BIGINT); + assertThat(topNB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + topNKey( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + ImmutableMap.of(NAME_ID, ASC_NULLS_LAST), 10), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + REGIONKEY_ID, Domain.create(ValueSet.ofRanges(lessThan(BIGINT, 10L)), false))))); + } + + @Test + public void testCacheSingleAggregation() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey) FROM nation + WHERE regionkey > 10 + GROUP BY name""", + true, true, true); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(1); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(AggregationNode.class)); + + CommonPlanAdaptation aggregation = Iterables.get(planAdaptations.values(), 0); + PlanMatchPattern commonSubplan = aggregation( + singleGroupingSet("NAME"), + ImmutableMap.of( + Optional.of("SUM"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY")))), + Optional.empty(), + AggregationNode.Step.PARTIAL, + identityProject( + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey"))))); + + // validate common subplan + assertTpchPlan(aggregation.getCommonSubplan(), commonSubplan); + + // validate no adaptation is required + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertThat(aggregation.adaptCommonSubplan(aggregation.getCommonSubplan(), idAllocator)).isEqualTo(aggregation.getCommonSubplan()); + + // validate signature + CanonicalAggregation sum = canonicalAggregation("sum", NATIONKEY_EXPRESSION); + List cacheColumnIds = ImmutableList.of(NAME_ID, canonicalAggregationToColumnId(sum)); + List cacheColumnsTypes = ImmutableList.of(createVarcharType(25), BIGINT); + assertThat(aggregation.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of(NAME_ID)), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + REGIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false))))); + } + + @Test + public void testCacheSingleProjection() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey) FROM nation + WHERE regionkey > 10 + GROUP BY name""", + true, false, true); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(1); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> assertThat(node).isInstanceOf(ProjectNode.class)); + + CommonPlanAdaptation projection = Iterables.get(planAdaptations.values(), 0); + PlanMatchPattern commonSubplan = + identityProject( + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey")))); + + // validate common subplan + assertTpchPlan(projection.getCommonSubplan(), commonSubplan); + + // validate no adaptation is required + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertThat(projection.adaptCommonSubplan(projection.getCommonSubplan(), idAllocator)).isEqualTo(projection.getCommonSubplan()); + + // validate signature + List cacheColumnIds = ImmutableList.of(NATIONKEY_ID, NAME_ID); + List cacheColumnsTypes = ImmutableList.of(BIGINT, createVarcharType(25)); + assertThat(projection.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + REGIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false))))); + } + + @Test + public void testSimpleAggregation() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey) FROM nation + UNION ALL + SELECT sum(nationkey) FROM nation"""); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(AggregationNode.class)); + + CommonPlanAdaptation aggregationA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation aggregationB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = aggregation( + globalAggregation(), + ImmutableMap.of(Optional.of("SUM"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY")))), + Optional.empty(), + AggregationNode.Step.PARTIAL, + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey"))); + + assertTpchPlan(aggregationA.getCommonSubplan(), commonSubplan); + assertTpchPlan(aggregationB.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertThat(aggregationA.adaptCommonSubplan(aggregationA.getCommonSubplan(), idAllocator)).isEqualTo(aggregationA.getCommonSubplan()); + assertThat(aggregationB.adaptCommonSubplan(aggregationB.getCommonSubplan(), idAllocator)).isEqualTo(aggregationB.getCommonSubplan()); + + // make sure plan signatures are same + CanonicalAggregation sum = canonicalAggregation("sum", NATIONKEY_EXPRESSION); + assertThat(aggregationA.getCommonSubplanSignature()).isEqualTo(aggregationB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(canonicalAggregationToColumnId(sum)); + List cacheColumnsTypes = ImmutableList.of(BIGINT); + assertThat(aggregationB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of()), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testGlobalAggregation() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT * FROM + (SELECT sum(nationkey), max(regionkey) FILTER(WHERE nationkey > 10) FROM nation) + CROSS JOIN + (SELECT + sum(nationkey), + avg(nationkey * 2) FILTER(WHERE nationkey > 10) + FROM nation)"""); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(AggregationNode.class)); + + CommonPlanAdaptation aggregationA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation aggregationB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = aggregation( + globalAggregation(), + ImmutableMap.of( + Optional.of("MAX_FILTERED"), aggregationFunction("max", false, ImmutableList.of(symbol("REGIONKEY"))), + Optional.of("SUM"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY"))), + Optional.of("AVG_FILTERED"), aggregationFunction("avg", false, ImmutableList.of(symbol("MULTIPLICATION")))), + ImmutableList.of(), + ImmutableList.of("MASK"), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project(ImmutableMap.of( + "MULTIPLICATION", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L)))), + "MASK", PlanMatchPattern.expression(new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 10L)))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey")))); + + assertTpchPlan(aggregationA.getCommonSubplan(), commonSubplan); + assertTpchPlan(aggregationB.getCommonSubplan(), commonSubplan); + assertAggregationsWithMasks(aggregationA.getCommonSubplan(), 1, 2); + assertAggregationsWithMasks(aggregationB.getCommonSubplan(), 1, 2); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(aggregationA.adaptCommonSubplan(aggregationA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "SUM", PlanMatchPattern.expression(new Reference(BIGINT, "SUM")), + "MAX_FILTERED", PlanMatchPattern.expression(new Reference(BIGINT, "MAX_FILTERED"))), + commonSubplan)); + assertTpchPlan(aggregationB.adaptCommonSubplan(aggregationB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "SUM", PlanMatchPattern.expression(new Reference(BIGINT, "SUM")), + "AVG_FILTERED", PlanMatchPattern.expression(new Reference(DOUBLE, "AVG_FILTERED"))), + commonSubplan)); + + // make sure plan signatures are same + CacheColumnId nationKeyGreaterThan10 = canonicalExpressionToColumnId(new Comparison(GREATER_THAN, new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 10L))); + CacheColumnId nationKeyMultiplyBy2 = canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))); + CanonicalAggregation max = canonicalAggregation( + "max", + Optional.of(columnIdToSymbol(nationKeyGreaterThan10, BOOLEAN)), + new Reference(BIGINT, "[regionkey:bigint]")); + CanonicalAggregation sum = canonicalAggregation("sum", NATIONKEY_EXPRESSION); + CanonicalAggregation avg = canonicalAggregation( + "avg", + Optional.of(columnIdToSymbol(nationKeyGreaterThan10, BOOLEAN)), + columnIdToSymbol(nationKeyMultiplyBy2, BIGINT).toSymbolReference()); + assertThat(aggregationA.getCommonSubplanSignature()).isEqualTo(aggregationB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(canonicalAggregationToColumnId(sum), canonicalAggregationToColumnId(max), canonicalAggregationToColumnId(avg)); + List cacheColumnsTypes = ImmutableList.of( + BIGINT, + BIGINT, + RowType.from(List.of(RowType.field(DOUBLE), RowType.field(BIGINT)))); + //columnTypes=[bigint, row(bigint, bigint), row(double, bigint)], + assertThat(aggregationB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of()), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testBigintGroupByColumnAggregation() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey) FROM nation GROUP BY regionkey * 2 + UNION ALL + SELECT sum(nationkey) FROM nation GROUP BY regionkey * 2 + UNION ALL + SELECT nationkey FROM nation"""); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(AggregationNode.class)); + + CommonPlanAdaptation aggregationA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation aggregationB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = aggregation( + singleGroupingSet("MULTIPLICATION"), + ImmutableMap.of(Optional.of("SUM"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY")))), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project(ImmutableMap.of( + "MULTIPLICATION", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 2L))))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey")))); + + assertTpchPlan(aggregationA.getCommonSubplan(), commonSubplan); + assertTpchPlan(aggregationB.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertThat(aggregationA.adaptCommonSubplan(aggregationA.getCommonSubplan(), idAllocator)).isEqualTo(aggregationA.getCommonSubplan()); + assertThat(aggregationB.adaptCommonSubplan(aggregationB.getCommonSubplan(), idAllocator)).isEqualTo(aggregationB.getCommonSubplan()); + + // make sure plan signatures are same + CacheColumnId groupByColumn = canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[regionkey:bigint]"), new Constant(BIGINT, 2L)))); + CanonicalAggregation sum = canonicalAggregation("sum", NATIONKEY_EXPRESSION); + assertThat(aggregationA.getCommonSubplanSignature()).isEqualTo(aggregationB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(groupByColumn, canonicalAggregationToColumnId(sum)); + List cacheColumnsTypes = ImmutableList.of(BIGINT, BIGINT); + assertThat(aggregationB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of(groupByColumn)), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testMultiColumnGroupByAggregation() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey) FROM nation + WHERE regionkey > 10 AND nationkey > 10 + GROUP BY regionkey, name + UNION ALL + SELECT max(nationkey) + FROM nation + WHERE regionkey < 5 AND nationkey > 10 + GROUP BY name, regionkey + UNION ALL + SELECT avg(nationkey) FROM nation + WHERE regionkey > 10 AND nationkey > 11 + GROUP BY regionkey, name"""); + Map planAdaptations = commonSubqueries.planAdaptations(); + // only aggregations with "nationkey > 10" predicate share common subqueries + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(AggregationNode.class)); + + CommonPlanAdaptation aggregationA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation aggregationB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = aggregation( + singleGroupingSet("REGIONKEY", "NAME"), + ImmutableMap.of( + Optional.of("SUM"), aggregationFunction("sum", false, ImmutableList.of(symbol("NATIONKEY"))), + Optional.of("MAX"), aggregationFunction("max", false, ImmutableList.of(symbol("NATIONKEY")))), + Optional.empty(), + AggregationNode.Step.PARTIAL, + filter( + new Logical(AND, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 10L)), + new Logical(OR, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 5L)))))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey", "NAME", "name")))); + + assertTpchPlan(aggregationA.getCommonSubplan(), commonSubplan); + assertTpchPlan(aggregationB.getCommonSubplan(), commonSubplan); + + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(aggregationA.adaptCommonSubplan(aggregationA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY")), + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "SUM", PlanMatchPattern.expression(new Reference(BIGINT, "SUM"))), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 10L)), commonSubplan))); + assertTpchPlan(aggregationB.adaptCommonSubplan(aggregationB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY")), + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "MAX", PlanMatchPattern.expression(new Reference(BIGINT, "MAX"))), + filter(new Comparison(LESS_THAN, new Reference(BIGINT, "REGIONKEY"), new Constant(BIGINT, 5L)), commonSubplan))); + + // make sure plan signatures are same + CanonicalAggregation sum = canonicalAggregation("sum", NATIONKEY_EXPRESSION); + CanonicalAggregation max = canonicalAggregation("max", NATIONKEY_EXPRESSION); + assertThat(aggregationA.getCommonSubplanSignature()).isEqualTo(aggregationB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(REGIONKEY_ID, NAME_ID, canonicalAggregationToColumnId(sum), canonicalAggregationToColumnId(max)); + List cacheColumnsTypes = ImmutableList.of(BIGINT, createVarcharType(25), BIGINT, BIGINT); + assertThat(aggregationB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of(NAME_ID, REGIONKEY_ID)), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + REGIONKEY_ID, Domain.create(ValueSet.ofRanges(lessThan(BIGINT, 5L), greaterThan(BIGINT, 10L)), false), + NATIONKEY_ID, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 10L)), false))))); + } + + @Test + public void testAggregationWithComplexAggregationExpression() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey + 1) FROM nation GROUP BY name, regionkey + UNION ALL + SELECT sum(nationkey + 1) FROM nation GROUP BY regionkey, name"""); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(AggregationNode.class)); + + CommonPlanAdaptation aggregationA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation aggregationB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = aggregation( + singleGroupingSet("NAME", "REGIONKEY"), + ImmutableMap.of( + Optional.of("SUM"), aggregationFunction("sum", false, ImmutableList.of(symbol("EXPR")))), + Optional.empty(), + AggregationNode.Step.PARTIAL, + strictProject(ImmutableMap.of( + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY")), + "EXPR", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 1L))))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "NAME", "name", "REGIONKEY", "regionkey")))); + + assertTpchPlan(aggregationA.getCommonSubplan(), commonSubplan); + assertTpchPlan(aggregationB.getCommonSubplan(), commonSubplan); + + // only subplan B required adaptation (different order for group by columns) + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertThat(aggregationA.adaptCommonSubplan(aggregationA.getCommonSubplan(), idAllocator)).isEqualTo(aggregationA.getCommonSubplan()); + assertTpchPlan(aggregationB.adaptCommonSubplan(aggregationB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "REGIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "REGIONKEY")), + "NAME", PlanMatchPattern.expression(new Reference(createVarcharType(25), "NAME")), + "SUM", PlanMatchPattern.expression(new Reference(BIGINT, "SUM"))), + commonSubplan)); + + // make sure plan signatures are same + CacheColumnId nationKeyPlusOne = canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 1L)))); + CanonicalAggregation sum = canonicalAggregation("sum", columnIdToSymbol(nationKeyPlusOne, BIGINT).toSymbolReference()); + assertThat(aggregationA.getCommonSubplanSignature()).isEqualTo(aggregationB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(NAME_ID, REGIONKEY_ID, canonicalAggregationToColumnId(sum)); + List cacheColumnsTypes = ImmutableList.of(createVarcharType(25), BIGINT, BIGINT); + assertThat(aggregationB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + aggregationKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.of(ImmutableList.of(NAME_ID, REGIONKEY_ID)), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testNestedFilters() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT nationkey_mul FROM (SELECT nationkey * 2 AS nationkey_mul FROM nation) WHERE nationkey_mul * nationkey_mul = nationkey_mul + UNION ALL + SELECT nationkey_add FROM (SELECT nationkey + 2 AS nationkey_add FROM nation) WHERE nationkey_add + nationkey_add = nationkey_add"""); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(FilterNode.class)); + + CommonPlanAdaptation filterA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation filterB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = filter( + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY_MUL"), new Reference(BIGINT, "NATIONKEY_MUL"))), new Reference(BIGINT, "NATIONKEY_MUL")), + new Comparison(EQUAL, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY_ADD"), new Reference(BIGINT, "NATIONKEY_ADD"))), new Reference(BIGINT, "NATIONKEY_ADD")))), + strictProject( + ImmutableMap.of( + "NATIONKEY_MUL", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L)))), + "NATIONKEY_ADD", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L))))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))); + + assertTpchPlan(filterA.getCommonSubplan(), commonSubplan); + assertTpchPlan(filterB.getCommonSubplan(), commonSubplan); + + // validate adaptations + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(filterA.adaptCommonSubplan(filterA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of("NATIONKEY_MUL", PlanMatchPattern.expression(new Reference(BIGINT, "NATIONKEY_MUL"))), + filter( + new Comparison(EQUAL, new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY_MUL"), new Reference(BIGINT, "NATIONKEY_MUL"))), new Reference(BIGINT, "NATIONKEY_MUL")), + commonSubplan))); + assertTpchPlan(filterB.adaptCommonSubplan(filterB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of("NATIONKEY_ADD", PlanMatchPattern.expression(new Reference(BIGINT, "NATIONKEY_ADD"))), + filter( + new Comparison(EQUAL, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY_ADD"), new Reference(BIGINT, "NATIONKEY_ADD"))), new Reference(BIGINT, "NATIONKEY_ADD")), + commonSubplan))); + + // make sure plan signatures are same + Reference nationKeyMultiplyReference = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))), BIGINT).toSymbolReference(); + Reference nationKeyAddReference = columnIdToSymbol(canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))), BIGINT).toSymbolReference(); + assertThat(filterA.getCommonSubplanSignature()).isEqualTo(filterB.getCommonSubplanSignature()); + assertThat(filterB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + combine(filterProjectKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + "filters=(($operator$multiply(\"($operator$multiply(\"\"[nationkey:bigint]\"\", bigint '2'))\", \"($operator$multiply(\"\"[nationkey:bigint]\"\", bigint '2'))\") = \"($operator$multiply(\"\"[nationkey:bigint]\"\", bigint '2'))\") " + + "OR ($operator$add(\"($operator$add(\"\"[nationkey:bigint]\"\", bigint '2'))\", \"($operator$add(\"\"[nationkey:bigint]\"\", bigint '2'))\") = \"($operator$add(\"\"[nationkey:bigint]\"\", bigint '2'))\"))"), + Optional.empty(), + ImmutableList.of(canonicalExpressionToColumnId(nationKeyMultiplyReference), canonicalExpressionToColumnId(nationKeyAddReference)), + ImmutableList.of(BIGINT, BIGINT)), + TupleDomain.all())); + } + + @Test + public void testNestedProjections() + { + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT nationkey_mul * nationkey_mul FROM (SELECT nationkey * 2 AS nationkey_mul FROM nation) + UNION ALL + SELECT nationkey_add + nationkey_add FROM (SELECT nationkey + 2 AS nationkey_add FROM nation)"""); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(ProjectNode.class)); + + CommonPlanAdaptation projectionA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation projectionB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = strictProject( + ImmutableMap.of( + "MUL", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY_MUL"), new Reference(BIGINT, "NATIONKEY_MUL")))), + "ADD", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY_ADD"), new Reference(BIGINT, "NATIONKEY_ADD"))))), + strictProject( + ImmutableMap.of( + "NATIONKEY_MUL", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L)))), + "NATIONKEY_ADD", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L))))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))); + + assertTpchPlan(projectionA.getCommonSubplan(), commonSubplan); + assertTpchPlan(projectionB.getCommonSubplan(), commonSubplan); + + // validate adaptations + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(projectionA.adaptCommonSubplan(projectionA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of("MUL", PlanMatchPattern.expression(new Reference(BIGINT, "MUL"))), + commonSubplan)); + assertTpchPlan(projectionB.adaptCommonSubplan(projectionB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of("ADD", PlanMatchPattern.expression(new Reference(BIGINT, "ADD"))), + commonSubplan)); + + // make sure plan signatures are same + Reference nationKeyMultiplyReference = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))), BIGINT).toSymbolReference(); + Reference nationKeyAddReference = columnIdToSymbol(canonicalExpressionToColumnId(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))), BIGINT).toSymbolReference(); + Expression multiplyProjection = new Call(MULTIPLY_BIGINT, ImmutableList.of(nationKeyMultiplyReference, nationKeyMultiplyReference)); + Expression addProjection = new Call(ADD_BIGINT, ImmutableList.of(nationKeyAddReference, nationKeyAddReference)); + assertThat(projectionA.getCommonSubplanSignature()).isEqualTo(projectionB.getCommonSubplanSignature()); + assertThat(projectionB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + filterProjectKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.empty(), + ImmutableList.of(canonicalExpressionToColumnId(multiplyProjection), canonicalExpressionToColumnId(addProjection)), + ImmutableList.of(BIGINT, BIGINT)), + TupleDomain.all())); + } + + @Test + public void testCommonProjectionOnDifferentLevels() + { + // First subquery evaluates "nationkey * 2" in child projection, the second subquery evaluates + // "nationkey * 2" in parent projection. This leads to conflict how "nationkey * 2" should be + // evaluated in common subplan. By default, simpler projection is chosen. + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT nationkey_mul, nationkey_mul * nationkey_mul FROM (SELECT nationkey * 2 AS nationkey_mul FROM nation) + UNION ALL + SELECT nationkey * 2, nationkey_mul * nationkey_mul FROM (SELECT nationkey, nationkey * 4 AS nationkey_mul FROM nation)"""); + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(ProjectNode.class)); + + CommonPlanAdaptation projectionA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation projectionB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = strictProject( + ImmutableMap.of( + "MUL2", PlanMatchPattern.expression(new Reference(BIGINT, "MUL2")), + "MUL2_MUL2", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "MUL2"), new Reference(BIGINT, "MUL2")))), + "MUL4_MUL4", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "MUL4"), new Reference(BIGINT, "MUL4"))))), + strictProject( + ImmutableMap.of( + "MUL2", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 2L)))), + "MUL4", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "NATIONKEY"), new Constant(BIGINT, 4L)))), + // NATIONKEY is actually unused by parent projection, but it's an artifact of common subplan extraction + "NATIONKEY", PlanMatchPattern.expression(new Reference(BIGINT, "NATIONKEY"))), + tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey")))); + + assertTpchPlan(projectionA.getCommonSubplan(), commonSubplan); + assertTpchPlan(projectionB.getCommonSubplan(), commonSubplan); + + // validate adaptations + PlanNodeIdAllocator idAllocator = commonSubqueries.idAllocator(); + assertTpchPlan(projectionA.adaptCommonSubplan(projectionA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of( + "MUL2", PlanMatchPattern.expression(new Reference(BIGINT, "MUL2")), + "MUL2_MUL2", PlanMatchPattern.expression(new Reference(BIGINT, "MUL2_MUL2"))), + commonSubplan)); + assertTpchPlan(projectionB.adaptCommonSubplan(projectionB.getCommonSubplan(), idAllocator), + project(ImmutableMap.of( + "MUL2_ORIGINAL", PlanMatchPattern.expression(new Reference(BIGINT, "MUL2")), + "MUL4_MUL4", PlanMatchPattern.expression(new Reference(BIGINT, "MUL4_MUL4"))), + commonSubplan)); + ProjectNode adaptationB = (ProjectNode) projectionB.adaptCommonSubplan(projectionB.getCommonSubplan(), idAllocator); + // output symbols are remapped to match original subplan + assertThat(adaptationB.isIdentity()).isFalse(); + assertThat(adaptationB.getAssignments().getExpressions()).allMatch(expression -> expression instanceof Reference); + + // make sure plan signatures are same + Reference mul2 = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 2L)))), BIGINT).toSymbolReference(); + Reference mul4 = columnIdToSymbol(canonicalExpressionToColumnId(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BIGINT, "[nationkey:bigint]"), new Constant(BIGINT, 4L)))), BIGINT).toSymbolReference(); + Expression mul2_mul2 = new Call(MULTIPLY_BIGINT, ImmutableList.of(mul2, mul2)); + Expression mul4_mul4 = new Call(MULTIPLY_BIGINT, ImmutableList.of(mul4, mul4)); + assertThat(projectionA.getCommonSubplanSignature()).isEqualTo(projectionB.getCommonSubplanSignature()); + assertThat(projectionB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + filterProjectKey(scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01"))), + Optional.empty(), + ImmutableList.of(canonicalExpressionToColumnId(mul2), canonicalExpressionToColumnId(mul2_mul2), canonicalExpressionToColumnId(mul4_mul4)), + ImmutableList.of(BIGINT, BIGINT, BIGINT)), + TupleDomain.all())); + } + + @Test + public void testQueryWithAggregatedAndNonAggregatedSubqueries() + { + // data should be cached on table scan level + CommonSubqueries commonSubqueries = extractTpchCommonSubqueries(""" + SELECT sum(nationkey) FROM nation GROUP BY regionkey + UNION ALL + SELECT nationkey FROM nation"""); + + Map planAdaptations = commonSubqueries.planAdaptations(); + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).allSatisfy((node, adaptation) -> + assertThat(node).isInstanceOf(TableScanNode.class)); + + CommonPlanAdaptation aggregationA = Iterables.get(planAdaptations.values(), 0); + CommonPlanAdaptation aggregationB = Iterables.get(planAdaptations.values(), 1); + + PlanMatchPattern commonSubplan = tableScan("nation", ImmutableMap.of("NATIONKEY", "nationkey", "REGIONKEY", "regionkey")); + + assertTpchPlan(aggregationA.getCommonSubplan(), commonSubplan); + assertTpchPlan(aggregationB.getCommonSubplan(), commonSubplan); + + // make sure plan signatures are same + assertThat(aggregationA.getCommonSubplanSignature()).isEqualTo(aggregationB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(NATIONKEY_ID, REGIONKEY_ID); + List cacheColumnsTypes = ImmutableList.of(BIGINT, BIGINT); + assertThat(aggregationB.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(tpchCatalogId + ":tiny:nation:0.01")), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testExtractCommonSubqueries() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + Symbol subqueryAColumn2 = symbolAllocator.newSymbol("subquery_a_column2", BIGINT); + Symbol subqueryAProjection1 = symbolAllocator.newSymbol("subquery_a_projection1", BIGINT); + // subquery A scans column1 and column2 + PlanNode scanA = new TableScanNode( + new PlanNodeId("scanA"), + testTableHandle, + ImmutableList.of(subqueryAColumn1, subqueryAColumn2), + ImmutableMap.of(subqueryAColumn1, HANDLE_1, subqueryAColumn2, HANDLE_2), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + // subquery A has complex predicate, but no DF + FilterNode filterA = new FilterNode( + new PlanNodeId("filterA"), + scanA, + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 4L))), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "subquery_a_column2"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 0L))))); + ProjectNode projectA = new ProjectNode( + new PlanNodeId("projectA"), + filterA, + Assignments.of( + subqueryAProjection1, new Constant(BIGINT, 10L), + subqueryAColumn1, new Reference(BIGINT, "subquery_a_column1"))); + + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + Symbol subqueryBProjection1 = symbolAllocator.newSymbol("subquery_b_projection1", BIGINT); + // subquery B scans just column 1 + PlanNode scanB = new TableScanNode( + new PlanNodeId("scanB"), + testTableHandle, + ImmutableList.of(subqueryBColumn1), + ImmutableMap.of(subqueryBColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + // Subquery B predicate is subset of subquery A predicate. Subquery B has dynamic filter + FilterNode filterB = new FilterNode( + new PlanNodeId("filterB"), + scanB, + and( + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 4L))), new Constant(BIGINT, 0L)), + createDynamicFilterExpression( + getPlanTester().getPlannerContext().getMetadata(), + new DynamicFilterId("subquery_b_dynamic_id"), + BIGINT, + new Reference(BIGINT, "subquery_b_column1")))); + // Subquery B projection is subset of subquery 1 projection + ProjectNode projectB = new ProjectNode( + new PlanNodeId("projectB"), + filterB, + Assignments.of( + subqueryBProjection1, new Constant(BIGINT, 10L))); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Map planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(projectA, projectB), + ImmutableListMultimap.of(), + ImmutableList.of())); + + // there should be a common subquery found for both subplans + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(projectA); + assertThat(planAdaptations).containsKey(projectB); + + CommonPlanAdaptation subqueryA = planAdaptations.get(projectA); + CommonPlanAdaptation subqueryB = planAdaptations.get(projectB); + + // common subplan should be identical for both subqueries + PlanMatchPattern commonSubplanTableScan = strictTableScan( + TEST_TABLE, + ImmutableMap.of( + "column1", "column1", + "column2", "column2")); + PlanMatchPattern commonSubplan = strictProject( + ImmutableMap.of( + "column1", PlanMatchPattern.expression(new Reference(BIGINT, "column1")), + "projection", PlanMatchPattern.expression(new Constant(BIGINT, 10L))), + filter( + new Logical(OR, ImmutableList.of( + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "column1"), new Constant(BIGINT, 4L))), new Constant(BIGINT, 0L)), + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "column2"), new Constant(BIGINT, 2L))), new Constant(BIGINT, 0L)))), + commonSubplanTableScan)); + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + + // assert that FilteredTableScan has correct table and predicate for both subplans + assertPlan(subqueryA.getCommonSubplanFilteredTableScan().tableScanNode(), commonSubplanTableScan); + assertPlan(subqueryB.getCommonSubplanFilteredTableScan().tableScanNode(), commonSubplanTableScan); + assertThat(subqueryA.getCommonSubplanFilteredTableScan().filterPredicate()).hasValue( + ((FilterNode) PlanNodeSearcher.searchFrom(subqueryA.getCommonSubplan()) + .whereIsInstanceOfAny(FilterNode.class) + .findOnlyElement()) + .getPredicate()); + assertThat(subqueryB.getCommonSubplanFilteredTableScan().filterPredicate()).hasValue( + ((FilterNode) PlanNodeSearcher.searchFrom(subqueryB.getCommonSubplan()) + .whereIsInstanceOfAny(FilterNode.class) + .findOnlyElement()) + .getPredicate()); + + // assert that useConnectorNodePartitioning is propagated correctly + assertThat(((TableScanNode) PlanNodeSearcher.searchFrom(subqueryA.getCommonSubplan()) + .whereIsInstanceOfAny(TableScanNode.class) + .findOnlyElement()) + .isUseConnectorNodePartitioning()) + .isFalse(); + assertThat(((TableScanNode) PlanNodeSearcher.searchFrom(subqueryB.getCommonSubplan()) + .whereIsInstanceOfAny(TableScanNode.class) + .findOnlyElement()) + .isUseConnectorNodePartitioning()) + .isFalse(); + + // assert that common subplan for subquery A doesn't have dynamic filter + assertThat(extractExpressions(subqueryA.getCommonSubplan()).stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream())) + .isEmpty(); + + CacheColumnId column1 = new CacheColumnId("[cache_column1]"); + CacheColumnId column2 = new CacheColumnId("[cache_column2]"); + assertThat(subqueryA.getCommonColumnHandles()).containsExactly( + entry(column1, HANDLE_1), + entry(column2, HANDLE_2)); + + // assert that common subplan for subquery B has dynamic filter preserved + assertThat(extractExpressions(subqueryB.getCommonSubplan()).stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .collect(toImmutableList())) + .containsExactly(new Descriptor( + new DynamicFilterId("subquery_b_dynamic_id"), + new Reference(BIGINT, "subquery_b_column1"))); + assertThat(subqueryB.getCommonColumnHandles()).isEqualTo(subqueryA.getCommonColumnHandles()); + + // common DF is true since subqueryA doesn't have DF + assertThat(subqueryA.getCommonDynamicFilterDisjuncts()).isEqualTo(Booleans.TRUE); + assertThat(subqueryB.getCommonDynamicFilterDisjuncts()).isEqualTo(Booleans.TRUE); + + // symbols used in common subplans for both subqueries should be unique + assertThat(extractUnique(subqueryA.getCommonSubplan())) + .doesNotContainAnyElementsOf(extractUnique(subqueryB.getCommonSubplan())); + + // since subqueryA has the same predicate and projections as common subquery, then no adaptation is required + PlanNode subqueryACommonSubplan = subqueryA.getCommonSubplan(); + assertThat(subqueryA.adaptCommonSubplan(subqueryACommonSubplan, idAllocator)).isEqualTo(subqueryACommonSubplan); + + assertPlan(subqueryB.adaptCommonSubplan(subqueryB.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of("projection", PlanMatchPattern.expression(new Reference(BIGINT, "projection"))), + filter( + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "column1"), new Constant(BIGINT, 4L))), new Constant(BIGINT, 0L)), + commonSubplan))); + + // make sure plan signatures are same + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(subqueryB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(canonicalExpressionToColumnId(new Constant(BIGINT, 10L)), column1); + List cacheColumnsTypes = ImmutableList.of(BIGINT, BIGINT); + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate(new PlanSignature( + combine(scanFilterProjectKey(new CacheTableId(testTableHandle.catalogHandle().getId() + ":cache_table_id")), "filters=(($operator$modulus(\"[cache_column1]\", bigint '4') = bigint '0') OR ($operator$modulus(\"[cache_column2]\", bigint '2') = bigint '0'))"), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testCommonPredicateWasPushedDownAndDynamicFilter() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), getPlanTester().getPlannerContext(), TEST_SESSION); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + + // subquery A + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + + PlanNode planA = planBuilder.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 150L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryAColumn1)) + .setAssignments(ImmutableMap.of(subqueryAColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // subquery B + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + Symbol subqueryBColumn2 = symbolAllocator.newSymbol("subquery_b_column2", BIGINT); + + PlanNode planB = planBuilder.filter( + and( + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 50L)), + createDynamicFilterExpression( + getPlanTester().getPlannerContext().getMetadata(), + new DynamicFilterId("subquery_b_dynamic_id"), + BIGINT, + new Reference(BIGINT, "subquery_b_column2"))), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryBColumn1, subqueryBColumn2)) + .setAssignments(ImmutableMap.of(subqueryBColumn1, HANDLE_1, subqueryBColumn2, HANDLE_2)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // create a plan + PlanNode root = planBuilder.union(ImmutableListMultimap.of(), ImmutableList.of(planA, planB)); + + // extract common subqueries + Map planAdaptations = extractCommonSubqueries(idAllocator, symbolAllocator, root); + CommonPlanAdaptation subqueryA = planAdaptations.get(planA); + CommonPlanAdaptation subqueryB = planAdaptations.get(planB); + PlanMatchPattern commonTableScan = tableScan(TEST_TABLE, ImmutableMap.of("column2", "column2")) + .with(TableScanNode.class, tableScan -> ((MockConnectorTableHandle) tableScan.getTable().connectorHandle()).getConstraint().equals(CONSTRAINT_1)); + + // check whether common predicates were pushed down to common table scan + assertPlan(subqueryA.getCommonSubplan(), commonTableScan); + + // There is a FilterNode because of dynamic filters + PlanMatchPattern commonSubplanB = filter(Booleans.TRUE, createDynamicFilterExpression( + getPlanTester().getPlannerContext().getMetadata(), + new DynamicFilterId("subquery_b_dynamic_id"), + BIGINT, + new Reference(BIGINT, "column2")), + commonTableScan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplanB); + } + + @Test + public void testCommonPredicateWasFullyPushedDown() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), getPlanTester().getPlannerContext(), TEST_SESSION); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + + // subquery A + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + + PlanNode planA = planBuilder.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 150L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryAColumn1)) + .setAssignments(ImmutableMap.of(subqueryAColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // subquery B + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + + PlanNode planB = planBuilder.filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 50L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryBColumn1)) + .setAssignments(ImmutableMap.of(subqueryBColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // create a plan + PlanNode root = planBuilder.union(ImmutableListMultimap.of(), ImmutableList.of(planA, planB)); + + // extract common subqueries + Map planAdaptations = extractCommonSubqueries(idAllocator, symbolAllocator, root); + CommonPlanAdaptation subqueryA = planAdaptations.get(planA); + CommonPlanAdaptation subqueryB = planAdaptations.get(planB); + + // check whether common predicates were pushed down to common table scan + PlanMatchPattern commonSubplan = tableScan(TEST_TABLE) + .with(TableScanNode.class, tableScan -> ((MockConnectorTableHandle) tableScan.getTable().connectorHandle()).getConstraint().equals(CONSTRAINT_1)); + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + } + + @Test + public void testCommonPredicateWasPartiallyPushedDown() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), getPlanTester().getPlannerContext(), TEST_SESSION); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + + // subquery A + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + + PlanNode planA = planBuilder.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 70L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryAColumn1)) + .setAssignments(ImmutableMap.of(subqueryAColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // subquery B + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + + PlanNode planB = planBuilder.filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 30L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryBColumn1)) + .setAssignments(ImmutableMap.of(subqueryBColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // create a plan + PlanNode root = planBuilder.union(ImmutableListMultimap.of(), ImmutableList.of(planA, planB)); + + // extract common subqueries + Map planAdaptations = extractCommonSubqueries(idAllocator, symbolAllocator, root); + CommonPlanAdaptation subqueryA = planAdaptations.get(planA); + CommonPlanAdaptation subqueryB = planAdaptations.get(planB); + + // check whether common predicates were partially pushed down (there is remaining filter and pushed down filter to table handle) + // to common table scan + PlanMatchPattern commonSubplan = filter( + new Logical(OR, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 30L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 70L)))), + tableScan(TEST_TABLE, ImmutableMap.of("column1", "column1")) + .with(TableScanNode.class, tableScan -> ((MockConnectorTableHandle) tableScan.getTable().connectorHandle()).getConstraint().equals(CONSTRAINT_3))); + + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + } + + @Test + public void testCommonPredicateWasNotPushedDownWhenValuesNode() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), getPlanTester().getPlannerContext(), TEST_SESSION); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + + // subquery A + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + + PlanNode planA = planBuilder.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 40L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryAColumn1)) + .setAssignments(ImmutableMap.of(subqueryAColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // subquery B + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + + PlanNode planB = planBuilder.filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 20L)), + planBuilder.tableScan( + tableScan -> tableScan + .setTableHandle(testTableHandle) + .setSymbols(ImmutableList.of(subqueryBColumn1)) + .setAssignments(ImmutableMap.of(subqueryBColumn1, HANDLE_1)) + .setEnforcedConstraint(TupleDomain.all()) + .setUseConnectorNodePartitioning(Optional.of(false)))); + + // create a plan + PlanNode root = planBuilder.union(ImmutableListMultimap.of(), ImmutableList.of(planA, planB)); + + // extract common subqueries + Map planAdaptations = extractCommonSubqueries(idAllocator, symbolAllocator, root); + CommonPlanAdaptation subqueryA = planAdaptations.get(planA); + CommonPlanAdaptation subqueryB = planAdaptations.get(planB); + + PlanMatchPattern commonSubplan = filter( + new Logical(OR, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 20L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 40L)))), + tableScan(TEST_TABLE, ImmutableMap.of("column1", "column1")) + .with(TableScanNode.class, tableScan -> ((MockConnectorTableHandle) tableScan.getTable().connectorHandle()).getConstraint().equals(TupleDomain.all()))); + + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + } + + @Test + public void testExtractDomain() + { + // both subqueries contain simple predicate that can be translated into tuple domain in plan signature + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + PlanNode scanA = new TableScanNode( + new PlanNodeId("scanA"), + testTableHandle, + ImmutableList.of(subqueryAColumn1), + ImmutableMap.of(subqueryAColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + FilterNode filterA = new FilterNode( + new PlanNodeId("filterA"), + scanA, + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 42L))); + + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + PlanNode scanB = new TableScanNode( + new PlanNodeId("scanB"), + testTableHandle, + ImmutableList.of(subqueryBColumn1), + ImmutableMap.of(subqueryBColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + FilterNode filterB = new FilterNode( + new PlanNodeId("filterB"), + scanB, + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 0L))); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Map planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(filterA, filterB), + ImmutableListMultimap.of(), + ImmutableList.of())); + + // there should be a common subquery found for both subplans + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(filterA); + assertThat(planAdaptations).containsKey(filterB); + + CommonPlanAdaptation subqueryA = planAdaptations.get(filterA); + CommonPlanAdaptation subqueryB = planAdaptations.get(filterB); + + // common subplan should be identical for both subqueries + PlanMatchPattern commonSubplan = + filter( + new Logical(OR, ImmutableList.of( + new Comparison(GREATER_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 42L)), + new Comparison(LESS_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 0L)))), + strictTableScan( + TEST_TABLE, + ImmutableMap.of( + "column1", "column1"))); + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + + // filtering adaptation is required + assertPlan(subqueryA.adaptCommonSubplan(subqueryA.getCommonSubplan(), idAllocator), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 42L)), + commonSubplan)); + + assertPlan(subqueryB.adaptCommonSubplan(subqueryB.getCommonSubplan(), idAllocator), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 0L)), + commonSubplan)); + + // make sure plan signatures are same and contain domain + SortedRangeSet expectedValues = (SortedRangeSet) ValueSet.ofRanges(lessThan(BIGINT, 0L), greaterThan(BIGINT, 42L)); + TupleDomain expectedTupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( + new CacheColumnId("[cache_column1]"), Domain.create(expectedValues, false))); + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(subqueryB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(new CacheColumnId("[cache_column1]")); + List cacheColumnsTypes = ImmutableList.of(BIGINT); + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(testTableHandle.catalogHandle().getId() + ":cache_table_id")), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + expectedTupleDomain)); + + // make sure signature tuple domain is normalized + SortedRangeSet actualValues = (SortedRangeSet) subqueryA.getCommonSubplanSignature() + .predicate() + .getDomains() + .orElseThrow() + .get(new CacheColumnId("[cache_column1]")) + .getValues(); + assertBlockEquals(BIGINT, actualValues.getSortedRanges(), expectedValues.getSortedRanges()); + assertThat(actualValues.getSortedRanges()).isInstanceOf(LongArrayBlock.class); + } + + @Test + public void testSimpleSubqueries() + { + // both subqueries are just table scans + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + PlanNode scanA = new TableScanNode( + new PlanNodeId("scanA"), + testTableHandle, + ImmutableList.of(subqueryAColumn1), + ImmutableMap.of(subqueryAColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + Symbol subqueryBColumn2 = symbolAllocator.newSymbol("subquery_b_column2", BIGINT); + PlanNode scanB = new TableScanNode( + new PlanNodeId("scanB"), + testTableHandle, + ImmutableList.of(subqueryBColumn2, subqueryBColumn1), + ImmutableMap.of(subqueryBColumn2, HANDLE_2, subqueryBColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Map planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(scanA, scanB), + ImmutableListMultimap.of(), + ImmutableList.of())); + + // there should be a common subquery found for both subplans + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(scanA); + assertThat(planAdaptations).containsKey(scanB); + + CommonPlanAdaptation subqueryA = planAdaptations.get(scanA); + CommonPlanAdaptation subqueryB = planAdaptations.get(scanB); + + // common subplan should be identical for both subqueries + PlanMatchPattern commonSubplan = + strictTableScan( + TEST_TABLE, + ImmutableMap.of( + "column1", "column1", + "column2", "column2")); + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + + // only projection adaptation is required + assertPlan(subqueryA.adaptCommonSubplan(subqueryA.getCommonSubplan(), idAllocator), + strictProject(ImmutableMap.of("column1", PlanMatchPattern.expression(new Reference(BIGINT, "column1"))), + commonSubplan)); + + assertPlan(subqueryB.adaptCommonSubplan(subqueryB.getCommonSubplan(), idAllocator), + // order of common subquery output needs to shuffled to match original query + strictProject(ImmutableMap.of( + "column2", PlanMatchPattern.expression(new Reference(BIGINT, "column2")), + "column1", PlanMatchPattern.expression(new Reference(BIGINT, "column1"))), + commonSubplan)); + + // make sure plan signatures are same and contain domain + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(subqueryB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(new CacheColumnId("[cache_column1]"), new CacheColumnId("[cache_column2]")); + List cacheColumnsTypes = ImmutableList.of(BIGINT, BIGINT); + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(testTableHandle.catalogHandle().getId() + ":cache_table_id")), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.all())); + } + + @Test + public void testPredicateInSingleSubquery() + { + // one subquery has filter, the other does not + // common subquery shouldn't have any predicate + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + PlanNode scanA = new TableScanNode( + new PlanNodeId("scanA"), + testTableHandle, + ImmutableList.of(subqueryAColumn1), + ImmutableMap.of(subqueryAColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + FilterNode filterA = new FilterNode( + new PlanNodeId("filterA"), + scanA, + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 4L))), new Constant(BIGINT, 0L))); + + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + PlanNode scanB = new TableScanNode( + new PlanNodeId("scanB"), + testTableHandle, + ImmutableList.of(subqueryBColumn1), + ImmutableMap.of(subqueryBColumn1, HANDLE_1), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Map planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(filterA, scanB), + ImmutableListMultimap.of(), + ImmutableList.of())); + + // there should be a common subquery found for both subplans + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(filterA); + assertThat(planAdaptations).containsKey(scanB); + + CommonPlanAdaptation subqueryA = planAdaptations.get(filterA); + CommonPlanAdaptation subqueryB = planAdaptations.get(scanB); + + // common subplan should consist on only table scan + PlanMatchPattern commonSubplan = strictTableScan( + TEST_TABLE, + ImmutableMap.of("column1", "column1")); + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + + // only filtering adaptation is required on subplan a + assertPlan(subqueryA.adaptCommonSubplan(subqueryA.getCommonSubplan(), idAllocator), + filter( + new Comparison(EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BIGINT, "column1"), new Constant(BIGINT, 4L))), new Constant(BIGINT, 0L)), + commonSubplan)); + + assertPlan(subqueryB.adaptCommonSubplan(subqueryB.getCommonSubplan(), idAllocator), commonSubplan); + } + + @Test + public void testSharedConjunct() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(); + + // subquery A scans column1 and column2 + Symbol subqueryAColumn1 = symbolAllocator.newSymbol("subquery_a_column1", BIGINT); + Symbol subqueryAColumn2 = symbolAllocator.newSymbol("subquery_a_column2", BIGINT); + PlanNode scanA = new TableScanNode( + new PlanNodeId("scanA"), + testTableHandle, + ImmutableList.of(subqueryAColumn1, subqueryAColumn2), + ImmutableMap.of(subqueryAColumn1, HANDLE_1, subqueryAColumn2, HANDLE_2), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + // subquery A has predicate on both columns + FilterNode filterA = new FilterNode( + new PlanNodeId("filterA"), + scanA, + new Logical(AND, ImmutableList.of( + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_a_column1"), new Constant(BIGINT, 42L)), + new Comparison(GREATER_THAN, new Reference(BIGINT, "subquery_a_column2"), new Constant(BIGINT, 24L))))); + ProjectNode projectA = new ProjectNode( + new PlanNodeId("projectA"), + filterA, + Assignments.of( + subqueryAColumn2, new Reference(BIGINT, "subquery_a_column2"))); + + // subquery B scans column1 and column2 + Symbol subqueryBColumn1 = symbolAllocator.newSymbol("subquery_b_column1", BIGINT); + Symbol subqueryBColumn2 = symbolAllocator.newSymbol("subquery_b_column2", BIGINT); + PlanNode scanB = new TableScanNode( + new PlanNodeId("scanB"), + testTableHandle, + ImmutableList.of(subqueryBColumn1, subqueryBColumn2), + ImmutableMap.of(subqueryBColumn1, HANDLE_1, subqueryBColumn2, HANDLE_2), + TupleDomain.all(), + Optional.empty(), + false, + Optional.of(false)); + // subquery B has predicate on column1 only + FilterNode filterB = new FilterNode( + new PlanNodeId("filterA"), + scanB, + new Comparison(LESS_THAN, new Reference(BIGINT, "subquery_b_column1"), new Constant(BIGINT, 42L))); + ProjectNode projectB = new ProjectNode( + new PlanNodeId("projectA"), + filterB, + Assignments.of(subqueryBColumn2, new Reference(BIGINT, "subquery_b_column2"))); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Map planAdaptations = extractCommonSubqueries( + idAllocator, + symbolAllocator, + new UnionNode( + new PlanNodeId("union"), + ImmutableList.of(projectA, projectB), + ImmutableListMultimap.of(), + ImmutableList.of())); + + // there should be a common subquery found for both subplans + assertThat(planAdaptations).hasSize(2); + assertThat(planAdaptations).containsKey(projectA); + assertThat(planAdaptations).containsKey(projectB); + + CommonPlanAdaptation subqueryA = planAdaptations.get(projectA); + CommonPlanAdaptation subqueryB = planAdaptations.get(projectB); + + // common subplan should be identical for both subqueries + PlanMatchPattern commonSubplanTableScan = strictTableScan( + TEST_TABLE, + ImmutableMap.of( + "column1", "column1", + "column2", "column2")); + PlanMatchPattern commonSubplan = strictProject( + ImmutableMap.of( + "column2", PlanMatchPattern.expression(new Reference(BIGINT, "column2"))), + filter( + new Comparison(LESS_THAN, new Reference(BIGINT, "column1"), new Constant(BIGINT, 42L)), + commonSubplanTableScan)); + assertPlan(subqueryA.getCommonSubplan(), commonSubplan); + assertPlan(subqueryB.getCommonSubplan(), commonSubplan); + + // subquery A should have predicate adaptation + assertPlan(subqueryA.adaptCommonSubplan(subqueryA.getCommonSubplan(), idAllocator), + filter(new Comparison(GREATER_THAN, new Reference(BIGINT, "column2"), new Constant(BIGINT, 24L)), commonSubplan)); + + PlanNode subqueryBCommonSubplan = subqueryB.getCommonSubplan(); + assertThat(subqueryB.adaptCommonSubplan(subqueryBCommonSubplan, idAllocator)).isEqualTo(subqueryBCommonSubplan); + + // make sure plan signatures are same + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(subqueryB.getCommonSubplanSignature()); + List cacheColumnIds = ImmutableList.of(new CacheColumnId("[cache_column2]")); + List cacheColumnsTypes = ImmutableList.of(BIGINT); + assertThat(subqueryA.getCommonSubplanSignature()).isEqualTo(new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(testTableHandle.catalogHandle().getId() + ":cache_table_id")), + Optional.empty(), + cacheColumnIds, + cacheColumnsTypes), + TupleDomain.withColumnDomains(ImmutableMap.of( + new CacheColumnId("[cache_column1]"), Domain.create(ValueSet.ofRanges(lessThan(BIGINT, 5L), lessThan(BIGINT, 42L)), false))))); + } + + private CanonicalAggregation canonicalAggregation(String name, Expression... arguments) + { + return canonicalAggregation(name, Optional.empty(), arguments); + } + + private CanonicalAggregation canonicalAggregation(String name, Optional mask, Expression... arguments) + { + ResolvedFunction resolvedFunction = getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction( + name, + TypeSignatureProvider.fromTypes(Stream.of(arguments) + .map(Expression::type) + .collect(toImmutableList()))); + return new CanonicalAggregation( + resolvedFunction, + mask, + ImmutableList.copyOf(arguments)); + } + + private CommonSubqueries extractTpchCommonSubqueries(@Language("SQL") String query) + { + return extractTpchCommonSubqueries(query, true, false, false); + } + + private CommonSubqueries extractTpchCommonSubqueries(@Language("SQL") String query, boolean cacheSubqueries, boolean cacheAggregations, boolean cacheProjections) + { + return extractTpchCommonSubqueries(query, cacheSubqueries, cacheAggregations, cacheProjections, true); + } + + private CommonSubqueries extractTpchCommonSubqueries(@Language("SQL") String query, boolean cacheSubqueries, boolean cacheAggregations, boolean cacheProjections, boolean forceSingleNode) + { + Session tpchSession = Session.builder(TPCH_SESSION) + .setSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, Boolean.toString(cacheSubqueries)) + .setSystemProperty(CACHE_AGGREGATIONS_ENABLED, Boolean.toString(cacheAggregations)) + .setSystemProperty(CACHE_PROJECTIONS_ENABLED, Boolean.toString(cacheProjections)) + .build(); + PlanTester planTester = getPlanTester(); + return planTester.inTransaction(tpchSession, session -> { + Plan plan = planTester.createPlan(session, query, planTester.getPlanOptimizers(forceSingleNode), OPTIMIZED_AND_VALIDATED, WarningCollector.NOOP, createPlanOptimizersStatsCollector()); + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> getPlanTester().getPlannerContext().getMetadata().getCatalogHandle(session, catalog)); + SymbolAllocator symbolAllocator = new SymbolAllocator(ImmutableSet.builder() + .addAll(extractUnique(plan.getRoot())) + .addAll(extractOutputSymbols(plan.getRoot())).build()); + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + return new CommonSubqueries( + CommonSubqueriesExtractor.extractCommonSubqueries( + new CacheController(), + getPlanTester().getPlannerContext(), + session, + idAllocator, + symbolAllocator, + plan.getRoot()), + symbolAllocator, + idAllocator, + plan.getRoot()); + }); + } + + record CommonSubqueries(Map planAdaptations, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlanNode plan) {} + + private Map extractCommonSubqueries( + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator, + PlanNode root) + { + return getPlanTester().inTransaction(TEST_SESSION, session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> getPlanTester().getPlannerContext().getMetadata().getCatalogHandle(session, catalog)); + return CommonSubqueriesExtractor.extractCommonSubqueries( + new CacheController(), + getPlanTester().getPlannerContext(), + session, + idAllocator, + symbolAllocator, + root); + }); + } + + private void assertAggregationsWithMasks(PlanNode node, int... indexes) + { + // assert aggregations at given indexes are masked/unmasked + assertThat(node).isInstanceOf(AggregationNode.class); + AggregationNode aggregation = (AggregationNode) node; + List aggregations = ImmutableList.copyOf(aggregation.getAggregations().values()); + Set maskedAggregations = Arrays.stream(indexes).boxed().collect(toImmutableSet()); + for (int i = 0; i < aggregations.size(); ++i) { + if (maskedAggregations.contains(i)) { + assertThat(aggregations.get(i).getMask()).isPresent(); + } + else { + assertThat(aggregations.get(i).getMask()).isEmpty(); + } + } + } + + private void assertPlan(PlanNode root, PlanMatchPattern expected) + { + assertPlan(TEST_SESSION, root, expected); + } + + private void assertTpchPlan(PlanNode root, PlanMatchPattern expected) + { + assertPlan(TPCH_SESSION, root, expected); + } + + private void assertPlan(Session customSession, PlanNode root, PlanMatchPattern expected) + { + getPlanTester().inTransaction(customSession, session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> getPlanTester().getPlannerContext().getMetadata().getCatalogHandle(session, catalog)); + Plan plan = new Plan(root, StatsAndCosts.empty()); + PlanAssert.assertPlan(session, getPlanTester().getPlannerContext().getMetadata(), createTestingFunctionManager(), noopStatsCalculator(), plan, expected); + return null; + }); + } +} diff --git a/core/trino-main/src/test/java/io/trino/cache/TestConsistentHashingAddressProvider.java b/core/trino-main/src/test/java/io/trino/cache/TestConsistentHashingAddressProvider.java new file mode 100644 index 000000000000..ec731e76fb08 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cache/TestConsistentHashingAddressProvider.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import io.trino.client.NodeVersion; +import io.trino.metadata.InternalNode; +import io.trino.spi.HostAddress; +import io.trino.spi.Node; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.type.Type; +import io.trino.testing.TestingNodeManager; +import org.assertj.core.data.Percentage; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.spi.cache.PlanSignature.canonicalizePlanSignature; +import static io.trino.spi.type.IntegerType.INTEGER; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestConsistentHashingAddressProvider +{ + private static final CacheColumnId COLUMN1 = new CacheColumnId("col1"); + private static final CacheColumnId COLUMN2 = new CacheColumnId("col2"); + private static final CacheSplitId SPLIT1 = new CacheSplitId("split1"); + private static final CacheSplitId SPLIT2 = new CacheSplitId("split2"); + + @Test + public void testAddressProvider() + { + TestingNodeManager nodeManager = new TestingNodeManager(); + nodeManager.addNode(node("node1")); + nodeManager.addNode(node("node2")); + nodeManager.addNode(node("node3")); + nodeManager.addNode(node("node4")); + + ConsistentHashingAddressProvider addressProvider = new ConsistentHashingAddressProvider(nodeManager); + String signature1 = canonicalizePlanSignature(createPlanSignature("signature1", COLUMN1)).toString(); + String signature2 = canonicalizePlanSignature(createPlanSignature("signature2", COLUMN1)).toString(); + String signature3 = canonicalizePlanSignature(createPlanSignature("signature1", COLUMN2)).toString(); + + // assert that both different signature or split id affects preferred address + assertThat(getPreferredAddress(addressProvider, signature1, SPLIT1)) + .isNotEqualTo(getPreferredAddress(addressProvider, signature1, SPLIT2)); + assertThat(getPreferredAddress(addressProvider, signature2, SPLIT1)) + .isNotEqualTo(getPreferredAddress(addressProvider, signature2, SPLIT2)); + assertThat(getPreferredAddress(addressProvider, signature1, SPLIT1)) + .isNotEqualTo(getPreferredAddress(addressProvider, signature2, SPLIT1)); + + // make sure that columns don't affect preferred address + assertThat(getPreferredAddress(addressProvider, signature1, SPLIT1)) + .isEqualTo(getPreferredAddress(addressProvider, signature3, SPLIT1)); + assertThat(getPreferredAddress(addressProvider, signature1, SPLIT2)) + .isEqualTo(getPreferredAddress(addressProvider, signature3, SPLIT2)); + + assertFairDistribution(addressProvider, signature1, nodeManager.getWorkerNodes()); + + Map> distribution = getDistribution(addressProvider, signature1); + nodeManager.removeNode(node("node2")); + addressProvider.refreshHashRing(); + assertFairDistribution(addressProvider, signature1, nodeManager.getWorkerNodes()); + Map> removeOne = getDistribution(addressProvider, signature1); + assertMinimalRedistribution(distribution, removeOne); + + nodeManager.addNode(node("node5")); + addressProvider.refreshHashRing(); + assertFairDistribution(addressProvider, signature1, nodeManager.getWorkerNodes()); + Map> addOne = getDistribution(addressProvider, signature1); + assertMinimalRedistribution(removeOne, addOne); + } + + private static HostAddress getPreferredAddress( + ConsistentHashingAddressProvider addressProvider, + String planSignature, + CacheSplitId splitId) + { + String key = planSignature + splitId; + return addressProvider.getPreferredAddress(key) + .orElseThrow(() -> new RuntimeException("Unable to locate key: " + key)); + } + + private static void assertFairDistribution(ConsistentHashingAddressProvider addressProvider, String planSignature, Set nodeNames) + { + int totalSplits = 1000; + Map counts = new HashMap<>(); + for (int i = 0; i < totalSplits; i++) { + counts.merge(getPreferredAddress(addressProvider, planSignature, new CacheSplitId("split" + i)).getHostText(), 1, Math::addExact); + } + assertThat(counts.keySet()).isEqualTo(nodeNames.stream().map(node -> node.getHostAndPort().getHostText()).collect(toImmutableSet())); + int expectedSplitsPerNode = totalSplits / nodeNames.size(); + counts.values().forEach(splitsPerNode -> assertThat(splitsPerNode).isCloseTo(expectedSplitsPerNode, Percentage.withPercentage(20))); + } + + private void assertMinimalRedistribution(Map> oldDistribution, Map> newDistribution) + { + oldDistribution.entrySet().stream().filter(e -> newDistribution.containsKey(e.getKey())).forEach(entry -> { + Set oldNodeBuckets = entry.getValue(); + Set newNodeBuckets = newDistribution.get(entry.getKey()); + int redDistributedBucketsCount = Sets.difference(oldNodeBuckets, newNodeBuckets).size(); + int oldClusterSize = oldDistribution.size(); + assertThat(redDistributedBucketsCount).isLessThan(oldNodeBuckets.size() / oldClusterSize); + }); + } + + private Map> getDistribution(ConsistentHashingAddressProvider addressProvider, String planSignature) + { + int totalSplits = 1000; + Map> distribution = new HashMap<>(); + for (int i = 0; i < totalSplits; i++) { + String host = getPreferredAddress(addressProvider, planSignature, new CacheSplitId("split" + i)).getHostText(); + distribution.computeIfAbsent(host, (key) -> new HashSet<>()).add(i); + } + return distribution; + } + + private static Node node(String nodeName) + { + return new InternalNode(nodeName, URI.create("http://" + nodeName + "/"), NodeVersion.UNKNOWN, false); + } + + private static PlanSignature createPlanSignature(String signature, CacheColumnId... ids) + { + return new PlanSignature( + new SignatureKey(signature), + Optional.empty(), + ImmutableList.copyOf(ids), + Stream.of(ids).map(ignore -> (Type) INTEGER).collect(toImmutableList())); + } +} diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java index 345ed0c580f4..de57c8596695 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java @@ -29,6 +29,9 @@ import io.trino.spi.Page; import io.trino.spi.RefreshType; import io.trino.spi.TrinoException; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; import io.trino.spi.connector.BeginTableExecuteResult; @@ -193,6 +196,9 @@ public class MockConnector private final List> sessionProperties; private final Function tableFunctionSplitsSources; private final OptionalInt maxWriterTasks; + private final Function> getCacheTableId; + private final Function> getCacheColumnId; + private final Function getCanonicalTableHandle; private final BiFunction> getLayoutForTableExecute; private final WriterScalingOptions writerScalingOptions; private final Supplier> capabilities; @@ -247,6 +253,9 @@ public class MockConnector Supplier>> columnProperties, Function tableFunctionSplitsSources, OptionalInt maxWriterTasks, + Function> getCacheTableId, + Function> getCacheColumnId, + Function getCanonicalTableHandle, BiFunction> getLayoutForTableExecute, WriterScalingOptions writerScalingOptions, Supplier> capabilities, @@ -300,6 +309,9 @@ public class MockConnector this.columnProperties = requireNonNull(columnProperties, "columnProperties is null"); this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); this.maxWriterTasks = requireNonNull(maxWriterTasks, "maxWriterTasks is null"); + this.getCacheTableId = requireNonNull(getCacheTableId, "getCacheTableId is null"); + this.getCacheColumnId = requireNonNull(getCacheColumnId, "getCacheColumnId is null"); + this.getCanonicalTableHandle = requireNonNull(getCanonicalTableHandle, "getCanonicalTableHandle is null"); this.getLayoutForTableExecute = requireNonNull(getLayoutForTableExecute, "getLayoutForTableExecute is null"); this.writerScalingOptions = requireNonNull(writerScalingOptions, "writerScalingOptions is null"); this.capabilities = requireNonNull(capabilities, "capabilities is null"); @@ -361,6 +373,12 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, Co }; } + @Override + public ConnectorCacheMetadata getCacheMetadata() + { + return new MockCacheMetadata(); + } + @Override public ConnectorNodePartitioningProvider getNodePartitioningProvider() { @@ -715,7 +733,8 @@ public void createMaterializedView( ConnectorMaterializedViewDefinition definition, Map properties, boolean replace, - boolean ignoreExisting) {} + boolean ignoreExisting) + {} @Override public List listMaterializedViews(ConnectorSession session, Optional schemaName) @@ -1026,6 +1045,28 @@ private MockConnectorAccessControl getMockAccessControl() } } + private class MockCacheMetadata + implements ConnectorCacheMetadata + { + @Override + public Optional getCacheTableId(ConnectorTableHandle tableHandle) + { + return getCacheTableId.apply(tableHandle); + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return getCacheColumnId.apply(columnHandle); + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle tableHandle) + { + return getCanonicalTableHandle.apply(tableHandle); + } + } + private static class MockPageSinkProvider implements ConnectorPageSinkProvider { diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java index 7a0c8e090360..ed350eb4ead5 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; import io.trino.spi.connector.CatalogSchemaTableName; @@ -145,6 +147,9 @@ public class MockConnectorFactory private final ListRoleGrants roleGrants; private final Optional accessControl; private final OptionalInt maxWriterTasks; + private final Function> getCacheTableId; + private final Function> getCacheColumnId; + private final Function getCanonicalTableHandle; private final BiFunction> getLayoutForTableExecute; private final WriterScalingOptions writerScalingOptions; @@ -201,6 +206,9 @@ private MockConnectorFactory( boolean allowMissingColumnsOnInsert, Function tableFunctionSplitsSources, OptionalInt maxWriterTasks, + Function> getCacheTableId, + Function> getCacheColumnId, + Function getCanonicalTableHandle, BiFunction> getLayoutForTableExecute, WriterScalingOptions writerScalingOptions, Supplier> capabilities, @@ -255,6 +263,9 @@ private MockConnectorFactory( this.allowMissingColumnsOnInsert = allowMissingColumnsOnInsert; this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); this.maxWriterTasks = maxWriterTasks; + this.getCacheTableId = requireNonNull(getCacheTableId, "getCacheTableId is null"); + this.getCacheColumnId = requireNonNull(getCacheColumnId, "getCacheColumnId is null"); + this.getCanonicalTableHandle = requireNonNull(getCanonicalTableHandle, "getCacheColumnId is null"); this.getLayoutForTableExecute = requireNonNull(getLayoutForTableExecute, "getLayoutForTableExecute is null"); this.writerScalingOptions = requireNonNull(writerScalingOptions, "writerScalingOptions is null"); this.capabilities = requireNonNull(capabilities, "capabilities is null"); @@ -319,6 +330,9 @@ public Connector create(String catalogName, Map config, Connecto columnProperties, tableFunctionSplitsSources, maxWriterTasks, + getCacheTableId, + getCacheColumnId, + getCanonicalTableHandle, getLayoutForTableExecute, writerScalingOptions, capabilities, @@ -475,6 +489,9 @@ public static final class Builder private BiFunction columnMask = (tableName, columnName) -> null; private boolean allowMissingColumnsOnInsert; private OptionalInt maxWriterTasks = OptionalInt.empty(); + private Function> getCacheTableId = handle -> Optional.empty(); + private Function> getCacheColumnId = handle -> Optional.empty(); + private Function getCanonicalTableHandle = Function.identity(); private BiFunction> getLayoutForTableExecute = (session, handle) -> Optional.empty(); private WriterScalingOptions writerScalingOptions = WriterScalingOptions.DISABLED; private Supplier> capabilities = ImmutableSet::of; @@ -820,6 +837,24 @@ public Builder withMaxWriterTasks(OptionalInt maxWriterTasks) return this; } + public Builder withGetCacheTableId(Function> getCacheTableId) + { + this.getCacheTableId = requireNonNull(getCacheTableId, "getCacheTableId is null"); + return this; + } + + public Builder withGetCacheColumnId(Function> getCacheColumnId) + { + this.getCacheColumnId = requireNonNull(getCacheColumnId, "getCacheColumnId is null"); + return this; + } + + public Builder withGetCanonicalTableHandle(Function getCanonicalTableHandle) + { + this.getCanonicalTableHandle = requireNonNull(getCanonicalTableHandle, "getCanonicalTableHandle is null"); + return this; + } + public Builder withAllowMissingColumnsOnInsert(boolean allowMissingColumnsOnInsert) { this.allowMissingColumnsOnInsert = allowMissingColumnsOnInsert; @@ -900,6 +935,9 @@ public MockConnectorFactory build() allowMissingColumnsOnInsert, tableFunctionSplitsSources, maxWriterTasks, + getCacheTableId, + getCacheColumnId, + getCanonicalTableHandle, getLayoutForTableExecute, writerScalingOptions, capabilities, diff --git a/core/trino-main/src/test/java/io/trino/cost/TestChooseAlternativeRule.java b/core/trino-main/src/test/java/io/trino/cost/TestChooseAlternativeRule.java new file mode 100644 index 000000000000..ad2e34ccc7f3 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/cost/TestChooseAlternativeRule.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cost; + +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.ChooseAlternativeNode.FilteredTableScan; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.List; +import java.util.Optional; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.ir.Comparison.Operator.EQUAL; +import static java.lang.Double.NaN; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) +class TestChooseAlternativeRule + extends BaseStatsCalculatorTest +{ + @Test + public void testStatsForChooseAlternative() + { + tester().assertStatsFor(builder -> builder + .chooseAlternative( + List.of( + builder.filter( + new Comparison(EQUAL, new Reference(BIGINT, "i1"), new Constant(BIGINT, 5L)), + builder.values(builder.symbol("i1"), builder.symbol("i2"))), + builder.filter( + new Comparison(EQUAL, new Reference(BIGINT, "i1"), new Constant(BIGINT, 10L)), + builder.values(builder.symbol("i1"), builder.symbol("i2")))), + new FilteredTableScan(builder.tableScan(List.of(builder.symbol("i1"), builder.symbol("i2")), false), Optional.empty()))) + .withSourceStats(0, PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(new Symbol(BIGINT, "i1"), SymbolStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setAverageRowSize(NaN) + .setDistinctValuesCount(5) + .setNullsFraction(0) + .build()) + .addSymbolStatistics(new Symbol(BIGINT, "i2"), SymbolStatsEstimate.builder() + .setLowValue(0) + .setHighValue(3) + .setAverageRowSize(25) + .setDistinctValuesCount(4) + .setNullsFraction(0.5) + .build()) + .build()) + .withSourceStats(1, PlanNodeStatsEstimate.builder() + .setOutputRowCount(5) + .addSymbolStatistics(new Symbol(BIGINT, "i1"), SymbolStatsEstimate.builder() + .setLowValue(7) + .setHighValue(9) + .setAverageRowSize(3) + .setDistinctValuesCount(NaN) + .setNullsFraction(0.2) + .build()) + .addSymbolStatistics(new Symbol(BIGINT, "i2"), SymbolStatsEstimate.builder() + .setLowValue(-5) + .setHighValue(12) + .setAverageRowSize(NaN) + .setDistinctValuesCount(NaN) + .setNullsFraction(0.1) + .build()) + .build()) + .check(check -> check + .outputRowsCount(10) + .symbolStats("i1", BIGINT, assertion -> assertion + .lowValue(1) + .highValue(10) + .averageRowSize(NaN) + .distinctValuesCount(5) + .nullsFraction(0)) + .symbolStats("i2", BIGINT, assertion -> assertion + .lowValue(0) + .highValue(3) + .averageRowSize(25) + .distinctValuesCount(4) + .nullsFraction(0.5))); + } +} diff --git a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java index 237e26c16ae7..e6fa251d92c8 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java @@ -38,6 +38,7 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.EnforceSingleRowNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; @@ -65,6 +66,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; @@ -74,6 +76,7 @@ import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.TransactionBuilder.transaction; +import static java.lang.Double.NaN; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -209,6 +212,43 @@ public void testFilter() assertCostHasUnknownComponentsForUnknownStats(filter); } + @Test + public void testChooseAlternative() + { + ChooseAlternativeNode chooseAlternativeNode = new ChooseAlternativeNode( + new PlanNodeId("chooseAlternative"), + List.of( + new FilterNode(new PlanNodeId("alternative1"), tableScan("ts1", new Symbol(VARCHAR, "string")), TRUE), + tableScan("alternative2", new Symbol(VARCHAR, "string"))), + new ChooseAlternativeNode.FilteredTableScan(tableScan("ts_original", new Symbol(VARCHAR, "string")), Optional.empty())); + + Map costs = ImmutableMap.of( + "alternative1", new PlanCostEstimate(1000, 3000, NaN, 0), + "alternative2", new PlanCostEstimate(2000, 1500, 1000, NaN)); + Map stats = ImmutableMap.of( + "chooseAlternative", statsEstimate(chooseAlternativeNode, 5000)); + + assertCost(chooseAlternativeNode, costs, stats) + .cpu(1000) + .memory(3000) + .memoryWhenOutputting(NaN) + .network(0); + + assertCostEstimatedExchanges(chooseAlternativeNode, costs, stats) + .cpu(1000) + .memory(3000) + .memoryWhenOutputting(NaN) + .network(0); + + assertCostFragmentedPlan(chooseAlternativeNode, costs, stats) + .cpu(1000) + .memory(3000) + .memoryWhenOutputting(NaN) + .network(0); + + assertCostHasUnknownComponentsForUnknownStats(chooseAlternativeNode); + } + @Test public void testRepartitionedJoin() { diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 0878e53f7bb5..dcaec0de542b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -30,6 +30,7 @@ import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.cost.StatsAndCosts; import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; @@ -146,7 +147,8 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L partitionedSplitCountTracker, ImmutableSet.of(), Optional.empty(), - true); + true, + new SplitAdmissionControllerProvider(ImmutableList.of(), TEST_SESSION)); } @Override @@ -162,7 +164,8 @@ public MockRemoteTask createRemoteTask( PartitionedSplitCountTracker partitionedSplitCountTracker, Set outboundDynamicFilterIds, Optional estimatedMemory, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { return new MockRemoteTask(taskId, fragment, node.getNodeIdentifier(), executor, scheduledExecutor, initialSplits, partitionedSplitCountTracker); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index d3568e4d5432..365a6b64f3b7 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -17,8 +17,12 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.configuration.secrets.SecretsResolver; +import io.airlift.json.JsonCodecFactory; import io.airlift.json.ObjectMapperProvider; import io.opentelemetry.api.trace.Span; +import io.trino.cache.CacheConfig; +import io.trino.cache.CacheManagerRegistry; +import io.trino.cache.CacheStats; import io.trino.client.NodeVersion; import io.trino.connector.CatalogServiceProvider; import io.trino.cost.StatsAndCosts; @@ -31,6 +35,8 @@ import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.execution.scheduler.UniformNodeSelectorFactory; +import io.trino.memory.LocalMemoryManager; +import io.trino.memory.NodeMemoryConfig; import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.Split; import io.trino.operator.FlatHashStrategyCompiler; @@ -39,6 +45,7 @@ import io.trino.operator.index.IndexManager; import io.trino.server.protocol.spooling.QueryDataEncoders; import io.trino.server.protocol.spooling.SpoolingEnabledConfig; +import io.trino.spi.block.TestingBlockEncodingSerde; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.predicate.TupleDomain; import io.trino.spiller.GenericSpillerFactory; @@ -162,10 +169,14 @@ public static LocalExecutionPlanner createTestingPlanner() PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(PLANNER_CONTEXT.getFunctionManager(), 0); ColumnarFilterCompiler columnarFilterCompiler = new ColumnarFilterCompiler(PLANNER_CONTEXT.getFunctionManager(), 0); + CacheStats cacheStats = new CacheStats(); return new LocalExecutionPlanner( PLANNER_CONTEXT, Optional.empty(), pageSourceManager, + new CacheManagerRegistry(new CacheConfig(), new LocalMemoryManager(new NodeMemoryConfig()), new TestingBlockEncodingSerde(), cacheStats), + new JsonCodecFactory(new ObjectMapperProvider()).jsonCodec(TupleDomain.class), + cacheStats, new IndexManager(CatalogServiceProvider.fail()), nodePartitioningManager, new PageSinkManager(CatalogServiceProvider.fail()), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryStats.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryStats.java index c389d31bc680..2ba334366693 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryStats.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryStats.java @@ -45,6 +45,7 @@ public class TestQueryStats new OperatorStats( 10, 11, + 0, 12, new PlanNodeId("13"), TableWriterOperator.class.getSimpleName(), @@ -86,6 +87,7 @@ public class TestQueryStats new OperatorStats( 20, 21, + 0, 22, new PlanNodeId("23"), FilterAndProjectOperator.class.getSimpleName(), @@ -127,6 +129,7 @@ public class TestQueryStats new OperatorStats( 30, 31, + 0, 32, new PlanNodeId("33"), TableWriterOperator.class.getSimpleName(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index ea4715360693..4819ea1b54b3 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; import io.opentelemetry.api.trace.Span; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.client.NodeVersion; import io.trino.cost.StatsAndCosts; import io.trino.execution.buffer.PipelinedOutputBuffers; @@ -125,7 +126,8 @@ private void testFinalStageInfoInternal() executor, noopTracer(), Span.getInvalid(), - new SplitSchedulerStats()); + new SplitSchedulerStats(), + new SplitAdmissionControllerProvider(ImmutableList.of(), TEST_SESSION)); // add listener that fetches stage info when the final status is available SettableFuture finalStageInfo = SettableFuture.create(); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index 94b8c0c5dc67..a52f0919d043 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -42,8 +42,8 @@ import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.metadata.Split; import io.trino.operator.DriverContext; -import io.trino.operator.DriverFactory; import io.trino.operator.OperatorContext; +import io.trino.operator.OperatorDriverFactory; import io.trino.operator.SourceOperator; import io.trino.operator.SourceOperatorFactory; import io.trino.operator.TaskContext; @@ -129,7 +129,7 @@ public void testSimple() Function.identity(), new PagesSerdeFactory(new TestingBlockEncodingSerde(), NONE)); LocalExecutionPlan localExecutionPlan = new LocalExecutionPlan( - ImmutableList.of(new DriverFactory( + ImmutableList.of(new OperatorDriverFactory( 0, true, true, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java index b0f6114c08e4..da5d61ff5632 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java @@ -22,6 +22,8 @@ import io.opentelemetry.api.trace.Span; import io.trino.cost.StatsAndCosts; import io.trino.execution.scheduler.SplitSchedulerStats; +import io.trino.operator.DriverContext; +import io.trino.operator.OperatorStats; import io.trino.operator.PipelineContext; import io.trino.operator.TaskStats; import io.trino.operator.TestingOperatorContext; @@ -46,7 +48,7 @@ import java.sql.SQLException; import java.util.List; import java.util.Optional; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -59,7 +61,6 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; -import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -78,7 +79,7 @@ public class TestStageStateMachine FAILED_CAUSE.setStackTrace(new StackTraceElement[0]); } - private ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%s")); + private ScheduledExecutorService executor = Executors.newScheduledThreadPool(0, daemonThreadsNamed(getClass().getSimpleName() + "-%s")); @AfterAll public void tearDown() @@ -267,6 +268,89 @@ public void testGetBasicStageInfo() assertThat(stats.getSpilledDataSize()).isEqualTo(succinctBytes(0)); } + @Test + public void testAlternativeOperatorsNotMerged() + { + StageStateMachine stateMachine = createStageStateMachine(); + PipelineContext pipeline0Context = TestingOperatorContext.createDriverContext(executor).getPipelineContext(); + DriverContext alternative0DriverContext = pipeline0Context.addDriverContext(); + alternative0DriverContext.setAlternativeId(0); + alternative0DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + DriverContext alternative1DriverContext = pipeline0Context.addDriverContext(); + alternative1DriverContext.setAlternativeId(1); + alternative1DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + pipeline0Context.driverFinished(alternative0DriverContext); + pipeline0Context.driverFinished(alternative1DriverContext); + + PipelineContext pipeline1Context = TestingOperatorContext.createDriverContext(executor).getPipelineContext(); + DriverContext alternative10DriverContext = pipeline1Context.addDriverContext(); + alternative10DriverContext.setAlternativeId(0); + alternative10DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + DriverContext alternative11DriverContext = pipeline1Context.addDriverContext(); + alternative11DriverContext.setAlternativeId(1); + alternative11DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + pipeline1Context.driverFinished(alternative10DriverContext); + pipeline1Context.driverFinished(alternative11DriverContext); + + StageId stageId = new StageId(new QueryId("0"), 0); + List taskInfoList = ImmutableList.of( + TaskInfo.createInitialTask( + new TaskId(stageId, 0, 0), + URI.create(""), + "0", + false, + Optional.empty(), + taskStats(ImmutableList.of(pipeline0Context, pipeline1Context)))); + StageInfo stageInfo = stateMachine.getStageInfo(() -> taskInfoList); + + List operatorSummaries = stageInfo.getStageStats().getOperatorSummaries(); + assertThat(operatorSummaries).hasSize(2); + assertThat(operatorSummaries.get(0).getOperatorId()).isEqualTo(0); + assertThat(operatorSummaries.get(1).getOperatorId()).isEqualTo(0); + assertThat(operatorSummaries.get(0).getAlternativeId()).isNotEqualTo(operatorSummaries.get(1).getAlternativeId()); + } + + @Test + public void testOperatorsMerged() + { + StageStateMachine stateMachine = createStageStateMachine(); + PipelineContext pipeline0Context = TestingOperatorContext.createDriverContext(executor).getPipelineContext(); + DriverContext alternative0DriverContext = pipeline0Context.addDriverContext(); + alternative0DriverContext.setAlternativeId(2); + alternative0DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + DriverContext alternative1DriverContext = pipeline0Context.addDriverContext(); + alternative1DriverContext.setAlternativeId(2); + alternative1DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + pipeline0Context.driverFinished(alternative0DriverContext); + pipeline0Context.driverFinished(alternative1DriverContext); + + PipelineContext pipeline1Context = TestingOperatorContext.createDriverContext(executor).getPipelineContext(); + DriverContext alternative10DriverContext = pipeline1Context.addDriverContext(); + alternative10DriverContext.setAlternativeId(2); + alternative10DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + DriverContext alternative11DriverContext = pipeline1Context.addDriverContext(); + alternative11DriverContext.setAlternativeId(2); + alternative11DriverContext.addOperatorContext(0, new PlanNodeId("0"), "operator"); + pipeline1Context.driverFinished(alternative10DriverContext); + pipeline1Context.driverFinished(alternative11DriverContext); + + StageId stageId = new StageId(new QueryId("0"), 0); + List taskInfoList = ImmutableList.of( + TaskInfo.createInitialTask( + new TaskId(stageId, 0, 0), + URI.create(""), + "0", + false, + Optional.empty(), + taskStats(ImmutableList.of(pipeline0Context, pipeline1Context)))); + StageInfo stageInfo = stateMachine.getStageInfo(() -> taskInfoList); + + List operatorSummaries = stageInfo.getStageStats().getOperatorSummaries(); + assertThat(operatorSummaries).hasSize(1); + assertThat(operatorSummaries.get(0).getOperatorId()).isEqualTo(0); + assertThat(operatorSummaries.get(0).getAlternativeId()).isEqualTo(2); + } + private static TaskStats taskStats(List pipelineContexts) { return taskStats(pipelineContexts, 0); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java index 8f207aa261eb..0db7b18af2df 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskExecutorStuckSplits.java @@ -36,6 +36,7 @@ import io.trino.memory.LocalMemoryManager; import io.trino.memory.NodeMemoryConfig; import io.trino.metadata.WorkerLanguageFunctionProvider; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.catalog.CatalogProperties; import io.trino.spi.connector.CatalogHandle; import io.trino.spiller.LocalSpillManager; @@ -44,6 +45,7 @@ import org.junit.jupiter.api.Test; import java.util.List; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.ExecutionException; @@ -248,6 +250,12 @@ public ListenableFuture processFor(Duration duration) return immediateVoidFuture(); } + @Override + public Optional getCacheSplitId() + { + return Optional.empty(); + } + @Override public String getInfo() { diff --git a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java index 9f7e71fe9228..0b7d3bd85083 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java @@ -26,6 +26,7 @@ import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.BufferState; @@ -81,7 +82,8 @@ public synchronized RemoteTask createRemoteTask( PartitionedSplitCountTracker partitionedSplitCountTracker, Set outboundDynamicFilterIds, Optional estimatedMemory, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + SplitAdmissionControllerProvider splitAdmissionControllerProvider) { TestingRemoteTask task = new TestingRemoteTask(taskId, node.getNodeIdentifier(), fragment); task.addSplits(initialSplits); diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java index 5eded88984e8..f1219e3bc3f9 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java @@ -26,10 +26,12 @@ import io.trino.execution.TaskManagerConfig; import io.trino.execution.executor.TaskHandle; import io.trino.execution.executor.scheduler.FairScheduler; +import io.trino.spi.cache.CacheSplitId; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import java.util.List; +import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; @@ -237,6 +239,12 @@ public final ListenableFuture processFor(Duration duration) return blocked; } + @Override + public Optional getCacheSplitId() + { + return Optional.empty(); + } + @Override public final String getInfo() { diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/SimulationSplit.java b/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/SimulationSplit.java index 47b2870ffd57..3b80bfb44cd8 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/SimulationSplit.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/SimulationSplit.java @@ -18,7 +18,9 @@ import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; import io.trino.execution.SplitRunner; +import io.trino.spi.cache.CacheSplitId; +import java.util.Optional; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; @@ -149,6 +151,12 @@ public ListenableFuture processFor(Duration duration) return processResult; } + @Override + public Optional getCacheSplitId() + { + return Optional.empty(); + } + static class LeafSplit extends SimulationSplit { diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/TestTimeSharingTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/TestTimeSharingTaskExecutor.java index 3242857199ac..e602d8d52f3c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/TestTimeSharingTaskExecutor.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/timesharing/TestTimeSharingTaskExecutor.java @@ -26,12 +26,14 @@ import io.trino.execution.executor.TaskExecutor; import io.trino.execution.executor.TaskHandle; import io.trino.spi.QueryId; +import io.trino.spi.cache.CacheSplitId; import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.Future; import java.util.concurrent.Phaser; @@ -44,6 +46,8 @@ import static io.trino.execution.executor.timesharing.MultilevelSplitQueue.LEVEL_CONTRIBUTION_CAP; import static io.trino.execution.executor.timesharing.MultilevelSplitQueue.LEVEL_THRESHOLD_SECONDS; import static java.lang.Double.isNaN; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -72,9 +76,9 @@ public void testTasksComplete() verificationComplete.register(); // add two jobs - TestingJob driver1 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver1 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0, Optional.empty()); ListenableFuture future1 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver1))); - TestingJob driver2 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver2 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0, Optional.empty()); ListenableFuture future2 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver2))); assertThat(driver1.getCompletedPhases()).isEqualTo(0); assertThat(driver2.getCompletedPhases()).isEqualTo(0); @@ -100,7 +104,7 @@ public void testTasksComplete() verificationComplete.arriveAndAwaitAdvance(); // add one more job - TestingJob driver3 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver3 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0, Optional.empty()); ListenableFuture future3 = getOnlyElement(taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(driver3))); // advance one phase and verify @@ -166,8 +170,8 @@ public void testQuantaFairness() Phaser endQuantaPhaser = new Phaser(); - TestingJob shortQuantaDriver = new TestingJob(ticker, new Phaser(), new Phaser(), endQuantaPhaser, 10, 10); - TestingJob longQuantaDriver = new TestingJob(ticker, new Phaser(), new Phaser(), endQuantaPhaser, 10, 20); + TestingJob shortQuantaDriver = new TestingJob(ticker, new Phaser(), new Phaser(), endQuantaPhaser, 10, 10, Optional.empty()); + TestingJob longQuantaDriver = new TestingJob(ticker, new Phaser(), new Phaser(), endQuantaPhaser, 10, 20, Optional.empty()); taskExecutor.enqueueSplits(shortQuantaTaskHandle, true, ImmutableList.of(shortQuantaDriver)); taskExecutor.enqueueSplits(longQuantaTaskHandle, true, ImmutableList.of(longQuantaDriver)); @@ -203,8 +207,8 @@ public void testLevelMovement() int quantaTimeMills = 500; int phasesPerSecond = 1000 / quantaTimeMills; int totalPhases = LEVEL_THRESHOLD_SECONDS[LEVEL_THRESHOLD_SECONDS.length - 1] * phasesPerSecond; - TestingJob driver1 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), totalPhases, quantaTimeMills); - TestingJob driver2 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), totalPhases, quantaTimeMills); + TestingJob driver1 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), totalPhases, quantaTimeMills, Optional.empty()); + TestingJob driver2 = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), totalPhases, quantaTimeMills, Optional.empty()); taskExecutor.enqueueSplits(testTaskHandle, true, ImmutableList.of(driver1, driver2)); @@ -242,18 +246,18 @@ public void testLevelMultipliers() }; // move task 0 to next level - TestingJob task0Job = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i + 1] * 1000); + TestingJob task0Job = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i + 1] * 1000, Optional.empty()); taskExecutor.enqueueSplits( taskHandles[0], true, ImmutableList.of(task0Job)); // move tasks 1 and 2 to this level - TestingJob task1Job = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i] * 1000); + TestingJob task1Job = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i] * 1000, Optional.empty()); taskExecutor.enqueueSplits( taskHandles[1], true, ImmutableList.of(task1Job)); - TestingJob task2Job = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i] * 1000); + TestingJob task2Job = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, LEVEL_THRESHOLD_SECONDS[i] * 1000, Optional.empty()); taskExecutor.enqueueSplits( taskHandles[2], true, @@ -268,7 +272,7 @@ public void testLevelMultipliers() int phasesForNextLevel = LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]; TestingJob[] drivers = new TestingJob[6]; for (int j = 0; j < 6; j++) { - drivers[j] = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), phasesForNextLevel, 1000); + drivers[j] = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), phasesForNextLevel, 1000, Optional.empty()); } taskExecutor.enqueueSplits(taskHandles[0], true, ImmutableList.of(drivers[0], drivers[1])); @@ -319,8 +323,8 @@ public void testTaskHandle() Phaser verificationComplete = new Phaser(); verificationComplete.register(); - TestingJob driver1 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); - TestingJob driver2 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver1 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0, Optional.empty()); + TestingJob driver2 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0, Optional.empty()); // force enqueue a split taskExecutor.enqueueSplits(taskHandle, true, ImmutableList.of(driver1)); @@ -396,8 +400,8 @@ public void testMinMaxDriversPerTask() for (int batch = 0; batch < batchCount; batch++) { phasers[batch] = new Phaser(); phasers[batch].register(); - TestingJob split1 = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0); - TestingJob split2 = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0); + TestingJob split1 = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0, Optional.empty()); + TestingJob split2 = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0, Optional.empty()); splits[2 * batch] = split1; splits[2 * batch + 1] = split2; taskExecutor.enqueueSplits(testTaskHandle, false, ImmutableList.of(split1, split2)); @@ -439,7 +443,7 @@ public void testUserSpecifiedMaxDriversPerTask() for (int batch = 0; batch < batchCount; batch++) { phasers[batch] = new Phaser(); phasers[batch].register(); - TestingJob split = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0); + TestingJob split = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0, Optional.empty()); splits[batch] = split; taskExecutor.enqueueSplits(testTaskHandle, false, ImmutableList.of(split)); } @@ -484,7 +488,7 @@ public void testMinDriversPerTaskWhenTargetConcurrencyIncreases() for (int batch = 0; batch < batchCount; batch++) { phasers[batch] = new Phaser(); phasers[batch].register(); - TestingJob split = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0); + TestingJob split = new TestingJob(ticker, new Phaser(), new Phaser(), phasers[batch], 1, 0, Optional.empty()); splits[batch] = split; } @@ -513,8 +517,8 @@ public void testLeafSplitsSize() TimeSharingTaskExecutor taskExecutor = new TimeSharingTaskExecutor(4, 1, 2, 2, splitQueue, ticker); TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); - TestingJob driver1 = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, 500); - TestingJob driver2 = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, 1000 / 500); + TestingJob driver1 = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, 500, Optional.empty()); + TestingJob driver2 = new TestingJob(ticker, new Phaser(), new Phaser(), new Phaser(), 1, 1000 / 500, Optional.empty()); ticker.increment(0, TimeUnit.SECONDS); taskExecutor.enqueueSplits(testTaskHandle, false, ImmutableList.of(driver1, driver2)); @@ -529,6 +533,103 @@ public void testLeafSplitsSize() assertThat(taskExecutor.getLeafSplitsSize().getAllTime().getMax()).isEqualTo(2.0); } + @Test + public void testSplitsWithSameCacheSplitId() + { + MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); + TestingTicker ticker = new TestingTicker(); + TimeSharingTaskExecutor taskExecutor = new TimeSharingTaskExecutor(6, 6, 1, 4, splitQueue, ticker); + taskExecutor.start(); + + try { + TaskHandle taskA = taskExecutor.addTask(new TaskId(new StageId("test1", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TestingJob[] taskADrivers = new TestingJob[6]; + Phaser[] taskAPhasers = new Phaser[6]; + for (int i = 0; i < 6; i++) { + taskAPhasers[i] = new Phaser(); + taskAPhasers[i].register(); + // Create a driver with a unique cache split id + TestingJob driver = new TestingJob(ticker, new Phaser(), new Phaser(), taskAPhasers[i], 1, 1000, Optional.of(new CacheSplitId(format("cache-split-id-%d", i)))); + taskADrivers[i] = driver; + } + + TaskHandle taskB = taskExecutor.addTask(new TaskId(new StageId("test2", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TestingJob[] taskBDrivers = new TestingJob[6]; + Phaser[] taskBPhasers = new Phaser[6]; + for (int i = 0; i < 6; i++) { + taskBPhasers[i] = new Phaser(); + taskBPhasers[i].register(); + // Create a driver with a unique cache split id + TestingJob driver = new TestingJob(ticker, new Phaser(), new Phaser(), taskBPhasers[i], 1, 1000, Optional.of(new CacheSplitId(format("cache-split-id-%d", i)))); + taskBDrivers[i] = driver; + } + + taskExecutor.enqueueSplits(taskA, false, ImmutableList.copyOf(taskADrivers)); + taskExecutor.enqueueSplits(taskB, false, ImmutableList.copyOf(taskBDrivers)); + + waitUntilSplitsStart(ImmutableList.of( + taskADrivers[0], + taskADrivers[1], + taskADrivers[2], + taskADrivers[3], + // Since we are below the min drivers limit, we will start the first 2 splits of taskB even + // though the splits with same cache split id from taskA are not finished yet. This is to + // ensure progress is made on all tasks + taskBDrivers[0], + taskBDrivers[1])); + // Verify that the first 4 splits are running and the last 2 splits are not running from taskA + assertSplitStates(3, taskADrivers); + // Verify that the first 2 splits are running and the last 4 splits are not running from taskB + assertSplitStates(1, taskBDrivers); + + // Finish the first two taskHandleB splits + for (int i = 0; i < 2; i++) { + taskBPhasers[i].arriveAndDeregister(); + } + + // Now, splits with cache-id-4 and cache-id-5 from taskB should start instead of splits + // with cache-id-2 and cache-id-3 since they are already running from taskA + waitUntilSplitsStart(ImmutableList.of(taskBDrivers[4], taskBDrivers[5])); + // Verify that the first 4 splits are running and the last 2 splits are not running from taskA + assertSplitStates(3, taskADrivers); + // Verify that the last 2 splits are running and the middle 2 splits are not running from taskB + for (int i = 4; i < 6; i++) { + assertThat(taskBDrivers[i].isStarted()).isTrue(); + } + for (int i = 2; i < 4; i++) { + assertThat(taskBDrivers[i].isStarted()).isFalse(); + } + + // Finish two splits from taskA + for (int i = 0; i < 2; i++) { + taskAPhasers[i].arriveAndDeregister(); + } + waitUntilSplitsStart(ImmutableList.of(taskADrivers[4], taskADrivers[5])); + // Verify that all splits from taskA are running + assertSplitStates(5, taskADrivers); + // Verify that the middle 2 splits are still not running from taskB + for (int i = 2; i < 4; i++) { + assertThat(taskBDrivers[i].isStarted()).isFalse(); + } + + // Finish the remaining splits from taskA + for (int i = 2; i < 6; i++) { + taskAPhasers[i].arriveAndDeregister(); + } + waitUntilSplitsStart(ImmutableList.of(taskBDrivers[2], taskBDrivers[3])); + // Verify that all splits from taskB are running + assertSplitStates(5, taskBDrivers); + + // Finish the remaining splits from taskB + for (int i = 2; i < 6; i++) { + taskBPhasers[i].arriveAndDeregister(); + } + } + finally { + taskExecutor.stop(); + } + } + private void assertSplitStates(int endIndex, TestingJob[] splits) { // assert that splits up to and including endIndex are all started @@ -564,6 +665,7 @@ private static class TestingJob private final Phaser endQuantaPhaser; private final int requiredPhases; private final int quantaTimeMillis; + private final Optional cacheSplitId; private final AtomicInteger completedPhases = new AtomicInteger(); private final AtomicInteger firstPhase = new AtomicInteger(-1); @@ -572,7 +674,14 @@ private static class TestingJob private final AtomicBoolean started = new AtomicBoolean(); private final SettableFuture completed = SettableFuture.create(); - public TestingJob(TestingTicker ticker, Phaser globalPhaser, Phaser beginQuantaPhaser, Phaser endQuantaPhaser, int requiredPhases, int quantaTimeMillis) + public TestingJob( + TestingTicker ticker, + Phaser globalPhaser, + Phaser beginQuantaPhaser, + Phaser endQuantaPhaser, + int requiredPhases, + int quantaTimeMillis, + Optional cacheSplitId) { this.ticker = ticker; this.globalPhaser = globalPhaser; @@ -580,6 +689,7 @@ public TestingJob(TestingTicker ticker, Phaser globalPhaser, Phaser beginQuantaP this.endQuantaPhaser = endQuantaPhaser; this.requiredPhases = requiredPhases; this.quantaTimeMillis = quantaTimeMillis; + this.cacheSplitId = requireNonNull(cacheSplitId, "cacheSplitId is null"); beginQuantaPhaser.register(); endQuantaPhaser.register(); @@ -624,6 +734,12 @@ public ListenableFuture processFor(Duration duration) return immediateVoidFuture(); } + @Override + public Optional getCacheSplitId() + { + return cacheSplitId; + } + @Override public String getInfo() { diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java index e651317bab6c..791bbf07569b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java @@ -19,6 +19,7 @@ import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.client.NodeVersion; import io.trino.cost.StatsAndCosts; import io.trino.execution.DynamicFilterConfig; @@ -629,7 +630,8 @@ TABLE_SCAN_2_NODE_ID, new TableInfo(Optional.of("test"), new QualifiedObjectName queryExecutor, noopTracer(), Span.getInvalid(), - new SplitSchedulerStats()); + new SplitSchedulerStats(), + new SplitAdmissionControllerProvider(ImmutableList.of(), TEST_SESSION)); ImmutableMap.Builder outputBuffers = ImmutableMap.builder(); outputBuffers.put(fragment.getId(), new PartitionedPipelinedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 1)); fragment.getRemoteSourceNodes().stream() diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index df69d4813ed1..688b002df8ac 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -19,6 +19,7 @@ import io.airlift.units.Duration; import io.opentelemetry.api.trace.Span; import io.trino.Session; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.client.NodeVersion; import io.trino.cost.StatsAndCosts; import io.trino.execution.DynamicFilterConfig; @@ -765,7 +766,8 @@ private StageExecution createStageExecution(PlanFragment fragment, NodeTaskMap n queryExecutor, noopTracer(), Span.getInvalid(), - new SplitSchedulerStats()); + new SplitSchedulerStats(), + new SplitAdmissionControllerProvider(ImmutableList.of(), TEST_SESSION)); ImmutableMap.Builder outputBuffers = ImmutableMap.builder(); outputBuffers.put(fragment.getId(), new PartitionedPipelinedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 1)); fragment.getRemoteSourceNodes().stream() diff --git a/core/trino-main/src/test/java/io/trino/operator/TestOperatorStats.java b/core/trino-main/src/test/java/io/trino/operator/TestOperatorStats.java index d16cee4e9c96..ec472d98aaa0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestOperatorStats.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestOperatorStats.java @@ -39,6 +39,7 @@ public class TestOperatorStats public static final OperatorStats EXPECTED = new OperatorStats( 0, 1, + 51, 41, new PlanNodeId("test"), "test", @@ -88,6 +89,7 @@ public class TestOperatorStats public static final OperatorStats MERGEABLE = new OperatorStats( 0, 1, + 0, 41, new PlanNodeId("test"), "test", diff --git a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java index 0cf131ef186d..42dd583def1b 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java +++ b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java @@ -18,12 +18,16 @@ import com.google.common.collect.ImmutableSet; import io.airlift.units.DataSize; import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation; import io.trino.cost.StatsAndCosts; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.StageId; import io.trino.execution.TaskId; +import io.trino.metadata.Metadata; import io.trino.operator.RetryPolicy; import io.trino.spi.QueryId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.TestingColumnHandle; @@ -33,18 +37,22 @@ import io.trino.sql.DynamicFilters; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; +import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.RemoteSourceNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.testing.TestingMetadata; @@ -68,6 +76,7 @@ import static io.trino.metadata.TestMetadataManager.createTestMetadataManager; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; +import static io.trino.server.DynamicFilterService.getCacheDynamicFilters; import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters; import static io.trino.server.DynamicFilterService.getSourceStageInnerLazyDynamicFilters; import static io.trino.spi.predicate.Domain.multipleValues; @@ -78,6 +87,8 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; +import static io.trino.sql.ir.IrUtils.and; +import static io.trino.sql.ir.IrUtils.or; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; @@ -1003,6 +1014,26 @@ public void testMultipleTaskAttempts() getSimplifiedDomainString(1L, 6L, 3, INTEGER)))); } + @Test + public void testCacheDynamicFilters() + { + Metadata metadata = createTestMetadataManager(); + assertThat(getCacheDynamicFilters( + new ProjectNode(new PlanNodeId("0"), + new LoadCachedDataPlanNode( + new PlanNodeId("1"), + new CommonPlanAdaptation.PlanSignatureWithPredicate( + new PlanSignature(new SignatureKey("test"), Optional.empty(), ImmutableList.of(), ImmutableList.of()), + TupleDomain.all()), + or(and(createDynamicFilterExpression(metadata, new DynamicFilterId("0"), BIGINT, new Reference(BIGINT, "symbol0")), + createDynamicFilterExpression(metadata, new DynamicFilterId("1"), BIGINT, new Reference(BIGINT, "symbol1"))), + createDynamicFilterExpression(metadata, new DynamicFilterId("2"), BIGINT, new Reference(BIGINT, "symbol2"))), + ImmutableMap.of(), + ImmutableList.of()), + Assignments.of()))) + .isEqualTo(ImmutableSet.of(new DynamicFilterId("0"), new DynamicFilterId("1"), new DynamicFilterId("2"))); + } + private static DynamicFilterService createDynamicFilterService() { return new DynamicFilterService( diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index 0b63590dffd3..fd061465f852 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -34,6 +34,7 @@ import io.opentelemetry.api.trace.Span; import io.trino.Session; import io.trino.block.BlockJsonSerde; +import io.trino.cache.SplitAdmissionControllerProvider; import io.trino.client.NodeVersion; import io.trino.execution.BaseTestSqlTaskManager; import io.trino.execution.DynamicFilterConfig; @@ -636,7 +637,8 @@ private HttpRemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFact new NodeTaskMap.PartitionedSplitCountTracker(i -> {}), outboundDynamicFilterIds, Optional.empty(), - true); + true, + new SplitAdmissionControllerProvider(ImmutableList.of(), session)); } private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskResource testingTaskResource) diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 448479d97a6d..c7d540eb6cc1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -23,6 +23,7 @@ import io.trino.FeaturesConfig; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.cache.CacheConfig; import io.trino.client.NodeVersion; import io.trino.connector.CatalogServiceProvider; import io.trino.connector.MockConnectorFactory; @@ -1146,6 +1147,7 @@ public void testTooManyGroupingElements() new OptimizerConfig(), new NodeMemoryConfig(), new DynamicFilterConfig(), + new CacheConfig(), new NodeSchedulerConfig()))).build(); analyze(session, "SELECT a, b, c, d, e, f, g, h, i, j, k, SUM(l)" + "FROM (VALUES (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12))\n" + diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java index 140744adf5d9..806140cf97da 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableSet; import io.trino.FeaturesConfig; +import io.trino.cache.CacheMetadata; import io.trino.connector.CatalogServiceProvider; import io.trino.metadata.BlockEncodingManager; import io.trino.metadata.FunctionBundle; @@ -46,6 +47,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkState; import static io.airlift.tracing.Tracing.noopTracer; @@ -66,6 +68,7 @@ public static Builder plannerContextBuilder() public static class Builder { private Metadata metadata; + private CacheMetadata cacheMetadata; private TransactionManager transactionManager; private final List types = new ArrayList<>(); private final List parametricTypes = new ArrayList<>(); @@ -81,6 +84,13 @@ public Builder withMetadata(Metadata metadata) return this; } + public Builder withCacheMetadata(CacheMetadata cacheMetadata) + { + checkState(this.cacheMetadata == null, "cacheMetadata already set"); + this.cacheMetadata = requireNonNull(cacheMetadata, "cacheMetadata is null"); + return this; + } + public Builder withTransactionManager(TransactionManager transactionManager) { checkState(this.metadata == null, "metadata already set"); @@ -147,8 +157,13 @@ public PlannerContext build() new JsonQueryFunction(functionManager, metadata, typeManager))); typeRegistry.addType(new JsonPath2016Type(new TypeDeserializer(typeManager), blockEncodingSerde)); + CacheMetadata cacheMetadata = this.cacheMetadata; + if (cacheMetadata == null) { + cacheMetadata = new CacheMetadata(catalogHandle -> Optional.empty()); + } return new PlannerContext( metadata, + cacheMetadata, typeOperators, blockEncodingSerde, typeManager, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 5898f558c393..58db9d4b0a8c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -18,14 +18,17 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import io.trino.Session; +import io.trino.cache.CommonPlanAdaptation; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.cost.StatsProvider; import io.trino.metadata.Metadata; +import io.trino.spi.cache.CacheColumnId; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.SortOrder; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.DynamicFilters; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Row; @@ -38,6 +41,8 @@ import io.trino.sql.planner.plan.AggregationNode.Step; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; +import io.trino.sql.planner.plan.CacheDataPlanNode; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DistinctLimitNode; @@ -52,6 +57,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.JoinType; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.MergeWriterNode; import io.trino.sql.planner.plan.OffsetNode; @@ -91,7 +97,9 @@ import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_LAST; +import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.ir.Comparison.Operator.IDENTICAL; +import static io.trino.sql.ir.IrUtils.extractDisjuncts; import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; import static io.trino.sql.planner.assertions.MatchResult.match; import static io.trino.sql.planner.assertions.StrictAssignedSymbolsMatcher.actualAssignments; @@ -138,6 +146,57 @@ public static PlanMatchPattern anyNot(Class excludeNodeClass return any(sources).with(new NotPlanNodeMatcher(excludeNodeClass)); } + public static PlanMatchPattern chooseAlternativeNode(PlanMatchPattern... sources) + { + return node(ChooseAlternativeNode.class, sources); + } + + public static PlanMatchPattern cacheDataPlanNode(PlanMatchPattern source) + { + return node(CacheDataPlanNode.class, source); + } + + public static PlanMatchPattern loadCachedDataPlanNode(CommonPlanAdaptation.PlanSignatureWithPredicate signature, String... outputSymbolAliases) + { + return loadCachedDataPlanNode(signature, Optional.empty(), dynamicFilters -> true, outputSymbolAliases); + } + + public static PlanMatchPattern loadCachedDataPlanNode(CommonPlanAdaptation.PlanSignatureWithPredicate signature, Map commonColumnHandles, String... outputSymbolAliases) + { + return loadCachedDataPlanNode(signature, Optional.of(commonColumnHandles), dynamicFilters -> true, outputSymbolAliases); + } + + public static PlanMatchPattern loadCachedDataPlanNode(CommonPlanAdaptation.PlanSignatureWithPredicate signature, Predicate>> dynamicFiltersPredicate, String... outputSymbolAliases) + { + return loadCachedDataPlanNode(signature, Optional.empty(), dynamicFiltersPredicate, outputSymbolAliases); + } + + public static PlanMatchPattern loadCachedDataPlanNode(CommonPlanAdaptation.PlanSignatureWithPredicate signature, Map commonColumnHandles, Predicate>> dynamicFiltersPredicate, String... outputSymbolAliases) + { + return loadCachedDataPlanNode(signature, Optional.of(commonColumnHandles), dynamicFiltersPredicate, outputSymbolAliases); + } + + private static PlanMatchPattern loadCachedDataPlanNode(CommonPlanAdaptation.PlanSignatureWithPredicate signature, Optional> commonColumnHandles, Predicate>> dynamicFiltersPredicate, String... outputSymbolAliases) + { + PlanMatchPattern result = node(LoadCachedDataPlanNode.class); + for (int i = 0; i < outputSymbolAliases.length; i++) { + String outputSymbol = outputSymbolAliases[i]; + int index = i; + result.withAlias(outputSymbol, (node, session, metadata, symbolAliases) -> { + List outputSymbols = node.getOutputSymbols(); + checkState(index < outputSymbols.size(), "outputSymbolAliases size is more than LoadCachedDataPlanNode output symbols"); + return Optional.ofNullable(outputSymbols.get(index)); + }); + } + result.with(LoadCachedDataPlanNode.class, node -> node.getPlanSignature().equals(signature)); + result.with(LoadCachedDataPlanNode.class, node -> commonColumnHandles.map(handles -> node.getCommonColumnHandles().equals(handles)).orElse(true)); + result.with(LoadCachedDataPlanNode.class, node -> dynamicFiltersPredicate.test( + extractDisjuncts(node.getDynamicFilterDisjuncts()).stream() + .map(expression -> extractDynamicFilters(expression).getDynamicConjuncts()) + .collect(toImmutableList()))); + return result; + } + public static PlanMatchPattern adaptivePlan(PlanMatchPattern initialPlan, PlanMatchPattern currentPlan) { return node(AdaptivePlanNode.class, currentPlan).with(new AdaptivePlanMatcher(initialPlan)); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 55dbd99998b7..ca8ed3004610 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -59,6 +59,7 @@ import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.AssignUniqueId; import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.ChooseAlternativeNode; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.DistinctLimitNode; @@ -1225,6 +1226,11 @@ public ExceptNode except(ListMultimap outputsToInputs, List alternatives, ChooseAlternativeNode.FilteredTableScan originalTableScan) + { + return new ChooseAlternativeNode(idAllocator.getNextId(), alternatives, originalTableScan); + } + public TableWriterNode tableWriter(List columns, List columnNames, PlanNode source) { return tableWriter(columns, columnNames, Optional.empty(), Optional.empty(), Optional.empty(), source); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java index 9cc20709a75c..e9a408bedc17 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestAnonymizeJsonRepresentation.java @@ -249,7 +249,8 @@ private void assertAnonymizedRepresentation(Function sour valuePrinter, StatsAndCosts.empty(), Optional.empty(), - new CounterBasedAnonymizer()) + new CounterBasedAnonymizer(), + false) .toJson(); assertThat(jsonRenderedNode).isEqualTo(JSON_RENDERED_NODE_CODEC.toJson(expectedRepresentation)); return null; diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java index b0b518c1399d..3df2d760386e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/planprinter/TestJsonRepresentation.java @@ -225,7 +225,8 @@ private void assertJsonRepresentation(Function sourceNode valuePrinter, StatsAndCosts.empty(), Optional.empty(), - new NoOpAnonymizer()) + new NoOpAnonymizer(), + false) .toJson(); assertThat(jsonRenderedNode).isEqualTo(JSON_RENDERED_NODE_CODEC.toJson(expectedRepresentation)); return null; diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestFilterHideInacessibleColumnsSession.java b/core/trino-main/src/test/java/io/trino/sql/query/TestFilterHideInacessibleColumnsSession.java index 681ead0a02c2..0ac7e5e6cf45 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestFilterHideInacessibleColumnsSession.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestFilterHideInacessibleColumnsSession.java @@ -15,6 +15,7 @@ import io.trino.FeaturesConfig; import io.trino.SystemSessionProperties; +import io.trino.cache.CacheConfig; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.QueryManagerConfig; import io.trino.execution.TaskManagerConfig; @@ -74,6 +75,7 @@ private SessionPropertyManager createSessionPropertyManager(FeaturesConfig featu new OptimizerConfig(), new NodeMemoryConfig(), new DynamicFilterConfig(), + new CacheConfig(), new NodeSchedulerConfig())); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/Plugin.java b/core/trino-spi/src/main/java/io/trino/spi/Plugin.java index cf6c9dbe84cd..a51b786b143f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/Plugin.java +++ b/core/trino-spi/src/main/java/io/trino/spi/Plugin.java @@ -14,6 +14,7 @@ package io.trino.spi; import io.trino.spi.block.BlockEncoding; +import io.trino.spi.cache.CacheManagerFactory; import io.trino.spi.catalog.CatalogStoreFactory; import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.eventlistener.EventListenerFactory; @@ -115,4 +116,9 @@ default Iterable getSpoolingManagerFactories() { return emptyList(); } + + default Iterable getCacheManagerFactories() + { + return emptyList(); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index f4397d4d013d..28ed71276b44 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -186,6 +186,7 @@ public enum StandardErrorCode EXCHANGE_MANAGER_NOT_CONFIGURED(65564, INTERNAL_ERROR), CATALOG_NOT_AVAILABLE(65565, INTERNAL_ERROR), CATALOG_STORE_ERROR(65566, INTERNAL_ERROR), + CACHE_MANAGER_NOT_CONFIGURED(65567, INTERNAL_ERROR), GENERIC_INSUFFICIENT_RESOURCES(131072, INSUFFICIENT_RESOURCES), EXCEEDED_GLOBAL_MEMORY_LIMIT(131073, INSUFFICIENT_RESOURCES), diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/CacheColumnId.java b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheColumnId.java new file mode 100644 index 000000000000..5fca0fc27a03 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheColumnId.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.instanceSize; +import static java.util.Objects.requireNonNull; + +public class CacheColumnId +{ + private static final int INSTANCE_SIZE = instanceSize(CacheColumnId.class); + + private final String id; + + @JsonCreator + public CacheColumnId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + @JsonValue + public String toString() + { + return id; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CacheColumnId that = (CacheColumnId) o; + return id.equals(that.id); + } + + @Override + public int hashCode() + { + return id.hashCode(); + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + estimatedSizeOf(id); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManager.java b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManager.java new file mode 100644 index 000000000000..a778818a4b5e --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManager.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.predicate.TupleDomain; + +import java.io.Closeable; +import java.util.Optional; + +public interface CacheManager +{ + /** + * @return {@link SplitCache} for a given {@link PlanSignature}. + * Matching of {@link PlanSignature} per split could be expensive, + * therefore {@link SplitCache} is used to load or store data per split. + */ + SplitCache getSplitCache(PlanSignature signature); + + /** + * Triggers a memory revoke. {@link CacheManager} should revoke + * at least {@code bytesToRevoke} bytes (if it has allocated + * that much revocable memory) before allocating new memory. + * + * @return the number of revoked bytes + */ + long revokeMemory(long bytesToRevoke); + + interface SplitCache + extends Closeable + { + /** + * @param predicate Predicate that should be enforced on cached rows. + * Output of {@code cachedSplitA} can be used to derive output of matching {@code cachedSplitB} + * (with corresponding {@link PlanSignature}) as long as {@code cachedSplitB.predicate} is a strict + * subset of {@code cachedSplitA.predicate}. To do so, {@code cachedSplitB.predicate} must be + * applied on output of {@code cachedSplitA}. Before serialization as a cache key, predicate + * needs to be normalized using {@code io.trino.plugin.base.cache.CacheUtils#normalizeTupleDomain(TupleDomain)}. + * @param unenforcedPredicate Unenforced (best-effort) predicate that should be applied on cached rows. + * Output of {@code cachedSplitA} can be used to derive output of matching {@code cachedSplitB} + * (with corresponding {@link PlanSignature}) as long as {@code cachedSplitB.unenforcedPredicate} + * is a subset of {@code cachedSplitA.unenforcedPredicate}. Before serialization as a cache key, predicate + * needs to be normalized by {@link CacheManager} using {@code io.trino.plugin.base.cache.CacheUtils#normalizeTupleDomain(TupleDomain)}. + * @return cached pages for a given split. + */ + Optional loadPages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate); + + /** + * @param predicate Predicate that was enforced on cached rows. + * @param unenforcedPredicate Best-effort predicate that was applied on cached rows. + * @return {@link ConnectorPageSink} for caching pages for a given split. + * Might be empty if there isn't sufficient memory or split data is + * already cached. + */ + Optional storePages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManagerContext.java b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManagerContext.java new file mode 100644 index 000000000000..40d569cbda9d --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManagerContext.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import io.trino.spi.block.BlockEncodingSerde; + +public interface CacheManagerContext +{ + /** + * @return {@link MemoryAllocator} that {@link CacheManager} can use to allocate revocable memory + * from the engine. + */ + MemoryAllocator revocableMemoryAllocator(); + + /** + * @return {@link BlockEncodingSerde} that {@link CacheManager} can use to compress cached data. + */ + BlockEncodingSerde blockEncodingSerde(); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManagerFactory.java b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManagerFactory.java new file mode 100644 index 000000000000..01307d37278b --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheManagerFactory.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import java.util.Map; + +public interface CacheManagerFactory +{ + String getName(); + + CacheManager create(Map config, CacheManagerContext context); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/CacheSplitId.java b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheSplitId.java new file mode 100644 index 000000000000..427ab3be70a4 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheSplitId.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.instanceSize; +import static java.util.Objects.requireNonNull; + +public class CacheSplitId +{ + private static final int INSTANCE_SIZE = instanceSize(CacheSplitId.class); + + private final String id; + + @JsonCreator + public CacheSplitId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + @JsonValue + public String toString() + { + return id; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CacheSplitId that = (CacheSplitId) o; + return id.equals(that.id); + } + + @Override + public int hashCode() + { + return id.hashCode(); + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + estimatedSizeOf(id); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/CacheTableId.java b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheTableId.java new file mode 100644 index 000000000000..753c5af4cc98 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/CacheTableId.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import static java.util.Objects.requireNonNull; + +public class CacheTableId +{ + private final String id; + + @JsonCreator + public CacheTableId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + @JsonValue + public String toString() + { + return id; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CacheTableId that = (CacheTableId) o; + return id.equals(that.id); + } + + @Override + public int hashCode() + { + return id.hashCode(); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/ConnectorCacheMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/cache/ConnectorCacheMetadata.java new file mode 100644 index 000000000000..e374e01c38b2 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/ConnectorCacheMetadata.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.ConnectorTableHandle; + +import java.util.Optional; + +public interface ConnectorCacheMetadata +{ + /** + * Returns a table identifier for the purpose of caching with {@link CacheManager}. + * {@link CacheTableId} together with {@link CacheSplitId} and {@link CacheColumnId}s represents + * rows produced by {@link ConnectorPageSource} for a given split. Local table properties + * (e.g. rows order) must be part of {@link CacheTableId} if they are present. {@link CacheTableId} + * shouldn't contain elements that are specific to splits or columns. For example, partition predicate + * can be implicitly derived from list of enumerated splits. Hence, partition predicate shouldn't + * be part of {@link CacheTableId}. + */ + Optional getCacheTableId(ConnectorTableHandle tableHandle); + + /** + * Returns a column identifier for the purpose of caching with {@link CacheManager}. + * {@link CacheTableId} together with {@link CacheSplitId} and {@link CacheColumnId}s represents + * rows produced by {@link ConnectorPageSource} for a given split. {@link CacheColumnId} can represent + * simple, base column or more complex reference (e.g. map or array dereference expressions). + */ + Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle); + + /** + * Returns a canonical {@link ConnectorTableHandle}. + * If any property of {@link ConnectorTableHandle} affects final query result when underlying table + * is queried, then such property is considered canonical. Otherwise, the property is non-canonical. + * Canonical {@link ConnectorTableHandle}s allow to match more similar subqueries that + * are eligible for caching with {@link CacheManager}. Connector should convert provided + * {@link ConnectorTableHandle} into canonical one by pruning of every non-canonical field. + */ + ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle handle); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/MemoryAllocator.java b/core/trino-spi/src/main/java/io/trino/spi/cache/MemoryAllocator.java new file mode 100644 index 000000000000..0f69e7f9ffc6 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/MemoryAllocator.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +public interface MemoryAllocator +{ + /** + * @return true if the bytes tracked by this {@link MemoryAllocator} can be set to {@code bytes}. + * This method can return false when there is not enough memory available to satisfy a positive delta allocation. + */ + boolean trySetBytes(long bytes); +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/PlanSignature.java b/core/trino-spi/src/main/java/io/trino/spi/cache/PlanSignature.java new file mode 100644 index 000000000000..6be3dc1e2ceb --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/PlanSignature.java @@ -0,0 +1,154 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.StringJoiner; + +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * Plan signature is a normalized and canonicalized representation of subplan. + * Plan signatures allow to identify, match and adapt similar subqueries. + * Concept of plan signatures is described in http://www.cs.columbia.edu/~jrzhou/pub/cse.pdf + */ +public class PlanSignature +{ + private static final int INSTANCE_SIZE = instanceSize(PlanSignature.class); + + /** + * Key of a plan signature. Plans that can be potentially adapted + * to produce the same results (e.g. using column pruning, filtering or aggregation) + * will share the same key. + */ + private final SignatureKey key; + /** + * List of group by columns if plan signature represents aggregation. + */ + private final Optional> groupByColumns; + /** + * List of output columns. + */ + private final List columns; + /** + * List of output columns types parallel to {@link PlanSignature#columns}. + */ + private final List columnsTypes; + + private volatile int hashCode; + + @JsonCreator + public PlanSignature( + SignatureKey key, + Optional> groupByColumns, + List columns, + List columnsTypes) + { + this.key = requireNonNull(key, "key is null"); + this.groupByColumns = requireNonNull(groupByColumns, "groupByColumns is null").map(List::copyOf); + this.columns = List.copyOf(requireNonNull(columns, "columns is null")); + this.columnsTypes = requireNonNull(columnsTypes, "columns types is null"); + if (columns.size() != columnsTypes.size()) { + throw new IllegalArgumentException(format("Column list has different length (%s) from type list (%s)", columns.size(), columnsTypes.size())); + } + } + + @JsonProperty + public SignatureKey getKey() + { + return key; + } + + @JsonProperty + public Optional> getGroupByColumns() + { + return groupByColumns; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + + @JsonProperty + public List getColumnsTypes() + { + return columnsTypes; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PlanSignature signature = (PlanSignature) o; + return key.equals(signature.key) + && groupByColumns.equals(signature.groupByColumns) + && columns.equals(signature.columns) + && columnsTypes.equals(signature.columnsTypes); + } + + @Override + public int hashCode() + { + if (hashCode == 0) { + hashCode = Objects.hash(key, groupByColumns, columns, columnsTypes); + } + return hashCode; + } + + @Override + public String toString() + { + return new StringJoiner(", ", PlanSignature.class.getSimpleName() + "[", "]") + .add("key=" + key) + .add("groupByColumns=" + groupByColumns) + .add("columns=" + columns) + .add("columnTypes=" + columnsTypes) + .toString(); + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + + key.getRetainedSizeInBytes() + + sizeOf(groupByColumns, cols -> estimatedSizeOf(cols, CacheColumnId::getRetainedSizeInBytes)) + + estimatedSizeOf(columns, CacheColumnId::getRetainedSizeInBytes); + } + + public static PlanSignature canonicalizePlanSignature(PlanSignature signature) + { + return new PlanSignature( + signature.getKey(), + signature.getGroupByColumns(), + // columns are stored independently + List.of(), + List.of()); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/cache/SignatureKey.java b/core/trino-spi/src/main/java/io/trino/spi/cache/SignatureKey.java new file mode 100644 index 000000000000..95bbf9d009d1 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/cache/SignatureKey.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.cache; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.instanceSize; +import static java.util.Objects.requireNonNull; + +public class SignatureKey +{ + private static final int INSTANCE_SIZE = instanceSize(SignatureKey.class); + + private final String key; + + @JsonCreator + public SignatureKey(String key) + { + this.key = requireNonNull(key, "key is null"); + } + + @Override + @JsonValue + public String toString() + { + return key; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SignatureKey that = (SignatureKey) o; + return key.equals(that.key); + } + + @Override + public int hashCode() + { + return key.hashCode(); + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + estimatedSizeOf(key); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java index a6bf5755b44b..045c62887874 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java @@ -14,6 +14,7 @@ package io.trino.spi.connector; import io.trino.spi.Experimental; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.eventlistener.EventListener; import io.trino.spi.function.FunctionProvider; import io.trino.spi.function.table.ConnectorTableFunction; @@ -67,6 +68,14 @@ default ConnectorSplitManager getSplitManager() throw new UnsupportedOperationException(); } + /** + * @throws UnsupportedOperationException if this connector does not support cache ids + */ + default ConnectorCacheMetadata getCacheMetadata() + { + throw new UnsupportedOperationException(); + } + /** * @throws UnsupportedOperationException if this connector does not support reading tables page at a time */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSourceProvider.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSourceProvider.java index b589c79b4740..c670d001577c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSourceProvider.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSourceProvider.java @@ -13,6 +13,8 @@ */ package io.trino.spi.connector; +import io.trino.spi.predicate.TupleDomain; + import java.util.List; public interface ConnectorPageSourceProvider @@ -28,4 +30,31 @@ ConnectorPageSource createPageSource( ConnectorTableHandle table, List columns, DynamicFilter dynamicFilter); + + /** + * Returns unenforced predicate that {@link ConnectorPageSource} would use to filter split data. + * If split is completely filtered out, then this method should return {@link TupleDomain#none}. + */ + default TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain dynamicFilter) + { + throw new UnsupportedOperationException(); + } + + /** + * Prunes columns from predicate that are not effective in filtering split data. + * If split is completely filtered out by given predicate, then this + * method must return {@link TupleDomain#none}. + */ + default TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain predicate) + { + throw new UnsupportedOperationException(); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitManager.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitManager.java index 36a141b50db0..9d6f0974e07a 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitManager.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitManager.java @@ -14,8 +14,14 @@ package io.trino.spi.connector; import io.trino.spi.Experimental; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.CacheTableId; import io.trino.spi.function.table.ConnectorTableFunctionHandle; +import java.util.Optional; + public interface ConnectorSplitManager { default ConnectorSplitSource getSplits( @@ -36,4 +42,14 @@ default ConnectorSplitSource getSplits( { throw new UnsupportedOperationException(); } + + /** + * Returns a split identifier for the purpose of caching with {@link CacheManager}. + * {@link CacheSplitId} together with {@link CacheTableId} and {@link CacheColumnId}s + * represents rows produced by {@link ConnectorPageSource} for a given split. + */ + default Optional getCacheSplitId(ConnectorSplit split) + { + return Optional.empty(); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/predicate/SortedRangeSet.java b/core/trino-spi/src/main/java/io/trino/spi/predicate/SortedRangeSet.java index 152dd13b21d4..b1954454c909 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/predicate/SortedRangeSet.java +++ b/core/trino-spi/src/main/java/io/trino/spi/predicate/SortedRangeSet.java @@ -444,6 +444,30 @@ public boolean containsValue(Object value) return getRangeView(lowRangeIndex).overlaps(valueRange); } + public SortedRangeSet normalize() + { + switch (sortedRanges) { + case ValueBlock _ -> { + return this; + } + case DictionaryBlock dictionary -> { + // unwrap dictionary block + int[] positions = new int[dictionary.getPositionCount()]; + for (int position = 0; position < positions.length; position++) { + positions[position] = dictionary.getUnderlyingValuePosition(position); + } + return new SortedRangeSet(type, inclusive, dictionary.getUnderlyingValueBlock().copyPositions(positions, 0, positions.length), discreteSetMarker); + } + case RunLengthEncodedBlock rleBlock -> { + // unwrap RLE block + int[] positions = new int[rleBlock.getPositionCount()]; + Arrays.fill(positions, 0); + return new SortedRangeSet(type, inclusive, rleBlock.getUnderlyingValueBlock().copyPositions(positions, 0, positions.length), discreteSetMarker); + } + case LazyBlock _ -> throw new IllegalArgumentException("Did not expect LazyBlock"); + } + } + public Range getSpan() { if (isNone()) { diff --git a/core/trino-spi/src/test/java/io/trino/spi/predicate/TestSortedRangeSet.java b/core/trino-spi/src/test/java/io/trino/spi/predicate/TestSortedRangeSet.java index 6743d3ce41a6..185411c9b1b4 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/predicate/TestSortedRangeSet.java +++ b/core/trino-spi/src/test/java/io/trino/spi/predicate/TestSortedRangeSet.java @@ -20,6 +20,9 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.TestingBlockEncodingSerde; import io.trino.spi.block.TestingBlockJsonSerde; import io.trino.spi.type.TestingTypeDeserializer; @@ -36,6 +39,7 @@ import java.util.stream.Stream; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.block.BlockTestUtils.assertBlockEquals; import static io.trino.spi.predicate.SortedRangeSet.DiscreteSetMarker.UNKNOWN; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -805,6 +809,37 @@ public void testSubtract() assertThat(SortedRangeSet.of(Range.greaterThan(BIGINT, 0L)).subtract(SortedRangeSet.of(Range.greaterThan(BIGINT, 0L)))).isEqualTo(SortedRangeSet.none(BIGINT)); } + @Test + public void testNormalizeDictionarySortedRanges() + { + SortedRangeSet values = (SortedRangeSet) ValueSet.of(BIGINT, 0L, -1L); + + // make sure normalization preserves equality of TupleDomains + SortedRangeSet normalizedValues = values.normalize(); + assertThat(normalizedValues).isEqualTo(values); + assertBlockEquals(BIGINT, normalizedValues.getSortedRanges(), values.getSortedRanges()); + assertThat(values.getSortedRanges()).isInstanceOf(DictionaryBlock.class); + assertThat(normalizedValues.getSortedRanges()).isInstanceOf(LongArrayBlock.class); + + // further normalization shouldn't change SortedRangeSet underlying block + SortedRangeSet doubleNormalizedValues = normalizedValues.normalize(); + assertThat(doubleNormalizedValues.getSortedRanges()).isInstanceOf(LongArrayBlock.class); + assertBlockEquals(BIGINT, doubleNormalizedValues.getSortedRanges(), normalizedValues.getSortedRanges()); + } + + @Test + public void testNormalizeRleSortedRanges() + { + SortedRangeSet values = (SortedRangeSet) ValueSet.of(BIGINT, 0L); + + // make sure normalization preserves equality of TupleDomains + SortedRangeSet normalizedValues = values.normalize(); + assertThat(normalizedValues).isEqualTo(values); + assertBlockEquals(BIGINT, normalizedValues.getSortedRanges(), values.getSortedRanges()); + assertThat(values.getSortedRanges()).isInstanceOf(RunLengthEncodedBlock.class); + assertThat(normalizedValues.getSortedRanges()).isInstanceOf(LongArrayBlock.class); + } + @Test public void testJsonSerialization() throws Exception diff --git a/core/trino-web-ui/src/main/resources/webapp/src/components/StageDetail.jsx b/core/trino-web-ui/src/main/resources/webapp/src/components/StageDetail.jsx index 67d29b2296a5..dfc019b810ed 100644 --- a/core/trino-web-ui/src/main/resources/webapp/src/components/StageDetail.jsx +++ b/core/trino-web-ui/src/main/resources/webapp/src/components/StageDetail.jsx @@ -385,18 +385,23 @@ class StageOperatorGraph extends React.Component { ) } + // returns map from pipelineId to map from alternativeId to the sinkOperator computeOperatorGraphs(planNode, operatorMap) { const sources = getChildren(planNode) const sourceResults = new Map() sources.forEach((source) => { const sourceResult = this.computeOperatorGraphs(source, operatorMap) - sourceResult.forEach((operator, pipelineId) => { - if (sourceResults.has(pipelineId)) { - console.error('Multiple sources for ', planNode['@type'], ' had the same pipeline ID') - return sourceResults + sourceResult.forEach((alternatives, pipelineId) => { + if (!sourceResults.has(pipelineId)) { + sourceResults.set(pipelineId, new Map()) } - sourceResults.set(pipelineId, operator) + + const mergedAlternatives = sourceResults.get(pipelineId) + + alternatives.forEach(function (operator, alternativeId) { + mergedAlternatives.set(alternativeId, operator) + }) }) }) @@ -408,40 +413,57 @@ class StageOperatorGraph extends React.Component { const pipelineOperators = new Map() nodeOperators.forEach((operator) => { if (!pipelineOperators.has(operator.pipelineId)) { - pipelineOperators.set(operator.pipelineId, []) + pipelineOperators.set(operator.pipelineId, new Map()) } - pipelineOperators.get(operator.pipelineId).push(operator) + const pipeline = pipelineOperators.get(operator.pipelineId) + if (!pipeline.has(operator.alternativeId)) { + pipeline.set(operator.alternativeId, []) + } + pipeline.get(operator.alternativeId).push(operator) }) const result = new Map() - pipelineOperators.forEach((pipelineOperators, pipelineId) => { - // sort deep-copied operators in this pipeline from source to sink - const linkedOperators = pipelineOperators - .map((a) => Object.assign({}, a)) - .sort((a, b) => a.operatorId - b.operatorId) - const sinkOperator = linkedOperators[linkedOperators.length - 1] - const sourceOperator = linkedOperators[0] - - if (sourceResults.has(pipelineId)) { - const pipelineChildResult = sourceResults.get(pipelineId) - if (pipelineChildResult) { - sourceOperator.child = pipelineChildResult + pipelineOperators.forEach((alternatives, pipelineId) => { + const sourceAlternatives = sourceResults.get(pipelineId) + + const linkedAlternatives = new Map() + alternatives.forEach((pipelineOperators, alternativeId) => { + // sort deep-copied operators in this pipeline from source to sink + const linkedOperators = pipelineOperators + .map((a) => Object.assign({}, a)) + .sort((a, b) => a.operatorId - b.operatorId) + const sinkOperator = linkedOperators[linkedOperators.length - 1] + const sourceOperator = linkedOperators[0] + + if (sourceAlternatives && sourceAlternatives.has(alternativeId)) { + const pipelineChildResult = sourceAlternatives.get(alternativeId) + if (pipelineChildResult) { + sourceOperator.child = pipelineChildResult + } } - } - // chain operators at this level - let currentOperator = sourceOperator - linkedOperators.slice(1).forEach((source) => { - source.child = currentOperator - currentOperator = source - }) + // chain operators at this level + let currentOperator = sourceOperator + linkedOperators.slice(1).forEach((source) => { + source.child = currentOperator + currentOperator = source + }) - result.set(pipelineId, sinkOperator) + linkedAlternatives.set(alternativeId, sinkOperator) + }) + result.set(pipelineId, linkedAlternatives) }) - sourceResults.forEach((operator, pipelineId) => { + sourceResults.forEach((sourceAlternatives, pipelineId) => { if (!result.has(pipelineId)) { - result.set(pipelineId, operator) + result.set(pipelineId, sourceAlternatives) + } else { + const alternatives = result.get(pipelineId) + sourceAlternatives.forEach((operator, alternativeId) => { + if (!alternatives.has(alternativeId)) { + alternatives.set(alternativeId, operator) + } + }) } }) @@ -462,7 +484,8 @@ class StageOperatorGraph extends React.Component { } computeD3StageOperatorGraph(graph, operator, sink, pipelineNode) { - const operatorNodeId = 'operator-' + operator.pipelineId + '-' + operator.operatorId + const operatorNodeId = + 'operator-' + operator.pipelineId + '-' + operator.alternativeId + '-' + operator.operatorId // this is a non-standard use of ReactDOMServer, but it's the cleanest way to unify DagreD3 with React const html = ReactDOMServer.renderToString( @@ -498,7 +521,7 @@ class StageOperatorGraph extends React.Component { const operatorGraphs = this.computeOperatorGraphs(stage.plan.root, operatorMap) const graph = initializeGraph() - operatorGraphs.forEach((operator, pipelineId) => { + operatorGraphs.forEach((alternatives, pipelineId) => { const pipelineNodeId = 'pipeline-' + pipelineId graph.setNode(pipelineNodeId, { label: 'Pipeline ' + pipelineId + ' ', @@ -506,7 +529,24 @@ class StageOperatorGraph extends React.Component { style: 'fill: #2b2b2b', labelStyle: 'fill: #fff', }) - this.computeD3StageOperatorGraph(graph, operator, null, pipelineNodeId) + if (alternatives.size === 1) { + this.computeD3StageOperatorGraph(graph, alternatives.get(0), null, pipelineNodeId) + } else { + const sortedAlternatives = Array.from(alternatives).sort((a, b) => a[0] - b[0]) + sortedAlternatives.forEach((entry) => { + const alternativeId = entry[0] + const operator = entry[1] + const alternativeNodeId = 'alternative-' + alternativeId + '-' + pipelineId + graph.setNode(alternativeNodeId, { + label: 'Alternative ' + alternativeId + ' ', + clusterLabelPos: 'top', + style: 'fill: #262626', + labelStyle: 'fill: #fff', + }) + this.computeD3StageOperatorGraph(graph, operator, null, alternativeNodeId) + graph.setParent(alternativeNodeId, pipelineNodeId) + }) + } }) $('#operator-canvas').html('') diff --git a/core/trino-web-ui/src/main/resources/webapp/src/utils.js b/core/trino-web-ui/src/main/resources/webapp/src/utils.js index ee9a6df1e57b..58cba13a663e 100644 --- a/core/trino-web-ui/src/main/resources/webapp/src/utils.js +++ b/core/trino-web-ui/src/main/resources/webapp/src/utils.js @@ -219,7 +219,6 @@ export function getChildren(nodeInfo: any): any { switch (nodeInfo['@type']) { case 'aggregation': case 'assignUniqueId': - case 'cacheData': case 'delete': case 'distinctLimit': case 'dynamicFilterSource': @@ -244,6 +243,7 @@ export function getChildren(nodeInfo: any): any { case 'topNRanking': case 'unnest': case 'window': + case 'cacheData': return [nodeInfo.source] case 'join': return [nodeInfo.left, nodeInfo.right] @@ -256,15 +256,16 @@ export function getChildren(nodeInfo: any): any { case 'exchange': case 'intersect': case 'union': + case 'chooseAlternative': return nodeInfo.sources case 'indexSource': - case 'loadCachedData': case 'refreshMaterializedView': case 'remoteSource': case 'tableDelete': case 'tableScan': case 'tableUpdate': case 'values': + case 'loadCachedData': break default: console.log('NOTE: Unhandled PlanNode: ' + nodeInfo['@type']) diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/cache/CacheUtils.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/cache/CacheUtils.java new file mode 100644 index 000000000000..37acc77617d0 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/cache/CacheUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.cache; + +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.SortedRangeSet; +import io.trino.spi.predicate.TupleDomain; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collector; + +import static java.lang.String.format; +import static java.util.Comparator.comparing; +import static java.util.stream.Collectors.toMap; + +public final class CacheUtils +{ + private CacheUtils() {} + + /** + * Normalizes {@link TupleDomain} so that equal tuple domains are serialized in same way. + * Without normalization 1) domain map entry order might differ 2) {@link SortedRangeSet} + * {@code sortedRanges} block might differ. + */ + public static TupleDomain normalizeTupleDomain(TupleDomain tupleDomain) + { + if (tupleDomain.getDomains().isEmpty()) { + return tupleDomain; + } + + Map domains = tupleDomain.getDomains().get(); + return TupleDomain.withColumnDomains(domains.entrySet().stream() + // sort domains by string representation of column + .sorted(comparing(domainEntry -> domainEntry.getKey().toString())) + .collect(toLinkedMap( + Map.Entry::getKey, + entry -> { + Domain domain = entry.getValue(); + if (domain.getValues() instanceof SortedRangeSet values) { + // normalize sorted range set + domain = Domain.create(values.normalize(), domain.isNullAllowed()); + } + return domain; + }))); + } + + private static Collector> toLinkedMap(Function keyMapper, Function valueMapper) + { + return toMap( + keyMapper, + valueMapper, + (u, v) -> { + throw new IllegalStateException(format("Duplicate values for a key: %s and %s", u, v)); + }, + LinkedHashMap::new); + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorCacheMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorCacheMetadata.java new file mode 100644 index 000000000000..44d738f07bda --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorCacheMetadata.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.classloader; + +import com.google.inject.Inject; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorTableHandle; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class ClassLoaderSafeConnectorCacheMetadata + implements ConnectorCacheMetadata +{ + private final ConnectorCacheMetadata delegate; + private final ClassLoader classLoader; + + @Inject + public ClassLoaderSafeConnectorCacheMetadata(@ForClassLoaderSafe ConnectorCacheMetadata delegate, ClassLoader classLoader) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public Optional getCacheTableId(ConnectorTableHandle tableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getCacheTableId(tableHandle); + } + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getCacheColumnId(tableHandle, columnHandle); + } + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle tableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getCanonicalTableHandle(tableHandle); + } + } + + public ConnectorCacheMetadata unwrap() + { + return delegate; + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSourceProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSourceProvider.java index 5fa55cdb36ad..f7fb955058f2 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSourceProvider.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSourceProvider.java @@ -23,6 +23,7 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.predicate.TupleDomain; import java.util.List; @@ -48,4 +49,28 @@ public ConnectorPageSource createPageSource(ConnectorTransactionHandle transacti return delegate.createPageSource(transaction, session, split, table, columns, dynamicFilter); } } + + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain dynamicFilter) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getUnenforcedPredicate(session, split, table, dynamicFilter); + } + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain predicate) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.prunePredicate(session, split, table, predicate); + } + } } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java index 85ac0b24efe3..6d4c5d0f5942 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitManager.java @@ -14,8 +14,10 @@ package io.trino.plugin.base.classloader; import com.google.inject.Inject; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableHandle; @@ -24,6 +26,8 @@ import io.trino.spi.connector.DynamicFilter; import io.trino.spi.function.table.ConnectorTableFunctionHandle; +import java.util.Optional; + import static java.util.Objects.requireNonNull; public final class ClassLoaderSafeConnectorSplitManager @@ -62,4 +66,12 @@ public ConnectorSplitSource getSplits( return delegate.getSplits(transaction, session, function); } } + + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getCacheSplitId(split); + } + } } diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/cache/TestCacheUtils.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/cache/TestCacheUtils.java new file mode 100644 index 000000000000..fbdd17137c56 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/cache/TestCacheUtils.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.cache; + +import com.google.common.collect.ImmutableMap; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.SortedRangeSet; +import io.trino.spi.predicate.ValueSet; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.Optional; + +import static io.trino.plugin.base.cache.CacheUtils.normalizeTupleDomain; +import static io.trino.spi.block.BlockTestUtils.assertBlockEquals; +import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.predicate.TupleDomain.none; +import static io.trino.spi.predicate.TupleDomain.withColumnDomains; +import static io.trino.spi.type.BigintType.BIGINT; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCacheUtils +{ + @Test + public void testNormalizeTupleDomainEmptyTupleDomain() + { + assertThat(none()).isSameAs(none()); + } + + @Test + public void testNormalizeTupleDomainKeyOrder() + { + Optional> domains = normalizeTupleDomain(withColumnDomains(ImmutableMap.of( + new CacheColumnId("col2"), singleValue(BIGINT, 1L), + new CacheColumnId("col1"), singleValue(BIGINT, 2L)))) + .getDomains(); + assertThat(domains).isPresent(); + assertThat(domains.get()).containsExactlyEntriesOf(ImmutableMap.of( + new CacheColumnId("col1"), singleValue(BIGINT, 2L), + new CacheColumnId("col2"), singleValue(BIGINT, 1L))); + } + + @Test + public void testNormalizeTupleDomainSortedRanges() + { + SortedRangeSet values = (SortedRangeSet) ValueSet.of(BIGINT, 0L, -1L); + SortedRangeSet normalizedValues = (SortedRangeSet) normalizeTupleDomain(withColumnDomains(ImmutableMap.of( + new CacheColumnId("col1"), Domain.create(values, false)))) + .getDomains() + .orElseThrow() + .get(new CacheColumnId("col1")) + .getValues(); + // make sure normalization preserves equality of TupleDomains + assertThat(normalizedValues).isEqualTo(values); + assertBlockEquals(BIGINT, normalizedValues.getSortedRanges(), values.getSortedRanges()); + assertThat(values.getSortedRanges()).isInstanceOf(DictionaryBlock.class); + assertThat(normalizedValues.getSortedRanges()).isInstanceOf(LongArrayBlock.class); + + // further normalization shouldn't change SortedRangeSet underlying block + SortedRangeSet doubleNormalizedValues = (SortedRangeSet) normalizeTupleDomain(withColumnDomains(ImmutableMap.of( + new CacheColumnId("col1"), Domain.create(normalizedValues, false)))) + .getDomains() + .orElseThrow() + .get(new CacheColumnId("col1")) + .getValues(); + assertThat(doubleNormalizedValues.getSortedRanges()).isInstanceOf(LongArrayBlock.class); + assertBlockEquals(BIGINT, doubleNormalizedValues.getSortedRanges(), normalizedValues.getSortedRanges()); + } +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java index 66a1412e61a5..f1f66030413e 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/classloader/TestClassLoaderSafeWrappers.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.base.classloader; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorMergeSink; import io.trino.spi.connector.ConnectorMetadata; @@ -48,6 +49,7 @@ public void test() throws Exception { testClassLoaderSafe(ConnectorAccessControl.class, ClassLoaderSafeConnectorAccessControl.class); + testClassLoaderSafe(ConnectorCacheMetadata.class, ClassLoaderSafeConnectorCacheMetadata.class); testClassLoaderSafe(ConnectorMetadata.class, ClassLoaderSafeConnectorMetadata.class); testClassLoaderSafe(ConnectorMergeSink.class, ClassLoaderSafeConnectorMergeSink.class); testClassLoaderSafe(ConnectorPageSink.class, ClassLoaderSafeConnectorPageSink.class); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheMetadata.java new file mode 100644 index 000000000000..fde6e968e9ec --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheMetadata.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.inject.Inject; +import io.airlift.json.JsonCodec; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.predicate.TupleDomain; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class DeltaLakeCacheMetadata + implements ConnectorCacheMetadata +{ + private final JsonCodec tableIdCodec; + private final JsonCodec columnCodec; + + @Inject + public DeltaLakeCacheMetadata(JsonCodec tableIdCodec, JsonCodec columnCodec) + { + this.tableIdCodec = requireNonNull(tableIdCodec, "tableIdCodec is null"); + this.columnCodec = requireNonNull(columnCodec, "columnCodec is null"); + } + + @Override + public Optional getCacheTableId(ConnectorTableHandle tableHandle) + { + DeltaLakeTableHandle deltaLakeTableHandle = (DeltaLakeTableHandle) tableHandle; + + // skip caching if it is UPDATE / INSERT query + if (((DeltaLakeTableHandle) tableHandle).getWriteType().isPresent()) { + return Optional.empty(); + } + + // skip caching of analyze queries + if (deltaLakeTableHandle.getAnalyzeHandle().isPresent()) { + return Optional.empty(); + } + + // Ensure cache id generation is revisited whenever handle classes change. + DeltaLakeTableHandle handle = new DeltaLakeTableHandle( + deltaLakeTableHandle.getSchemaName(), + deltaLakeTableHandle.getTableName(), + deltaLakeTableHandle.isManaged(), + deltaLakeTableHandle.getLocation(), + deltaLakeTableHandle.getMetadataEntry(), + deltaLakeTableHandle.getProtocolEntry(), + deltaLakeTableHandle.getEnforcedPartitionConstraint(), + deltaLakeTableHandle.getNonPartitionConstraint(), + deltaLakeTableHandle.getWriteType(), + deltaLakeTableHandle.getProjectedColumns(), + deltaLakeTableHandle.getUpdatedColumns(), + deltaLakeTableHandle.getUpdateRowIdColumns(), + deltaLakeTableHandle.getAnalyzeHandle(), + deltaLakeTableHandle.getReadVersion(), + deltaLakeTableHandle.isTimeTravel()); + + DeltaLakeCacheTableId tableId = new DeltaLakeCacheTableId( + handle.getSchemaName(), + handle.getTableName(), + handle.getLocation(), + handle.getMetadataEntry()); + return Optional.of(new CacheTableId(tableIdCodec.toJson(tableId))); + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return Optional.of(new CacheColumnId(columnCodec.toJson((DeltaLakeColumnHandle) columnHandle))); + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle tableHandle) + { + DeltaLakeTableHandle deltaLakeTableHandle = (DeltaLakeTableHandle) tableHandle; + return new DeltaLakeTableHandle( + deltaLakeTableHandle.getSchemaName(), + deltaLakeTableHandle.getTableName(), + deltaLakeTableHandle.isManaged(), + deltaLakeTableHandle.getLocation(), + deltaLakeTableHandle.getMetadataEntry(), + deltaLakeTableHandle.getProtocolEntry(), + deltaLakeTableHandle.getEnforcedPartitionConstraint(), + /* + It overwrites `nonPartitionConstraint` because setting this property to `TupleDomain.all()` does not affect + final result when table is queried. It allows to match more similar subqueries that reads from same table + but has different predicates. + */ + TupleDomain.all(), + deltaLakeTableHandle.getWriteType(), + deltaLakeTableHandle.getProjectedColumns(), + deltaLakeTableHandle.getUpdatedColumns(), + deltaLakeTableHandle.getUpdateRowIdColumns(), + deltaLakeTableHandle.getAnalyzeHandle(), + deltaLakeTableHandle.getReadVersion(), + deltaLakeTableHandle.isTimeTravel()); + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheSplitId.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheSplitId.java new file mode 100644 index 000000000000..1442a495ec61 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheSplitId.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.plugin.deltalake.transactionlog.DeletionVectorEntry; + +import java.util.Map; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class DeltaLakeCacheSplitId +{ + private final String path; + private final long start; + private final long length; + private final long fileSize; + // DeltaLakePageSourceProvider.createPageSource uses fileRowCount field. This field is mutable and affects split result even though data files in DeltaLake are immutable + private final Optional fileRowCount; + private final long fileModifiedTime; + private final Map> partitionKeys; + private final Optional deletionVector; + + public DeltaLakeCacheSplitId( + String path, + long start, + long length, + long fileSize, + Optional fileRowCount, + long fileModifiedTime, + Map> partitionKeys, + Optional deletionVector) + { + this.path = requireNonNull(path, "path is null"); + this.start = start; + this.length = length; + this.fileSize = fileSize; + this.fileRowCount = requireNonNull(fileRowCount, "rowCount is null"); + this.fileModifiedTime = fileModifiedTime; + this.partitionKeys = requireNonNull(partitionKeys, "partitionKeys is null"); + this.deletionVector = requireNonNull(deletionVector, "deletionVector is null"); + } + + @JsonProperty + public String getPath() + { + return path; + } + + @JsonProperty + public long getStart() + { + return start; + } + + @JsonProperty + public long getLength() + { + return length; + } + + @JsonProperty + public long getFileSize() + { + return fileSize; + } + + @JsonProperty + public Optional getFileRowCount() + { + return fileRowCount; + } + + @JsonProperty + public long getFileModifiedTime() + { + return fileModifiedTime; + } + + @JsonProperty + public Map> getPartitionKeys() + { + return partitionKeys; + } + + @JsonProperty + public Optional getDeletionVector() + { + return deletionVector; + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheTableId.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheTableId.java new file mode 100644 index 000000000000..5d3d34436999 --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeCacheTableId.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.plugin.deltalake.transactionlog.MetadataEntry; + +import static java.util.Objects.requireNonNull; + +public class DeltaLakeCacheTableId +{ + private final String schemaName; + private final String tableName; + private final String location; + private final MetadataEntry metadataEntry; + + public DeltaLakeCacheTableId( + String schemaName, + String tableName, + String location, + MetadataEntry metadataEntry) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.location = requireNonNull(location, "location is null"); + this.metadataEntry = requireNonNull(metadataEntry, "metadataEntry is null"); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public String getLocation() + { + return location; + } + + @JsonProperty + public MetadataEntry getMetadataEntry() + { + return metadataEntry; + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnector.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnector.java index db556ffb1e54..85f9d3772e4b 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnector.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnector.java @@ -21,6 +21,7 @@ import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorMetadata; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.hive.HiveTransactionHandle; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorCapabilities; @@ -57,6 +58,7 @@ public class DeltaLakeConnector private final Injector injector; private final LifeCycleManager lifeCycleManager; private final ConnectorSplitManager splitManager; + private final ConnectorCacheMetadata cacheMetadata; private final ConnectorPageSourceProvider pageSourceProvider; private final ConnectorPageSinkProvider pageSinkProvider; private final ConnectorNodePartitioningProvider nodePartitioningProvider; @@ -78,6 +80,7 @@ public DeltaLakeConnector( Injector injector, LifeCycleManager lifeCycleManager, ConnectorSplitManager splitManager, + ConnectorCacheMetadata cacheMetadata, ConnectorPageSourceProvider pageSourceProvider, ConnectorPageSinkProvider pageSinkProvider, ConnectorNodePartitioningProvider nodePartitioningProvider, @@ -96,6 +99,7 @@ public DeltaLakeConnector( this.injector = requireNonNull(injector, "injector is null"); this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.cacheMetadata = requireNonNull(cacheMetadata, "cacheMetadata is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.pageSinkProvider = requireNonNull(pageSinkProvider, "pageSinkProvider is null"); this.nodePartitioningProvider = requireNonNull(nodePartitioningProvider, "nodePartitioningProvider is null"); @@ -128,6 +132,12 @@ public ConnectorSplitManager getSplitManager() return splitManager; } + @Override + public ConnectorCacheMetadata getCacheMetadata() + { + return cacheMetadata; + } + @Override public ConnectorPageSourceProvider getPageSourceProvider() { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnectorFactory.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnectorFactory.java index 93f11c401674..9ea2bbceb83a 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnectorFactory.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeConnectorFactory.java @@ -25,6 +25,7 @@ import io.trino.filesystem.manager.FileSystemModule; import io.trino.plugin.base.CatalogNameModule; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorAccessControl; +import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorCacheMetadata; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorPageSinkProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorPageSourceProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorSplitManager; @@ -37,6 +38,7 @@ import io.trino.plugin.hive.NodeVersion; import io.trino.spi.NodeManager; import io.trino.spi.PageIndexerFactory; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.catalog.CatalogName; import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.Connector; @@ -118,6 +120,7 @@ public static Connector createConnector( LifeCycleManager lifeCycleManager = injector.getInstance(LifeCycleManager.class); ConnectorSplitManager splitManager = injector.getInstance(ConnectorSplitManager.class); + ConnectorCacheMetadata cacheMetadata = injector.getInstance(ConnectorCacheMetadata.class); ConnectorPageSourceProvider connectorPageSource = injector.getInstance(ConnectorPageSourceProvider.class); ConnectorPageSinkProvider connectorPageSink = injector.getInstance(ConnectorPageSinkProvider.class); ConnectorNodePartitioningProvider connectorDistributionProvider = injector.getInstance(ConnectorNodePartitioningProvider.class); @@ -143,6 +146,7 @@ public static Connector createConnector( injector, lifeCycleManager, new ClassLoaderSafeConnectorSplitManager(splitManager, classLoader), + new ClassLoaderSafeConnectorCacheMetadata(cacheMetadata, classLoader), new ClassLoaderSafeConnectorPageSourceProvider(connectorPageSource, classLoader), new ClassLoaderSafeConnectorPageSinkProvider(connectorPageSink, classLoader), new ClassLoaderSafeNodePartitioningProvider(connectorDistributionProvider, classLoader), diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java index 7c62bdb87765..03983e416f02 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeModule.java @@ -53,6 +53,10 @@ import io.trino.plugin.hive.metastore.thrift.TranslateHiveViews; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hive.parquet.ParquetWriterConfig; +import io.trino.plugin.hive.util.BlockJsonSerde; +import io.trino.plugin.hive.util.HiveBlockEncodingSerde; +import io.trino.spi.block.Block; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -70,6 +74,7 @@ import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.trino.plugin.base.ClosingBinder.closingBinder; import static io.trino.plugin.deltalake.DeltaLakeAccessControlMetadataFactory.SYSTEM; @@ -137,6 +142,18 @@ public void setup(Binder binder) newExporter(binder).export(FileFormatDataSourceStats.class) .as(generator -> generator.generatedNameOf(FileFormatDataSourceStats.class, catalogName.get().toString())); + binder.bind(ConnectorCacheMetadata.class).to(DeltaLakeCacheMetadata.class).in(Scopes.SINGLETON); + + // for table handle, column handle and split ids + jsonCodecBinder(binder).bindJsonCodec(DeltaLakeCacheTableId.class); + jsonCodecBinder(binder).bindJsonCodec(DeltaLakeCacheSplitId.class); + jsonCodecBinder(binder).bindJsonCodec(DeltaLakeColumnHandle.class); + + // bind block serializers for the purpose of TupleDomain serde + binder.bind(HiveBlockEncodingSerde.class).in(Scopes.SINGLETON); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + Multibinder procedures = newSetBinder(binder, Procedure.class); procedures.addBinding().toProvider(DropExtendedStatsProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(VacuumProcedure.class).in(Scopes.SINGLETON); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index 80d32276fb8a..fd2cb2ad527b 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -83,6 +83,7 @@ import static io.trino.plugin.deltalake.DeltaHiveTypeTranslator.toHiveType; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.rowPositionColumnHandle; +import static io.trino.plugin.deltalake.DeltaLakeColumnType.PARTITION_KEY; import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA; import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getParquetMaxReadBlockRowCount; @@ -179,20 +180,16 @@ public ConnectorPageSource createPageSource( // and the dynamic filter in the coordinator during split generation. The file level stats // in DeltaLakeSplit#statisticsPredicate could help to prune this split when a more selective dynamic filter // is available now, without having to access parquet file footer for row-group stats. - TupleDomain filteredSplitPredicate = TupleDomain.intersect(ImmutableList.of( - table.getNonPartitionConstraint(), - split.getStatisticsPredicate(), - dynamicFilter.getCurrentPredicate().transformKeys(DeltaLakeColumnHandle.class::cast))); - if (filteredSplitPredicate.isNone()) { - return new EmptyPageSource(); - } - Map partitionColumnDomains = filteredSplitPredicate.getDomains().orElseThrow().entrySet().stream() - .filter(entry -> entry.getKey().columnType() == DeltaLakeColumnType.PARTITION_KEY) - .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); - if (!partitionMatchesPredicate(split.getPartitionKeys(), partitionColumnDomains)) { + TupleDomain effectivePredicate = getUnenforcedPredicate( + session, + split, + table, + dynamicFilter.getCurrentPredicate()) + .transformKeys(DeltaLakeColumnHandle.class::cast); + if (effectivePredicate.isNone()) { return new EmptyPageSource(); } - if (filteredSplitPredicate.isAll() && + if (effectivePredicate.isAll() && split.getStart() == 0 && split.getLength() == split.getFileSize() && split.getFileRowCount().isPresent() && split.getDeletionVector().isEmpty() && @@ -237,7 +234,7 @@ public ConnectorPageSource createPageSource( hiveColumnHandles.add(PARQUET_ROW_INDEX_COLUMN); } - TupleDomain parquetPredicate = getParquetTupleDomain(filteredSplitPredicate.simplify(domainCompactionThreshold), columnMappingMode, parquetFieldIdToName); + TupleDomain parquetPredicate = getParquetTupleDomain(effectivePredicate, columnMappingMode, parquetFieldIdToName); ReaderPageSource pageSource = ParquetPageSourceFactory.createPageSource( inputFile, @@ -300,6 +297,46 @@ private PositionDeleteFilter readDeletes( } } + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession connectorSession, + ConnectorSplit connectorSplit, + ConnectorTableHandle connectorTable, + TupleDomain dynamicFilter) + { + DeltaLakeSplit split = (DeltaLakeSplit) connectorSplit; + DeltaLakeTableHandle table = (DeltaLakeTableHandle) connectorTable; + + TupleDomain prunedPredicate = prunePredicate(connectorSession, connectorSplit, connectorTable, + TupleDomain.intersect(ImmutableList.of( + table.getNonPartitionConstraint(), + split.getStatisticsPredicate(), + dynamicFilter))); + return prunedPredicate.simplify(domainCompactionThreshold); + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession connectorSession, + ConnectorSplit connectorSplit, + ConnectorTableHandle connectorTable, + TupleDomain predicate) + { + DeltaLakeSplit split = (DeltaLakeSplit) connectorSplit; + + TupleDomain predicateOnPartitioningColumn = predicate + .transformKeys(DeltaLakeColumnHandle.class::cast) + .filter((columnHandle, domain) -> columnHandle.columnType() == PARTITION_KEY); + + if (predicateOnPartitioningColumn.getDomains().isPresent() && !partitionMatchesPredicate(split.getPartitionKeys(), predicateOnPartitioningColumn.getDomains().get())) { + return TupleDomain.none(); + } + + return predicate.filter((columnHandle, domain) -> ((DeltaLakeColumnHandle) columnHandle).columnType() != PARTITION_KEY) + // remove domains from predicate that fully contain split data because they are irrelevant for filtering + .filter((handle, domain) -> !domain.contains(split.getStatisticsPredicate().getDomain((DeltaLakeColumnHandle) handle, domain.getType()))); + } + public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java index 0d2c05cb0762..eca1aa8d682f 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; +import io.airlift.json.JsonCodec; import io.airlift.units.DataSize; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystemFactory; @@ -29,8 +30,10 @@ import io.trino.plugin.deltalake.transactionlog.TransactionLogAccess; import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeFileStatistics; import io.trino.spi.SplitWeight; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableHandle; @@ -81,6 +84,7 @@ public class DeltaLakeSplitManager private final int maxOutstandingSplits; private final double minimumAssignedSplitWeight; private final TrinoFileSystemFactory fileSystemFactory; + private final JsonCodec splitIdCodec; private final DeltaLakeTransactionManager deltaLakeTransactionManager; private final CachingHostAddressProvider cachingHostAddressProvider; @@ -91,6 +95,7 @@ public DeltaLakeSplitManager( ExecutorService executor, DeltaLakeConfig config, TrinoFileSystemFactory fileSystemFactory, + JsonCodec splitIdCodec, DeltaLakeTransactionManager deltaLakeTransactionManager, CachingHostAddressProvider cachingHostAddressProvider) { @@ -101,6 +106,7 @@ public DeltaLakeSplitManager( this.maxOutstandingSplits = config.getMaxOutstandingSplits(); this.minimumAssignedSplitWeight = config.getMinimumAssignedSplitWeight(); this.fileSystemFactory = requireNonNull(fileSystemFactory, "fileSystemFactory is null"); + this.splitIdCodec = requireNonNull(splitIdCodec, "splitIdCodec is null"); this.deltaLakeTransactionManager = requireNonNull(deltaLakeTransactionManager, "deltaLakeTransactionManager is null"); this.cachingHostAddressProvider = requireNonNull(cachingHostAddressProvider, "cacheHostAddressProvider is null"); } @@ -248,6 +254,36 @@ private Stream getSplits( }); } + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + DeltaLakeSplit deltaLakeSplit = (DeltaLakeSplit) split; + + // ensure cache id generation is revisited whenever split classes change + deltaLakeSplit = new DeltaLakeSplit( + deltaLakeSplit.getPath(), + deltaLakeSplit.getStart(), + deltaLakeSplit.getLength(), + deltaLakeSplit.getFileSize(), + deltaLakeSplit.getFileRowCount(), + deltaLakeSplit.getFileModifiedTime(), + deltaLakeSplit.getDeletionVector(), + // weight does not impact split rows + SplitWeight.standard(), + deltaLakeSplit.getStatisticsPredicate(), + deltaLakeSplit.getPartitionKeys()); + + return Optional.of(new CacheSplitId(splitIdCodec.toJson(new DeltaLakeCacheSplitId( + deltaLakeSplit.getPath(), + deltaLakeSplit.getStart(), + deltaLakeSplit.getLength(), + deltaLakeSplit.getFileSize(), + deltaLakeSplit.getFileRowCount(), + deltaLakeSplit.getFileModifiedTime(), + deltaLakeSplit.getPartitionKeys(), + deltaLakeSplit.getDeletionVector())))); + } + private static Stream filterValidDataFilesForOptimize(Stream validDataFiles, long maxScannedFileSizeInBytes) { // Value being present is a pending file (potentially the only one) for a given partition. diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheFileOperations.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheFileOperations.java index 32bdc7c2cb7c..7e384f4e4f48 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheFileOperations.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheFileOperations.java @@ -46,7 +46,7 @@ public class TestDeltaLakeAlluxioCacheFileOperations protected DistributedQueryRunner createQueryRunner() throws Exception { - Path cacheDirectory = Files.createTempDirectory("cache"); + Path cacheDirectory = Files.createTempDirectory("deltalake-cache"); closeAfterClass(() -> deleteRecursively(cacheDirectory, ALLOW_INSECURE)); return DeltaLakeQueryRunner.builder() diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheMinioAndHmsConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheMinioAndHmsConnectorSmokeTest.java index 77c16689cbae..9cf527945745 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheMinioAndHmsConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAlluxioCacheMinioAndHmsConnectorSmokeTest.java @@ -39,7 +39,7 @@ public class TestDeltaLakeAlluxioCacheMinioAndHmsConnectorSmokeTest public void init() throws Exception { - cacheDirectory = Files.createTempDirectory("cache"); + cacheDirectory = Files.createTempDirectory("deltalake-cache"); super.init(); } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeCacheIds.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeCacheIds.java new file mode 100644 index 000000000000..ea37616f4d81 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeCacheIds.java @@ -0,0 +1,460 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.airlift.json.ObjectMapperProvider; +import io.trino.filesystem.Location; +import io.trino.filesystem.cache.DefaultCachingHostAddressProvider; +import io.trino.filesystem.hdfs.HdfsFileSystemFactory; +import io.trino.filesystem.memory.MemoryFileSystemFactory; +import io.trino.hdfs.HdfsConfig; +import io.trino.hdfs.HdfsConfiguration; +import io.trino.hdfs.HdfsEnvironment; +import io.trino.hdfs.authentication.NoHdfsAuthentication; +import io.trino.plugin.base.TypeDeserializer; +import io.trino.plugin.base.metrics.FileFormatDataSourceStats; +import io.trino.plugin.deltalake.metastore.DeltaLakeTableMetadataScheduler; +import io.trino.plugin.deltalake.metastore.file.DeltaLakeFileMetastoreTableOperationsProvider; +import io.trino.plugin.deltalake.statistics.CachingExtendedStatisticsAccess; +import io.trino.plugin.deltalake.statistics.ExtendedStatistics; +import io.trino.plugin.deltalake.statistics.MetaDirStatisticsAccess; +import io.trino.plugin.deltalake.transactionlog.DeletionVectorEntry; +import io.trino.plugin.deltalake.transactionlog.MetadataEntry; +import io.trino.plugin.deltalake.transactionlog.ProtocolEntry; +import io.trino.plugin.deltalake.transactionlog.TransactionLogAccess; +import io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointSchemaManager; +import io.trino.plugin.deltalake.transactionlog.checkpoint.CheckpointWriterManager; +import io.trino.plugin.deltalake.transactionlog.checkpoint.LastCheckpoint; +import io.trino.plugin.deltalake.transactionlog.writer.NoIsolationSynchronizer; +import io.trino.plugin.deltalake.transactionlog.writer.TransactionLogSynchronizerManager; +import io.trino.plugin.deltalake.transactionlog.writer.TransactionLogWriterFactory; +import io.trino.plugin.hive.NodeVersion; +import io.trino.plugin.hive.metastore.HiveMetastoreFactory; +import io.trino.plugin.hive.parquet.ParquetReaderConfig; +import io.trino.spi.SplitWeight; +import io.trino.spi.block.Block; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.block.TestingBlockJsonSerde; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.TestingTypeManager; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeManager; +import io.trino.testing.TestingConnectorContext; +import io.trino.testing.TestingNodeManager; +import org.apache.hadoop.conf.Configuration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; + +import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; +import static io.trino.plugin.deltalake.DeltaLakeAnalyzeProperties.AnalyzeMode.FULL_REFRESH; +import static io.trino.plugin.deltalake.DeltaLakeTableHandle.WriteType.UPDATE; +import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_FACTORY; +import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_STATS; +import static io.trino.plugin.hive.metastore.file.TestingFileHiveMetastore.createTestingFileHiveMetastore; +import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestDeltaLakeCacheIds +{ + private static ScheduledExecutorService executorService = newScheduledThreadPool(1); + private DeltaLakeCacheMetadata metadata; + private DeltaLakeSplitManager splitManager; + + @BeforeAll + public void setup() + { + DeltaLakeConfig config = new DeltaLakeConfig(); + HdfsConfiguration hdfsConfiguration = (_, _) -> new Configuration(false); + HdfsEnvironment hdfsEnvironment = new HdfsEnvironment(hdfsConfiguration, new HdfsConfig(), new NoHdfsAuthentication()); + TestingConnectorContext context = new TestingConnectorContext(); + TypeManager typeManager = context.getTypeManager(); + CheckpointSchemaManager checkpointSchemaManager = new CheckpointSchemaManager(typeManager); + HdfsFileSystemFactory hdfsFileSystemFactory = new HdfsFileSystemFactory(hdfsEnvironment, HDFS_FILE_SYSTEM_STATS); + + FileFormatDataSourceStats fileFormatDataSourceStats = new FileFormatDataSourceStats(); + + TransactionLogAccess transactionLogAccess = new TransactionLogAccess( + typeManager, + checkpointSchemaManager, + config, + fileFormatDataSourceStats, + HDFS_FILE_SYSTEM_FACTORY, + new ParquetReaderConfig()); + CheckpointWriterManager checkpointWriterManager = new CheckpointWriterManager( + typeManager, + new CheckpointSchemaManager(typeManager), + hdfsFileSystemFactory, + new NodeVersion("test_version"), + transactionLogAccess, + new FileFormatDataSourceStats(), + JsonCodec.jsonCodec(LastCheckpoint.class)); + + HiveMetastoreFactory hiveMetastoreFactory = HiveMetastoreFactory.ofInstance(createTestingFileHiveMetastore(new MemoryFileSystemFactory(), Location.of("memory:///"))); + DeltaLakeMetadataFactory metadataFactory = new DeltaLakeMetadataFactory( + hiveMetastoreFactory, + hdfsFileSystemFactory, + transactionLogAccess, + typeManager, + DeltaLakeAccessControlMetadataFactory.DEFAULT, + config, + JsonCodec.jsonCodec(DataFileInfo.class), + JsonCodec.jsonCodec(DeltaLakeMergeResult.class), + new TransactionLogWriterFactory( + new TransactionLogSynchronizerManager(ImmutableMap.of(), new NoIsolationSynchronizer(hdfsFileSystemFactory))), + new TestingNodeManager(), + checkpointWriterManager, + DeltaLakeRedirectionsProvider.NOOP, + new CachingExtendedStatisticsAccess(new MetaDirStatisticsAccess(HDFS_FILE_SYSTEM_FACTORY, new JsonCodecFactory().jsonCodec(ExtendedStatistics.class))), + true, + new NodeVersion("test_version"), + new DeltaLakeTableMetadataScheduler(new TestingNodeManager(), TESTING_TYPE_MANAGER, new DeltaLakeFileMetastoreTableOperationsProvider(hiveMetastoreFactory), Integer.MAX_VALUE, new DeltaLakeConfig()), + newDirectExecutorService()); + metadata = new DeltaLakeCacheMetadata( + createJsonCodec(DeltaLakeCacheTableId.class), + createJsonCodec(DeltaLakeColumnHandle.class)); + splitManager = new DeltaLakeSplitManager( + typeManager, + transactionLogAccess, + newDirectExecutorService(), + config, + hdfsFileSystemFactory, + createJsonCodec(DeltaLakeCacheSplitId.class), + new DeltaLakeTransactionManager(metadataFactory), + new DefaultCachingHostAddressProvider()); + } + + @AfterAll + public void tearDown() + { + if (executorService != null) { + executorService.shutdownNow(); + executorService = null; + } + } + + @Test + public void testTableId() + { + DeltaLakeColumnHandle partitionColumn = new DeltaLakeColumnHandle("col1", BIGINT, OptionalInt.empty(), "base_col1", BIGINT, DeltaLakeColumnType.PARTITION_KEY, Optional.empty()); + String schema = "{\"fields\": [{\"name\": \"value\", \"metadata\": {}}]}"; + + // table id for updating query is empty + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle( + createMetadataEntry("id", schema), + ImmutableSet.of(partitionColumn), + Optional.of(ImmutableList.of(partitionColumn)), + Optional.of(UPDATE))) + ).isEqualTo(Optional.empty()); + + // `managed` shouldn't be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location", 0))) + .isEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", true, "location", 0))); + + // `location` should be part of table id - it is part of split + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location", 0))) + .isNotEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location2", 0))); + + // enforced predicate shouldn't be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle(TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 1L))), TupleDomain.all()))) + .isEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle(TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 2L))), TupleDomain.all()))); + + // metadataEntry should be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle( + createMetadataEntry("id1", schema), + ImmutableSet.of(), + Optional.empty(), + Optional.empty())) + ).isNotEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle(createMetadataEntry("id2", schema), ImmutableSet.of(), Optional.empty(), Optional.empty()))); + + // projectedColumns shouldn't be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle(createMetadataEntry("id", schema), ImmutableSet.of(partitionColumn), Optional.empty(), Optional.empty()))) + .isEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle(createMetadataEntry("id", schema), ImmutableSet.of(), Optional.empty(), Optional.empty()))); + + // nonPartitionConstraint should not be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle(TupleDomain.all(), TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(BIGINT, 1L)))))) + .isEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle(TupleDomain.all(), TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 2L)))))); + + // readVersion predicate should not be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location", 0))) + .isEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location", 1))); + + // schema should be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location", 0))) + .isNotEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle("schema2", "table", false, "location", 0))); + + // table should be part of table id + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table", false, "location", 0))) + .isNotEqualTo(metadata.getCacheTableId(createDeltaLakeTableHandle("schema", "table2", false, "location", 0))); + + // writing queries should result in empty cacheTableId + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle(createMetadataEntry("id", schema), ImmutableSet.of(), Optional.of(ImmutableList.of(partitionColumn)), Optional.of(UPDATE)))) + .isEmpty(); + + // analyze queries should result in empty cacheTableId + assertThat(metadata.getCacheTableId(createDeltaLakeTableHandle( + "schema", + "table", + false, + "location", + 0, + TupleDomain.all(), + TupleDomain.all(), + createMetadataEntry("id", schema), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new AnalyzeHandle(FULL_REFRESH, Optional.empty(), Optional.empty()))))) + .isEmpty(); + } + + @Test + public void testSplitId() + { + // different path should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 1024, 10))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit("path2", 10, 0, 1024, 10))); + + // different length should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 1024, 10))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 11, 0, 1024, 10))); + + // different start position should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 1024, 10))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 1, 1024, 10))); + + // different fileSize position should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 1024, 10))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 512, 10))); + + // different fileModification should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 1024, 10))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit("path", 10, 0, 1024, 20))); + + // different partitionKeys should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.standard(), TupleDomain.all(), ImmutableMap.of("partitionKet", Optional.empty())))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.standard(), TupleDomain.all(), ImmutableMap.of()))); + + // different split weight should not make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.fromProportion(0.1), TupleDomain.all(), ImmutableMap.of()))) + .isEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.fromProportion(0.11), TupleDomain.all(), ImmutableMap.of()))); + + // different row count should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.fromProportion(0.1), TupleDomain.all(), ImmutableMap.of(), Optional.of(10L), Optional.empty()))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.fromProportion(0.1), TupleDomain.all(), ImmutableMap.of(), Optional.of(11L), Optional.empty()))); + + // different deletion vector should make ids different + assertThat(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.standard(), TupleDomain.all(), ImmutableMap.of(), Optional.of(10L), Optional.of(new DeletionVectorEntry("", "", OptionalInt.of(10), 0, 0L))))) + .isNotEqualTo(splitManager.getCacheSplitId(createDeltaLakeSplit(SplitWeight.standard(), TupleDomain.all(), ImmutableMap.of(), Optional.of(10L), Optional.of(new DeletionVectorEntry("", "", OptionalInt.of(11), 0, 0L))))); + } + + private static DeltaLakeSplit createDeltaLakeSplit( + SplitWeight splitWeight, + TupleDomain statisticsPredicate, + Map> partitionKeys) + { + return new DeltaLakeSplit( + "path", + 0, + 1024, + 1024, + Optional.of(10L), + 10, + Optional.empty(), + splitWeight, + statisticsPredicate, + partitionKeys); + } + + private static DeltaLakeSplit createDeltaLakeSplit( + SplitWeight splitWeight, + TupleDomain statisticsPredicate, + Map> partitionKeys, + Optional rowCount, + Optional deletionVectorEntry) + { + return new DeltaLakeSplit( + "path", + 0, + 1024, + 1024, + rowCount, + 10, + deletionVectorEntry, + splitWeight, + statisticsPredicate, + partitionKeys); + } + + private static DeltaLakeSplit createDeltaLakeSplit( + String path, + long length, + long start, + long fileSize, + long fileModifiedTime) + { + return new DeltaLakeSplit( + path, + start, + length, + fileSize, + Optional.empty(), + fileModifiedTime, + Optional.empty(), + SplitWeight.standard(), + TupleDomain.all(), + ImmutableMap.of()); + } + + private static DeltaLakeTableHandle createDeltaLakeTableHandle( + String schemaName, + String tableName, + boolean managed, + String location, + long readVersion) + { + return createDeltaLakeTableHandle( + schemaName, + tableName, + managed, + location, + readVersion, + TupleDomain.all(), + TupleDomain.all(), + createMetadataEntry("id", "{\"fields\": [{\"name\": \"value\", \"metadata\": {}}]}"), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + private static DeltaLakeTableHandle createDeltaLakeTableHandle( + MetadataEntry entry, + Set projectedColumns, + Optional> updatedColumns, + Optional writeType) + { + return createDeltaLakeTableHandle( + "schema", + "table", + true, + "location", + 0, + TupleDomain.all(), + TupleDomain.all(), + entry, + writeType, + Optional.of(projectedColumns), + updatedColumns, + Optional.empty()); + } + + private static DeltaLakeTableHandle createDeltaLakeTableHandle( + TupleDomain enforcedPartitionConstraint, + TupleDomain nonPartitionConstraint) + { + return createDeltaLakeTableHandle( + "schema", + "table", + true, + "location", + 0, + enforcedPartitionConstraint, + nonPartitionConstraint, + createMetadataEntry("id", "{\"fields\": [{\"name\": \"value\", \"metadata\": {}}]}"), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + private static DeltaLakeTableHandle createDeltaLakeTableHandle( + String schemaName, + String tableName, + boolean managed, + String location, + long readVersion, + TupleDomain enforcedPartitionConstraint, + TupleDomain nonPartitionConstraint, + MetadataEntry metadataEntry, + Optional writeType, + Optional> projectedColumns, + Optional> updatedColumns, + Optional analyzeHandle) + { + return new DeltaLakeTableHandle( + schemaName, + tableName, + managed, + location, + metadataEntry, + new ProtocolEntry(3, 7, Optional.empty(), Optional.empty()), + enforcedPartitionConstraint, + nonPartitionConstraint, + writeType, + projectedColumns, + updatedColumns, + updatedColumns, + analyzeHandle, + readVersion, + false); + } + + private static MetadataEntry createMetadataEntry(String id, String schema) + { + return new MetadataEntry( + id, + "name", + "description", + new MetadataEntry.Format("provider", ImmutableMap.of()), + schema, + ImmutableList.of(), + ImmutableMap.of(), + 0); + } + + static JsonCodec createJsonCodec(Class clazz) + { + ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); + TypeDeserializer typeDeserializer = new TypeDeserializer(new TestingTypeManager()); + objectMapperProvider.setJsonDeserializers( + ImmutableMap.of( + Block.class, new TestingBlockJsonSerde.Deserializer(new TestingBlockEncodingSerde()), + Type.class, typeDeserializer)); + objectMapperProvider.setJsonSerializers(ImmutableMap.of(Block.class, new TestingBlockJsonSerde.Serializer(new TestingBlockEncodingSerde()))); + return new JsonCodecFactory(objectMapperProvider).jsonCodec(clazz); + } +} diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeCacheSubqueriesTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeCacheSubqueriesTest.java new file mode 100644 index 000000000000..c97e937727df --- /dev/null +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeCacheSubqueriesTest.java @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.operator.TableScanOperator; +import io.trino.testing.BaseCacheSubqueriesTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.QueryRunner.MaterializedResultWithPlan; +import io.trino.testing.sql.TestTable; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Objects; +import java.util.stream.Stream; + +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; +import static io.trino.testing.QueryAssertions.copyTpchTables; +import static java.lang.String.format; +import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@TestInstance(PER_CLASS) +@Execution(SAME_THREAD) +public class TestDeltaLakeCacheSubqueriesTest + extends BaseCacheSubqueriesTest +{ + private Path deletionVectorTablePath; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + QueryRunner queryRunner = DeltaLakeQueryRunner.builder() + .setExtraProperties(EXTRA_PROPERTIES) + .addDeltaProperty("delta.register-table-procedure.enabled", "true") + .addDeltaProperty("delta.enable-non-concurrent-writes", "true") + .addDeltaProperty("delta.dynamic-filtering.wait-timeout", "20s") + .build(); + copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, queryRunner.getDefaultSession(), REQUIRED_TABLES); + return queryRunner; + } + + @BeforeAll + public void registerTables() + throws IOException + { + String deletionVectors = "deletion_vectors"; + Path tempDirectory = Files.createTempDirectory(deletionVectors); + deletionVectorTablePath = tempDirectory.resolve("deltalake/cache/table_with_deletion_vector"); + copyResources(Path.of(getClass().getResource("/deltalake/cache/table_with_deletion_vector").getPath()), tempDirectory); + getQueryRunner().execute(format("CALL system.register_table('%s', '%s', '%s')", getSession().getSchema().orElseThrow(), deletionVectors, deletionVectorTablePath.toString())); + } + + @Test + public void testDoNotUseCacheAfterInsert() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_do_not_use_cache", + "(name VARCHAR)", + ImmutableList.of("'value1'", "'value2'"))) { + @Language("SQL") String selectQuery = "select name from %s union all select name from %s".formatted(testTable.getName(), testTable.getName()); + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), selectQuery); + assertEqualsIgnoreOrder(result.result().getMaterializedRows().stream().map(row -> row.getField(0)).toList(), ImmutableList.of("value1", "value2", "value1", "value2")); + assertThat(getOperatorInputPositions(result.queryId(), TableScanOperator.class.getSimpleName())).isPositive(); + + assertUpdate("insert into %s(name) values ('value3')".formatted(testTable.getName()), 1); + result = executeWithPlan(withCacheEnabled(), selectQuery); + + assertThat(result.result().getRowCount()).isEqualTo(6); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isPositive(); + + result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isEqualTo(6); + assertThat(getScanOperatorInputPositions(result.queryId())).isZero(); + } + } + + @Test + public void testCacheWhenSchemaEvolved() + { + computeActual("create table orders2 with (column_mapping_mode='name') as select orderkey, orderdate, orderpriority from orders limit 100"); + @Language("SQL") String query = "select * from orders2 union all select * from orders2"; + + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), query); + assertThat(result.result().getMaterializedRows().get(0).getFieldCount()).isEqualTo(3); + assertThat(getCacheDataOperatorInputPositions(result.queryId())).isPositive(); + + // add a nullable column - schema will be evolved + assertUpdate("alter table orders2 add column c varchar"); + + // should not use cache because of schema was evolved + result = executeWithPlan(withCacheEnabled(), query); + assertThat(result.result().getMaterializedRows().get(0).getFieldCount()).isEqualTo(4); + assertThat(result.result().getMaterializedRows().stream().map(row -> row.getField(3))).allMatch(Objects::isNull); + + // drop a column - schema will be evolved + assertUpdate("alter table orders2 drop column c"); + + // should not use cache because of schema was evolved + result = executeWithPlan(withCacheEnabled(), query); + assertThat(result.result().getMaterializedRows().get(0).getFieldCount()).isEqualTo(3); + + assertUpdate("drop table orders2"); + } + + @Test + public void testDeletionVectors() + throws IOException + { + @Language("SQL") String query = "select * from (select * from deletion_vectors union all select * from deletion_vectors) order by 1"; + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), query); + assertThat(result.result().getRowCount()).isEqualTo(4); + assertThat(getCacheDataOperatorInputPositions(result.queryId())).isPositive(); + + // should use only cache + result = executeWithPlan(withCacheEnabled(), query); + assertThat(result.result().getRowCount()).isEqualTo(4); + assertThat(getScanOperatorInputPositions(result.queryId())).isZero(); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isPositive(); + + // simulate that row was deleted with deletion vector + Files.copy( + Path.of(getClass().getResource("/deltalake/cache/00000000000000000002.json").getPath()), + Path.of(deletionVectorTablePath.toString(), "_delta_log", "00000000000000000002.json"), + REPLACE_EXISTING); + + // cache should not be used because of deletion vector presence + result = executeWithPlan(withCacheEnabled(), query); + assertThat(result.result().getRowCount()).isEqualTo(2); + assertThat(result.result().getMaterializedRows()).map(row -> row.getField(0)).matches(values -> !values.contains(2)); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + } + + @Override + protected void createPartitionedTableAsSelect(String tableName, List partitionColumns, String asSelect) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s WITH (partitioned_by=array[%s]) as %s", + tableName, + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), + asSelect); + + getQueryRunner().execute(sql); + } + + @Override + protected Session withProjectionPushdownEnabled(Session session, boolean projectionPushdownEnabled) + { + return Session.builder(session) + .setSystemProperty("delta.projection_pushdown_enabled", String.valueOf(projectionPushdownEnabled)) + .build(); + } + + private void copyResources(Path resourceDirectory, Path destinationDirectory) + throws IOException + { + try (Stream files = Files.walk(resourceDirectory)) { + files.forEach(input -> { + try { + Path destination = destinationDirectory.resolve(input.subpath(resourceDirectory.getNameCount() - 3, input.getNameCount()).toString()); + if (Files.isDirectory(input)) { + Files.createDirectories(destination); + } + else { + Files.copy(input, destination); + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + } +} diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java index 779318497889..2432a5f10352 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeSplitManager.java @@ -65,6 +65,7 @@ import java.util.stream.Stream; import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; +import static io.trino.plugin.deltalake.TestDeltaLakeCacheIds.createJsonCodec; import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_FACTORY; import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_STATS; @@ -247,6 +248,7 @@ public Stream getActiveFiles( newDirectExecutorService(), deltaLakeConfig, HDFS_FILE_SYSTEM_FACTORY, + createJsonCodec(DeltaLakeCacheSplitId.class), deltaLakeTransactionManager, new DefaultCachingHostAddressProvider()); } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaPageSourceProvider.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaPageSourceProvider.java new file mode 100644 index 000000000000..143f5cdcc026 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaPageSourceProvider.java @@ -0,0 +1,226 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.filesystem.local.LocalFileSystemFactory; +import io.trino.plugin.base.metrics.FileFormatDataSourceStats; +import io.trino.plugin.deltalake.transactionlog.MetadataEntry; +import io.trino.plugin.deltalake.transactionlog.ProtocolEntry; +import io.trino.plugin.hive.parquet.ParquetReaderConfig; +import io.trino.spi.SplitWeight; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestDeltaPageSourceProvider +{ + private DeltaLakePageSourceProvider pageSourceProvider; + + @BeforeAll + public void setUp() + throws IOException + { + pageSourceProvider = new DeltaLakePageSourceProvider( + new LocalFileSystemFactory(Files.createTempDirectory("prefix")), + new FileFormatDataSourceStats(), + new ParquetReaderConfig(), + new DeltaLakeConfig(), + TESTING_TYPE_MANAGER); + } + + @Test + public void testPrunePredicate() + { + String partitionedColumn = "partitionedColumn"; + ColumnHandle partitionedColumnHandle = prepareColumnHandle(partitionedColumn, BIGINT, DeltaLakeColumnType.PARTITION_KEY); + ColumnHandle regularColumnHandle = prepareColumnHandle("regular", BIGINT, DeltaLakeColumnType.REGULAR); + MetadataEntry metadataEntry = createMetadataEntry( + ImmutableList.of(partitionedColumn), + """ + {"fields":[{"name":"partitionedColumn","type":"long","nullable":false,"metadata":{}}]}" + """); + DeltaLakeTableHandle tableHandle = createDeltaLakeTableHandle(metadataEntry, TupleDomain.all()); + TupleDomain predicate = TupleDomain.withColumnDomains(ImmutableMap.of( + partitionedColumnHandle, Domain.singleValue(BIGINT, 0L), + regularColumnHandle, Domain.singleValue(BIGINT, 0L))); + TupleDomain prunedPredicate = pageSourceProvider.prunePredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.all(), ImmutableMap.of("partitionedColumn", Optional.of("0"))), + tableHandle, + predicate); + assertThat(prunedPredicate).isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.singleValue(BIGINT, 0L)))); + + // prune data column domain if domain fully contains split data + assertThat(pageSourceProvider.prunePredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit( + TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L)))), + ImmutableMap.of("partitionedColumn", Optional.of("0"))), + tableHandle, + predicate)) + .isEqualTo(TupleDomain.all()); + } + + @Test + public void testPrunePredicateWhenSplitIsFilteredOut() + { + String partitionedColumn = "partitionedColumn"; + ColumnHandle partitionedColumnHandle = prepareColumnHandle(partitionedColumn, BIGINT, DeltaLakeColumnType.PARTITION_KEY); + ColumnHandle regularColumnHandle = prepareColumnHandle("regular", BIGINT, DeltaLakeColumnType.REGULAR); + MetadataEntry metadataEntry = createMetadataEntry( + ImmutableList.of(partitionedColumn), + """ + {"fields":[{"name":"partitionedColumn","type":"long","nullable":false,"metadata":{}}]}" + """); + DeltaLakeTableHandle tableHandle = createDeltaLakeTableHandle(metadataEntry, TupleDomain.all()); + TupleDomain predicate = TupleDomain.withColumnDomains(ImmutableMap.of( + partitionedColumnHandle, Domain.singleValue(BIGINT, 0L), + regularColumnHandle, Domain.singleValue(BIGINT, 0L))); + TupleDomain prunedPredicate = pageSourceProvider.prunePredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.all(), ImmutableMap.of("partitionedColumn", Optional.of("1"))), + tableHandle, + predicate); + assertThat(prunedPredicate).isEqualTo(TupleDomain.none()); + prunedPredicate = pageSourceProvider.prunePredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.all(), ImmutableMap.of("partitionedColumn", Optional.of("0"))), + tableHandle, + predicate); + assertThat(prunedPredicate).isNotEqualTo(TupleDomain.none()); + } + + @Test + public void testGetUnenforcedPredicate() + { + ColumnHandle regularColumnHandle = prepareColumnHandle("regular", BIGINT, DeltaLakeColumnType.REGULAR); + MetadataEntry metadataEntry = createMetadataEntry( + ImmutableList.of(), + """ + {"fields":[{"name":"partitionedColumn","type":"long","nullable":false,"metadata":{}}]}" + """); + + TupleDomain unenforcedPredicate = pageSourceProvider.getUnenforcedPredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L, 100L)))), ImmutableMap.of()), + createDeltaLakeTableHandle(metadataEntry, TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.singleValue(BIGINT, 0L)))), + TupleDomain.all()); + assertThat(unenforcedPredicate).isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.singleValue(BIGINT, 0L)))); + + unenforcedPredicate = pageSourceProvider.getUnenforcedPredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L, 100L)))), ImmutableMap.of()), + createDeltaLakeTableHandle(metadataEntry, TupleDomain.all()), + TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L, 1L, 2L))))); + assertThat(unenforcedPredicate).isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.singleValue(BIGINT, 0L)))); + + unenforcedPredicate = pageSourceProvider.getUnenforcedPredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.all(), ImmutableMap.of()), + createDeltaLakeTableHandle(metadataEntry, TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L))))), + TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L, 1L, 2L))))); + assertThat(unenforcedPredicate).isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.singleValue(BIGINT, 0L)))); + + unenforcedPredicate = pageSourceProvider.getUnenforcedPredicate( + TEST_SESSION.toConnectorSession(), + prepareSplit(TupleDomain.all(), ImmutableMap.of()), + createDeltaLakeTableHandle(metadataEntry, TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L, 1L))))), + TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.multipleValues(BIGINT, ImmutableList.of(0L))))); + assertThat(unenforcedPredicate).isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(regularColumnHandle, Domain.singleValue(BIGINT, 0L)))); + } + + private static ColumnHandle prepareColumnHandle(String name, Type type, DeltaLakeColumnType columnType) + { + return new DeltaLakeColumnHandle( + name, + type, + OptionalInt.empty(), + name, + type, + columnType, + Optional.empty() + ); + } + + private static DeltaLakeSplit prepareSplit(TupleDomain statisticsPredicate, Map> partitioningKeys) + { + return new DeltaLakeSplit( + "", + 0, + 0, + 0, + Optional.empty(), + 0, + Optional.empty(), + SplitWeight.standard(), + statisticsPredicate.transformKeys(DeltaLakeColumnHandle.class::cast), + partitioningKeys); + } + + private static DeltaLakeTableHandle createDeltaLakeTableHandle(MetadataEntry metadataEntry, TupleDomain nonPartitionConstraint) + { + return new DeltaLakeTableHandle( + "schema", + "table", + false, + "test_location", + metadataEntry, + new ProtocolEntry(1, 2, Optional.empty(), Optional.empty()), + TupleDomain.all(), + nonPartitionConstraint.transformKeys(DeltaLakeColumnHandle.class::cast), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + 0, + false); + } + + private static MetadataEntry createMetadataEntry(List partitionedColumns, String schema) + { + return new MetadataEntry( + "test_id", + "test_name", + "test_description", + new MetadataEntry.Format("test_provider", ImmutableMap.of()), + schema, + partitionedColumns, + ImmutableMap.of(), + 1); + } +} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/cache/00000000000000000002.json b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/00000000000000000002.json new file mode 100644 index 000000000000..00f135f1c8d2 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/00000000000000000002.json @@ -0,0 +1,3 @@ +{"commitInfo":{"timestamp":1682326592314,"operation":"DELETE","operationParameters":{"predicate":"[\"(spark_catalog.default.test_deletion_vectors_vsipbnhjjg.a = 2)\"]"},"readVersion":1,"isolationLevel":"WriteSerializable","isBlindAppend":false,"operationMetrics":{"numRemovedFiles":"0","numRemovedBytes":"0","numCopiedRows":"0","numDeletionVectorsAdded":"1","numDeletionVectorsRemoved":"0","numAddedChangeFiles":"0","executionTimeMs":"2046","numDeletedRows":"1","scanTimeMs":"1335","numAddedFiles":"0","numAddedBytes":"0","rewriteTimeMs":"709"},"engineInfo":"Databricks-Runtime/12.2.x-scala2.12","txnId":"219ffc4f-ff84-49d6-98a3-b0b105ce2a1e"}} +{"remove":{"path":"part-00000-0aa47759-3062-4e53-94c8-2e20a0796fee-c000.snappy.parquet","deletionTimestamp":1682326592313,"dataChange":true,"extendedFileMetadata":true,"partitionValues":{},"size":796,"tags":{"INSERTION_TIME":"1682326588000000","MIN_INSERTION_TIME":"1682326588000000","MAX_INSERTION_TIME":"1682326588000000","OPTIMIZE_TARGET_SIZE":"268435456"}}} +{"add":{"path":"part-00000-0aa47759-3062-4e53-94c8-2e20a0796fee-c000.snappy.parquet","partitionValues":{},"size":796,"modificationTime":1682326588000,"dataChange":true,"stats":"{\"numRecords\":2,\"minValues\":{\"a\":1,\"b\":11},\"maxValues\":{\"a\":2,\"b\":22},\"nullCount\":{\"a\":0,\"b\":0},\"tightBounds\":false}","tags":{"INSERTION_TIME":"1682326588000000","MIN_INSERTION_TIME":"1682326588000000","MAX_INSERTION_TIME":"1682326588000000","OPTIMIZE_TARGET_SIZE":"268435456"},"deletionVector":{"storageType":"u","pathOrInlineDv":"R7QFX3rGXPFLhHGq&7g<","offset":1,"sizeInBytes":34,"cardinality":1}}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/README.md b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/README.md new file mode 100644 index 000000000000..f30f3b1279ae --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/README.md @@ -0,0 +1,13 @@ +Data generated using Databricks 12.2: + +```sql +CREATE TABLE default.test_deletion_vectors ( + a INT, + b INT) +USING delta +LOCATION 's3://trino-ci-test/test_deletion_vectors' +TBLPROPERTIES ('delta.enableDeletionVectors' = true); + +INSERT INTO default.test_deletion_vectors VALUES (1, 11), (2, 22); +DELETE FROM default.test_deletion_vectors WHERE a = 2; +``` diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/_delta_log/00000000000000000000.json b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/_delta_log/00000000000000000000.json new file mode 100644 index 000000000000..4a5d53407173 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/_delta_log/00000000000000000000.json @@ -0,0 +1,3 @@ +{"commitInfo":{"timestamp":1682326581374,"operation":"CREATE TABLE","operationParameters":{"isManaged":"false","description":null,"partitionBy":"[]","properties":"{\"delta.enableDeletionVectors\":\"true\"}"},"isolationLevel":"WriteSerializable","isBlindAppend":true,"operationMetrics":{},"engineInfo":"Databricks-Runtime/12.2.x-scala2.12","txnId":"2cbfa481-d2b0-4f59-83f9-1261492dfd46"}} +{"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["deletionVectors"],"writerFeatures":["deletionVectors"]}} +{"metaData":{"id":"32f26f4b-95ba-4980-b209-0132e949b3e4","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"a\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"b\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{"delta.enableDeletionVectors":"true"},"createdTime":1682326580906}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/_delta_log/00000000000000000001.json b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/_delta_log/00000000000000000001.json new file mode 100644 index 000000000000..7a5e8e6418b8 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/_delta_log/00000000000000000001.json @@ -0,0 +1,2 @@ +{"commitInfo":{"timestamp":1682326587253,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"WriteSerializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"2","numOutputBytes":"796"},"engineInfo":"Databricks-Runtime/12.2.x-scala2.12","txnId":"99cd5421-a1b9-40c6-8063-7298ec935fd6"}} +{"add":{"path":"part-00000-0aa47759-3062-4e53-94c8-2e20a0796fee-c000.snappy.parquet","partitionValues":{},"size":796,"modificationTime":1682326588000,"dataChange":true,"stats":"{\"numRecords\":2,\"minValues\":{\"a\":1,\"b\":11},\"maxValues\":{\"a\":2,\"b\":22},\"nullCount\":{\"a\":0,\"b\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1682326588000000","MIN_INSERTION_TIME":"1682326588000000","MAX_INSERTION_TIME":"1682326588000000","OPTIMIZE_TARGET_SIZE":"268435456"}}} diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/deletion_vector_a52eda8c-0a57-4636-814b-9c165388f7ca.bin b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/deletion_vector_a52eda8c-0a57-4636-814b-9c165388f7ca.bin new file mode 100644 index 000000000000..66b4b7369d9f Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/deletion_vector_a52eda8c-0a57-4636-814b-9c165388f7ca.bin differ diff --git a/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/part-00000-0aa47759-3062-4e53-94c8-2e20a0796fee-c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/part-00000-0aa47759-3062-4e53-94c8-2e20a0796fee-c000.snappy.parquet new file mode 100644 index 000000000000..b4fbdc1f40bd Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/deltalake/cache/table_with_deletion_vector/part-00000-0aa47759-3062-4e53-94c8-2e20a0796fee-c000.snappy.parquet differ diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheMetadata.java new file mode 100644 index 000000000000..3b3d0c562cc1 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheMetadata.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import io.airlift.json.JsonCodec; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.predicate.TupleDomain; + +import java.util.Optional; + +import static io.trino.plugin.hive.acid.AcidTransaction.NO_ACID_TRANSACTION; +import static java.util.Objects.requireNonNull; + +public class HiveCacheMetadata + implements ConnectorCacheMetadata +{ + private final JsonCodec tableIdCodec; + private final JsonCodec columnHandleCodec; + + @Inject + public HiveCacheMetadata(JsonCodec tableIdCodec, JsonCodec columnHandleCodec) + { + this.tableIdCodec = requireNonNull(tableIdCodec, "tableIdCodec is null"); + this.columnHandleCodec = requireNonNull(columnHandleCodec, "columnHandleCodec is null"); + } + + @Override + public Optional getCacheTableId(ConnectorTableHandle tableHandle) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + + if (hiveTableHandle.getTransaction().isAcidTransactionRunning()) { + // skip caching of transactional tables as transaction affects how split rows are read + return Optional.empty(); + } + + if (hiveTableHandle.getAnalyzePartitionValues().isPresent()) { + // skip caching of analyze queries + return Optional.empty(); + } + + // Ensure cache id generation is revisited whenever handle classes change. + // Only fields that are sent to worker matter for CacheTableId. + // This constructor is used as JSON deserializer on worker nodes. + hiveTableHandle = new HiveTableHandle( + hiveTableHandle.getSchemaName(), + hiveTableHandle.getTableName(), + // columns can be skipped from table id as they are obtained separately + ImmutableList.of(), + ImmutableList.of(), + // compactEffectivePredicate is returned as part of ConnectorPageSourceProvider#getUnenforcedPredicate + TupleDomain.all(), + // enforced constraint is only enforced on partition columns, therefore it can be skipped + TupleDomain.all(), + hiveTableHandle.getBucketHandle(), + // skip bucket filter as splits are entirely embedded within buckets + Optional.empty(), + Optional.empty(), + NO_ACID_TRANSACTION); + + HiveCacheTableId tableId = new HiveCacheTableId( + hiveTableHandle.getSchemaName(), + hiveTableHandle.getTableName(), + hiveTableHandle.getBucketHandle()); + return Optional.of(new CacheTableId(tableIdCodec.toJson(tableId))); + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + HiveColumnHandle hiveColumnHandle = (HiveColumnHandle) columnHandle; + + // ensure cache id generation is revisited whenever handle classes change + HiveColumnHandle canonicalizedHandle = new HiveColumnHandle( + hiveColumnHandle.getBaseColumnName(), + hiveColumnHandle.getBaseHiveColumnIndex(), + hiveColumnHandle.getBaseHiveType(), + hiveColumnHandle.getBaseType(), + hiveColumnHandle.getHiveColumnProjectionInfo(), + hiveColumnHandle.getColumnType(), + // comment is irrelevant + Optional.empty()); + return Optional.of(new CacheColumnId(columnHandleCodec.toJson(canonicalizedHandle))); + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle tableHandle) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + return new HiveTableHandle( + hiveTableHandle.getSchemaName(), + hiveTableHandle.getTableName(), + hiveTableHandle.getTableParameters(), + hiveTableHandle.getPartitionColumns(), + hiveTableHandle.getDataColumns(), + hiveTableHandle.getPartitionNames(), + hiveTableHandle.getPartitions(), + /* + It overwrites `compactEffectivePredicate` because setting this property to `TupleDomain.all()` does not affect + final result when table is queried. It allows to match more similar subqueries that reads from same table + but has different predicates. + */ + TupleDomain.all(), + hiveTableHandle.getEnforcedConstraint(), + hiveTableHandle.getBucketHandle(), + hiveTableHandle.getBucketFilter(), + hiveTableHandle.getAnalyzePartitionValues(), + hiveTableHandle.getConstraintColumns(), + hiveTableHandle.getProjectedColumns(), + hiveTableHandle.getTransaction(), + hiveTableHandle.isRecordScannedFiles(), + hiveTableHandle.getMaxScannedFileSize()); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheSplitId.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheSplitId.java new file mode 100644 index 000000000000..519031548127 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheSplitId.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; +import io.trino.metastore.HiveTypeName; +import io.trino.plugin.hive.HiveSplit.BucketConversion; +import io.trino.plugin.hive.HiveSplit.BucketValidation; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static java.util.Objects.requireNonNull; + +public class HiveCacheSplitId +{ + private final String path; + private final long start; + private final long length; + private final long estimatedFileSize; + private final long fileModifiedTime; + private final List partitionKeys; + private final String partitionName; + private final OptionalInt readBucketNumber; + private final OptionalInt tableBucketNumber; + private final Map hiveColumnCoercions; + private final Optional bucketConversion; + private final Optional bucketValidation; + private final Schema schema; + + public HiveCacheSplitId( + String path, + long start, + long length, + long estimatedFileSize, + long fileModifiedTime, + List partitionKeys, + String partitionName, + OptionalInt readBucketNumber, + OptionalInt tableBucketNumber, + Map hiveColumnCoercions, + Optional bucketConversion, + Optional bucketValidation, + Schema schema) + { + this.path = requireNonNull(path, "path is null"); + this.start = start; + this.length = length; + this.estimatedFileSize = estimatedFileSize; + this.fileModifiedTime = fileModifiedTime; + this.partitionKeys = requireNonNull(partitionKeys, "partitionKeys is null"); + this.partitionName = requireNonNull(partitionName, "partitionName is null"); + this.readBucketNumber = requireNonNull(readBucketNumber, "readBucketNumber is null"); + this.tableBucketNumber = requireNonNull(tableBucketNumber, "tableBucketNumber is null"); + this.hiveColumnCoercions = ImmutableMap.copyOf(requireNonNull(hiveColumnCoercions, "hiveColumnCoercions is null")); + this.bucketConversion = requireNonNull(bucketConversion, "bucketConversion is null"); + this.bucketValidation = requireNonNull(bucketValidation, "bucketValidation is null"); + this.schema = requireNonNull(schema, "schema is null"); + } + + @JsonProperty + public String getPath() + { + return path; + } + + @JsonProperty + public long getStart() + { + return start; + } + + @JsonProperty + public long getLength() + { + return length; + } + + @JsonProperty + public long getEstimatedFileSize() + { + return estimatedFileSize; + } + + @JsonProperty + public long getFileModifiedTime() + { + return fileModifiedTime; + } + + @JsonProperty + public List getPartitionKeys() + { + return partitionKeys; + } + + @JsonProperty + public String getPartitionName() + { + return partitionName; + } + + @JsonProperty + public OptionalInt getReadBucketNumber() + { + return readBucketNumber; + } + + @JsonProperty + public OptionalInt getTableBucketNumber() + { + return tableBucketNumber; + } + + @JsonProperty + public Map getHiveColumnCoercions() + { + return hiveColumnCoercions; + } + + @JsonProperty + public Optional getBucketConversion() + { + return bucketConversion; + } + + @JsonProperty + public Optional getBucketValidation() + { + return bucketValidation; + } + + @JsonProperty + public Schema getSchema() + { + return schema; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheTableId.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheTableId.java new file mode 100644 index 000000000000..f87247b04e52 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveCacheTableId.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class HiveCacheTableId +{ + private final String schemaName; + private final String tableName; + private final Optional bucketHandle; + + public HiveCacheTableId( + String schemaName, + String tableName, + Optional bucketHandle) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.bucketHandle = requireNonNull(bucketHandle, "bucketHandle is null"); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public Optional getBucketHandle() + { + return bucketHandle; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnector.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnector.java index 578a6f1e9c93..bdcac0ebf341 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnector.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnector.java @@ -19,6 +19,7 @@ import io.airlift.bootstrap.LifeCycleManager; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorMetadata; import io.trino.plugin.base.session.SessionPropertiesProvider; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorMetadata; @@ -51,6 +52,7 @@ public class HiveConnector private final Injector injector; private final LifeCycleManager lifeCycleManager; private final ConnectorSplitManager splitManager; + private final ConnectorCacheMetadata cacheMetadata; private final ConnectorPageSourceProvider pageSourceProvider; private final ConnectorPageSinkProvider pageSinkProvider; private final ConnectorNodePartitioningProvider nodePartitioningProvider; @@ -77,6 +79,7 @@ public HiveConnector( LifeCycleManager lifeCycleManager, HiveTransactionManager transactionManager, ConnectorSplitManager splitManager, + ConnectorCacheMetadata cacheMetadata, ConnectorPageSourceProvider pageSourceProvider, ConnectorPageSinkProvider pageSinkProvider, ConnectorNodePartitioningProvider nodePartitioningProvider, @@ -99,6 +102,7 @@ public HiveConnector( this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.cacheMetadata = requireNonNull(cacheMetadata, "cacheMetadata is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.pageSinkProvider = requireNonNull(pageSinkProvider, "pageSinkProvider is null"); this.nodePartitioningProvider = requireNonNull(nodePartitioningProvider, "nodePartitioningProvider is null"); @@ -134,6 +138,12 @@ public ConnectorSplitManager getSplitManager() return splitManager; } + @Override + public ConnectorCacheMetadata getCacheMetadata() + { + return cacheMetadata; + } + @Override public ConnectorPageSourceProvider getPageSourceProvider() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnectorFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnectorFactory.java index 2b1eba0220f2..25c7b4c3d266 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnectorFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveConnectorFactory.java @@ -28,6 +28,7 @@ import io.trino.plugin.base.CatalogNameModule; import io.trino.plugin.base.TypeDeserializerModule; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorAccessControl; +import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorCacheMetadata; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorPageSinkProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorPageSourceProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorSplitManager; @@ -43,6 +44,7 @@ import io.trino.spi.PageIndexerFactory; import io.trino.spi.PageSorter; import io.trino.spi.VersionEmbedder; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.catalog.CatalogName; import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.Connector; @@ -129,6 +131,7 @@ public static Connector createConnector( LifeCycleManager lifeCycleManager = injector.getInstance(LifeCycleManager.class); HiveTransactionManager transactionManager = injector.getInstance(HiveTransactionManager.class); ConnectorSplitManager splitManager = injector.getInstance(ConnectorSplitManager.class); + ConnectorCacheMetadata cacheMetadata = injector.getInstance(ConnectorCacheMetadata.class); ConnectorPageSourceProvider connectorPageSource = injector.getInstance(ConnectorPageSourceProvider.class); ConnectorPageSinkProvider pageSinkProvider = injector.getInstance(ConnectorPageSinkProvider.class); ConnectorNodePartitioningProvider connectorDistributionProvider = injector.getInstance(ConnectorNodePartitioningProvider.class); @@ -150,6 +153,7 @@ public static Connector createConnector( lifeCycleManager, transactionManager, new ClassLoaderSafeConnectorSplitManager(splitManager, classLoader), + new ClassLoaderSafeConnectorCacheMetadata(cacheMetadata, classLoader), new ClassLoaderSafeConnectorPageSourceProvider(connectorPageSource, classLoader), new ClassLoaderSafeConnectorPageSinkProvider(pageSinkProvider, classLoader), new ClassLoaderSafeNodePartitioningProvider(connectorDistributionProvider, classLoader), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java index 4aedf0a0786e..98efaec82a3b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java @@ -50,6 +50,10 @@ import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hive.parquet.ParquetWriterConfig; import io.trino.plugin.hive.rcfile.RcFilePageSourceFactory; +import io.trino.plugin.hive.util.BlockJsonSerde; +import io.trino.plugin.hive.util.HiveBlockEncodingSerde; +import io.trino.spi.block.Block; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -65,6 +69,7 @@ import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.trino.plugin.base.ClosingBinder.closingBinder; import static java.util.concurrent.Executors.newCachedThreadPool; @@ -144,6 +149,18 @@ public void configure(Binder binder) fileWriterFactoryBinder.addBinding().to(RcFileFileWriterFactory.class).in(Scopes.SINGLETON); fileWriterFactoryBinder.addBinding().to(AvroFileWriterFactory.class).in(Scopes.SINGLETON); + binder.bind(ConnectorCacheMetadata.class).to(HiveCacheMetadata.class).in(Scopes.SINGLETON); + + // for table handle, column handle and split ids + jsonCodecBinder(binder).bindJsonCodec(HiveCacheTableId.class); + jsonCodecBinder(binder).bindJsonCodec(HiveCacheSplitId.class); + jsonCodecBinder(binder).bindJsonCodec(HiveColumnHandle.class); + + // bind block serializers for the purpose of TupleDomain serde + binder.bind(HiveBlockEncodingSerde.class).in(Scopes.SINGLETON); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + configBinder(binder).bindConfig(ParquetReaderConfig.class); configBinder(binder).bindConfig(ParquetWriterConfig.class); fileWriterFactoryBinder.addBinding().to(ParquetFileWriterFactory.class).in(Scopes.SINGLETON); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java index 1ce82c26c55e..73cd11438263 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java @@ -58,6 +58,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Maps.uniqueIndex; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; @@ -108,7 +109,8 @@ public ConnectorPageSource createPageSource( HiveTableHandle hiveTable = (HiveTableHandle) tableHandle; HiveSplit hiveSplit = (HiveSplit) split; - if (shouldSkipBucket(hiveTable, hiveSplit, dynamicFilter)) { + TupleDomain effectivePredicate = getUnenforcedPredicate(session, split, tableHandle, dynamicFilter.getCurrentPredicate()); + if (effectivePredicate.isNone()) { return new EmptyPageSource(); } @@ -129,12 +131,6 @@ public ConnectorPageSource createPageSource( hiveSplit.getEstimatedFileSize(), hiveSplit.getFileModifiedTime()); - // Perform dynamic partition pruning in case coordinator didn't prune split. - // This can happen when dynamic filters are collected after partition splits were listed. - if (shouldSkipSplit(columnMappings, dynamicFilter)) { - return new EmptyPageSource(); - } - Optional pageSource = createHivePageSource( pageSourceFactories, session, @@ -145,9 +141,7 @@ public ConnectorPageSource createPageSource( hiveSplit.getEstimatedFileSize(), hiveSplit.getFileModifiedTime(), hiveSplit.getSchema(), - hiveTable.getCompactEffectivePredicate().intersect( - dynamicFilter.getCurrentPredicate().transformKeys(HiveColumnHandle.class::cast)) - .simplify(domainCompactionThreshold), + effectivePredicate.transformKeys(HiveColumnHandle.class::cast), typeManager, hiveSplit.getBucketConversion(), hiveSplit.getBucketValidation(), @@ -238,18 +232,83 @@ public static Optional createHivePageSource( return Optional.empty(); } - private static boolean shouldSkipBucket(HiveTableHandle hiveTable, HiveSplit hiveSplit, DynamicFilter dynamicFilter) + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain dynamicFilter) + { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + return prunePredicate( + session, + split, + hiveTableHandle, + dynamicFilter.intersect(hiveTableHandle.getCompactEffectivePredicate())) + .simplify(domainCompactionThreshold); + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain predicate) + { + if (predicate.isNone()) { + return TupleDomain.none(); + } + + HiveTableHandle hiveTable = (HiveTableHandle) tableHandle; + HiveSplit hiveSplit = (HiveSplit) split; + + if (shouldSkipBucket(hiveTable, hiveSplit, predicate)) { + return TupleDomain.none(); + } + + List hiveColumns = predicate.getDomains().orElseThrow().keySet().stream() + .map(HiveColumnHandle.class::cast) + .collect(toImmutableList()); + + List columnMappings = ColumnMapping.buildColumnMappings( + hiveSplit.getPartitionName(), + hiveSplit.getPartitionKeys(), + hiveColumns, + hiveSplit.getBucketConversion().map(BucketConversion::bucketColumnHandles).orElse(ImmutableList.of()), + hiveSplit.getHiveColumnCoercions(), + hiveSplit.getPath(), + hiveSplit.getTableBucketNumber(), + hiveSplit.getEstimatedFileSize(), + hiveSplit.getFileModifiedTime()); + + // Perform dynamic partition pruning in case coordinator didn't prune split. + // This can happen when dynamic filters are collected after partition splits were listed. + if (shouldSkipSplit(columnMappings, predicate)) { + return TupleDomain.none(); + } + + Set prefilledColumns = columnMappings.stream() + .filter(mapping -> mapping.getKind() == PREFILLED) + .map(ColumnMapping::getHiveColumnHandle) + .collect(toImmutableSet()); + + // Exclude prefilled columns because such columns won't be used + // to filter split data when reading files. + return predicate + .filter((columnHandle, domain) -> !prefilledColumns.contains(columnHandle)); + } + + private static boolean shouldSkipBucket(HiveTableHandle hiveTable, HiveSplit hiveSplit, TupleDomain predicate) { if (hiveSplit.getTableBucketNumber().isEmpty()) { return false; } - Optional hiveBucketFilter = getHiveBucketFilter(hiveTable, dynamicFilter.getCurrentPredicate()); + Optional hiveBucketFilter = getHiveBucketFilter(hiveTable, predicate); return hiveBucketFilter.map(filter -> !filter.getBucketsToKeep().contains(hiveSplit.getTableBucketNumber().getAsInt())).orElse(false); } - private static boolean shouldSkipSplit(List columnMappings, DynamicFilter dynamicFilter) + private static boolean shouldSkipSplit(List columnMappings, TupleDomain predicate) { - TupleDomain predicate = dynamicFilter.getCurrentPredicate(); if (predicate.isNone()) { return true; } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java index 0266fef30d2d..5fe74d04611a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSplitManager.java @@ -21,6 +21,7 @@ import com.google.common.collect.Streams; import com.google.inject.Inject; import io.airlift.concurrent.BoundedExecutor; +import io.airlift.json.JsonCodec; import io.airlift.stats.CounterStat; import io.airlift.units.DataSize; import io.trino.filesystem.TrinoFileSystemFactory; @@ -36,10 +37,13 @@ import io.trino.plugin.hive.metastore.SemiTransactionalHiveMetastore; import io.trino.plugin.hive.util.HiveBucketing.HiveBucketFilter; import io.trino.plugin.hive.util.HiveUtil; +import io.trino.spi.SplitWeight; import io.trino.spi.TrinoException; import io.trino.spi.VersionEmbedder; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableHandle; @@ -70,6 +74,7 @@ import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterators.peekingIterator; import static com.google.common.collect.Iterators.singletonIterator; @@ -122,6 +127,7 @@ public class HiveSplitManager private final boolean recursiveDfsWalkerEnabled; private final CounterStat highMemorySplitSourceCounter; private final TypeManager typeManager; + private final JsonCodec splitIdCodec; private final CachingHostAddressProvider cachingHostAddressProvider; private final int maxPartitionsPerScan; @@ -134,6 +140,7 @@ public HiveSplitManager( ExecutorService executorService, VersionEmbedder versionEmbedder, TypeManager typeManager, + JsonCodec splitIdCodec, CachingHostAddressProvider cachingHostAddressProvider) { this( @@ -151,6 +158,7 @@ public HiveSplitManager( hiveConfig.getMaxSplitsPerSecond(), hiveConfig.getRecursiveDirWalkerEnabled(), typeManager, + splitIdCodec, cachingHostAddressProvider, hiveConfig.getMaxPartitionsPerScan()); } @@ -170,6 +178,7 @@ public HiveSplitManager( @Nullable Integer maxSplitsPerSecond, boolean recursiveDfsWalkerEnabled, TypeManager typeManager, + JsonCodec splitIdCodec, CachingHostAddressProvider cachingHostAddressProvider, int maxPartitionsPerScan) { @@ -188,6 +197,7 @@ public HiveSplitManager( this.maxSplitsPerSecond = firstNonNull(maxSplitsPerSecond, Integer.MAX_VALUE); this.recursiveDfsWalkerEnabled = recursiveDfsWalkerEnabled; this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.splitIdCodec = requireNonNull(splitIdCodec, "splitIdCodec is null"); this.cachingHostAddressProvider = requireNonNull(cachingHostAddressProvider, "cachingHostAddressProvider is null"); this.maxPartitionsPerScan = maxPartitionsPerScan; } @@ -297,6 +307,62 @@ public ConnectorSplitSource getSplits( return splitSource; } + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + HiveSplit hiveSplit = (HiveSplit) split; + + if (hiveSplit.getAcidInfo().isPresent()) { + // skip caching of transactional tables as transactions affect how split rows are read + return Optional.empty(); + } + + // ensure cache id generation is revisited whenever handle classes change + hiveSplit = new HiveSplit( + // database and table names are already part of table id + hiveSplit.getPartitionName(), + hiveSplit.getPath(), + hiveSplit.getStart(), + hiveSplit.getLength(), + hiveSplit.getEstimatedFileSize(), + hiveSplit.getFileModifiedTime(), + hiveSplit.getSchema(), + hiveSplit.getPartitionKeys(), + // addresses can be ignored + ImmutableList.of(), + hiveSplit.getReadBucketNumber(), + hiveSplit.getTableBucketNumber(), + // force local scheduling can be skipped + false, + hiveSplit.getHiveColumnCoercions(), + hiveSplit.getBucketConversion(), + hiveSplit.getBucketValidation(), + Optional.empty(), + // weight does not impact split rows + SplitWeight.standard()); + + return Optional.of(new CacheSplitId(splitIdCodec.toJson(new HiveCacheSplitId( + hiveSplit.getPath(), + hiveSplit.getStart(), + hiveSplit.getLength(), + hiveSplit.getEstimatedFileSize(), + hiveSplit.getFileModifiedTime(), + hiveSplit.getPartitionKeys(), + hiveSplit.getPartitionName(), + hiveSplit.getReadBucketNumber(), + hiveSplit.getTableBucketNumber(), + hiveSplit.getHiveColumnCoercions(), + hiveSplit.getBucketConversion(), + hiveSplit.getBucketValidation(), + new Schema( + hiveSplit.getSchema().serializationLibraryName(), + hiveSplit.getSchema().isFullAcidTable(), + // order schema keys to canonicalize schema map + hiveSplit.getSchema().serdeProperties().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .collect(toImmutableMap((Map.Entry entry) -> entry.getKey().toString(), Map.Entry::getValue))))))); + } + @Managed @Nested public CounterStat getHighMemorySplitSource() diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveCacheIds.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveCacheIds.java new file mode 100644 index 000000000000..cb1ab6cd732f --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveCacheIds.java @@ -0,0 +1,374 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.airlift.json.ObjectMapperProvider; +import io.trino.filesystem.Location; +import io.trino.filesystem.cache.DefaultCachingHostAddressProvider; +import io.trino.filesystem.hdfs.HdfsFileSystemFactory; +import io.trino.filesystem.memory.MemoryFileSystemFactory; +import io.trino.hdfs.HdfsConfig; +import io.trino.hdfs.HdfsConfiguration; +import io.trino.hdfs.HdfsEnvironment; +import io.trino.hdfs.authentication.NoHdfsAuthentication; +import io.trino.plugin.base.TypeDeserializer; +import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.plugin.hive.fs.CachingDirectoryLister; +import io.trino.plugin.hive.fs.TransactionScopeCachingDirectoryListerFactory; +import io.trino.plugin.hive.metastore.HiveMetastoreConfig; +import io.trino.plugin.hive.metastore.HiveMetastoreFactory; +import io.trino.plugin.hive.security.SqlStandardAccessControlMetadata; +import io.trino.plugin.hive.util.HiveBlockEncodingSerde; +import io.trino.spi.SplitWeight; +import io.trino.spi.block.Block; +import io.trino.spi.block.TestingBlockJsonSerde; +import io.trino.spi.catalog.CatalogName; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.TestingTypeManager; +import io.trino.spi.type.Type; +import io.trino.util.EmbedVersion; +import org.apache.hadoop.conf.Configuration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.ScheduledExecutorService; + +import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; +import static io.trino.metastore.HiveType.HIVE_INT; +import static io.trino.metastore.HiveType.HIVE_STRING; +import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; +import static io.trino.plugin.hive.HiveColumnHandle.createBaseColumn; +import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_STATS; +import static io.trino.plugin.hive.HiveTestUtils.getDefaultHiveFileWriterFactories; +import static io.trino.plugin.hive.acid.AcidTransaction.NO_ACID_TRANSACTION; +import static io.trino.plugin.hive.metastore.file.TestingFileHiveMetastore.createTestingFileHiveMetastore; +import static io.trino.spi.connector.MetadataProvider.NOOP_METADATA_PROVIDER; +import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestHiveCacheIds +{ + private ScheduledExecutorService executorService; + private HiveCacheMetadata metadata; + private HiveSplitManager splitManager; + + @BeforeAll + public void setup() + { + executorService = newScheduledThreadPool(1); + HiveConfig config = new HiveConfig(); + HdfsConfiguration hdfsConfiguration = (context, uri) -> new Configuration(false); + HdfsEnvironment hdfsEnvironment = new HdfsEnvironment(hdfsConfiguration, new HdfsConfig(), new NoHdfsAuthentication()); + HivePartitionManager hivePartitionManager = new HivePartitionManager(config); + HiveMetadataFactory metadataFactory = new HiveMetadataFactory( + new CatalogName("hive"), + config, + new HiveMetastoreConfig(), + HiveMetastoreFactory.ofInstance(createTestingFileHiveMetastore(new MemoryFileSystemFactory(), Location.of("memory:///"))), + getDefaultHiveFileWriterFactories(config, hdfsEnvironment), + new HdfsFileSystemFactory(hdfsEnvironment, HDFS_FILE_SYSTEM_STATS), + hivePartitionManager, + newDirectExecutorService(), + executorService, + TESTING_TYPE_MANAGER, + NOOP_METADATA_PROVIDER, + new HiveLocationService(new HdfsFileSystemFactory(hdfsEnvironment, HDFS_FILE_SYSTEM_STATS), config), + JsonCodec.jsonCodec(PartitionUpdate.class), + new NodeVersion("test_version"), + new NoneHiveRedirectionsProvider(), + ImmutableSet.of( + new PartitionsSystemTableProvider(hivePartitionManager, TESTING_TYPE_MANAGER), + new PropertiesSystemTableProvider()), + SqlStandardAccessControlMetadata::new, + new CachingDirectoryLister(config), + new TransactionScopeCachingDirectoryListerFactory(config), + true); + + metadata = new HiveCacheMetadata( + createJsonCodec(HiveCacheTableId.class), + createJsonCodec(HiveColumnHandle.class)); + splitManager = new HiveSplitManager( + config, + new HiveTransactionManager(metadataFactory), + hivePartitionManager, + new MemoryFileSystemFactory(), + executorService, + new EmbedVersion("test"), + new TestingTypeManager(), + createJsonCodec(HiveCacheSplitId.class), + new DefaultCachingHostAddressProvider()); + } + + @AfterAll + public void tearDown() + { + if (executorService != null) { + executorService.shutdownNow(); + executorService = null; + } + } + + @Test + public void testTableId() + { + HiveColumnHandle partitionColumn = createBaseColumn("col1", 0, HIVE_INT, INTEGER, PARTITION_KEY, Optional.empty()); + // column list shouldn't be part of table id + assertThat(metadata.getCacheTableId(createHiveTableHandle( + "schema", + "table", + ImmutableList.of(partitionColumn), + TupleDomain.all(), + TupleDomain.all()))) + .isEqualTo(metadata.getCacheTableId(createHiveTableHandle( + "schema", + "table", + ImmutableList.of(), + TupleDomain.all(), + TupleDomain.all()))); + + // enforced predicate shouldn't be part of table id + assertThat(metadata.getCacheTableId(createHiveTableHandle( + "schema", + "table", + ImmutableList.of(partitionColumn), + TupleDomain.all(), + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 1L)))))) + .isEqualTo(metadata.getCacheTableId(createHiveTableHandle( + "schema", + "table", + ImmutableList.of(partitionColumn), + TupleDomain.all(), + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 2L)))))); + + // effective predicate should not be part of table id + assertThat(metadata.getCacheTableId(createHiveTableHandle( + "schema", + "table", + ImmutableList.of(), + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 1L))), + TupleDomain.all()))) + .isEqualTo(metadata.getCacheTableId(createHiveTableHandle( + "schema", + "table", + ImmutableList.of(), + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, singleValue(INTEGER, 2L))), + TupleDomain.all()))); + } + + @Test + public void testColumnId() + { + HiveTableHandle tableHandle = createHiveTableHandle( + "schema", + "table", + ImmutableList.of(), + TupleDomain.all(), + TupleDomain.all()); + // comment shouldn't be part of column id + assertThat(metadata.getCacheColumnId( + tableHandle, + createBaseColumn( + "col", + 0, + HIVE_INT, + INTEGER, + PARTITION_KEY, + Optional.of("comment")))) + .isEqualTo(metadata.getCacheColumnId( + tableHandle, + createBaseColumn( + "col", + 0, + HIVE_INT, + INTEGER, + PARTITION_KEY, + Optional.of("other comment")))); + + // different column names should change column id + assertThat(metadata.getCacheColumnId( + tableHandle, + createBaseColumn( + "col1", + 0, + HIVE_INT, + INTEGER, + PARTITION_KEY, + Optional.empty()))) + .isNotEqualTo(metadata.getCacheColumnId( + tableHandle, + createBaseColumn( + "col2", + 0, + HIVE_INT, + INTEGER, + PARTITION_KEY, + Optional.empty()))); + } + + @Test + public void testSplitId() + { + Schema emptySchema = new Schema("table", true, ImmutableMap.of()); + + // table name should be stripped from id + assertThat(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", emptySchema, OptionalInt.empty()))) + .isEqualTo(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", emptySchema, OptionalInt.empty()))); + + // different properties order in schema shouldn't make ids different + Schema schema1 = new Schema("table", true, ImmutableMap.of( + "key1", "value1", + "key2", "value2")); + Schema schema2 = new Schema("table", true, ImmutableMap.of( + "key2", "value2", + "key1", "value1")); + + assertThat(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", schema1, OptionalInt.empty()))) + .isEqualTo(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", schema2, OptionalInt.empty()))); + + // different path should make ids different + assertThat(splitManager.getCacheSplitId(createHiveSplit("path1", 10, "part", emptySchema, OptionalInt.empty()))) + .isNotEqualTo(splitManager.getCacheSplitId(createHiveSplit("path2", 10, "part", emptySchema, OptionalInt.empty()))); + + // different length should make ids different + assertThat(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", emptySchema, OptionalInt.empty()))) + .isNotEqualTo(splitManager.getCacheSplitId(createHiveSplit("path", 11, "part", emptySchema, OptionalInt.empty()))); + + // different partition name should make ids different + assertThat(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part1", emptySchema, OptionalInt.empty()))) + .isNotEqualTo(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part2", emptySchema, OptionalInt.empty()))); + + // different schema should make ids different + schema1 = new Schema("table", true, ImmutableMap.of("key", "value1")); + schema2 = new Schema("table", true, ImmutableMap.of("key", "value2")); + + assertThat(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", schema1, OptionalInt.empty()))) + .isNotEqualTo(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", schema2, OptionalInt.empty()))); + + // different read bucket number should make ids different + assertThat(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", emptySchema, OptionalInt.empty()))) + .isNotEqualTo(splitManager.getCacheSplitId(createHiveSplit("path", 10, "part", emptySchema, OptionalInt.of(1)))); + } + + @Test + public void testGetCanonicalTableHandle() + { + HiveColumnHandle hiveColumnHandle = createBaseColumn("any", 0, HIVE_STRING, VARCHAR, PARTITION_KEY, Optional.empty()); + TupleDomain compactEffectivePredicate = TupleDomain.withColumnDomains(ImmutableMap.of(hiveColumnHandle, Domain.create(ValueSet.none(VARCHAR), false))); + HiveTableHandle handle = new HiveTableHandle( + "schema", + "table", + ImmutableList.of(), + ImmutableList.of(), + compactEffectivePredicate, + TupleDomain.all(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + AcidTransaction.NO_ACID_TRANSACTION); + + HiveTableHandle canonicalHandle = (HiveTableHandle) metadata.getCanonicalTableHandle(handle); + + assertThat(canonicalHandle.getSchemaName()).isEqualTo(handle.getSchemaName()); + assertThat(canonicalHandle.getTableName()).isEqualTo(handle.getTableName()); + assertThat(canonicalHandle.getPartitionColumns()).isEqualTo(handle.getPartitionColumns()); + assertThat(canonicalHandle.getDataColumns()).isEqualTo(handle.getDataColumns()); + assertThat(canonicalHandle.getCompactEffectivePredicate()).isEqualTo(TupleDomain.all()); + assertThat(canonicalHandle.getEnforcedConstraint()).isEqualTo(handle.getEnforcedConstraint()); + assertThat(canonicalHandle.getBucketHandle()).isEqualTo(handle.getBucketHandle()); + assertThat(canonicalHandle.getBucketFilter()).isEqualTo(handle.getBucketFilter()); + assertThat(canonicalHandle.getAnalyzePartitionValues()).isEqualTo(handle.getAnalyzePartitionValues()); + assertThat(canonicalHandle.getTransaction()).isEqualTo(handle.getTransaction()); + } + + private static HiveSplit createHiveSplit( + String path, + long length, + String partitionName, + Schema schema, + OptionalInt readBucketNumber) + { + return new HiveSplit( + partitionName, + path, + 0, + length, + 10, + 12, + schema, + ImmutableList.of(), + ImmutableList.of(), + readBucketNumber, + OptionalInt.empty(), + false, + ImmutableMap.of(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + SplitWeight.standard()); + } + + private static HiveTableHandle createHiveTableHandle( + String schemaName, + String tableName, + List partitionColumns, + TupleDomain compactEffectivePredicate, + TupleDomain enforcedConstraint) + { + return new HiveTableHandle( + schemaName, + tableName, + partitionColumns, + ImmutableList.of(), + compactEffectivePredicate, + enforcedConstraint, + Optional.empty(), + Optional.empty(), + Optional.empty(), + NO_ACID_TRANSACTION); + } + + private static JsonCodec createJsonCodec(Class clazz) + { + ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); + TypeDeserializer typeDeserializer = new TypeDeserializer(new TestingTypeManager()); + objectMapperProvider.setJsonDeserializers( + ImmutableMap.of( + Block.class, new TestingBlockJsonSerde.Deserializer(new HiveBlockEncodingSerde()), + Type.class, typeDeserializer)); + objectMapperProvider.setJsonSerializers(ImmutableMap.of(Block.class, new TestingBlockJsonSerde.Serializer(new HiveBlockEncodingSerde()))); + return new JsonCodecFactory(objectMapperProvider).jsonCodec(clazz); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveCacheSubqueriesTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveCacheSubqueriesTest.java new file mode 100644 index 000000000000..369e8e730c00 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveCacheSubqueriesTest.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.operator.TableScanOperator; +import io.trino.testing.BaseCacheSubqueriesTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.QueryRunner.MaterializedResultWithPlan; +import io.trino.testing.sql.TestTable; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.List; +import java.util.function.Function; + +import static java.lang.String.format; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@Execution(SAME_THREAD) +public class TestHiveCacheSubqueriesTest + extends BaseCacheSubqueriesTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return HiveQueryRunner.builder() + .setExtraProperties(EXTRA_PROPERTIES) + .setInitialTables(REQUIRED_TABLES) + .addHiveProperty("hive.dynamic-filtering.wait-timeout", "20s") + .build(); + } + + @Test + public void testDoNotUseCacheForNewData() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_do_not_use_cache", + "(name VARCHAR)", + ImmutableList.of( + "'value1'", + "'value2'"))) { + @Language("SQL") String selectQuery = "select name from %s union all select name from %s".formatted(testTable.getName(), testTable.getName()); + + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(result.result().getRowCount()).isEqualTo(4); + assertThat(getOperatorInputPositions(result.queryId(), TableScanOperator.class.getSimpleName())).isPositive(); + + assertUpdate("insert into %s(name) values ('value3')".formatted(testTable.getName()), 1); + result = executeWithPlan(withCacheEnabled(), selectQuery); + + // make sure that if underlying data was changed the second query sees changes + // and data was read from both table (newly inserted data) and from cache (existing data) + assertThat(result.result().getRowCount()).isEqualTo(6); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + } + } + + @Override + protected void createPartitionedTableAsSelect(String tableName, List partitionColumns, String asSelect) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s WITH (partitioned_by=array[%s]) as %s", + tableName, + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), + asSelect); + + getQueryRunner().execute(sql); + } + + @Override + protected Session withProjectionPushdownEnabled(Session session, boolean projectionPushdownEnabled) + { + return Session.builder(session) + .setSystemProperty("hive.projection_pushdown_enabled", String.valueOf(projectionPushdownEnabled)) + .build(); + } + + @Override + protected T withTransaction(Function transactionSessionConsumer) + { + return newTransaction().execute(getSession(), transactionSessionConsumer); + } + + @Override + protected boolean supportsDataColumnPruning() + { + return false; + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSourceProvider.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSourceProvider.java new file mode 100644 index 000000000000..febe0fdcbdf5 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSourceProvider.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.SplitWeight; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.Optional; +import java.util.OptionalInt; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.metastore.HiveType.HIVE_INT; +import static io.trino.metastore.HiveType.HIVE_STRING; +import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; +import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; +import static io.trino.plugin.hive.HiveColumnHandle.createBaseColumn; +import static io.trino.plugin.hive.acid.AcidTransaction.NO_ACID_TRANSACTION; +import static io.trino.plugin.hive.util.HiveBucketing.BucketingVersion.BUCKETING_V1; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestHivePageSourceProvider +{ + private static final HiveColumnHandle PARTITION_COLUMN = createBaseColumn("partition_col", 0, HIVE_STRING, VARCHAR, PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle DATA_COLUMN = createBaseColumn("data_col", 0, HIVE_INT, INTEGER, REGULAR, Optional.empty()); + private static final HiveColumnHandle BUCKET_COLUMN = createBaseColumn("bucket_col", 1, HIVE_INT, INTEGER, REGULAR, Optional.empty()); + private static final String PARTITION_NAME = "part1"; + private static final HiveBucketHandle HIVE_BUCKET_HANDLE = new HiveBucketHandle( + ImmutableList.of(BUCKET_COLUMN), + BUCKETING_V1, + 10, + 10, + ImmutableList.of()); + private static final Domain DATA_DOMAIN = Domain.create(ValueSet.ofRanges(Range.range(INTEGER, 1L, true, 100L, true)), false); + private static final Domain PARTITION_DOMAIN = Domain.create(ValueSet.of(VARCHAR, utf8Slice("part1")), false); + private static final HiveTableHandle HIVE_TABLE_HANDLE = new HiveTableHandle( + "schema", + "table", + ImmutableList.of(PARTITION_COLUMN), + ImmutableList.of(BUCKET_COLUMN, DATA_COLUMN), + TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, DATA_DOMAIN, + PARTITION_COLUMN, PARTITION_DOMAIN)), + TupleDomain.all(), + Optional.of(HIVE_BUCKET_HANDLE), + Optional.empty(), + Optional.empty(), + NO_ACID_TRANSACTION); + private static final HiveSplit HIVE_SPLIT = new HiveSplit( + PARTITION_NAME, + "path", + 0, + 100, + 10, + 12, + new Schema("abc", true, ImmutableMap.of()), + ImmutableList.of(new HivePartitionKey(PARTITION_COLUMN.getName(), PARTITION_NAME)), + ImmutableList.of(), + OptionalInt.empty(), + OptionalInt.of(1), + false, + ImmutableMap.of(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + SplitWeight.standard()); + private HivePageSourceProvider pageSourceProvider; + + @BeforeAll + public void setup() + { + HiveConfig config = new HiveConfig() + .setDomainCompactionThreshold(2); + pageSourceProvider = new HivePageSourceProvider( + TESTING_TYPE_MANAGER, + config, + ImmutableSet.of()); + } + + @Test + public void testGetUnenforcedPredicateCompactsData() + { + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, Domain.create(ValueSet.of(INTEGER, 1L, 10L, 20L), false))))) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, Domain.create(ValueSet.ofRanges(Range.range(INTEGER, 1L, true, 20L, true)), false)))); + } + + @Test + public void testGetUnenforcedPredicateConsidersEffectivePredicate() + { + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.all())) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(DATA_COLUMN, DATA_DOMAIN))); + + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, Domain.create(ValueSet.of(INTEGER, 1L, 10L, 110L), false))))) + // data column domain should not be simplified because it contains only 2 values after intersection with effective predicate + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, Domain.create(ValueSet.of(INTEGER, 1L, 10L), false)))); + + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, Domain.create(ValueSet.of(INTEGER, 1L, 10L, 12L), false))))) + // data column domain should be simplified because it contains 3 values after intersection with effective predicate + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of( + DATA_COLUMN, Domain.create(ValueSet.ofRanges(Range.range(INTEGER, 1L, true, 12L, true)), false)))); + } + + @Test + public void testGetUnenforcedPredicatePrunesPartitionColumn() + { + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of( + PARTITION_COLUMN, Domain.create(ValueSet.of(VARCHAR, utf8Slice("part1")), false))))) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(DATA_COLUMN, DATA_DOMAIN))); + + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of( + PARTITION_COLUMN, Domain.create(ValueSet.of(VARCHAR, utf8Slice("part2")), false))))) + .isEqualTo(TupleDomain.none()); + } + + @Test + public void testGetUnenforcedPredicateSkipsBucket() + { + Domain bucketDomain = Domain.create(ValueSet.of(INTEGER, 1L), false); + TupleDomain bucketTupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( + BUCKET_COLUMN, bucketDomain)); + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + bucketTupleDomain)) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of( + BUCKET_COLUMN, bucketDomain, + DATA_COLUMN, DATA_DOMAIN))); + + assertThat(pageSourceProvider.getUnenforcedPredicate( + SESSION, + HIVE_SPLIT, + HIVE_TABLE_HANDLE, + TupleDomain.withColumnDomains(ImmutableMap.of( + BUCKET_COLUMN, Domain.create(ValueSet.of(INTEGER, 2L), false))))) + .isEqualTo(TupleDomain.none()); + } +} diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index 517e3f7d5030..8a4bc11e22bf 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -353,6 +353,12 @@ runtime + + io.airlift + tracing + runtime + + io.opentelemetry opentelemetry-sdk-trace diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheMetadata.java new file mode 100644 index 000000000000..36cd4051893d --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheMetadata.java @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.inject.Inject; +import io.airlift.json.JsonCodec; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.predicate.TupleDomain; + +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class IcebergCacheMetadata + implements ConnectorCacheMetadata +{ + private final JsonCodec tableIdCodec; + private final JsonCodec columnHandleCodec; + + @Inject + public IcebergCacheMetadata(JsonCodec tableIdCodec, JsonCodec columnHandleCodec) + { + this.tableIdCodec = requireNonNull(tableIdCodec, "tableIdCodec is null"); + this.columnHandleCodec = requireNonNull(columnHandleCodec, "columnHandleCodec is null"); + } + + @Override + public Optional getCacheTableId(ConnectorTableHandle tableHandle) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + + if (icebergTableHandle.getSnapshotId().isEmpty()) { + // A table with missing snapshot id produces no splits + return Optional.empty(); + } + + // Ensure cache id generation is revisited whenever handle classes change. + IcebergTableHandle handle = new IcebergTableHandle( + icebergTableHandle.getCatalog(), + icebergTableHandle.getSchemaName(), + icebergTableHandle.getTableName(), + icebergTableHandle.getTableType(), + icebergTableHandle.getSnapshotId(), + icebergTableHandle.getTableSchemaJson(), + icebergTableHandle.getPartitionSpecJson(), + icebergTableHandle.getFormatVersion(), + icebergTableHandle.getUnenforcedPredicate(), + icebergTableHandle.getEnforcedPredicate(), + icebergTableHandle.getLimit(), + icebergTableHandle.getProjectedColumns(), + icebergTableHandle.getNameMappingJson(), + icebergTableHandle.getTableLocation(), + icebergTableHandle.getStorageProperties(), + icebergTableHandle.isRecordScannedFiles(), + icebergTableHandle.getMaxScannedFileSize(), + icebergTableHandle.getConstraintColumns(), + icebergTableHandle.getForAnalyze()); + + IcebergCacheTableId tableId = new IcebergCacheTableId( + handle.getCatalog(), + handle.getSchemaName(), + handle.getTableName(), + handle.getTableLocation(), + handle.getStorageProperties().entrySet().stream() + .filter(IcebergCacheTableId::isCacheableStorageProperty) + .sorted(Map.Entry.comparingByKey()) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))); + + return Optional.of(new CacheTableId(tableIdCodec.toJson(tableId))); + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle tableHandle) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + + return new IcebergTableHandle( + icebergTableHandle.getCatalog(), + icebergTableHandle.getSchemaName(), + icebergTableHandle.getTableName(), + icebergTableHandle.getTableType(), + icebergTableHandle.getSnapshotId(), + icebergTableHandle.getTableSchemaJson(), + icebergTableHandle.getPartitionSpecJson(), + icebergTableHandle.getFormatVersion(), + /* + It overwrites `unenforcedPredicate` because setting this property to `TupleDomain.all()` does not affect + final result when table is queried. It allows to match more similar subqueries that reads from same table + but has different predicates. + */ + TupleDomain.all(), + icebergTableHandle.getEnforcedPredicate(), + icebergTableHandle.getLimit(), + icebergTableHandle.getProjectedColumns(), + icebergTableHandle.getNameMappingJson(), + icebergTableHandle.getTableLocation(), + icebergTableHandle.getStorageProperties(), + icebergTableHandle.isRecordScannedFiles(), + icebergTableHandle.getMaxScannedFileSize(), + icebergTableHandle.getConstraintColumns(), + icebergTableHandle.getForAnalyze()); + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + IcebergColumnHandle icebergColumnHandle = (IcebergColumnHandle) columnHandle; + + // ensure cache id generation is revisited whenever handle classes change + IcebergColumnHandle canonicalizedHandle = new IcebergColumnHandle( + icebergColumnHandle.getBaseColumnIdentity(), + icebergColumnHandle.getBaseType(), + icebergColumnHandle.getPath(), + icebergColumnHandle.getType(), + icebergColumnHandle.isNullable(), + // comment is irrelevant + Optional.empty()); + + return Optional.of(new CacheColumnId(columnHandleCodec.toJson(canonicalizedHandle))); + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheSplitId.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheSplitId.java new file mode 100644 index 000000000000..5bbc56cefdc5 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheSplitId.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.plugin.iceberg.delete.DeleteFile; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class IcebergCacheSplitId +{ + private final String path; + private final long start; + private final long length; + private final long fileSize; + private final String partitionSpecJson; + private final String partitionDataJson; + private final List deletes; + + public IcebergCacheSplitId( + String path, + long start, + long length, + long fileSize, + String partitionSpecJson, + String partitionDataJson, + List deletes) + { + this.path = requireNonNull(path, "path is null"); + this.start = start; + this.length = length; + this.fileSize = fileSize; + this.partitionSpecJson = requireNonNull(partitionSpecJson, "partitionSpecJson is null"); + this.partitionDataJson = requireNonNull(partitionDataJson, "partitionDataJson is null"); + this.deletes = ImmutableList.copyOf(requireNonNull(deletes, "deletes is null")); + } + + @JsonProperty + public String getPath() + { + return path; + } + + @JsonProperty + public long getStart() + { + return start; + } + + @JsonProperty + public long getLength() + { + return length; + } + + @JsonProperty + public long getFileSize() + { + return fileSize; + } + + @JsonProperty + public String getPartitionSpecJson() + { + return partitionSpecJson; + } + + @JsonProperty + public String getPartitionDataJson() + { + return partitionDataJson; + } + + @JsonProperty + public List getDeletes() + { + return deletes; + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheTableId.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheTableId.java new file mode 100644 index 000000000000..34c03c293f09 --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergCacheTableId.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.connector.CatalogHandle; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class IcebergCacheTableId +{ + private static final String STORAGE_PROPERTIES_READ_PREFIX = "read."; + + private final CatalogHandle catalog; + private final String schemaName; + private final String tableName; + private final String tableLocation; + private final Map storageProperties; + + public IcebergCacheTableId( + CatalogHandle catalog, + String schemaName, + String tableName, + String tableLocation, + Map storageProperties) + { + this.catalog = requireNonNull(catalog, "catalog is null"); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.tableLocation = requireNonNull(tableLocation, "tableLocation is null"); + this.storageProperties = requireNonNull(storageProperties, "storageProperties is null"); + } + + @JsonProperty + public CatalogHandle getCatalog() + { + return catalog; + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public String getTableLocation() + { + return tableLocation; + } + + @JsonProperty + public Map getStorageProperties() + { + return storageProperties; + } + + public static boolean isCacheableStorageProperty(Map.Entry entry) + { + return entry.getKey().startsWith(STORAGE_PROPERTIES_READ_PREFIX); + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java index 8f8c53b4929a..d9849d67169f 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergConnector.java @@ -22,6 +22,7 @@ import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorMetadata; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.hive.HiveTransactionHandle; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorCapabilities; @@ -58,6 +59,7 @@ public class IcebergConnector private final LifeCycleManager lifeCycleManager; private final IcebergTransactionManager transactionManager; private final ConnectorSplitManager splitManager; + private final ConnectorCacheMetadata cacheMetadata; private final ConnectorPageSourceProviderFactory pageSourceProviderFactory; private final ConnectorPageSinkProvider pageSinkProvider; private final ConnectorNodePartitioningProvider nodePartitioningProvider; @@ -78,6 +80,7 @@ public IcebergConnector( LifeCycleManager lifeCycleManager, IcebergTransactionManager transactionManager, ConnectorSplitManager splitManager, + ConnectorCacheMetadata cacheMetadata, ConnectorPageSourceProviderFactory pageSourceProviderFactory, ConnectorPageSinkProvider pageSinkProvider, ConnectorNodePartitioningProvider nodePartitioningProvider, @@ -96,6 +99,7 @@ public IcebergConnector( this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.cacheMetadata = requireNonNull(cacheMetadata, "cacheMetadata is null"); this.pageSourceProviderFactory = requireNonNull(pageSourceProviderFactory, "pageSourceProviderFactory is null"); this.pageSinkProvider = requireNonNull(pageSinkProvider, "pageSinkProvider is null"); this.nodePartitioningProvider = requireNonNull(nodePartitioningProvider, "nodePartitioningProvider is null"); @@ -134,6 +138,12 @@ public ConnectorSplitManager getSplitManager() return splitManager; } + @Override + public ConnectorCacheMetadata getCacheMetadata() + { + return cacheMetadata; + } + @Override public ConnectorPageSourceProviderFactory getPageSourceProviderFactory() { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java index 0992e51d2375..8db02f1d6583 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergModule.java @@ -22,6 +22,7 @@ import com.google.inject.Singleton; import com.google.inject.multibindings.Multibinder; import io.trino.filesystem.cache.CacheKeyProvider; +import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorCacheMetadata; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorPageSinkProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorPageSourceProviderFactory; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorSplitManager; @@ -37,6 +38,8 @@ import io.trino.plugin.hive.orc.OrcWriterConfig; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hive.parquet.ParquetWriterConfig; +import io.trino.plugin.hive.util.BlockJsonSerde; +import io.trino.plugin.hive.util.HiveBlockEncodingSerde; import io.trino.plugin.iceberg.cache.IcebergCacheKeyProvider; import io.trino.plugin.iceberg.catalog.rest.DefaultIcebergFileSystemFactory; import io.trino.plugin.iceberg.functions.IcebergFunctionProvider; @@ -51,6 +54,8 @@ import io.trino.plugin.iceberg.procedure.RemoveOrphanFilesTableProcedure; import io.trino.plugin.iceberg.procedure.RollbackToSnapshotProcedure; import io.trino.plugin.iceberg.procedure.UnregisterTableProcedure; +import io.trino.spi.block.Block; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -69,6 +74,7 @@ import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.trino.plugin.base.ClosingBinder.closingBinder; import static java.util.concurrent.Executors.newCachedThreadPool; @@ -117,6 +123,19 @@ public void configure(Binder binder) binder.bind(FileFormatDataSourceStats.class).in(Scopes.SINGLETON); newExporter(binder).export(FileFormatDataSourceStats.class).withGeneratedName(); + binder.bind(ConnectorCacheMetadata.class).annotatedWith(ForClassLoaderSafe.class).to(IcebergCacheMetadata.class).in(Scopes.SINGLETON); + binder.bind(ConnectorCacheMetadata.class).to(ClassLoaderSafeConnectorCacheMetadata.class).in(Scopes.SINGLETON); + + // for table handle, column handle and split ids + jsonCodecBinder(binder).bindJsonCodec(IcebergCacheTableId.class); + jsonCodecBinder(binder).bindJsonCodec(IcebergCacheSplitId.class); + jsonCodecBinder(binder).bindJsonCodec(IcebergColumnHandle.class); + + // bind block serializers for the purpose of TupleDomain serde + binder.bind(HiveBlockEncodingSerde.class).in(Scopes.SINGLETON); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + binder.bind(IcebergFileWriterFactory.class).in(Scopes.SINGLETON); newExporter(binder).export(IcebergFileWriterFactory.class).withGeneratedName(); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 4b5abd936d98..44ca44f2bc14 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -299,10 +299,9 @@ public ConnectorPageSource createPageSource( .forEach(requiredColumns::add); TupleDomain effectivePredicate = getUnenforcedPredicate( - tableSchema, - partitionKeys, - dynamicFilter, + new SplitSpec(tableSchema, partitionSpec, partitionKeys), unenforcedPredicate, + dynamicFilter.getCurrentPredicate().transformKeys(IcebergColumnHandle.class::cast), fileStatisticsDomain); if (effectivePredicate.isNone()) { return new EmptyPageSource(); @@ -481,7 +480,94 @@ private ConnectorPageSource openDeletes( .get(); } - private ReaderPageSourceWithRowPositions createDataPageSource( + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain dynamicFilter) + { + IcebergSplit icebergSplit = (IcebergSplit) split; + return getUnenforcedPredicate( + getSplitSpec(tableHandle, split), + ((IcebergTableHandle) tableHandle).getUnenforcedPredicate(), + dynamicFilter, + icebergSplit.getFileStatisticsDomain()) + .transformKeys(ColumnHandle.class::cast); + } + + private TupleDomain getUnenforcedPredicate( + SplitSpec splitSpec, + TupleDomain unenforcedPredicate, + TupleDomain dynamicFilter, + TupleDomain fileStatisticsDomain) + { + return prunePredicate( + splitSpec, + // We reach here when we could not prune the split using file level stats, table predicate + // and the dynamic filter in the coordinator during split generation. The file level stats + // in IcebergSplit#fileStatisticsDomain could help to prune this split when a more selective dynamic filter + // is available now, without having to access parquet/orc file footer for row-group/stripe stats. + TupleDomain.intersect(ImmutableList.of( + unenforcedPredicate, + fileStatisticsDomain, + dynamicFilter)), + fileStatisticsDomain) + .simplify(ICEBERG_DOMAIN_COMPACTION_THRESHOLD); + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain predicate) + { + IcebergSplit icebergSplit = (IcebergSplit) split; + return prunePredicate( + getSplitSpec(tableHandle, split), + predicate.transformKeys(IcebergColumnHandle.class::cast), + icebergSplit.getFileStatisticsDomain()) + .transformKeys(ColumnHandle.class::cast); + } + + private TupleDomain prunePredicate(SplitSpec splitSpec, TupleDomain predicate, TupleDomain fileStatisticsDomain) + { + if (predicate.isNone() || predicate.isAll()) { + return predicate.transformKeys(IcebergColumnHandle.class::cast); + } + + Set partitionColumns = splitSpec.partitionKeys.keySet().stream() + .map(fieldId -> getColumnHandle(splitSpec.tableSchema.findField(fieldId), typeManager)) + .collect(toImmutableSet()); + Supplier> partitionValues = memoize(() -> getPartitionValues(partitionColumns, splitSpec.partitionKeys)); + + if (!partitionMatchesPredicate(partitionColumns, partitionValues, predicate.transformKeys(IcebergColumnHandle.class::cast))) { + return TupleDomain.none(); + } + + return predicate.transformKeys(IcebergColumnHandle.class::cast) + .filter((columnHandle, domain) -> !partitionColumns.contains(columnHandle)) + // remove domains from predicate that fully contain split data because they are irrelevant for filtering + .filter((handle, domain) -> !domain.contains(fileStatisticsDomain.getDomain(handle, domain.getType()))); + } + + private static SplitSpec getSplitSpec(ConnectorTableHandle tableHandle, ConnectorSplit split) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + IcebergSplit icebergSplit = (IcebergSplit) split; + Schema tableSchema = SchemaParser.fromJson(icebergTableHandle.getTableSchemaJson()); + PartitionSpec partitionSpec = PartitionSpecParser.fromJson(tableSchema, icebergSplit.getPartitionSpecJson()); + Map> partitionKeys = getPartitionKeys(tableSchema, partitionSpec, icebergSplit.getPartitionDataJson()); + + return new SplitSpec(tableSchema, partitionSpec, partitionKeys); + } + + private record SplitSpec(Schema tableSchema, PartitionSpec partitionSpec, Map> partitionKeys) + { + } + + public ReaderPageSourceWithRowPositions createDataPageSource( ConnectorSession session, TrinoInputFile inputFile, long start, diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java index f23983a4a2c3..000eb1f8815a 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitManager.java @@ -16,12 +16,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.Inject; +import io.airlift.json.JsonCodec; import io.airlift.units.Duration; import io.trino.filesystem.cache.CachingHostAddressProvider; import io.trino.plugin.base.classloader.ClassLoaderSafeConnectorSplitSource; import io.trino.plugin.iceberg.functions.tablechanges.TableChangesFunctionHandle; import io.trino.plugin.iceberg.functions.tablechanges.TableChangesSplitSource; +import io.trino.spi.SplitWeight; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableHandle; @@ -39,6 +43,7 @@ import org.apache.iceberg.Table; import org.apache.iceberg.util.SnapshotUtil; +import java.util.Optional; import java.util.concurrent.ExecutorService; import static io.trino.plugin.iceberg.IcebergSessionProperties.getDynamicFilteringWaitTimeout; @@ -56,6 +61,7 @@ public class IcebergSplitManager private final IcebergFileSystemFactory fileSystemFactory; private final ListeningExecutorService splitSourceExecutor; private final ExecutorService icebergPlanningExecutor; + private final JsonCodec splitIdCodec; private final CachingHostAddressProvider cachingHostAddressProvider; @Inject @@ -65,6 +71,7 @@ public IcebergSplitManager( IcebergFileSystemFactory fileSystemFactory, @ForIcebergSplitManager ListeningExecutorService splitSourceExecutor, @ForIcebergScanPlanning ExecutorService icebergPlanningExecutor, + JsonCodec splitIdCodec, CachingHostAddressProvider cachingHostAddressProvider) { this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); @@ -72,6 +79,7 @@ public IcebergSplitManager( this.fileSystemFactory = requireNonNull(fileSystemFactory, "fileSystemFactory is null"); this.splitSourceExecutor = requireNonNull(splitSourceExecutor, "splitSourceExecutor is null"); this.icebergPlanningExecutor = requireNonNull(icebergPlanningExecutor, "icebergPlanningExecutor is null"); + this.splitIdCodec = requireNonNull(splitIdCodec, "splitIdCodec is null"); this.cachingHostAddressProvider = requireNonNull(cachingHostAddressProvider, "cachingHostAddressProvider is null"); } @@ -160,4 +168,37 @@ public ConnectorSplitSource getSplits( throw new IllegalStateException("Unknown table function: " + function); } + + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + IcebergSplit icebergSplit = (IcebergSplit) split; + + // ensure cache id generation is revisited whenever handle classes change + icebergSplit = new IcebergSplit( + // database and table names are already part of table id + icebergSplit.getPath(), + icebergSplit.getStart(), + icebergSplit.getLength(), + icebergSplit.getFileSize(), + icebergSplit.getFileRecordCount(), + icebergSplit.getFileFormat(), + icebergSplit.getPartitionSpecJson(), + icebergSplit.getPartitionDataJson(), + icebergSplit.getDeletes(), + // weight does not impact split rows + SplitWeight.standard(), + icebergSplit.getFileStatisticsDomain(), + icebergSplit.getFileIoProperties(), + icebergSplit.getDataSequenceNumber()); + + return Optional.of(new CacheSplitId(splitIdCodec.toJson(new IcebergCacheSplitId( + icebergSplit.getPath(), + icebergSplit.getStart(), + icebergSplit.getLength(), + icebergSplit.getFileSize(), + icebergSplit.getPartitionSpecJson(), + icebergSplit.getPartitionDataJson(), + icebergSplit.getDeletes())))); + } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java index ecd82d188260..20ee397a0346 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergUtil.java @@ -733,6 +733,15 @@ public static Map> getPartitionKeys(FileScanTask scanT return getPartitionKeys(scanTask.file().partition(), scanTask.spec()); } + public static Map> getPartitionKeys(Schema tableSchema, PartitionSpec partitionSpec, String partitionDataJson) + { + org.apache.iceberg.types.Type[] partitionColumnTypes = partitionSpec.fields().stream() + .map(field -> field.transform().getResultType(tableSchema.findType(field.sourceId()))) + .toArray(org.apache.iceberg.types.Type[]::new); + PartitionData partitionData = PartitionData.fromJson(partitionDataJson, partitionColumnTypes); + return getPartitionKeys(partitionData, partitionSpec); + } + public static Map> getPartitionKeys(StructLike partition, PartitionSpec spec) { Map fieldToIndex = getIdentityPartitions(spec); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergCacheIds.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergCacheIds.java new file mode 100644 index 000000000000..c41c7028b9b8 --- /dev/null +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergCacheIds.java @@ -0,0 +1,488 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.airlift.json.ObjectMapperProvider; +import io.airlift.tracing.Tracing; +import io.trino.filesystem.cache.DefaultCachingHostAddressProvider; +import io.trino.plugin.base.TypeDeserializer; +import io.trino.plugin.hive.NodeVersion; +import io.trino.plugin.hive.metastore.file.FileHiveMetastoreConfig; +import io.trino.plugin.hive.metastore.file.FileHiveMetastoreFactory; +import io.trino.plugin.hive.util.HiveBlockEncodingSerde; +import io.trino.plugin.iceberg.catalog.file.FileMetastoreTableOperationsProvider; +import io.trino.plugin.iceberg.catalog.hms.TrinoHiveCatalogFactory; +import io.trino.plugin.iceberg.catalog.rest.DefaultIcebergFileSystemFactory; +import io.trino.plugin.iceberg.delete.DeleteFile; +import io.trino.spi.SplitWeight; +import io.trino.spi.block.Block; +import io.trino.spi.block.TestingBlockJsonSerde; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.catalog.CatalogName; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TestingTypeManager; +import io.trino.spi.type.Type; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Schema; +import org.apache.iceberg.types.Types; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.LocalDate; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; +import static io.trino.plugin.hive.HiveTestUtils.HDFS_FILE_SYSTEM_FACTORY; +import static io.trino.plugin.iceberg.ColumnIdentity.primitiveColumnIdentity; +import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_DAY; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.time.ZoneOffset.UTC; +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestIcebergCacheIds +{ + private static final String DATABASE_NAME = "iceberg_cache"; + private IcebergCacheMetadata icebergMetadata; + private IcebergSplitManager splitManager; + private File tempDir; + private static final AtomicInteger nextColumnId = new AtomicInteger(1); + + @BeforeAll + public void setup() + throws IOException + { + tempDir = Files.createTempDirectory(null).toFile(); + FileMetastoreTableOperationsProvider tableOperationsProvider = new FileMetastoreTableOperationsProvider(HDFS_FILE_SYSTEM_FACTORY); + IcebergConfig icebergConfig = new IcebergConfig(); + IcebergMetadataFactory icebergMetadataFactory = new IcebergMetadataFactory( + TESTING_TYPE_MANAGER, + CatalogHandle.fromId("iceberg:NORMAL:v12345"), + createJsonCodec(CommitTaskData.class), + new TrinoHiveCatalogFactory( + icebergConfig, + new CatalogName("iceberg"), + new FileHiveMetastoreFactory( + new NodeVersion("test_version"), + HDFS_FILE_SYSTEM_FACTORY, + true, + new FileHiveMetastoreConfig() + .setCatalogDirectory(tempDir.toURI().toString()) + .setMetastoreUser("user"), + Tracing.noopTracer()), + HDFS_FILE_SYSTEM_FACTORY, + TESTING_TYPE_MANAGER, + tableOperationsProvider, + new NodeVersion("test_version"), + new IcebergSecurityConfig(), + newDirectExecutorService()), + new DefaultIcebergFileSystemFactory(HDFS_FILE_SYSTEM_FACTORY), + new TableStatisticsWriter(new NodeVersion("test-version")), + Optional.empty(), + icebergConfig); + icebergMetadata = new IcebergCacheMetadata( + createJsonCodec(IcebergCacheTableId.class), + createJsonCodec(IcebergColumnHandle.class)); + splitManager = new IcebergSplitManager( + new IcebergTransactionManager(icebergMetadataFactory), + TESTING_TYPE_MANAGER, + new DefaultIcebergFileSystemFactory(HDFS_FILE_SYSTEM_FACTORY), + newDirectExecutorService(), + newDirectExecutorService(), + createJsonCodec(IcebergCacheSplitId.class), + new DefaultCachingHostAddressProvider()); + } + + @AfterAll + public void tearDown() + throws IOException + { + if (tempDir != null) { + deleteRecursively(tempDir.toPath(), ALLOW_INSECURE); + } + } + + @Test + public void testTableId() + { + IcebergColumnHandle bigIntColumnHandle = newPrimitiveColumn(BIGINT); + IcebergColumnHandle timestampColumnHandle = newPrimitiveColumn(TIMESTAMP_TZ_MICROS); + Optional partitionSpecJson = Optional.of("partitionSpecJson"); + SchemaTableName schemaTableName = new SchemaTableName(DATABASE_NAME, "testing"); + CatalogHandle catalogHandle = CatalogHandle.fromId("iceberg:NORMAL:v12345"); + + // table id without snapshot id is empty + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, Optional.empty(), "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isEqualTo(Optional.empty()); + + // `catalogHandle` should be part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isNotEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(CatalogHandle.fromId("iceberg:NORMAL:v12346"), schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // `schemaName` should be part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isNotEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, new SchemaTableName("different", "testing"), "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // `tableName` should be part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isNotEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, new SchemaTableName(DATABASE_NAME, "different"), "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // `tableSchemaJson` is not part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "different", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // `partitionSpecJson` is not part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", Optional.of("different"), Set.of(), Optional.empty(), "location"))); + + // `projectedColumns` is not part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(bigIntColumnHandle), Optional.empty(), "location"))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", Optional.of("different"), Set.of(), Optional.empty(), "location"))); + + // `nameMappingJson` is not part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", Optional.of("different"), Set.of(), Optional.of("different"), "location"))); + + // `location` should be part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))) + .isNotEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "different"))); + + // unenforce predicate should not be part of table id + assertThat(icebergMetadata.getCacheTableId(createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, TupleDomain.withColumnDomains(ImmutableMap.of(bigIntColumnHandle, singleValue(BIGINT, 1L))), TupleDomain.all(), Set.of(), Optional.empty(), "location"))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // unenforce predicate timestamp(6) should be part of table id + LocalDate someDate = LocalDate.of(2022, 3, 22); + + long startOfDateUtcEpochMillis = someDate.atStartOfDay().toEpochSecond(UTC) * MILLISECONDS_PER_SECOND; + LongTimestampWithTimeZone startOfDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis); + LongTimestampWithTimeZone startOfNextDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis + MILLISECONDS_PER_DAY); + assertThat(icebergMetadata.getCacheTableId(createIcebergTableHandle( + catalogHandle, + schemaTableName, + "tableSchemaJson", + partitionSpecJson, + TupleDomain.withColumnDomains(Map.of(timestampColumnHandle, Domain.create(ValueSet.ofRanges(Range.range(TIMESTAMP_TZ_MICROS, startOfDateUtc, true, startOfNextDateUtc, false)), false))), + TupleDomain.all(), + Set.of(), + Optional.empty(), + "location"))) + .isEqualTo(Optional.of(new CacheTableId("{\"catalog\":\"iceberg:normal:v12345\",\"schemaName\":\"iceberg_cache\",\"tableName\":\"testing\",\"tableLocation\":\"location\",\"storageProperties\":{}}"))); + + // enforce predicate is not part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, TupleDomain.all(), TupleDomain.withColumnDomains(ImmutableMap.of(bigIntColumnHandle, singleValue(BIGINT, 1L))), Set.of(), Optional.empty(), "location"))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // storage options is part of table id + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, Map.of("read.split.target-size", "1")))) + .isNotEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, "tableSchemaJson", partitionSpecJson, Set.of(), Optional.empty(), "location"))); + + // statistics in storage options support was dropped in https://github.com/trinodb/trino/pull/19803, so it is part of table id if exists + assertThat(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, Map.of("trino.stats.ndv.1231.ndv", "111", "other", "other")))) + .isEqualTo(icebergMetadata.getCacheTableId( + createIcebergTableHandle(catalogHandle, schemaTableName, Map.of("trino.stats.ndv.1231.ndv", "111", "other", "other")))); + } + + @Test + public void testSplitId() + { + String unpartitionedPartitionSpecJson = PartitionSpecParser.toJson(PartitionSpec.unpartitioned()); + String unpartitionedPartitionDataJson = PartitionData.toJson(new PartitionData(new Object[] {})); + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))); + + // different path should make ids different + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path1", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path2", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))); + + // different start position should make ids different + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 10, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))); + + // different length should make ids different + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 20, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))); + + // different fileSize should make ids different + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 100, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))); + + // different file format should make ids different + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 100, IcebergFileFormat.PARQUET, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))); + + // different partitionSpecJson should make ids different + String partitionSpecJson1 = PartitionSpecParser.toJson( + PartitionSpec.builderFor(new Schema( + List.of(Types.NestedField.required(0, "field 1", Types.IntegerType.get())))) + .build()); + String partitionSpecJson2 = PartitionSpecParser.toJson( + PartitionSpec.builderFor(new Schema( + List.of(Types.NestedField.required(0, "field 1", Types.IntegerType.get()), + Types.NestedField.required(1, "field 2", Types.IntegerType.get())))) + .build()); + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, partitionSpecJson1, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 100, IcebergFileFormat.PARQUET, partitionSpecJson2, unpartitionedPartitionDataJson, List.of()))); + + // different partitionDataJson should make ids different + assertThat(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.ORC, unpartitionedPartitionSpecJson, unpartitionedPartitionDataJson, List.of()))) + .isNotEqualTo(splitManager.getCacheSplitId(createIcebergSplit("path", 0, 10, 10, IcebergFileFormat.PARQUET, unpartitionedPartitionSpecJson, PartitionData.toJson(new PartitionData(new Long[] { + 1L})), List.of()))); + } + + @Test + public void testCacheableStorageProperty() + { + assertThat(IcebergCacheTableId.isCacheableStorageProperty(entry("not-cacheable", "test"))).isFalse(); + assertThat(IcebergCacheTableId.isCacheableStorageProperty(entry("trino.stats.ndv.1231.ndv", "test"))).isFalse(); + assertThat(IcebergCacheTableId.isCacheableStorageProperty(entry("fileloader.enabled", "test"))).isFalse(); + assertThat(IcebergCacheTableId.isCacheableStorageProperty(entry("read.parquet.vectorization.batch-size", "test"))).isTrue(); + assertThat(IcebergCacheTableId.isCacheableStorageProperty(entry("read.split.target-size", "test"))).isTrue(); + } + + private static IcebergSplit createIcebergSplit( + String path, + long start, + long length, + long fileSize, + IcebergFileFormat fileFormat, + String partitionSpecJson, + String partitionDataJson, + List deletes) + { + return new IcebergSplit( + path, + start, + length, + fileSize, + 0L, + fileFormat, + partitionSpecJson, + partitionDataJson, + deletes, + SplitWeight.standard(), + TupleDomain.all(), + ImmutableMap.of(), + 0L); + } + + private static IcebergTableHandle createIcebergTableHandle( + CatalogHandle catalogHandle, + SchemaTableName schemaTableName, + Optional snapshotId, + String tableSchemaJson, + Optional partitionSpecJson, + Set projectedColumns, + Optional nameMappingJson, + String tableLocation) + { + return new IcebergTableHandle( + catalogHandle, + schemaTableName.getSchemaName(), + schemaTableName.getTableName(), + TableType.DATA, + snapshotId, + tableSchemaJson, + partitionSpecJson, + 2, + TupleDomain.all(), + TupleDomain.all(), + OptionalLong.empty(), + projectedColumns, + nameMappingJson, + tableLocation, + Map.of(), + true, + Optional.empty(), + ImmutableSet.of(), + Optional.empty()); + } + + private static IcebergTableHandle createIcebergTableHandle( + CatalogHandle catalogHandle, + SchemaTableName schemaTableName, + String tableSchemaJson, + Optional partitionSpecJson, + Set projectedColumns, + Optional nameMappingJson, + String tableLocation) + { + return new IcebergTableHandle( + catalogHandle, + schemaTableName.getSchemaName(), + schemaTableName.getTableName(), + TableType.DATA, + Optional.of(1L), + tableSchemaJson, + partitionSpecJson, + 2, + TupleDomain.all(), + TupleDomain.all(), + OptionalLong.empty(), + projectedColumns, + nameMappingJson, + tableLocation, + Map.of(), + true, + Optional.empty(), + ImmutableSet.of(), + Optional.empty()); + } + + private static IcebergTableHandle createIcebergTableHandle( + CatalogHandle catalogHandle, + SchemaTableName schemaTableName, + String tableSchemaJson, + Optional partitionSpecJson, + TupleDomain unenforcedPredicate, + TupleDomain enforcedPredicate, + Set projectedColumns, + Optional nameMappingJson, + String tableLocation) + { + return new IcebergTableHandle( + catalogHandle, + schemaTableName.getSchemaName(), + schemaTableName.getTableName(), + TableType.DATA, + Optional.of(1L), + tableSchemaJson, + partitionSpecJson, + 2, + unenforcedPredicate, + enforcedPredicate, + OptionalLong.empty(), + projectedColumns, + nameMappingJson, + tableLocation, + Map.of(), + true, + Optional.empty(), + ImmutableSet.of(), + Optional.empty()); + } + + private static IcebergTableHandle createIcebergTableHandle( + CatalogHandle catalogHandle, + SchemaTableName schemaTableName, + Map storageProperties) + { + return new IcebergTableHandle( + catalogHandle, + schemaTableName.getSchemaName(), + schemaTableName.getTableName(), + TableType.DATA, + Optional.of(1L), + "tableSchemaJson", + Optional.of("partitionSpecJson"), + 2, + TupleDomain.all(), + TupleDomain.all(), + OptionalLong.empty(), + Set.of(), + Optional.empty(), + "tableLocation", + storageProperties, + true, + Optional.empty(), + ImmutableSet.of(), + Optional.empty()); + } + + private static LongTimestampWithTimeZone timestampTzFromEpochMillis(long epochMillis) + { + return LongTimestampWithTimeZone.fromEpochMillisAndFraction(epochMillis, 0, UTC_KEY); + } + + private static IcebergColumnHandle newPrimitiveColumn(Type type) + { + int id = nextColumnId.getAndIncrement(); + return new IcebergColumnHandle( + primitiveColumnIdentity(id, "column_" + id), + type, + ImmutableList.of(), + type, + true, + Optional.empty()); + } + + private static JsonCodec createJsonCodec(Class clazz) + { + ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); + TypeDeserializer typeDeserializer = new TypeDeserializer(new TestingTypeManager()); + objectMapperProvider.setJsonDeserializers( + ImmutableMap.of( + Block.class, new TestingBlockJsonSerde.Deserializer(new HiveBlockEncodingSerde()), + Type.class, typeDeserializer)); + objectMapperProvider.setJsonSerializers(ImmutableMap.of(Block.class, new TestingBlockJsonSerde.Serializer(new HiveBlockEncodingSerde()))); + return new JsonCodecFactory(objectMapperProvider).jsonCodec(clazz); + } +} diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergCacheSubqueriesTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergCacheSubqueriesTest.java new file mode 100644 index 000000000000..28ac3d6c5ceb --- /dev/null +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergCacheSubqueriesTest.java @@ -0,0 +1,155 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.common.collect.ImmutableList; +import io.trino.metadata.QualifiedObjectName; +import io.trino.metadata.TableHandle; +import io.trino.testing.BaseCacheSubqueriesTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.QueryRunner.MaterializedResultWithPlan; +import io.trino.testing.sql.TestTable; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.List; +import java.util.Optional; + +import static io.trino.plugin.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; +import static java.lang.String.format; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@Execution(SAME_THREAD) +public class TestIcebergCacheSubqueriesTest + extends BaseCacheSubqueriesTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setExtraProperties(EXTRA_PROPERTIES) + .setInitialTables(REQUIRED_TABLES) + .addIcebergProperty("iceberg.dynamic-filtering.wait-timeout", "20s") + .build(); + } + + @Test + public void testDoUsePartiallyCachedResultsWhenDataWasDeletedFromUnpartitionedTable() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "iceberg_do_not_cache", + "(name VARCHAR)", + ImmutableList.of("'value1'", "'value2'"))) { + // multi insert to place 2 values in single split + assertUpdate("insert into %s(name) values ('value3'), ('value4')".formatted(testTable.getName()), 2); + @Language("SQL") String selectQuery = "select name from %s union all select name from %s".formatted(testTable.getName(), testTable.getName()); + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(result.result().getRowCount()).isEqualTo(8); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + + assertUpdate("delete from %s where name='value3'".formatted(testTable.getName()), 1); + result = executeWithPlan(withCacheEnabled(), selectQuery); + + assertThat(result.result().getRowCount()).isEqualTo(6); + assertThat(result.result().getMaterializedRows().stream().noneMatch(row -> row.getField(0).equals("value3"))).isTrue(); + // split with deleted file should trigger table scan + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + // after deletion cached data is partially reused + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isPositive(); + + result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isEqualTo(6); + assertThat(getScanOperatorInputPositions(result.queryId())).isZero(); + } + } + + @Test + public void testTimeTravelQueryCache() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "iceberg_timetravel", + "(year INT, name VARCHAR) with (partitioning = ARRAY['year'])", + ImmutableList.of("2000, 'value1'", "2001, 'value2'"))) { + Optional tableHandler = withTransaction(session -> getDistributedQueryRunner().getCoordinator() + .getPlannerContext().getMetadata() + .getTableHandle(session, new QualifiedObjectName(ICEBERG_CATALOG, session.getSchema().get(), testTable.getName()))); + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandler.get().connectorHandle(); + + @Language("SQL") String selectQuery = """ + select name from %s where year = 2000 + union all + select name from %s FOR VERSION AS OF %s where year = 2000 + """.formatted(testTable.getName(), testTable.getName(), icebergTableHandle.getSnapshotId().get()); + + assertUpdate("insert into %s(year, name) values (2000, 'value3'), (2001, 'value4')".formatted(testTable.getName()), 2); + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), selectQuery); + // two rows from current snapshot and one row from previous + assertThat(result.result().getRowCount()).isEqualTo(3); + assertThat(getCacheDataOperatorInputPositions(result.queryId())).isEqualTo(2); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + + assertUpdate("delete from %s where year = 2000".formatted(testTable.getName()), 2); + result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(result.result().getRowCount()).isEqualTo(1); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isEqualTo(1); + } + } + + @Test + public void testChangeWhenSchemaEvolved() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "iceberg_schema_evolved", + "(year INT, name VARCHAR) with (partitioning = ARRAY['year'])", + ImmutableList.of("2001, 'value1'", "2001, 'value2'"))) { + + // cache data with first query + @Language("SQL") String selectQuery = "SELECT name FROM %s WHERE year = 2001".formatted(testTable.getName()); + MaterializedResultWithPlan result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(result.result().getRowCount()).isEqualTo(2); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isEqualTo(0); + assertThat(getCacheDataOperatorInputPositions(result.queryId())).isEqualTo(2); + + // partitioning is changed and a new files are added with new partitioning + assertUpdate("ALTER TABLE %s SET PROPERTIES partitioning = ARRAY['name']".formatted(testTable.getName())); + assertUpdate("INSERT INTO %s(year, name) VALUES (2000, 'value5'), (2001, 'value1')".formatted(testTable.getName()), 2); + + // data for split with old partitioning should still be read from cache + result = executeWithPlan(withCacheEnabled(), selectQuery); + assertThat(result.result().getRowCount()).isEqualTo(3); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isEqualTo(2); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + } + } + + @Override + protected void createPartitionedTableAsSelect(String tableName, List partitionColumns, String asSelect) + { + + @Language("SQL") String sql = format( + "CREATE TABLE %s WITH (partitioning=array[%s]) as %s", + tableName, + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), + asSelect); + + getQueryRunner().execute(sql); + } +} diff --git a/plugin/trino-memory-cache/pom.xml b/plugin/trino-memory-cache/pom.xml new file mode 100644 index 000000000000..e720187fd9ee --- /dev/null +++ b/plugin/trino-memory-cache/pom.xml @@ -0,0 +1,194 @@ + + + 4.0.0 + + + io.trino + trino-root + 467-SNAPSHOT + ../../pom.xml + + + trino-memory-cache + trino-plugin + Trino - Memory Cache + + + ${project.parent.basedir} + + + + + com.google.errorprone + error_prone_annotations + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + io.airlift + bootstrap + + + + io.trino + trino-memory-context + + + + io.trino + trino-plugin-toolkit + + + + it.unimi.dsi + fastutil + + + + org.weakref + jmxutils + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + io.airlift + slice + provided + + + + io.opentelemetry + opentelemetry-api + provided + + + + io.opentelemetry + opentelemetry-context + provided + + + + io.trino + trino-spi + provided + + + + org.openjdk.jol + jol-core + provided + + + + io.airlift + log + runtime + + + + io.airlift + log-manager + runtime + + + + io.airlift + node + runtime + + + + io.airlift + junit-extensions + test + + + + io.airlift + testing + test + + + + io.trino + trino-client + test + + + + io.trino + trino-spi + test-jar + test + + + + io.trino + trino-testing-services + test + + + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + + + org.assertj + assertj-core + test + + + + org.jetbrains + annotations + test + + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.openjdk.jmh + jmh-core + test + + + + org.openjdk.jmh + jmh-generator-annprocess + test + + + diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/ConcurrentCacheManager.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/ConcurrentCacheManager.java new file mode 100644 index 000000000000..ff4023e6c8f7 --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/ConcurrentCacheManager.java @@ -0,0 +1,263 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import com.google.inject.Inject; +import io.trino.memory.context.AggregatedMemoryContext; +import io.trino.memory.context.MemoryReservationHandler; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.MemoryAllocator; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.predicate.TupleDomain; +import org.weakref.jmx.Managed; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReferenceArray; + +import static io.trino.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; +import static io.trino.spi.cache.PlanSignature.canonicalizePlanSignature; +import static java.lang.Math.floorMod; +import static java.lang.Math.min; +import static java.util.Collections.shuffle; +import static java.util.Objects.requireNonNull; + +/** + * Distributed cache requests between {@link MemoryCacheManager}s thus reducing locking pressure. + */ +public class ConcurrentCacheManager + implements CacheManager +{ + private static final int CACHE_MANAGERS_COUNT = 128; + + private final MemoryAllocator revocableMemoryAllocator; + private final MemoryCacheManager[] cacheManagers; + @GuardedBy("this") + private long allocatedMemory; + + @Inject + public ConcurrentCacheManager(CacheManagerContext context) + { + this(context, false); + } + + @VisibleForTesting + ConcurrentCacheManager(CacheManagerContext context, boolean forceStore) + { + requireNonNull(context, "context is null"); + this.revocableMemoryAllocator = context.revocableMemoryAllocator(); + AggregatedMemoryContext memoryContext = newRootAggregatedMemoryContext(new CacheMemoryReservationHandler(), 0L); + cacheManagers = new MemoryCacheManager[CACHE_MANAGERS_COUNT]; + for (int i = 0; i < CACHE_MANAGERS_COUNT; i++) { + cacheManagers[i] = new MemoryCacheManager(memoryContext.newLocalMemoryContext("ignored")::trySetBytes, forceStore); + } + } + + @Managed + public int getCachedPlanSignaturesCount() + { + return Arrays.stream(cacheManagers) + .mapToInt(MemoryCacheManager::getCachedPlanSignaturesCount) + .sum(); + } + + @Managed + public int getCachedColumnIdsCount() + { + return Arrays.stream(cacheManagers) + .mapToInt(MemoryCacheManager::getCachedColumnIdsCount) + .sum(); + } + + @Managed + public int getCachedSplitsCount() + { + return Arrays.stream(cacheManagers) + .mapToInt(MemoryCacheManager::getCachedSplitsCount) + .sum(); + } + + @Override + public SplitCache getSplitCache(PlanSignature signature) + { + return new ConcurrentSplitCache(signature); + } + + @Override + public long revokeMemory(long bytesToRevoke) + { + return revokeMemory(bytesToRevoke, 10); + } + + @VisibleForTesting + long revokeMemory(long bytesToRevoke, int minElementsToRevoke) + { + // shuffle managers to prevent bias when revoking + List shuffledManagers = new ArrayList<>(Arrays.asList(cacheManagers)); + shuffle(shuffledManagers); + + // Acquire revoke lock for each MemoryCacheManager to prevent opening of new page sinks using storePages method + // (data might still be cached through opened page sinks). + for (MemoryCacheManager manager : cacheManagers) { + manager.startRevoke(); + } + try { + int elementsToRevoke = minElementsToRevoke; + long initialAllocatedMemory = getAllocatedMemory(); + long revokedMemory = 0; + // The loop below is racy against MemoryCacheManager#finishStoreChannels. However, + // because revoke locks are acquired, only a limited numbers of calls to finishStoreChannels + // can happen. Hence, the loop is guaranteed to finish. + while (revokedMemory < bytesToRevoke && getAllocatedMemory() > 0) { + for (MemoryCacheManager manager : shuffledManagers) { + manager.revokeMemory(bytesToRevoke, elementsToRevoke); + } + // increase the revoke batch size + elementsToRevoke = min(elementsToRevoke * 2, 10_000); + revokedMemory = initialAllocatedMemory - getAllocatedMemory(); + } + return revokedMemory; + } + finally { + for (MemoryCacheManager manager : cacheManagers) { + manager.finishRevoke(); + } + } + } + + private synchronized long getAllocatedMemory() + { + return allocatedMemory; + } + + private class ConcurrentSplitCache + implements SplitCache + { + private final PlanSignature signature; + private final int signatureHash; + private final AtomicReferenceArray splitCaches = new AtomicReferenceArray<>(CACHE_MANAGERS_COUNT); + + public ConcurrentSplitCache(PlanSignature signature) + { + this.signature = requireNonNull(signature, "signature is null"); + this.signatureHash = canonicalizePlanSignature(signature).hashCode(); + } + + @Override + public Optional loadPages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + return getSplitCache(splitId).loadPages(splitId, predicate, unenforcedPredicate); + } + + @Override + public Optional storePages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + return getSplitCache(splitId).storePages(splitId, predicate, unenforcedPredicate); + } + + @Override + public void close() + throws IOException + { + for (int i = 0; i < splitCaches.length(); i++) { + SplitCache cache = splitCaches.getAndSet(i, null); + if (cache != null) { + cache.close(); + } + } + } + + private SplitCache getSplitCache(CacheSplitId splitId) + { + int index = getCacheManagerIndex(signatureHash, splitId); + SplitCache splitCache = splitCaches.get(index); + if (splitCache != null) { + return splitCache; + } + + splitCache = cacheManagers[index].getSplitCache(signature); + if (!splitCaches.compareAndSet(index, null, splitCache)) { + // split cache instance was set concurrently + try { + splitCache.close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + splitCache = requireNonNull(splitCaches.get(index)); + } + + return splitCache; + } + } + + @VisibleForTesting + MemoryCacheManager getCacheManager(PlanSignature signature, CacheSplitId splitId) + { + int signatureHash = canonicalizePlanSignature(signature).hashCode(); + return cacheManagers[getCacheManagerIndex(signatureHash, splitId)]; + } + + @VisibleForTesting + MemoryCacheManager[] getCacheManagers() + { + return cacheManagers; + } + + private static int getCacheManagerIndex(int signatureHash, CacheSplitId splitId) + { + return floorMod(Objects.hash(signatureHash, splitId.hashCode()), CACHE_MANAGERS_COUNT); + } + + private class CacheMemoryReservationHandler + implements MemoryReservationHandler + { + @Override + public ListenableFuture reserveMemory(String allocationTag, long delta) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean tryReserveMemory(String allocationTag, long delta) + { + if (delta == 0) { + // empty allocation should always succeed + return true; + } + + synchronized (ConcurrentCacheManager.this) { + if (!revocableMemoryAllocator.trySetBytes(allocatedMemory + delta)) { + return false; + } + + allocatedMemory += delta; + return true; + } + } + } +} diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheManager.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheManager.java new file mode 100644 index 000000000000..1fd4421e9120 --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheManager.java @@ -0,0 +1,817 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.LinkedListMultimap; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.MemoryAllocator; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.predicate.TupleDomain; +import it.unimi.dsi.fastutil.longs.Long2ObjectMap; +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.BooleanSupplier; +import java.util.function.Supplier; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Lists.reverse; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.trino.spi.cache.PlanSignature.canonicalizePlanSignature; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; + +/** + * {@link CacheManager} implementation that caches split pages in revocable memory. + *

+ * Cache structure essentially consists of multimap: + *

+ * (CanonicalPlanSignature, ColumnID, SplitID) -> [(StoreID1, ColumnBlocks1), (StoreID2, ColumnBlocks2), ...]
+ * 
+ * Therefore, cache operates at column level rather than page level. Hence, cache can serve requests + * for subset of cached columns that share same {@code StoreID}. + *

+ * Whenever pages are cached a unique {@code StoreID} is assigned and all cached columns share that ID. + * This is required, because two table scans (for different subset of columns) could produce slightly + * different blocks (e.g. due to adaptive dynamic row filtering). It also means that there might be multiple + * entries for single {@code (CanonicalPlanSignature, ColumnID, SplitID)}. When fetching pages from cache, + * all cached entries for all columns must share same {@code StoreID}. + *

+ * {@link MemoryCacheManager} does not have support for any filtering adaptation. + */ +public class MemoryCacheManager + implements CacheManager +{ + // based on SizeOf.estimatedSizeOf(java.util.Map, java.util.function.ToLongFunction, java.util.function.ToLongFunction) + static final int MAP_ENTRY_SIZE = instanceSize(AbstractMap.SimpleEntry.class); + + private static final int MAP_SIZE_LIMIT = 1_000_000_000; + static final int MAX_CACHED_CHANNELS_PER_COLUMN = 20; + + private final MemoryAllocator revocableMemoryAllocator; + private final boolean forceStore; + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock revokeLock = new ReentrantReadWriteLock(); + @GuardedBy("lock") + private final LinkedListMultimap splitCache = LinkedListMultimap.create(); + @GuardedBy("lock") + private final ObjectToIdMap signatureToId = new ObjectToIdMap<>(PlanSignature::getRetainedSizeInBytes); + @GuardedBy("lock") + private final ObjectToIdMap columnToId = new ObjectToIdMap<>(CacheColumnId::getRetainedSizeInBytes); + @GuardedBy("lock") + private final ObjectToIdMap> predicateToId = new ObjectToIdMap<>(predicate -> predicate.getRetainedSizeInBytes(CacheColumnId::getRetainedSizeInBytes)); + private final AtomicLong nextStoreId = new AtomicLong(); + @GuardedBy("lock") + private long cacheRevocableBytes; + + public MemoryCacheManager(MemoryAllocator memoryAllocator, boolean forceStore) + { + this.revocableMemoryAllocator = requireNonNull(memoryAllocator, "memoryAllocator is null"); + this.forceStore = forceStore; + } + + @Override + public SplitCache getSplitCache(PlanSignature signature) + { + return new MemorySplitCache(allocateSignatureId(signature)); + } + + @Override + public long revokeMemory(long bytesToRevoke) + { + return revokeMemory(bytesToRevoke, Integer.MAX_VALUE); + } + + public long revokeMemory(long bytesToRevoke, int maxElementsToRevoke) + { + checkArgument(bytesToRevoke >= 0); + return runWithLock(lock.writeLock(), () -> { + long initialRevocableBytes = getRevocableBytes(); + return removeEldestSplits(() -> initialRevocableBytes - getRevocableBytes() >= bytesToRevoke, maxElementsToRevoke); + }); + } + + public void startRevoke() + { + revokeLock.writeLock().lock(); + } + + public void finishRevoke() + { + revokeLock.writeLock().unlock(); + } + + public long getRevocableBytes() + { + return runWithLock(lock.readLock(), () -> + cacheRevocableBytes + + signatureToId.getRevocableBytes() + + columnToId.getRevocableBytes() + + predicateToId.getRevocableBytes()); + } + + public int getCachedPlanSignaturesCount() + { + return runWithLock(lock.readLock(), signatureToId::size); + } + + public int getCachedColumnIdsCount() + { + return runWithLock(lock.readLock(), columnToId::size); + } + + public int getCachedSplitsCount() + { + return runWithLock(lock.readLock(), splitCache::size); + } + + private Optional loadPages(SignatureIds ids, CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + checkPredicates(ids, predicate, unenforcedPredicate); + return getLoadedChannelsWithSameStoreId(ids, splitId, predicate, unenforcedPredicate) + .map(channels -> new MemoryCachePageSource(updateChannels(channels.toArray(new Channel[0])))); + } + + private Optional storePages(SignatureIds ids, CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + checkPredicates(ids, predicate, unenforcedPredicate); + + // no column queries cannot be cached + if (ids.columnIds().length == 0) { + return Optional.empty(); + } + + if (hasChannelsWithSameStoreId(ids, splitId, predicate, unenforcedPredicate) && !forceStore) { + // split is already cached or currently being stored + return Optional.empty(); + } + + return createPageSink(ids, splitId, predicate, unenforcedPredicate, nextStoreId.getAndIncrement()); + } + + private void checkPredicates(SignatureIds ids, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + checkArgument(ids.columnSet().containsAll(predicate.getDomains().orElse(ImmutableMap.of()).keySet()), "Predicate references missing column"); + checkArgument(ids.columnSet().containsAll(unenforcedPredicate.getDomains().orElse(ImmutableMap.of()).keySet()), "Unenforced predicate references missing column"); + } + + private Channel[] updateChannels(Channel[] channels) + { + // make channels the freshest in cache + runWithLock(lock.writeLock(), () -> { + for (Channel channel : channels) { + if (removeChannel(channel)) { + // channel might have been purged in the meantime + splitCache.put(channel.getKey(), channel); + } + } + }); + return channels; + } + + private Optional createPageSink( + SignatureIds ids, + CacheSplitId splitId, + TupleDomain predicate, + TupleDomain unenforcedPredicate, + long storeId) + { + return runWithLock(lock.writeLock(), () -> { + PredicateIds predicateIds = allocatePredicateIds(predicate, unenforcedPredicate, ids.columnIds().length); + + SplitKey[] keys = getSplitKeys(ids, predicateIds, splitId); + Channel[] channels = new Channel[keys.length]; + for (int i = 0; i < keys.length; i++) { + channels[i] = new Channel(keys[i], storeId); + } + + // increment non-revocable reference count for ids used by sink channels + signatureToId.acquireId(ids.signatureId(), keys.length); + for (int i = 0; i < keys.length; i++) { + columnToId.acquireId(keys[i].columnId()); + splitCache.put(keys[i], channels[i]); + } + + return Optional.of(new MemoryCachePageSink(ids, predicateIds, channels)); + }); + } + + private boolean hasChannelsWithSameStoreId(SignatureIds ids, CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + return getChannelsWithSameStoreId(ids, splitId, predicate, unenforcedPredicate, false).long2ObjectEntrySet().stream() + // find store id that contain channels for all columns + .anyMatch(entry -> entry.getValue().size() == ids.columnIds().length); + } + + private Optional> getLoadedChannelsWithSameStoreId( + SignatureIds ids, + CacheSplitId splitId, + TupleDomain predicate, + TupleDomain unenforcedPredicate) + { + return getChannelsWithSameStoreId(ids, splitId, predicate, unenforcedPredicate, true).long2ObjectEntrySet().stream() + // filter store ids that contain channels for all columns + .filter(entry -> entry.getValue().size() == ids.columnIds().length) + // get channels with the newest store id + .sorted(comparing(entry -> -entry.getLongKey())) + .map(Map.Entry::getValue) + .findAny(); + } + + private Long2ObjectMap> getChannelsWithSameStoreId( + SignatureIds ids, + CacheSplitId splitId, + TupleDomain predicate, + TupleDomain unenforcedPredicate, + boolean onlyLoaded) + { + if (ids.columnIds().length == 0) { + return new Long2ObjectOpenHashMap<>(); + } + + Long2ObjectMap> channels = new Long2ObjectOpenHashMap<>(ids.columnIds().length); + runWithLock(lock.readLock(), () -> { + Optional predicateIds = getPredicateIds(predicate, unenforcedPredicate); + if (predicateIds.isEmpty()) { + // missing predicate ids, hence data cannot be cached + return; + } + + SplitKey[] keys = getSplitKeys(ids, predicateIds.get(), splitId); + getColumnChannels(keys[0], 0, onlyLoaded, channels); + for (int i = 1; i < keys.length; i++) { + getColumnChannels(keys[i], i, onlyLoaded, channels); + } + }); + + return channels; + } + + private SplitKey[] getSplitKeys(SignatureIds ids, PredicateIds predicateIds, CacheSplitId splitId) + { + SplitKey[] keys = new SplitKey[ids.columnIds().length]; + for (int i = 0; i < ids.columnIds().length; i++) { + long columnId = ids.columnIds()[i]; + keys[i] = new SplitKey(ids.signatureId(), columnId, splitId, predicateIds.predicateId(), predicateIds.unenforcedPredicateId()); + } + return keys; + } + + private void getColumnChannels(SplitKey key, int channelIndex, boolean onlyLoaded, Long2ObjectMap> channels) + { + // fetch MAX_CACHED_CHANNELS_PER_COLUMN latest cached channels + reverse(splitCache.get(key)).stream() + .limit(MAX_CACHED_CHANNELS_PER_COLUMN) + .forEach(channel -> { + if (onlyLoaded && !channel.isLoaded()) { + return; + } + + if (channelIndex == 0) { + List list = new ArrayList<>(); + list.add(channel); + checkState(channels.put(channel.getStoreId(), list) == null); + } + else { + List list = channels.get(channel.getStoreId()); + if (list != null && list.size() == channelIndex) { + list.add(channel); + } + } + }); + } + + private void finishStoreChannels(SignatureIds signatureIds, PredicateIds predicateIds, Channel[] channels) + { + runWithLock(lock.writeLock(), () -> { + checkState(signatureToId.getTotalUsageCount(signatureIds.signatureId()) >= channels.length, "Signature id must not be released while split is cached"); + long initialRevocableBytes = getRevocableBytes(); + + // make memory retained by ids revocable + signatureToId.acquireRevocableId(signatureIds.signatureId(), channels.length); + predicateToId.acquireRevocableId(predicateIds.predicateId(), channels.length); + predicateToId.acquireRevocableId(predicateIds.unenforcedPredicateId(), channels.length); + for (long columnId : signatureIds.columnIds()) { + columnToId.acquireRevocableId(columnId); + } + + // account for channels retained size + long entriesSize = 0L; + for (Channel channel : channels) { + channel.setLoaded(); + entriesSize += getCacheEntrySize(channel); + } + + long currentRevocableBytes = getRevocableBytes(); + checkState(currentRevocableBytes >= initialRevocableBytes); + if (!revocableMemoryAllocator.trySetBytes(currentRevocableBytes + entriesSize)) { + // not sufficient memory to store split pages + signatureToId.releaseRevocableId(signatureIds.signatureId(), channels.length); + predicateToId.releaseRevocableId(predicateIds.predicateId(), channels.length); + predicateToId.releaseRevocableId(predicateIds.unenforcedPredicateId(), channels.length); + for (long columnId : signatureIds.columnIds()) { + columnToId.releaseRevocableId(columnId); + } + abortStoreChannels(signatureIds, predicateIds, channels); + return; + } + + // dereference non-revocable ids + signatureToId.releaseId(signatureIds.signatureId(), channels.length); + predicateToId.releaseId(predicateIds.predicateId(), channels.length); + predicateToId.releaseId(predicateIds.unenforcedPredicateId(), channels.length); + for (long columnId : signatureIds.columnIds()) { + columnToId.releaseId(columnId); + } + + cacheRevocableBytes += entriesSize; + removeEldestChannels(channels); + }); + } + + private void abortStoreChannels(SignatureIds signatureIds, PredicateIds predicateIds, Channel[] channels) + { + runWithLock(lock.writeLock(), () -> { + signatureToId.releaseId(signatureIds.signatureId(), channels.length); + predicateToId.releaseId(predicateIds.predicateId(), channels.length); + predicateToId.releaseId(predicateIds.unenforcedPredicateId(), channels.length); + for (Channel channel : channels) { + checkState(removeChannel(channel), "Expected channel to be removed"); + columnToId.releaseId(channel.getKey().columnId()); + } + }); + } + + private boolean removeChannel(Channel channel) + { + boolean removed = false; + // Multimap remove(key, elem) can take significant about of time if list of elements + // for a given key is large. However, aborted channels are usually the latest elements, + // therefore we can search for a given channel by reversing the elements list. + // Ideally, we could keep pointer to a Channel entry in a LinkedListMultimap, but the API + // doesn't expose that. + for (Iterator iterator = reverse(splitCache.get(channel.getKey())).iterator(); iterator.hasNext(); ) { + if (iterator.next() == channel) { + iterator.remove(); + removed = true; + break; + } + } + return removed; + } + + /** + * Removes the eldest channels for a given split that exceed MAX_CACHED_CHANNELS_PER_COLUMN size threshold. + */ + private void removeEldestChannels(Channel[] splitChannels) + { + runWithLock(lock.writeLock(), () -> { + long initialRevocableBytes = getRevocableBytes(); + for (Channel splitChannel : splitChannels) { + SplitKey key = splitChannel.getKey(); + List channels = splitCache.get(key); + int counter = channels.size() - MAX_CACHED_CHANNELS_PER_COLUMN; + for (Iterator iterator = channels.iterator(); iterator.hasNext() && counter > 0; counter--) { + Channel channel = iterator.next(); + + if (!channel.isLoaded()) { + continue; + } + + iterator.remove(); + + signatureToId.releaseRevocableId(key.signatureId()); + columnToId.releaseRevocableId(key.columnId()); + predicateToId.releaseRevocableId(key.predicateId()); + predicateToId.releaseRevocableId(key.unenforcedPredicateId()); + + cacheRevocableBytes -= getCacheEntrySize(channel); + } + } + checkState(cacheRevocableBytes >= 0); + long currentRevocableBytes = getRevocableBytes(); + checkState(initialRevocableBytes >= currentRevocableBytes); + checkState(revocableMemoryAllocator.trySetBytes(currentRevocableBytes)); + }); + } + + private long removeEldestSplits(BooleanSupplier stopCondition, int maxElementsToRevoke) + { + return runWithLock(lock.writeLock(), () -> { + if (splitCache.isEmpty()) { + // no splits to remove + return 0L; + } + + long initialRevocableBytes = getRevocableBytes(); + int elementsToRevoke = maxElementsToRevoke; + for (Iterator> iterator = splitCache.entries().iterator(); iterator.hasNext(); ) { + if (stopCondition.getAsBoolean() || elementsToRevoke <= 0) { + break; + } + + Map.Entry entry = iterator.next(); + SplitKey key = entry.getKey(); + Channel channel = entry.getValue(); + + // skip unloaded entries + if (!channel.isLoaded()) { + continue; + } + + iterator.remove(); + + signatureToId.releaseRevocableId(key.signatureId()); + columnToId.releaseRevocableId(key.columnId()); + predicateToId.releaseRevocableId(key.predicateId()); + predicateToId.releaseRevocableId(key.unenforcedPredicateId()); + + elementsToRevoke--; + cacheRevocableBytes -= getCacheEntrySize(channel); + } + checkState(cacheRevocableBytes >= 0); + + // freeing memory should always succeed, while any non-negative allocation might return false + long currentRevocableBytes = getRevocableBytes(); + checkState(initialRevocableBytes >= currentRevocableBytes); + checkState(revocableMemoryAllocator.trySetBytes(currentRevocableBytes)); + return initialRevocableBytes - currentRevocableBytes; + }); + } + + private SignatureIds allocateSignatureId(PlanSignature signature) + { + PlanSignature canonicalSignature = canonicalizePlanSignature(signature); + Set columnSet = ImmutableSet.copyOf(signature.getColumns()); + return runWithLock(lock.writeLock(), () -> { + // allocate non-revocable ids for signature and columns + long signatureId = signatureToId.allocateId(canonicalSignature); + long[] columnIds = new long[signature.getColumns().size()]; + for (int i = 0; i < columnIds.length; i++) { + columnIds[i] = columnToId.allocateId(signature.getColumns().get(i)); + } + return new SignatureIds(signatureId, columnSet, columnIds, signature.getColumns()); + }); + } + + private record SignatureIds(long signatureId, Set columnSet, long[] columnIds, List columns) {} + + private void releaseSignatureIds(SignatureIds ids) + { + runWithLock(lock.writeLock(), () -> { + signatureToId.releaseId(ids.signatureId()); + for (long columnId : ids.columnIds()) { + columnToId.releaseId(columnId); + } + }); + } + + private Optional getPredicateIds(TupleDomain predicate, TupleDomain unenforcedPredicate) + { + return runWithLock(lock.readLock(), () -> predicateToId.getId(predicate) + .flatMap(predicateId -> predicateToId.getId(unenforcedPredicate) + .map(unenforcedPredicateId -> new PredicateIds(predicateId, unenforcedPredicateId)))); + } + + private PredicateIds allocatePredicateIds(TupleDomain predicate, TupleDomain unenforcedPredicate, int count) + { + // allocate non-revocable ids for predicates + return runWithLock(lock.writeLock(), () -> new PredicateIds(predicateToId.allocateId(predicate, count), predicateToId.allocateId(unenforcedPredicate, count))); + } + + private record PredicateIds(long predicateId, long unenforcedPredicateId) {} + + private static long getCacheEntrySize(Channel channel) + { + return MAP_ENTRY_SIZE + channel.getKey().getRetainedSizeInBytes() + channel.getRetainedSizeInBytes(); + } + + private static void runWithLock(Lock lock, Runnable runnable) + { + lock.lock(); + try { + runnable.run(); + } + finally { + lock.unlock(); + } + } + + private static T runWithLock(Lock lock, Supplier supplier) + { + lock.lock(); + try { + return supplier.get(); + } + finally { + lock.unlock(); + } + } + + private class MemorySplitCache + implements SplitCache + { + private final SignatureIds ids; + private volatile boolean closed; + + private MemorySplitCache(SignatureIds ids) + { + this.ids = ids; + } + + @Override + public Optional loadPages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + checkState(!closed, "MemorySplitCache already closed"); + return MemoryCacheManager.this.loadPages(ids, splitId, predicate, unenforcedPredicate); + } + + @Override + public Optional storePages(CacheSplitId splitId, TupleDomain predicate, TupleDomain unenforcedPredicate) + { + checkState(!closed, "MemorySplitCache already closed"); + if (!revokeLock.readLock().tryLock()) { + return Optional.empty(); + } + try { + return MemoryCacheManager.this.storePages(ids, splitId, predicate, unenforcedPredicate); + } + finally { + revokeLock.readLock().unlock(); + } + } + + @Override + public void close() + { + checkState(!closed, "MemorySplitCache already closed"); + closed = true; + releaseSignatureIds(ids); + // prevent cache overflow + removeEldestSplits(() -> splitCache.size() <= MAP_SIZE_LIMIT && signatureToId.size() <= MAP_SIZE_LIMIT && columnToId.size() <= MAP_SIZE_LIMIT, Integer.MAX_VALUE); + } + } + + private class MemoryCachePageSink + implements ConnectorPageSink + { + private final SignatureIds signatureIds; + private final PredicateIds predicateIds; + private final Channel[] channels; + private final List[] blocks; + private long memoryUsageBytes; + private boolean finished; + + public MemoryCachePageSink(SignatureIds signatureIds, PredicateIds predicateIds, Channel[] channels) + { + this.signatureIds = requireNonNull(signatureIds, "signatureIds is null"); + this.predicateIds = requireNonNull(predicateIds, "predicateIds is null"); + this.channels = requireNonNull(channels, "channels is null"); + // noinspection unchecked + this.blocks = (List[]) new List[channels.length]; + for (int i = 0; i < blocks.length; i++) { + blocks[i] = new ArrayList<>(); + } + } + + @Override + public long getMemoryUsage() + { + return memoryUsageBytes; + } + + @Override + public CompletableFuture appendPage(Page page) + { + for (int i = 0; i < channels.length; i++) { + // Compact the block + Block block = page.getBlock(i); + block = block.copyRegion(0, block.getPositionCount()); + blocks[i].add(block); + memoryUsageBytes += block.getRetainedSizeInBytes(); + } + return completedFuture(null); + } + + @Override + public CompletableFuture> finish() + { + checkState(!finished); + for (int i = 0; i < channels.length; i++) { + channels[i].setBlocks(blocks[i].toArray(new Block[0])); + } + finishStoreChannels(signatureIds, predicateIds, channels); + finished = true; + return completedFuture(ImmutableList.of()); + } + + @Override + public void abort() + { + checkState(!finished); + abortStoreChannels(signatureIds, predicateIds, channels); + finished = true; + } + } + + @VisibleForTesting + static final class SplitKey + { + static final int INSTANCE_SIZE = instanceSize(SplitKey.class); + private final long signatureId; + private final long columnId; + private final CacheSplitId splitId; + private final long predicateId; + private final long unenforcedPredicateId; + + private volatile int hashCode; + + SplitKey(long signatureId, long columnId, CacheSplitId splitId, long predicateId, long unenforcedPredicateId) + { + this.signatureId = signatureId; + this.columnId = columnId; + this.splitId = requireNonNull(splitId, "splitId is null"); + this.predicateId = predicateId; + this.unenforcedPredicateId = unenforcedPredicateId; + } + + public long signatureId() + { + return signatureId; + } + + public long columnId() + { + return columnId; + } + + public long predicateId() + { + return predicateId; + } + + public long unenforcedPredicateId() + { + return unenforcedPredicateId; + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + splitId.getRetainedSizeInBytes(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SplitKey splitKey = (SplitKey) o; + return signatureId == splitKey.signatureId + && columnId == splitKey.columnId + && splitId.equals(splitKey.splitId) + && predicateId == splitKey.predicateId + && unenforcedPredicateId == splitKey.unenforcedPredicateId; + } + + @Override + public int hashCode() + { + if (hashCode == 0) { + hashCode = Objects.hash(signatureId, columnId, splitId, predicateId, unenforcedPredicateId); + } + return hashCode; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("signatureId", signatureId) + .add("columnId", columnId) + .add("splitId", splitId) + .add("predicateId", predicateId) + .add("unenforcedPredicateId", unenforcedPredicateId) + .toString(); + } + } + + static class Channel + { + private static final int INSTANCE_SIZE = instanceSize(Channel.class); + + private final SplitKey key; + private final long storeId; + private volatile boolean loaded; + private volatile Block[] blocks; + private volatile long blocksRetainedSizeInBytes; + private volatile long positionCount; + + public Channel(SplitKey key, long storeId) + { + this.key = requireNonNull(key, "key is null"); + this.storeId = storeId; + } + + public boolean isLoaded() + { + return loaded; + } + + public void setLoaded() + { + checkState(!loaded); + loaded = true; + } + + public Block[] getBlocks() + { + checkState(loaded); + return blocks; + } + + public void setBlocks(Block[] blocks) + { + checkState(!loaded); + this.blocks = requireNonNull(blocks, "blocks is null"); + long blocksRetainedSizeInBytes = 0; + long positionCount = 0L; + for (Block block : blocks) { + blocksRetainedSizeInBytes += block.getRetainedSizeInBytes(); + positionCount += block.getPositionCount(); + } + this.blocksRetainedSizeInBytes = blocksRetainedSizeInBytes; + this.positionCount = positionCount; + } + + public long getRetainedSizeInBytes() + { + checkState(loaded); + return INSTANCE_SIZE + sizeOf(blocks) + blocksRetainedSizeInBytes; + } + + public long getBlocksRetainedSizeInBytes() + { + checkState(loaded); + return blocksRetainedSizeInBytes; + } + + public long getPositionCount() + { + checkState(loaded); + return positionCount; + } + + public SplitKey getKey() + { + return key; + } + + public long getStoreId() + { + return storeId; + } + } +} diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheManagerFactory.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheManagerFactory.java new file mode 100644 index 000000000000..4b429fdc7d04 --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheManagerFactory.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.inject.Injector; +import io.airlift.bootstrap.Bootstrap; +import io.trino.plugin.base.jmx.MBeanServerModule; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheManagerFactory; +import org.weakref.jmx.guice.MBeanModule; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class MemoryCacheManagerFactory + implements CacheManagerFactory +{ + public static final String NAME = "memory-cache"; + + @Override + public String getName() + { + return NAME; + } + + @Override + public CacheManager create(Map config, CacheManagerContext context) + { + requireNonNull(config, "requiredConfig is null"); + + // A plugin is not required to use Guice; it is just very convenient + Bootstrap app = new Bootstrap( + new MemoryCacheModule(), + new MBeanModule(), + new MBeanServerModule(), + binder -> binder.bind(CacheManagerContext.class).toInstance(context)); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(ConcurrentCacheManager.class); + } +} diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheModule.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheModule.java new file mode 100644 index 000000000000..0c60b269b1ec --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCacheModule.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static org.weakref.jmx.guice.ExportBinder.newExporter; + +public class MemoryCacheModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ConcurrentCacheManager.class).in(Scopes.SINGLETON); + newExporter(binder).export(ConcurrentCacheManager.class).withGeneratedName(); + } +} diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCachePageSource.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCachePageSource.java new file mode 100644 index 000000000000..b00606fc7995 --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCachePageSource.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import io.trino.plugin.memory.MemoryCacheManager.Channel; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorPageSource; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class MemoryCachePageSource + implements ConnectorPageSource +{ + private final long memoryUsageBytes; + private final Block[][] channels; + private final long positionCount; + private int currentBlock; + private long currentPosition; + private long completedBytes; + private boolean closed; + + public MemoryCachePageSource(Channel[] channels) + { + checkArgument(channels.length > 0); + requireNonNull(channels); + this.positionCount = channels[0].getPositionCount(); + this.channels = new Block[channels.length][]; + long memoryUsageBytes = 0L; + long storeId = channels[0].getStoreId(); + for (int i = 0; i < channels.length; i++) { + Channel channel = channels[i]; + memoryUsageBytes += channel.getBlocksRetainedSizeInBytes(); + checkArgument(positionCount == channel.getPositionCount(), "Position count (%s) doesn't match channel position count (%s)", positionCount, channel.getPositionCount()); + checkArgument(storeId == channel.getStoreId(), "Store ids don't match"); + this.channels[i] = channel.getBlocks(); + } + this.memoryUsageBytes = memoryUsageBytes; + } + + @Override + public void close() + { + closed = true; + } + + @Override + public long getCompletedBytes() + { + return completedBytes; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return closed || currentPosition >= positionCount; + } + + @Override + public Page getNextPage() + { + if (isFinished()) { + return null; + } + + // extract current blocks + Block[] blocks = new Block[channels.length]; + for (int channel = 0; channel < channels.length; channel++) { + blocks[channel] = channels[channel][currentBlock]; + completedBytes += blocks[channel].getSizeInBytes(); + } + + // extract next page position count + currentBlock++; + currentPosition += blocks[0].getPositionCount(); + return new Page(blocks); + } + + @Override + public long getMemoryUsage() + { + return memoryUsageBytes; + } +} diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCachePlugin.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCachePlugin.java new file mode 100644 index 000000000000..e1e33fa64c31 --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/MemoryCachePlugin.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.Plugin; +import io.trino.spi.cache.CacheManagerFactory; + +public final class MemoryCachePlugin + implements Plugin +{ + @Override + public Iterable getCacheManagerFactories() + { + return ImmutableList.of(new MemoryCacheManagerFactory()); + } +} diff --git a/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/ObjectToIdMap.java b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/ObjectToIdMap.java new file mode 100644 index 000000000000..5735685a2f24 --- /dev/null +++ b/plugin/trino-memory-cache/src/main/java/io/trino/plugin/memory/ObjectToIdMap.java @@ -0,0 +1,202 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import io.trino.spi.cache.PlanSignature; +import it.unimi.dsi.fastutil.longs.Long2LongMap; +import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap; + +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.slice.SizeOf.LONG_INSTANCE_SIZE; +import static io.trino.plugin.memory.MemoryCacheManager.MAP_ENTRY_SIZE; +import static java.util.Objects.requireNonNull; + +/** + * Maps objects to numeric id. Comparing of big objects like {@link PlanSignature} can be expensive. + * Therefore, it's more efficient to map objects to numerical ids and use them for comparison instead. + */ +public class ObjectToIdMap +{ + private final Function retainedSizeInBytesProvider; + private final BiMap objectToId = HashBiMap.create(); + /** + * Usage count per id. When non-revocable and revocable usage count + * for particular id drops to 0, then corresponding mapping from + * {@link ObjectToIdMap#objectToId} can be dropped. + */ + private final Long2LongMap idUsageCount = new Long2LongOpenHashMap(); + private final Long2LongMap idRevocableUsageCount = new Long2LongOpenHashMap(); + private long revocableBytes; + private long nextId; + + public ObjectToIdMap(Function retainedSizeInBytesProvider) + { + this.retainedSizeInBytesProvider = requireNonNull(retainedSizeInBytesProvider, "retainedSizeInBytesProvider is null"); + } + + public Optional getId(T object) + { + return Optional.ofNullable(objectToId.get(object)); + } + + public long allocateId(T object) + { + return allocateId(object, 1L); + } + + public long allocateId(T object, long delta) + { + return allocateId(object, delta, 0L); + } + + public long allocateRevocableId(T object) + { + return allocateRevocableId(object, 1L); + } + + public long allocateRevocableId(T object, long delta) + { + return allocateId(object, 0L, delta); + } + + private long allocateId(T object, long delta, long revocableDelta) + { + checkArgument(delta >= 0, "delta is negative"); + checkArgument(revocableDelta >= 0, "revocableDelta is negative"); + Long id = objectToId.get(object); + if (id == null) { + id = nextId++; + objectToId.put(object, id); + idUsageCount.put((long) id, delta); + idRevocableUsageCount.put((long) id, revocableDelta); + if (revocableDelta > 0) { + revocableBytes += getEntrySize(object); + } + return id; + } + + acquireId(id, delta, revocableDelta); + return id; + } + + public void acquireId(long id) + { + acquireId(id, 1L); + } + + public void acquireId(long id, long delta) + { + acquireId(id, delta, 0L); + } + + public void acquireRevocableId(long id) + { + acquireRevocableId(id, 1L); + } + + public void acquireRevocableId(long id, long delta) + { + acquireId(id, 0L, delta); + } + + private void acquireId(long id, long delta, long revocableDelta) + { + checkArgument(delta >= 0, "delta is negative"); + checkArgument(revocableDelta >= 0, "revocableDelta is negative"); + checkArgument(objectToId.inverse().containsKey(id), "Trying to acquire missing id"); + + idUsageCount.mergeLong(id, delta, Long::sum); + if (revocableDelta > 0 && idRevocableUsageCount.mergeLong(id, revocableDelta, Long::sum) == revocableDelta) { + revocableBytes += getEntrySize(objectToId.inverse().get(id)); + } + } + + public void releaseId(long id) + { + releaseId(id, 1L); + } + + public void releaseId(long id, long delta) + { + releaseId(id, delta, 0L); + } + + public void releaseRevocableId(long id) + { + releaseRevocableId(id, 1L); + } + + public void releaseRevocableId(long id, long delta) + { + releaseId(id, 0L, delta); + } + + private void releaseId(long id, long delta, long revocableDelta) + { + checkArgument(delta >= 0, "delta is negative"); + checkArgument(revocableDelta >= 0, "revocableDelta is negative"); + + long usageCount = idUsageCount.mergeLong(id, -delta, Long::sum); + checkState(usageCount >= 0, "Usage count is negative"); + + long revocableUsageCount = idRevocableUsageCount.mergeLong(id, -revocableDelta, Long::sum); + checkState(revocableUsageCount >= 0, "Revocable usage count is negative"); + if (revocableDelta > 0 && revocableUsageCount == 0) { + revocableBytes -= getEntrySize(objectToId.inverse().get(id)); + } + + if (usageCount == 0 && revocableUsageCount == 0) { + requireNonNull(objectToId.inverse().remove(id)); + idUsageCount.remove(id); + idRevocableUsageCount.remove(id); + } + } + + public long getTotalUsageCount(long id) + { + return idUsageCount.getOrDefault(id, 0L) + idRevocableUsageCount.getOrDefault(id, 0L); + } + + public int size() + { + return objectToId.size(); + } + + public long getRevocableBytes() + { + return revocableBytes; + } + + private long getEntrySize(T object) + { + return getEntrySize(object, retainedSizeInBytesProvider); + } + + @VisibleForTesting + static long getEntrySize(T object, Function retainedSizeInBytesProvider) + { + requireNonNull(object, "object is null"); + // account for objectToId + return MAP_ENTRY_SIZE + retainedSizeInBytesProvider.apply(object) + LONG_INSTANCE_SIZE + + // account for idUsageCount + MAP_ENTRY_SIZE + 2L * LONG_INSTANCE_SIZE; + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/BenchmarkMemoryCacheManager.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/BenchmarkMemoryCacheManager.java new file mode 100644 index 000000000000..1b0266718137 --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/BenchmarkMemoryCacheManager.java @@ -0,0 +1,236 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.MemoryAllocator; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; +import org.junit.jupiter.api.Test; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.RunnerException; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.jmh.Benchmarks.benchmark; +import static io.trino.spi.type.IntegerType.INTEGER; +import static java.util.Collections.nCopies; +import static org.assertj.core.api.Assertions.assertThat; +import static org.openjdk.jmh.annotations.Mode.Throughput; + +@State(Scope.Benchmark) +@Fork(2) +@Warmup(iterations = 6, time = 2) +@Measurement(iterations = 6, time = 2) +@BenchmarkMode(Throughput) +public class BenchmarkMemoryCacheManager +{ + @State(Scope.Benchmark) + public static class Context + { + @Param({"false", "true"}) + private boolean polluteCache; + @Param({"false", "true"}) + private boolean changeSignatures; + + private final MemoryCacheManager memoryCacheManager = new MemoryCacheManager(bytes -> bytes <= 4_000_000_000L, true); + private final ConcurrentCacheManager concurrentCacheManager = new ConcurrentCacheManager( + new CacheManagerContext() + { + @Override + public MemoryAllocator revocableMemoryAllocator() + { + return bytes -> bytes <= 4_000_000_000L; + } + + @Override + public BlockEncodingSerde blockEncodingSerde() + { + return new TestingBlockEncodingSerde(); + } + }, + true); + private final CacheSplitId splitId = new CacheSplitId("split"); + private final List columnIds = IntStream.range(0, 64) + .mapToObj(i -> new CacheColumnId("column" + i)) + .collect(toImmutableList()); + + private final List columnTypes = columnIds.stream() + .map(column -> INTEGER) + .collect(toImmutableList()); + private final PlanSignature[] signatures = IntStream.range(0, 200) + .mapToObj(i -> new PlanSignature( + new SignatureKey("key" + i), + Optional.empty(), + columnIds, + columnTypes)) + .toArray(PlanSignature[]::new); + private final AtomicLong nextSignature = new AtomicLong(); + + private final Page page = new Page(nCopies( + columnIds.size(), + new IntArrayBlock(4, Optional.empty(), new int[] {0, 1, 2, 3})) + .toArray(new Block[0])); + + @Setup + public void setup() + throws IOException + { + if (polluteCache) { + for (int i = 0; i < 100; i++) { + storeCachedData(memoryCacheManager); + storeCachedData(concurrentCacheManager); + } + } + storeCachedData(memoryCacheManager); + storeCachedData(concurrentCacheManager); + } + + public MemoryCacheManager memoryCacheManager() + { + return memoryCacheManager; + } + + public ConcurrentCacheManager concurrentCacheManager() + { + return concurrentCacheManager; + } + + public Optional loadCachedData(CacheManager cacheManager) + throws IOException + { + try (CacheManager.SplitCache splitCache = cacheManager.getSplitCache(getSignature())) { + return splitCache.loadPages(splitId, TupleDomain.all(), TupleDomain.all()); + } + } + + public void storeCachedData(CacheManager cacheManager) + throws IOException + { + try (CacheManager.SplitCache splitCache = cacheManager.getSplitCache(getSignature())) { + ConnectorPageSink sink = splitCache.storePages(splitId, TupleDomain.all(), TupleDomain.all()).orElseThrow(); + sink.appendPage(page); + sink.finish(); + } + } + + private PlanSignature getSignature() + { + if (changeSignatures) { + return signatures[(int) (nextSignature.getAndIncrement() % signatures.length)]; + } + return signatures[0]; + } + } + + @Threads(10) + @Benchmark + public Optional benchmarkMemoryLoadPages(Context context) + throws IOException + { + return context.loadCachedData(context.memoryCacheManager()); + } + + @Threads(10) + @Benchmark + public void benchmarkMemoryStorePages(Context context) + throws IOException + { + context.storeCachedData(context.memoryCacheManager()); + } + + @Threads(10) + @Benchmark + public Optional benchmarkConcurrentLoadPages(Context context) + throws IOException + { + return context.loadCachedData(context.concurrentCacheManager()); + } + + @Threads(10) + @Benchmark + public void benchmarkConcurrentStorePages(Context context) + throws IOException + { + context.storeCachedData(context.concurrentCacheManager()); + } + + @Test + public void testBenchmarkMemoryLoadPages() + throws IOException + { + Context context = new Context(); + context.setup(); + assertThat(benchmarkMemoryLoadPages(context)).isPresent(); + } + + @Test + public void testBenchmarkMemoryStorePages() + throws IOException + { + Context context = new Context(); + context.setup(); + benchmarkMemoryStorePages(context); + } + + @Test + public void testBenchmarkConcurrentLoadPages() + throws IOException + { + Context context = new Context(); + context.setup(); + assertThat(benchmarkConcurrentLoadPages(context)).isPresent(); + } + + @Test + public void testBenchmarkConcurrentStorePages() + throws IOException + { + Context context = new Context(); + context.setup(); + benchmarkConcurrentStorePages(context); + } + + public static void main(String[] args) + throws RunnerException + { + benchmark(BenchmarkMemoryCacheManager.class).run(); + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestConcurrentCacheManager.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestConcurrentCacheManager.java new file mode 100644 index 000000000000..bdd88d15e8b0 --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestConcurrentCacheManager.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import io.trino.spi.Page; +import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.cache.CacheManager.SplitCache; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.MemoryAllocator; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.predicate.TupleDomain; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.plugin.memory.TestMemoryCacheManager.createOneMegaBytePage; +import static io.trino.plugin.memory.TestMemoryCacheManager.createPlanSignature; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@Execution(SAME_THREAD) +public class TestConcurrentCacheManager +{ + private static final PlanSignature SIGNATURE_1 = createPlanSignature("sig1"); + private static final PlanSignature SIGNATURE_2 = createPlanSignature("sig2"); + private static final CacheSplitId SPLIT1 = new CacheSplitId("1"); + private static final CacheSplitId SPLIT2 = new CacheSplitId("122"); + private static final CacheSplitId SPLIT3 = new CacheSplitId("2"); + private static final CacheSplitId SPLIT4 = new CacheSplitId("123"); + + private Page oneMegabytePage; + private ConcurrentCacheManager cacheManager; + private long allocatedRevocableMemory; + + @BeforeEach + public void setup() + { + oneMegabytePage = createOneMegaBytePage(); + allocatedRevocableMemory = 0; + CacheManagerContext context = new CacheManagerContext() + { + @Override + public MemoryAllocator revocableMemoryAllocator() + { + return bytes -> { + checkArgument(bytes >= 0); + allocatedRevocableMemory = bytes; + return true; + }; + } + + @Override + public BlockEncodingSerde blockEncodingSerde() + { + return new TestingBlockEncodingSerde(); + } + }; + cacheManager = new ConcurrentCacheManager(context); + } + + @Test + public void testConcurrentManagerRevoke() + throws IOException + { + // make sure splits are cached in two MemoryCacheMaangers + assertThat(cacheManager.getCacheManager(SIGNATURE_1, SPLIT1)).isEqualTo(cacheManager.getCacheManager(SIGNATURE_1, SPLIT2)); + assertThat(cacheManager.getCacheManager(SIGNATURE_2, SPLIT3)).isEqualTo(cacheManager.getCacheManager(SIGNATURE_2, SPLIT4)); + assertThat(cacheManager.getCacheManager(SIGNATURE_1, SPLIT1)).isNotEqualTo(cacheManager.getCacheManager(SIGNATURE_2, SPLIT3)); + assertThat(allocatedRevocableMemory).isEqualTo(0); + + // cache some splits + storePage(SIGNATURE_1, SPLIT1); + storePage(SIGNATURE_1, SPLIT2); + storePage(SIGNATURE_2, SPLIT1); + storePage(SIGNATURE_2, SPLIT2); + assertThatMemoryMatches(); + + // revoke ~1.5MBs, the oldest splits should be purged from both sub-managers + long initialAllocatedMemory = allocatedRevocableMemory; + long revokedMemory = cacheManager.revokeMemory(1_500_000, 1); + + assertThat(revokedMemory).isPositive(); + assertThat(initialAllocatedMemory - revokedMemory).isEqualTo(allocatedRevocableMemory); + + assertSplitIsCached(SIGNATURE_1, SPLIT2); + assertSplitIsCached(SIGNATURE_2, SPLIT2); + + assertSplitIsNotCached(SIGNATURE_1, SPLIT1); + assertSplitIsNotCached(SIGNATURE_2, SPLIT1); + + // revoke everything + assertThat(cacheManager.revokeMemory(10_000_000)).isPositive(); + assertThat(allocatedRevocableMemory).isZero(); + + // nothing to revoke + assertThat(cacheManager.revokeMemory(10_000_000)).isZero(); + } + + private void assertThatMemoryMatches() + { + long totalManagersMemory = Arrays.stream(cacheManager.getCacheManagers()) + .mapToLong(MemoryCacheManager::getRevocableBytes) + .sum(); + assertThat(allocatedRevocableMemory).isEqualTo(totalManagersMemory); + } + + private void assertSplitIsNotCached(PlanSignature signature, CacheSplitId splitId) + throws IOException + { + try (SplitCache cache = cacheManager.getSplitCache(signature)) { + assertThat(cache.loadPages(splitId, TupleDomain.all(), TupleDomain.all())).isEmpty(); + } + } + + private void assertSplitIsCached(PlanSignature signature, CacheSplitId splitId) + throws IOException + { + try (SplitCache cache = cacheManager.getSplitCache(signature)) { + assertThat(cache.loadPages(splitId, TupleDomain.all(), TupleDomain.all())).isPresent(); + } + } + + private void storePage(PlanSignature signature, CacheSplitId splitId) + throws IOException + { + try (SplitCache cache = cacheManager.getSplitCache(signature)) { + Optional sinkOptional = cache.storePages(splitId, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + sinkOptional.get().appendPage(oneMegabytePage); + sinkOptional.get().finish(); + } + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCacheManager.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCacheManager.java new file mode 100644 index 000000000000..5e3d290e3058 --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCacheManager.java @@ -0,0 +1,502 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.memory.MemoryCacheManager.Channel; +import io.trino.plugin.memory.MemoryCacheManager.SplitKey; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheManager.SplitCache; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.cache.SignatureKey; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.Type; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.memory.MemoryCacheManager.MAP_ENTRY_SIZE; +import static io.trino.plugin.memory.MemoryCacheManager.MAX_CACHED_CHANNELS_PER_COLUMN; +import static io.trino.plugin.memory.TestUtils.assertBlockEquals; +import static io.trino.spi.cache.PlanSignature.canonicalizePlanSignature; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static java.util.Collections.nCopies; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@Execution(SAME_THREAD) +public class TestMemoryCacheManager +{ + private static final CacheColumnId COLUMN1 = new CacheColumnId("col1"); + private static final CacheColumnId COLUMN2 = new CacheColumnId("col2"); + private static final CacheColumnId COLUMN3 = new CacheColumnId("col3"); + private static final CacheSplitId SPLIT1 = new CacheSplitId("split1"); + private static final CacheSplitId SPLIT2 = new CacheSplitId("split2"); + + private Page oneMegabytePage; + private MemoryCacheManager cacheManager; + private long allocatedRevocableMemory; + private long memoryLimit; + + @BeforeEach + public void setup() + { + oneMegabytePage = createOneMegaBytePage(); + allocatedRevocableMemory = 0; + memoryLimit = Long.MAX_VALUE; + cacheManager = new MemoryCacheManager( + bytes -> { + checkArgument(bytes >= 0); + if (bytes > memoryLimit) { + return false; + } + allocatedRevocableMemory = bytes; + return true; + }, + false); + } + + @Test + public void testCachePages() + throws IOException + { + PlanSignature signature = createPlanSignature("sig"); + + // split data should not be cached yet + SplitCache cache = cacheManager.getSplitCache(signature); + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + long idSize = ObjectToIdMap.getEntrySize(canonicalizePlanSignature(signature), PlanSignature::getRetainedSizeInBytes) + + ObjectToIdMap.getEntrySize(COLUMN1, CacheColumnId::getRetainedSizeInBytes); + long tupleDomainIdSize = ObjectToIdMap.getEntrySize(TupleDomain.all(), tupleDomain -> tupleDomain.getRetainedSizeInBytes(CacheColumnId::getRetainedSizeInBytes)); + // SplitCache doesn't allocate any revocable memory + assertThat(allocatedRevocableMemory).isEqualTo(0L); + + Optional sinkOptional = cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + // active sink doesn't allocate any revocable memory + assertThat(allocatedRevocableMemory).isEqualTo(0L); + + // second sink should not be present as split data is already being cached + assertThat(cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + + Block block = oneMegabytePage.getBlock(0); + ConnectorPageSink sink = sinkOptional.get(); + sink.appendPage(oneMegabytePage); + + // make sure memory usage is accounted for page sink + assertThat(sink.getMemoryUsage()).isEqualTo(block.getRetainedSizeInBytes()); + assertThat(allocatedRevocableMemory).isEqualTo(0L); + + // make sure memory is transferred to cacheManager after sink is finished + sink.finish(); + long channelSize = getChannelRetainedSizeInBytes(block); + long cacheEntrySize = MAP_ENTRY_SIZE + SplitKey.INSTANCE_SIZE + SPLIT1.getRetainedSizeInBytes() + channelSize; + assertThat(allocatedRevocableMemory).isEqualTo(cacheEntrySize + idSize + tupleDomainIdSize); + + // split data should be available now + Optional sourceOptional = cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sourceOptional).isPresent(); + + // ensure cached pages are correct + ConnectorPageSource source = sourceOptional.get(); + assertThat(source.getMemoryUsage()).isEqualTo(block.getRetainedSizeInBytes()); + assertBlockEquals(source.getNextPage().getBlock(0), block); + assertThat(source.isFinished()).isTrue(); + + // make sure no data is available for other signatures + PlanSignature anotherSignature = createPlanSignature("sig2"); + SplitCache anotherCache = cacheManager.getSplitCache(anotherSignature); + assertThat(anotherCache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + assertThat(allocatedRevocableMemory).isEqualTo(cacheEntrySize + idSize + tupleDomainIdSize); + anotherCache.close(); + + // store data for another split + sink = cache.storePages(SPLIT2, TupleDomain.all(), TupleDomain.all()).orElseThrow(); + sink.appendPage(oneMegabytePage); + sink.finish(); + assertThat(allocatedRevocableMemory).isEqualTo(2 * cacheEntrySize + idSize + tupleDomainIdSize); + + // data for both splits should be cached + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + assertThat(cache.loadPages(SPLIT2, TupleDomain.all(), TupleDomain.all())).isPresent(); + + // revoke memory and make sure only the least recently used split is left + cacheManager.revokeMemory(500_000); + assertThat(allocatedRevocableMemory).isEqualTo(cacheEntrySize + idSize + tupleDomainIdSize); + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + assertThat(cache.loadPages(SPLIT2, TupleDomain.all(), TupleDomain.all())).isPresent(); + + // make sure no new split data is cached when memory limit is lowered + memoryLimit = 1_500_000; + sink = cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()).orElseThrow(); + sink.appendPage(oneMegabytePage); + sink.finish(); + assertThat(allocatedRevocableMemory).isEqualTo(cacheEntrySize + idSize + tupleDomainIdSize); + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + + cache.close(); + } + + @Test + public void testPredicate() + throws IOException + { + PlanSignature signature = createPlanSignature("sig", COLUMN1, COLUMN2); + SplitCache cache = cacheManager.getSplitCache(signature); + + // append a split with predicate + Domain domain = Domain.singleValue(INTEGER, 42L); + TupleDomain predicate = TupleDomain.withColumnDomains(ImmutableMap.of(COLUMN1, domain)); + ConnectorPageSink sink = cache.storePages( + SPLIT1, + predicate, + TupleDomain.all()).orElseThrow(); + Block col = new IntArrayBlock(1, Optional.empty(), new int[] {42}); + sink.appendPage(new Page(col, col)); + sink.finish(); + + long idSize = ObjectToIdMap.getEntrySize(canonicalizePlanSignature(signature), PlanSignature::getRetainedSizeInBytes) + + ObjectToIdMap.getEntrySize(COLUMN1, CacheColumnId::getRetainedSizeInBytes) + + ObjectToIdMap.getEntrySize(COLUMN2, CacheColumnId::getRetainedSizeInBytes) + + ObjectToIdMap.getEntrySize(predicate, tupleDomain -> tupleDomain.getRetainedSizeInBytes(CacheColumnId::getRetainedSizeInBytes)) + + ObjectToIdMap.getEntrySize(TupleDomain.all(), tupleDomain -> tupleDomain.getRetainedSizeInBytes(CacheColumnId::getRetainedSizeInBytes)); + long cacheEntrySize = MAP_ENTRY_SIZE + SplitKey.INSTANCE_SIZE + SPLIT1.getRetainedSizeInBytes() + getChannelRetainedSizeInBytes(col); + assertThat(allocatedRevocableMemory).isEqualTo(idSize + cacheEntrySize * 2); + + // entire tuple domain must much + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + assertThat(cache.loadPages(SPLIT1, TupleDomain.withColumnDomains(ImmutableMap.of( + COLUMN1, domain, + COLUMN2, Domain.singleValue(INTEGER, 43L))), + TupleDomain.all())).isEmpty(); + assertThat(cache.loadPages(SPLIT1, predicate, predicate)).isEmpty(); + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), predicate)).isEmpty(); + + assertThat(cache.loadPages(SPLIT1, predicate, TupleDomain.all())).isPresent(); + + // revoking should remove tuple domain ids + cache.close(); + cacheManager.revokeMemory(1_000_000); + assertThat(allocatedRevocableMemory).isEqualTo(0L); + } + + @Test + public void testColumnCaching() + throws IOException + { + // split data should not be cached yet + SplitCache cacheCol12 = cacheManager.getSplitCache(createPlanSignature("sig", COLUMN1, COLUMN2)); + assertThat(cacheManager.getCachedPlanSignaturesCount()).isEqualTo(1); + assertThat(cacheManager.getCachedColumnIdsCount()).isEqualTo(2); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(0); + assertThat(cacheCol12.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + + Optional sinkOptional = cacheCol12.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + ConnectorPageSink sink = sinkOptional.get(); + + // create another cache with reverse column order + SplitCache cacheCol21 = cacheManager.getSplitCache(createPlanSignature("sig", COLUMN2, COLUMN1)); + assertThat(cacheManager.getCachedPlanSignaturesCount()).isEqualTo(1); + assertThat(cacheManager.getCachedColumnIdsCount()).isEqualTo(2); + + // split data with reverse column order should be available after sink is finished + Block col1BlockStore1 = new IntArrayBlock(2, Optional.empty(), new int[] {0, 1}); + Block col2BlockStore1 = new IntArrayBlock(2, Optional.empty(), new int[] {10, 11}); + sink.appendPage(new Page(col1BlockStore1, col2BlockStore1)); + sink.finish(); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(2); + assertPageSourceEquals(cacheCol21.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()), col2BlockStore1, col1BlockStore1); + + // subset of columns should also be cached + SplitCache cacheCol2 = cacheManager.getSplitCache(createPlanSignature("sig", COLUMN2)); + assertPageSourceEquals(cacheCol2.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()), col2BlockStore1); + + // data for column1 and column3 should be cached together with separate store id + SplitCache cacheCol13 = cacheManager.getSplitCache(createPlanSignature("sig", COLUMN1, COLUMN3)); + assertThat(cacheManager.getCachedPlanSignaturesCount()).isEqualTo(1); + assertThat(cacheManager.getCachedColumnIdsCount()).isEqualTo(3); + assertThat(cacheCol13.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + + Block col1BlockStore2 = new IntArrayBlock(2, Optional.empty(), new int[] {20, 21}); + Block col3BlockStore2 = new IntArrayBlock(2, Optional.empty(), new int[] {30, 31}); + sinkOptional = cacheCol13.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + sink = sinkOptional.get(); + sink.appendPage(new Page(col1BlockStore2, col3BlockStore2)); + sink.finish(); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(4); + + // (col1, col2) page source should still use "store no 1" blocks + assertPageSourceEquals(cacheCol12.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()), col1BlockStore1, col2BlockStore1); + + // (col1, col3) page source should use "store no 2" blocks + assertPageSourceEquals(cacheCol13.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()), col1BlockStore2, col3BlockStore2); + + // cache should return the newest entries + SplitCache cacheCol123 = cacheManager.getSplitCache(createPlanSignature("sig", COLUMN1, COLUMN2, COLUMN3)); + Block col1BlockStore3 = new IntArrayBlock(2, Optional.empty(), new int[] {50, 51}); + Block col2BlockStore3 = new IntArrayBlock(2, Optional.empty(), new int[] {60, 61}); + Block col3BlockStore3 = new IntArrayBlock(2, Optional.empty(), new int[] {70, 71}); + sinkOptional = cacheCol123.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + sink = sinkOptional.get(); + sink.appendPage(new Page(col1BlockStore3, col2BlockStore3, col3BlockStore3)); + sink.finish(); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(7); + assertPageSourceEquals(cacheCol13.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()), col1BlockStore3, col3BlockStore3); + + // make sure group by columns do not use non-aggregated cached column data + SplitCache groupByCacheCol1 = cacheManager.getSplitCache(new PlanSignature( + new SignatureKey("sig"), + Optional.of(ImmutableList.of(COLUMN1)), + ImmutableList.of(COLUMN1), + ImmutableList.of(INTEGER))); + assertThat(groupByCacheCol1.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + + // make sure all ids are removed after revoke + cacheCol12.close(); + cacheCol123.close(); + cacheCol2.close(); + cacheCol13.close(); + cacheCol21.close(); + groupByCacheCol1.close(); + cacheManager.revokeMemory(1_000_000); + assertThat(cacheManager.getCachedPlanSignaturesCount()).isEqualTo(0); + assertThat(cacheManager.getCachedColumnIdsCount()).isEqualTo(0); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(0); + assertThat(cacheManager.getRevocableBytes()).isEqualTo(0); + } + + @Test + public void testLruCache() + { + SplitCache cacheA = cacheManager.getSplitCache(createPlanSignature("sigA", COLUMN1)); + SplitCache cacheB = cacheManager.getSplitCache(createPlanSignature("sigB", COLUMN1)); + + // cache two pages to different sinks + Optional sinkOptional = cacheA.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + sinkOptional.get().appendPage(oneMegabytePage); + sinkOptional.get().finish(); + + sinkOptional = cacheB.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + sinkOptional.get().appendPage(oneMegabytePage); + sinkOptional.get().finish(); + + // both pages should be cached + assertThat(cacheB.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + assertThat(cacheA.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + + // only latest used page should be cached after revoke + cacheManager.revokeMemory(500_000); + assertThat(cacheA.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + assertThat(cacheB.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + } + + @Test + public void testMaxChannelsPerColumn() + throws IOException + { + // store MAX_CACHED_CHANNELS_PER_COLUMN column ids for col1 + for (int i = 1; i <= MAX_CACHED_CHANNELS_PER_COLUMN; i++) { + List columns = IntStream.range(0, i + 1) + .mapToObj(col -> new CacheColumnId("col" + col)) + .collect(toImmutableList()); + List columnsTypes = columns + .stream().map(col -> INTEGER) + .collect(toImmutableList()); + PlanSignature signature = new PlanSignature( + new SignatureKey("sig"), + Optional.empty(), + columns, + columnsTypes); + try (SplitCache cache = cacheManager.getSplitCache(signature)) { + Optional sinkOptional = cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + ConnectorPageSink sink = sinkOptional.get(); + sink.appendPage(new Page(nCopies( + columns.size(), + new IntArrayBlock(1, Optional.empty(), new int[] {i})) + .toArray(new Block[0]))); + sink.finish(); + } + } + + int splitCount = MAX_CACHED_CHANNELS_PER_COLUMN * (MAX_CACHED_CHANNELS_PER_COLUMN + 1) / 2 + MAX_CACHED_CHANNELS_PER_COLUMN; + assertThat(cacheManager.getCachedPlanSignaturesCount()).isEqualTo(1); + assertThat(cacheManager.getCachedColumnIdsCount()).isEqualTo(MAX_CACHED_CHANNELS_PER_COLUMN + 1); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(splitCount); + + // add another channel for col1 + List columns = ImmutableList.of(new CacheColumnId("col1"), new CacheColumnId("col100")); + List columnsTypes = columns + .stream().map(col -> INTEGER) + .collect(toImmutableList()); + PlanSignature signature = new PlanSignature( + new SignatureKey("sig"), + Optional.empty(), + columns, + columnsTypes); + SplitCache cache = cacheManager.getSplitCache(signature); + Optional sinkOptional = cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + ConnectorPageSink sink = sinkOptional.get(); + Block block = new IntArrayBlock(1, Optional.empty(), new int[] {0}); + sink.appendPage(new Page(nCopies(columns.size(), block).toArray(new Block[0]))); + sink.finish(); + + // oldest column from col1 should be purged + assertThat(cacheManager.getCachedPlanSignaturesCount()).isEqualTo(1); + assertThat(cacheManager.getCachedColumnIdsCount()).isEqualTo(MAX_CACHED_CHANNELS_PER_COLUMN + 2); + assertThat(cacheManager.getCachedSplitsCount()).isEqualTo(splitCount + 1); + assertPageSourceEquals(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all()), block, block); + } + + private void assertPageSourceEquals(Optional sourceOptional, Block... expectedBlocks) + { + assertThat(sourceOptional).isPresent(); + ConnectorPageSource source = sourceOptional.get(); + Page actualPage = source.getNextPage(); + assertThat(source.isFinished()).isTrue(); + assertThat(actualPage.getChannelCount()).isEqualTo(expectedBlocks.length); + for (int i = 0; i < actualPage.getChannelCount(); i++) { + assertBlockEquals(actualPage.getBlock(i), expectedBlocks[i]); + } + } + + @Test + public void testSinkAbort() + throws IOException + { + PlanSignature signature = createPlanSignature("sig"); + + // create new SplitCache + SplitCache cache = cacheManager.getSplitCache(signature); + + // start caching of new split + Optional sinkOptional = cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()); + assertThat(sinkOptional).isPresent(); + ConnectorPageSink sink = sinkOptional.get(); + sink.appendPage(oneMegabytePage); + assertThat(sink.getMemoryUsage()).isEqualTo(oneMegabytePage.getBlock(0).getRetainedSizeInBytes()); + assertThat(cache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + + // active sink shouldn't allocate any revocable memory + cache.close(); + assertThat(allocatedRevocableMemory).isEqualTo(0L); + + // no data should be cached after abort + sink.abort(); + assertThat(allocatedRevocableMemory).isEqualTo(0L); + assertThat(cacheManager.getSplitCache(signature).loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + } + + @Test + public void testPlanSignatureRevoke() + throws IOException + { + Page smallPage = new Page(new IntArrayBlock(1, Optional.empty(), new int[] {0})); + PlanSignature bigSignature = createPlanSignature(IntStream.range(0, 500_000).mapToObj(Integer::toString).collect(joining())); + PlanSignature secondBigSignature = createPlanSignature(IntStream.range(0, 500_001).mapToObj(Integer::toString).collect(joining())); + + // cache some data for first signature + assertThat(allocatedRevocableMemory).isEqualTo(0); + SplitCache cache = cacheManager.getSplitCache(bigSignature); + ConnectorPageSink sink = cache.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()).orElseThrow(); + sink.appendPage(smallPage); + sink.finish(); + cache.close(); + + // make sure page is present with new SplitCache instance + SplitCache anotherCache = cacheManager.getSplitCache(bigSignature); + assertThat(anotherCache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + + // cache data for another signature + SplitCache cacheForSecondSignature = cacheManager.getSplitCache(secondBigSignature); + sink = cacheForSecondSignature.storePages(SPLIT1, TupleDomain.all(), TupleDomain.all()).orElseThrow(); + sink.appendPage(smallPage); + sink.finish(); + + // both splits should be still cached + assertThat(anotherCache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + assertThat(cacheForSecondSignature.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + anotherCache.close(); + cacheForSecondSignature.close(); + + // revoke some small amount of memory + assertThat(cacheManager.revokeMemory(100)).isPositive(); + + // only one split (for secondBigSignature signature) should be cached, because initial bigSignature was purged + anotherCache = cacheManager.getSplitCache(bigSignature); + cacheForSecondSignature = cacheManager.getSplitCache(secondBigSignature); + assertThat(anotherCache.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isEmpty(); + assertThat(cacheForSecondSignature.loadPages(SPLIT1, TupleDomain.all(), TupleDomain.all())).isPresent(); + anotherCache.close(); + } + + static long getChannelRetainedSizeInBytes(Block block) + { + Channel channel = new Channel(new SplitKey(0, 0, new CacheSplitId("id"), 0, 0), 0); + channel.setBlocks(new Block[] {block}); + channel.setLoaded(); + return channel.getRetainedSizeInBytes(); + } + + static PlanSignature createPlanSignature(String signature) + { + return createPlanSignature(signature, COLUMN1); + } + + private static PlanSignature createPlanSignature(String signature, CacheColumnId... ids) + { + return new PlanSignature( + new SignatureKey(signature), + Optional.empty(), + ImmutableList.copyOf(ids), + Stream.of(ids).map(ignore -> (Type) INTEGER).collect(toImmutableList())); + } + + static Page createOneMegaBytePage() + { + BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(0); + while (blockBuilder.getRetainedSizeInBytes() < 1024 * 1024) { + BIGINT.writeLong(blockBuilder, 42L); + } + Page page = new Page(blockBuilder.getPositionCount(), blockBuilder.build()); + page.compact(); + return page; + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCachePageSource.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCachePageSource.java new file mode 100644 index 000000000000..240481f5225d --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCachePageSource.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import io.trino.plugin.memory.MemoryCacheManager.Channel; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.cache.CacheSplitId; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static io.trino.plugin.memory.TestUtils.assertBlockEquals; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestMemoryCachePageSource +{ + @Test + public void testPageSource() + { + Channel firstChannel = createChannel( + new IntArrayBlock(4, Optional.empty(), new int[] {0, 1, 2, 3}), + new IntArrayBlock(2, Optional.empty(), new int[] {4, 5})); + Channel secondChannel = createChannel( + new LongArrayBlock(4, Optional.empty(), new long[] {10L, 11L, 12L, 13L}), + new LongArrayBlock(2, Optional.empty(), new long[] {14L, 15L})); + + MemoryCachePageSource pageSource = new MemoryCachePageSource(new Channel[] {firstChannel, secondChannel}); + assertThat(pageSource.isFinished()).isFalse(); + assertThat(pageSource.getMemoryUsage()).isEqualTo(firstChannel.getBlocksRetainedSizeInBytes() + secondChannel.getBlocksRetainedSizeInBytes()); + assertThat(pageSource.getCompletedBytes()).isEqualTo(0L); + + Page page = pageSource.getNextPage(); + assertThat(page.getChannelCount()).isEqualTo(2); + assertThat(page.getPositionCount()).isEqualTo(4); + assertBlockEquals(page.getBlock(0), firstChannel.getBlocks()[0]); + assertBlockEquals(page.getBlock(1), secondChannel.getBlocks()[0]); + assertThat(pageSource.getCompletedBytes()).isEqualTo(56); + assertThat(pageSource.isFinished()).isFalse(); + + page = pageSource.getNextPage(); + assertThat(page.getChannelCount()).isEqualTo(2); + assertThat(page.getPositionCount()).isEqualTo(2); + assertBlockEquals(page.getBlock(0), firstChannel.getBlocks()[1]); + assertBlockEquals(page.getBlock(1), secondChannel.getBlocks()[1]); + assertThat(pageSource.getCompletedBytes()).isEqualTo(84); + assertThat(pageSource.isFinished()).isTrue(); + } + + private static Channel createChannel(Block... blocks) + { + Channel channel = new Channel(new MemoryCacheManager.SplitKey(0, 0, new CacheSplitId("id"), 0, 0), 0L); + channel.setBlocks(blocks); + channel.setLoaded(); + return channel; + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCachePlugin.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCachePlugin.java new file mode 100644 index 000000000000..ba9abb46bf8c --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestMemoryCachePlugin.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import com.google.common.collect.ImmutableMap; +import io.trino.spi.Plugin; +import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.block.TestingBlockEncodingSerde; +import io.trino.spi.cache.CacheManagerContext; +import io.trino.spi.cache.CacheManagerFactory; +import io.trino.spi.cache.MemoryAllocator; +import org.junit.jupiter.api.Test; + +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestMemoryCachePlugin +{ + @Test + public void testCreateCacheManager() + { + Plugin plugin = new MemoryCachePlugin(); + CacheManagerFactory factory = getOnlyElement(plugin.getCacheManagerFactories()); + factory.create( + ImmutableMap.of(), + new CacheManagerContext() + { + @Override + public MemoryAllocator revocableMemoryAllocator() + { + return null; + } + + @Override + public BlockEncodingSerde blockEncodingSerde() + { + return new TestingBlockEncodingSerde(); + } + }); + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestObjectToIdMap.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestObjectToIdMap.java new file mode 100644 index 000000000000..2783001350f7 --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestObjectToIdMap.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import org.junit.jupiter.api.Test; + +import static io.airlift.slice.SizeOf.LONG_INSTANCE_SIZE; +import static io.trino.plugin.memory.MemoryCacheManager.MAP_ENTRY_SIZE; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestObjectToIdMap +{ + @Test + public void testObjectToIdMap() + { + ObjectToIdMap idMap = new ObjectToIdMap<>(string -> (long) string.length()); + + assertThat(idMap.getRevocableBytes()).isEqualTo(0L); + assertThat(idMap.getTotalUsageCount(42L)).isEqualTo(0L); + + long cacheEntrySize = 2L * MAP_ENTRY_SIZE + 3L * LONG_INSTANCE_SIZE + "A".length(); + long idA = idMap.allocateRevocableId("A"); + assertThat(idA).isEqualTo(0L); + assertThat(idMap.getTotalUsageCount(idA)).isEqualTo(1L); + assertThat(idMap.getRevocableBytes()).isEqualTo(cacheEntrySize); + + idMap.acquireRevocableId(idA); + assertThat(idMap.getTotalUsageCount(idA)).isEqualTo(2L); + + long idB = idMap.allocateRevocableId("B"); + assertThat(idB).isEqualTo(1L); + assertThat(idMap.getTotalUsageCount(idB)).isEqualTo(1L); + assertThat(idMap.getRevocableBytes()).isEqualTo(2 * cacheEntrySize); + + idMap.releaseRevocableId(idB); + assertThat(idMap.getTotalUsageCount(idB)).isEqualTo(0L); + assertThat(idMap.getRevocableBytes()).isEqualTo(cacheEntrySize); + } + + @Test + public void testRevocableAllocations() + { + ObjectToIdMap idMap = new ObjectToIdMap<>(string -> (long) string.length()); + + // non-revocable allocation + long idA = idMap.allocateId("A"); + assertThat(idA).isEqualTo(0L); + assertThat(idMap.getTotalUsageCount(idA)).isEqualTo(1L); + assertThat(idMap.getRevocableBytes()).isEqualTo(0); + + // revocable allocation + long cacheEntrySize = 2L * MAP_ENTRY_SIZE + 3L * LONG_INSTANCE_SIZE + "A".length(); + assertThat(idMap.allocateRevocableId("A")).isEqualTo(idA); + assertThat(idMap.getTotalUsageCount(idA)).isEqualTo(2L); + assertThat(idMap.getRevocableBytes()).isEqualTo(cacheEntrySize); + + // revocable free + idMap.releaseRevocableId(idA); + assertThat(idMap.getTotalUsageCount(idA)).isEqualTo(1L); + assertThat(idMap.getRevocableBytes()).isEqualTo(0); + + // non-revocable free + idMap.releaseId(idA); + assertThat(idMap.getTotalUsageCount(idA)).isEqualTo(0L); + assertThat(idMap.getRevocableBytes()).isEqualTo(0); + } +} diff --git a/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestUtils.java b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestUtils.java new file mode 100644 index 000000000000..79fae2d7df4b --- /dev/null +++ b/plugin/trino-memory-cache/src/test/java/io/trino/plugin/memory/TestUtils.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.memory; + +import io.trino.spi.block.Block; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static org.assertj.core.api.Assertions.assertThat; + +public final class TestUtils +{ + private TestUtils() {} + + public static void assertBlockEquals(Block actual, Block expected) + { + assertThat(actual.getPositionCount()).isEqualTo(expected.getPositionCount()); + if (expected instanceof IntArrayBlock) { + assertThat(actual).isInstanceOf(IntArrayBlock.class); + for (int position = 0; position < actual.getPositionCount(); position++) { + assertThat(INTEGER.getInt(actual, position)).isEqualTo(INTEGER.getInt(expected, position)); + } + } + else { + assertThat(actual).isInstanceOf(LongArrayBlock.class); + for (int position = 0; position < actual.getPositionCount(); position++) { + assertThat(BIGINT.getLong(actual, position)).isEqualTo(BIGINT.getLong(expected, position)); + } + } + } +} diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchCacheMetadata.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchCacheMetadata.java new file mode 100644 index 000000000000..e1cd4f56f15a --- /dev/null +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchCacheMetadata.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.tpch; + +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorTableHandle; + +import java.util.Optional; + +public class TpchCacheMetadata + implements ConnectorCacheMetadata +{ + @Override + public Optional getCacheTableId(ConnectorTableHandle table) + { + TpchTableHandle handle = (TpchTableHandle) table; + if (!handle.constraint().isAll()) { + // lossless conversion of TupleDomain to string requires JSON serialization + return Optional.empty(); + } + + // ensure cache id generation is revisited whenever handle classes change + handle = new TpchTableHandle( + handle.schemaName(), + handle.tableName(), + handle.scaleFactor(), + handle.constraint()); + + return Optional.of(new CacheTableId(handle.schemaName() + ":" + handle.tableName() + ":" + handle.scaleFactor())); + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle column) + { + TpchColumnHandle handle = (TpchColumnHandle) column; + + // ensure cache id generation is revisited whenever handle classes change + handle = new TpchColumnHandle( + handle.columnName(), + handle.type()); + + return Optional.of(new CacheColumnId(handle.columnName() + ":" + handle.type())); + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle handle) + { + return handle; + } +} diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchConnectorFactory.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchConnectorFactory.java index ec06eecfaefc..26305f70ffa9 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchConnectorFactory.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchConnectorFactory.java @@ -14,6 +14,7 @@ package io.trino.plugin.tpch; import io.trino.spi.NodeManager; +import io.trino.spi.cache.ConnectorCacheMetadata; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorContext; import io.trino.spi.connector.ConnectorFactory; @@ -110,6 +111,12 @@ public ConnectorSplitManager getSplitManager() return new TpchSplitManager(nodeManager, splitsPerNode); } + @Override + public ConnectorCacheMetadata getCacheMetadata() + { + return new TpchCacheMetadata(); + } + @Override public ConnectorPageSourceProvider getPageSourceProvider() { diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java index a5df782d2d45..f02079dcbfc1 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java @@ -21,6 +21,7 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.predicate.TupleDomain; import java.util.List; @@ -47,4 +48,26 @@ public ConnectorPageSource createPageSource( { return new LazyRecordPageSource(maxRowsPerPage, tpchRecordSetProvider.getRecordSet(transaction, session, split, table, columns)); } + + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain dynamicFilter) + { + // tpch connector doesn't support unenforced (effective) predicates + return TupleDomain.all(); + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle tableHandle, + TupleDomain predicate) + { + // tpch connector doesn't support pruning of predicates + return predicate; + } } diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchSplitManager.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchSplitManager.java index 146ad44c389b..c2445de5b187 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchSplitManager.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchSplitManager.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.trino.spi.Node; import io.trino.spi.NodeManager; +import io.trino.spi.cache.CacheSplitId; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorSplitManager; @@ -27,6 +28,7 @@ import io.trino.spi.connector.FixedSplitSource; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -74,4 +76,19 @@ public ConnectorSplitSource getSplits( } return new FixedSplitSource(splits.build()); } + + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + TpchSplit tpchSplit = (TpchSplit) split; + + // this ensures that when split changes someone will look if id has to updated + tpchSplit = new TpchSplit( + tpchSplit.getPartNumber(), + tpchSplit.getTotalParts(), + // ignore host addresses as it's irrelevant for ID + ImmutableList.of()); + + return Optional.of(new CacheSplitId(tpchSplit.getTotalParts() + ":" + tpchSplit.getPartNumber())); + } } diff --git a/pom.xml b/pom.xml index 90f1aba36cc0..7375862afc67 100644 --- a/pom.xml +++ b/pom.xml @@ -88,6 +88,7 @@ plugin/trino-kudu plugin/trino-mariadb plugin/trino-memory + plugin/trino-memory-cache plugin/trino-ml plugin/trino-mongodb plugin/trino-mysql @@ -1264,6 +1265,12 @@ test-jar + + io.trino + trino-memory-cache + ${project.version} + + io.trino trino-memory-context diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml index c6c2490746d0..940c282f89c3 100644 --- a/testing/trino-testing/pom.xml +++ b/testing/trino-testing/pom.xml @@ -226,6 +226,11 @@ junit-jupiter-api + + org.junit.jupiter + junit-jupiter-params + + org.jetbrains annotations diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseCacheSubqueriesTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseCacheSubqueriesTest.java new file mode 100644 index 000000000000..2d462996e814 --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseCacheSubqueriesTest.java @@ -0,0 +1,954 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.airlift.slice.Slices; +import io.opentelemetry.api.trace.Span; +import io.trino.Session; +import io.trino.cache.CacheDataOperator; +import io.trino.cache.CacheMetadata; +import io.trino.cache.CommonPlanAdaptation.PlanSignatureWithPredicate; +import io.trino.cache.LoadCachedDataOperator; +import io.trino.metadata.Metadata; +import io.trino.metadata.QualifiedObjectName; +import io.trino.metadata.Split; +import io.trino.metadata.TableHandle; +import io.trino.operator.OperatorStats; +import io.trino.operator.ScanFilterAndProjectOperator; +import io.trino.operator.TableScanOperator; +import io.trino.server.testing.TestingTrinoServer; +import io.trino.spi.QueryId; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.PlanSignature; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSourceProvider; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.connector.Constraint; +import io.trino.spi.connector.ConstraintApplicationResult; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.VarcharType; +import io.trino.split.PageSourceManager.PageSourceProviderInstance; +import io.trino.split.PageSourceProvider; +import io.trino.split.SplitSource; +import io.trino.sql.planner.Plan; +import io.trino.sql.planner.assertions.PlanAssert; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.LoadCachedDataPlanNode; +import io.trino.testing.QueryRunner.MaterializedResultWithPlan; +import io.trino.tpch.TpchTable; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.trino.SystemSessionProperties.CACHE_AGGREGATIONS_ENABLED; +import static io.trino.SystemSessionProperties.CACHE_COMMON_SUBQUERIES_ENABLED; +import static io.trino.SystemSessionProperties.CACHE_PROJECTIONS_ENABLED; +import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_ROW_FILTERING; +import static io.trino.SystemSessionProperties.ENABLE_LARGE_DYNAMIC_FILTERS; +import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static io.trino.cache.CacheDriverFactory.getDynamicRowFilteringUnenforcedPredicate; +import static io.trino.cache.CommonSubqueriesExtractor.scanFilterProjectKey; +import static io.trino.cost.StatsCalculator.noopStatsCalculator; +import static io.trino.metadata.FunctionManager.createTestingFunctionManager; +import static io.trino.spi.connector.Constraint.alwaysTrue; +import static io.trino.spi.predicate.Range.range; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.BROADCAST; +import static io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy.NONE; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.cacheDataPlanNode; +import static io.trino.sql.planner.assertions.PlanMatchPattern.chooseAlternativeNode; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; +import static io.trino.testing.TransactionBuilder.transaction; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static io.trino.tpch.TpchTable.LINE_ITEM; +import static io.trino.tpch.TpchTable.NATION; +import static io.trino.tpch.TpchTable.ORDERS; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.abort; + +public abstract class BaseCacheSubqueriesTest + extends AbstractTestQueryFramework +{ + protected static final Set> REQUIRED_TABLES = ImmutableSet.of(NATION, LINE_ITEM, ORDERS, CUSTOMER); + protected static final Map EXTRA_PROPERTIES = ImmutableMap.of("cache.enabled", "true"); + + @BeforeEach + public void flushCache() + { + getDistributedQueryRunner().getServers().forEach(server -> server.getCacheManagerRegistry().flushCache()); + } + + public static Object[][] isDynamicRowFilteringEnabled() + { + return new Object[][] {{true}, {false}}; + } + + @Test + public void testShowStats() + { + assertThat(query("SHOW STATS FOR nation")) + .result() + // Not testing average length and min/max, as this would make the test less reusable and is not that important to test. + .exceptColumns("data_size", "low_value", "high_value") + .skippingTypesCheck() + .matches("VALUES " + + "('nationkey', 25e0, 0e0, null)," + + "('name', 25e0, 0e0, null)," + + "('regionkey', 5e0, 0e0, null)," + + "('comment', 25e0, 0e0, null)," + + "(null, null, null, 25e0)"); + } + + @Test + public void testUnionWithJoinQuery() + { + @Language("SQL") String selectQuery = """ + select c.custkey from ( + select custkey, nationkey from (select c.custkey, c.nationkey from customer c, nation n where c.nationkey = n.nationkey) + union all + select custkey, nationkey from (select c.custkey, c.nationkey from customer c, nation n where c.nationkey = n.nationkey)) c + join nation n on c.nationkey = n.nationkey + """; + MaterializedResultWithPlan resultWithCache = executeWithPlan(withBroadcastJoin(withCacheEnabled()), selectQuery); + MaterializedResultWithPlan resultWithoutCache = executeWithPlan(withBroadcastJoin(withCacheDisabled()), selectQuery); + assertEqualsIgnoreOrder(resultWithCache.result(), resultWithoutCache.result()); + // make sure data was cached and query succeeds + assertThat(getCacheDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + + // make sure plan runs local UNION ALL source stage (no repartition remote exchanges) + Plan plan = getDistributedQueryRunner().getQueryPlan(resultWithCache.queryId()); + int actualRemoteExchangesCount = searchFrom(plan.getRoot()) + .where(node -> node instanceof ExchangeNode exchangeNode + && exchangeNode.getScope() == REMOTE + // exchanges for distributing nation build tables + && exchangeNode.getType() != REPLICATE) + .findAll() + .size(); + assertThat(actualRemoteExchangesCount).isEqualTo(0); + } + + @Test + public void testJoinQuery() + { + @Language("SQL") String selectQuery = "select count(l.orderkey) from lineitem l, lineitem r where l.orderkey = r.orderkey"; + MaterializedResultWithPlan resultWithCache = executeWithPlan(withCacheEnabled(), selectQuery); + MaterializedResultWithPlan resultWithoutCache = executeWithPlan(withCacheDisabled(), selectQuery); + assertEqualsIgnoreOrder(resultWithCache.result(), resultWithoutCache.result()); + + // make sure data was read from cache + assertThat(getLoadCachedDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + + // make sure data was cached + assertThat(getCacheDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + + // make sure less data is read from source when caching is on + assertThat(getScanOperatorInputPositions(resultWithCache.queryId())) + .isLessThan(getScanOperatorInputPositions(resultWithoutCache.queryId())); + } + + @Test + public void testAggregationQuery() + { + @Language("SQL") String countQuery = """ + SELECT * FROM + (SELECT count(orderkey), orderkey FROM lineitem GROUP BY orderkey) a + JOIN + (SELECT count(orderkey), orderkey FROM lineitem GROUP BY orderkey) b + ON a.orderkey = b.orderkey"""; + @Language("SQL") String sumQuery = """ + SELECT * FROM + (SELECT sum(orderkey), orderkey FROM lineitem GROUP BY orderkey) a + JOIN + (SELECT sum(orderkey), orderkey FROM lineitem GROUP BY orderkey) b + ON a.orderkey = b.orderkey"""; + MaterializedResultWithPlan countWithCache = executeWithPlan(withCacheEnabled(), countQuery); + MaterializedResultWithPlan countWithoutCache = executeWithPlan(withCacheDisabled(), countQuery); + assertEqualsIgnoreOrder(countWithCache.result(), countWithoutCache.result()); + + // make sure data was read from cache + assertThat(getLoadCachedDataOperatorInputPositions(countWithCache.queryId())).isPositive(); + + // make sure data was cached + assertThat(getCacheDataOperatorInputPositions(countWithCache.queryId())).isPositive(); + + // make sure less data is read from source when caching is on + assertThat(getScanOperatorInputPositions(countWithCache.queryId())) + .isLessThan(getScanOperatorInputPositions(countWithoutCache.queryId())); + + // subsequent count aggregation query should use cached data only + countWithCache = executeWithPlan(withCacheEnabled(), countQuery); + assertThat(getLoadCachedDataOperatorInputPositions(countWithCache.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(countWithCache.queryId())).isZero(); + + // subsequent sum aggregation query should read from source as it doesn't match count plan signature + MaterializedResultWithPlan sumWithCache = executeWithPlan(withCacheEnabled(), sumQuery); + assertThat(getScanOperatorInputPositions(sumWithCache.queryId())).isPositive(); + } + + @Test + public void testSubsequentQueryReadsFromCache() + { + @Language("SQL") String selectQuery = "select orderkey from lineitem union all (select orderkey from lineitem union all select orderkey from lineitem)"; + MaterializedResultWithPlan resultWithCache = executeWithPlan(withCacheEnabled(), selectQuery); + + // make sure data was cached + assertThat(getCacheDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + + resultWithCache = executeWithPlan(withCacheEnabled(), "select orderkey from lineitem union all select orderkey from lineitem"); + // make sure data was read from cache as data should be cached across queries + assertThat(getLoadCachedDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(resultWithCache.queryId())).isZero(); + } + + @Test + public void testSubsequentQueryReadsFromCacheWithPredicateOnDataColumn() + { + if (!supportsDataColumnPruning()) { + abort("Data column pruning is not supported"); + } + + MaterializedResultWithPlan resultWithCache = executeWithPlan( + withCacheEnabled(), + "SELECT partkey FROM lineitem WHERE orderkey BETWEEN 0 AND 1000000000"); + + // make sure data was cached + assertThat(getCacheDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + + resultWithCache = executeWithPlan( + withCacheEnabled(), + "SELECT partkey FROM lineitem WHERE orderkey BETWEEN 0 AND 1000000001"); + // make sure data was read from cache because both "orderkey BETWEEN 0 AND 1000000000" + // and "orderkey BETWEEN 0 AND 1000000001" should evaluate to TRUE for lineitem splits + assertThat(getLoadCachedDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(resultWithCache.queryId())).isZero(); + + // query with predicate that doesn't evaluate to TRUE for lineitem splits shouldn't read from cache + resultWithCache = executeWithPlan( + withCacheEnabled(), + "SELECT partkey FROM lineitem WHERE orderkey BETWEEN 0 AND 1000"); + assertThat(getLoadCachedDataOperatorInputPositions(resultWithCache.queryId())).isZero(); + } + + @Test + public void testSubsequentQueryReadsFromCacheWithDynamicFilterOnDataColumn() + { + if (!supportsDataColumnPruning()) { + abort("Data column pruning is not supported"); + } + + MaterializedResultWithPlan resultWithCache = executeWithPlan( + withCacheEnabled(), + """ + SELECT partkey FROM lineitem l JOIN + (SELECT suppkey, orderkey FROM (VALUES (2, 17125), (3, 60000), (4, 60000)) t(suppkey, orderkey)) o + ON l.suppkey = o.suppkey AND l.orderkey <= o.orderkey + """); + // make sure data was cached + assertThat(getCacheDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + assertThat(getScanSplitsWithDynamicFiltersApplied(resultWithCache.queryId())).isPositive(); + + resultWithCache = executeWithPlan( + withCacheEnabled(), + """ + SELECT partkey FROM lineitem l JOIN + (SELECT suppkey, orderkey FROM (VALUES (2, 17125), (3, 60000), (4, 60001)) t(suppkey, orderkey)) o + ON l.suppkey = o.suppkey AND l.orderkey <= o.orderkey + """); + // make sure data was read from cache because dynamic filters for "l.orderkey < o.orderkey" + // should evaluate to TRUE for both queries since the highest lineitem "orderkey" value is 60000 + assertThat(getLoadCachedDataOperatorInputPositions(resultWithCache.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(resultWithCache.queryId())).isZero(); + + // query with dynamic filter that doesn't evaluate to TRUE for lineitem splits shouldn't read from cache + resultWithCache = executeWithPlan( + withCacheEnabled(), + """ + SELECT partkey FROM lineitem l JOIN + (SELECT suppkey, orderkey FROM (VALUES (2, 17125), (3, 59999), (4, 59999)) t(suppkey, orderkey)) o + ON l.suppkey = o.suppkey AND l.orderkey <= o.orderkey + """); + assertThat(getLoadCachedDataOperatorInputPositions(resultWithCache.queryId())).isZero(); + assertThat(getScanSplitsWithDynamicFiltersApplied(resultWithCache.queryId())).isPositive(); + } + + @ParameterizedTest + @MethodSource("isDynamicRowFilteringEnabled") + public void testDynamicFilterCache(boolean isDynamicRowFilteringEnabled) + { + createPartitionedTableAsSelect("orders_part", ImmutableList.of("custkey"), "select orderkey, orderdate, orderpriority, mod(custkey, 10) as custkey from orders"); + @Language("SQL") String totalScanOrdersQuery = "select count(orderkey) from orders_part"; + @Language("SQL") String firstJoinQuery = """ + select count(orderkey) from orders_part o join (select * from (values 0, 1, 2) t(custkey)) t on o.custkey = t.custkey + union all + select count(orderkey) from orders_part o join (select * from (values 0, 1, 2) t(custkey)) t on o.custkey = t.custkey + """; + @Language("SQL") String secondJoinQuery = """ + select count(orderkey) from orders_part o join (select * from (values 0, 1, 2, 4) t(custkey)) t on o.custkey = t.custkey + union all + select count(orderkey) from orders_part o join (select * from (values 0, 1, 2, 3) t(custkey)) t on o.custkey = t.custkey + """; + @Language("SQL") String thirdJoinQuery = """ + select count(orderkey) from orders_part o join (select * from (values 0, 1) t(custkey)) t on o.custkey = t.custkey + union all + select count(orderkey) from orders_part o join (select * from (values 0, 1) t(custkey)) t on o.custkey = t.custkey + """; + + Session cacheSubqueriesEnabled = withDynamicRowFiltering(withCacheEnabled(), isDynamicRowFilteringEnabled); + Session cacheSubqueriesDisabled = withDynamicRowFiltering(withCacheDisabled(), isDynamicRowFilteringEnabled); + MaterializedResultWithPlan totalScanOrdersExecution = executeWithPlan(cacheSubqueriesDisabled, totalScanOrdersQuery); + MaterializedResultWithPlan firstJoinExecution = executeWithPlan(cacheSubqueriesEnabled, firstJoinQuery); + MaterializedResultWithPlan anotherFirstJoinExecution = executeWithPlan(cacheSubqueriesEnabled, firstJoinQuery); + MaterializedResultWithPlan secondJoinExecution = executeWithPlan(cacheSubqueriesEnabled, secondJoinQuery); + MaterializedResultWithPlan thirdJoinExecution = executeWithPlan(cacheSubqueriesEnabled, thirdJoinQuery); + + // firstJoinQuery does not read whole probe side as some splits were pruned by dynamic filters + assertThat(getScanOperatorInputPositions(firstJoinExecution.queryId())).isLessThan(getScanOperatorInputPositions(totalScanOrdersExecution.queryId())); + assertThat(getCacheDataOperatorInputPositions(firstJoinExecution.queryId())).isPositive(); + // firstJoinQuery reads from table + assertThat(getScanOperatorInputPositions(firstJoinExecution.queryId())).isPositive(); + // second run of firstJoinQuery reads only from cache + assertThat(getScanOperatorInputPositions(anotherFirstJoinExecution.queryId())).isZero(); + assertThat(getLoadCachedDataOperatorInputPositions(anotherFirstJoinExecution.queryId())).isPositive(); + + // secondJoinQuery reads from table and cache because its predicate is wider that firstJoinQuery's predicate + assertThat(getCacheDataOperatorInputPositions(secondJoinExecution.queryId())).isPositive(); + assertThat(getLoadCachedDataOperatorInputPositions(secondJoinExecution.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(secondJoinExecution.queryId())).isPositive(); + + // thirdJoinQuery reads only from cache + assertThat(getLoadCachedDataOperatorInputPositions(thirdJoinExecution.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(thirdJoinExecution.queryId())).isZero(); + + assertUpdate("drop table orders_part"); + } + + @Test + public void testPredicateOnPartitioningColumnThatWasNotFullyPushed() + { + createPartitionedTableAsSelect("orders_part", ImmutableList.of("orderkey"), "select orderdate, orderpriority, mod(orderkey, 50) as orderkey from orders"); + // mod predicate will be not pushed to connector + @Language("SQL") String query = + """ + select * from ( + select orderdate from orders_part where orderkey > 5 and mod(orderkey, 10) = 0 and orderpriority = '1-MEDIUM' + union all + select orderdate from orders_part where orderkey > 10 and mod(orderkey, 10) = 1 and orderpriority = '3-MEDIUM' + ) order by orderdate + """; + MaterializedResultWithPlan cacheDisabledResult = executeWithPlan(withCacheDisabled(), query); + executeWithPlan(withCacheEnabled(), query); + MaterializedResultWithPlan cacheEnabledResult = executeWithPlan(withCacheEnabled(), query); + + assertThat(getLoadCachedDataOperatorInputPositions(cacheEnabledResult.queryId())).isPositive(); + assertThat(cacheDisabledResult.result()).isEqualTo(cacheEnabledResult.result()); + assertUpdate("drop table orders_part"); + } + + @Test + public void testCacheWhenProjectionsWerePushedDown() + { + computeActual("create table orders_with_row (c row(name varchar, lastname varchar, age integer))"); + computeActual("insert into orders_with_row values (row (row ('any_name', 'any_lastname', 25)))"); + + @Language("SQL") String query = "select c.name, c.age from orders_with_row union all select c.name, c.age from orders_with_row"; + @Language("SQL") String secondQuery = "select c.lastname, c.age from orders_with_row union all select c.lastname, c.age from orders_with_row"; + + Session cacheEnabledProjectionDisabled = withProjectionPushdownEnabled(withCacheEnabled(), false); + + MaterializedResultWithPlan firstRun = executeWithPlan(withCacheEnabled(), query); + assertThat(firstRun.result().getRowCount()).isEqualTo(2); + assertThat(firstRun.result().getMaterializedRows().get(0).getFieldCount()).isEqualTo(2); + assertThat(getCacheDataOperatorInputPositions(firstRun.queryId())).isPositive(); + + // should use cache + MaterializedResultWithPlan secondRun = executeWithPlan(withCacheEnabled(), query); + assertThat(secondRun.result().getRowCount()).isEqualTo(2); + assertThat(secondRun.result().getMaterializedRows().get(0).getFieldCount()).isEqualTo(2); + assertThat(getLoadCachedDataOperatorInputPositions(secondRun.queryId())).isPositive(); + + // shouldn't use cache because selected cacheColumnIds were different in the first case as projections were pushed down + MaterializedResultWithPlan pushDownProjectionDisabledRun = executeWithPlan(cacheEnabledProjectionDisabled, query); + assertThat(pushDownProjectionDisabledRun.result()).isEqualTo(firstRun.result()); + + // shouldn't use cache because selected columns are different + MaterializedResultWithPlan thirdRun = executeWithPlan(withCacheEnabled(), secondQuery); + assertThat(getLoadCachedDataOperatorInputPositions(thirdRun.queryId())).isLessThanOrEqualTo(1); + + assertUpdate("drop table orders_with_row"); + } + + @Test + public void testPartitionedQueryCache() + { + createPartitionedTableAsSelect("orders_part", ImmutableList.of("orderpriority"), "select orderkey, orderdate, orderpriority from orders"); + @Language("SQL") String selectTwoPartitions = """ + select orderkey from orders_part where orderpriority IN ('3-MEDIUM', '1-URGENT') + union all + select orderkey from orders_part where orderpriority IN ('3-MEDIUM', '1-URGENT') + """; + @Language("SQL") String selectAllPartitions = """ + select orderkey from orders_part + union all + select orderkey from orders_part + """; + @Language("SQL") String selectSinglePartition = """ + select orderkey from orders_part where orderpriority = '3-MEDIUM' + union all + select orderkey from orders_part where orderpriority = '3-MEDIUM' + """; + + MaterializedResultWithPlan twoPartitionsQueryFirst = executeWithPlan(withCacheEnabled(), selectTwoPartitions); + Plan twoPartitionsQueryPlan = getDistributedQueryRunner().getQueryPlan(twoPartitionsQueryFirst.queryId()); + MaterializedResultWithPlan twoPartitionsQuerySecond = executeWithPlan(withCacheEnabled(), selectTwoPartitions); + + MaterializedResultWithPlan allPartitionsQuery = executeWithPlan(withCacheEnabled(), selectAllPartitions); + Plan allPartitionsQueryPlan = getDistributedQueryRunner().getQueryPlan(allPartitionsQuery.queryId()); + + String catalogId = withTransaction(session -> getDistributedQueryRunner().getCoordinator() + .getPlannerContext().getMetadata() + .getCatalogHandle(session, session.getCatalog().get()) + .orElseThrow() + .getId()); + + PlanSignatureWithPredicate signature = new PlanSignatureWithPredicate( + new PlanSignature( + scanFilterProjectKey(new CacheTableId(catalogId + ":" + getCacheTableId(getSession(), "orders_part"))), + Optional.empty(), + ImmutableList.of(getCacheColumnId(getSession(), "orders_part", "orderkey")), + ImmutableList.of(BIGINT)), + TupleDomain.all()); + + PlanMatchPattern chooseAlternativeNode = chooseAlternativeNode( + tableScan("orders_part"), + cacheDataPlanNode(tableScan("orders_part")), + node(LoadCachedDataPlanNode.class) + .with(LoadCachedDataPlanNode.class, node -> node.getPlanSignature().equals(signature))); + + PlanMatchPattern originalPlanPattern = anyTree(chooseAlternativeNode, chooseAlternativeNode); + + // predicate for both original plans were pushed down to tableHandle what means that there is no + // filter nodes. As a result, there is a same plan signatures for both (actually different) queries + assertPlan(getSession(), twoPartitionsQueryPlan, originalPlanPattern); + assertPlan(getSession(), allPartitionsQueryPlan, originalPlanPattern); + + // make sure that full scan reads data from table instead of basing on cache even though + // plan signature is same + assertThat(getScanOperatorInputPositions(twoPartitionsQueryFirst.queryId())).isPositive(); + assertThat(getScanOperatorInputPositions(twoPartitionsQuerySecond.queryId())).isZero(); + assertThat(getScanOperatorInputPositions(allPartitionsQuery.queryId())).isPositive(); + + // notFilteringExecution should read from both cache (for partitions pre-loaded by filtering executions) and + // from source table + assertThat(getLoadCachedDataOperatorInputPositions(allPartitionsQuery.queryId())).isPositive(); + + // single partition query should read from cache only because data for all partitions have been pre-loaded + MaterializedResultWithPlan singlePartitionQuery = executeWithPlan(withCacheEnabled(), selectSinglePartition); + assertThat(getScanOperatorInputPositions(singlePartitionQuery.queryId())).isZero(); + assertThat(getLoadCachedDataOperatorInputPositions(singlePartitionQuery.queryId())).isPositive(); + + // make sure that adding new partition doesn't invalidate existing cache entries + computeActual("insert into orders_part values (-42, date '1991-01-01', 'foo')"); + singlePartitionQuery = executeWithPlan(withCacheEnabled(), selectSinglePartition); + assertThat(getScanOperatorInputPositions(singlePartitionQuery.queryId())).isZero(); + assertThat(getLoadCachedDataOperatorInputPositions(singlePartitionQuery.queryId())).isPositive(); + + // validate results + int twoPartitionsRowCount = twoPartitionsQueryFirst.result().getRowCount(); + assertThat(twoPartitionsRowCount).isEqualTo(twoPartitionsQuerySecond.result().getRowCount()); + assertThat(twoPartitionsRowCount).isLessThan(allPartitionsQuery.result().getRowCount()); + assertThat(singlePartitionQuery.result().getRowCount()).isLessThan(twoPartitionsRowCount); + assertUpdate("drop table orders_part"); + } + + @Test + public void testCommonSubqueryCacheSplitByIntersectionOfEnforcedConstraint() + { + createPartitionedTableAsSelect("orders_part", ImmutableList.of("orderpriority"), "select orderkey, orderdate, orderpriority from orders"); + @Language("SQL") String query = """ + select orderkey from orders_part where orderpriority = '3-MEDIUM' + union all + select orderkey from orders_part where orderpriority = '1-URGENT' + """; + // no caching because enforced constraint does not intersect between subplans + MaterializedResultWithPlan result = executeWithPlan(withCommonSubqueryCacheEnabled(), query); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isZero(); + result = executeWithPlan(withCommonSubqueryCacheEnabled(), query); + assertThat(getScanOperatorInputPositions(result.queryId())).isPositive(); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isZero(); + query = """ + select orderkey from orders_part where orderpriority = '1-URGENT' + union all + select orderkey from orders_part where orderpriority = '1-URGENT' + """; + executeWithPlan(withCommonSubqueryCacheEnabled(), query); + result = executeWithPlan(withCommonSubqueryCacheEnabled(), query); + assertThat(getScanOperatorInputPositions(result.queryId())).isZero(); + assertThat(getLoadCachedDataOperatorInputPositions(result.queryId())).isPositive(); + assertUpdate("drop table orders_part"); + } + + @ParameterizedTest + @MethodSource("isDynamicRowFilteringEnabled") + public void testGetUnenforcedPredicateAndPrunePredicate(boolean isDynamicRowFilteringEnabled) + { + String tableName = "get_unenforced_predicate_is_prune_and_prune_orders_part_" + isDynamicRowFilteringEnabled; + createPartitionedTableAsSelect(tableName, ImmutableList.of("orderpriority"), "select orderkey, orderdate, '9876' as orderpriority from orders"); + DistributedQueryRunner runner = getDistributedQueryRunner(); + Session session = withDynamicRowFiltering( + Session.builder(getSession()) + .setQueryId(new QueryId("prune_predicate_" + isDynamicRowFilteringEnabled)) + .build(), + isDynamicRowFilteringEnabled); + transaction(runner.getTransactionManager(), runner.getPlannerContext().getMetadata(), runner.getAccessControl()) + .singleStatement() + .execute(session, transactionSession -> { + TestingTrinoServer coordinator = runner.getCoordinator(); + TestingTrinoServer worker = runner.getServers().get(0); + checkState(!worker.isCoordinator()); + String catalog = transactionSession.getCatalog().orElseThrow(); + CatalogHandle catalogHandle = coordinator.getCatalogHandle(catalog); + // metadata.getCatalogHandle() registers the catalog for the transaction + coordinator.getPlannerContext().getMetadata().getCatalogHandle(transactionSession, catalog); + ConnectorTransactionHandle catalogTransaction = coordinator.getTransactionManager().getConnectorTransaction(transactionSession.getTransactionId().orElseThrow(), catalogHandle); + Metadata metadata = coordinator.getPlannerContext().getMetadata(); + TableHandle handle = metadata.getTableHandle( + transactionSession, + new QualifiedObjectName(catalog, transactionSession.getSchema().orElseThrow(), tableName)).orElseThrow(); + ConnectorTableHandle connectorTableHandle = handle.connectorHandle(); + + SplitSource splitSource = coordinator.getSplitManager().getSplits(transactionSession, Span.current(), handle, DynamicFilter.EMPTY, alwaysTrue()); + ConnectorSplit split = getFutureValue(splitSource.getNextBatch(1000)).getSplits().get(0).getConnectorSplit(); + + ColumnHandle partitionColumn = metadata.getColumnHandles(transactionSession, handle).get("orderpriority"); + assertThat(partitionColumn).isNotNull(); + ColumnHandle dataColumn = metadata.getColumnHandles(transactionSession, handle).get("orderkey"); + assertThat(dataColumn).isNotNull(); + + ConnectorPageSourceProvider pageSourceProvider = worker.getConnector(coordinator.getCatalogHandle(catalog)).getPageSourceProviderFactory().createPageSourceProvider(); + VarcharType type = VarcharType.createVarcharType(4); + + // getUnenforcedPredicate and prunePredicate should return none if predicate is exclusive on partition column + ConnectorSession connectorSession = transactionSession.toConnectorSession(metadata.getCatalogHandle(transactionSession, catalog).orElseThrow()); + Domain nonPartitionDomain = Domain.multipleValues(type, Streams.concat(LongStream.range(0, 9_000), LongStream.of(9_999)) + .boxed() + .map(value -> Slices.utf8Slice(value.toString())) + .collect(toImmutableList())); + assertThat(pageSourceProvider.prunePredicate( + connectorSession, + split, + connectorTableHandle, + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, nonPartitionDomain)))) + .matches(TupleDomain::isNone); + assertThat(getUnenforcedPredicate( + new PageSourceProviderInstance(pageSourceProvider), + isDynamicRowFilteringEnabled, + session, + new Split(catalogHandle, split), + new TableHandle(catalogHandle, connectorTableHandle, catalogTransaction), + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, nonPartitionDomain)))) + .matches(TupleDomain::isNone); + + // getUnenforcedPredicate and prunePredicate should prune prefilled column that matches given predicate fully + Domain partitionDomain = Domain.singleValue(type, Slices.utf8Slice("9876")); + assertThat(pageSourceProvider.prunePredicate( + connectorSession, + split, + connectorTableHandle, + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, partitionDomain)))) + .matches(TupleDomain::isAll); + assertThat(getUnenforcedPredicate( + new PageSourceProviderInstance(pageSourceProvider), + isDynamicRowFilteringEnabled, + session, + new Split(catalogHandle, split), + new TableHandle(catalogHandle, connectorTableHandle, catalogTransaction), + TupleDomain.withColumnDomains(ImmutableMap.of(partitionColumn, partitionDomain)))) + .matches(TupleDomain::isAll); + + // prunePredicate should not prune or simplify data column if there was no predicate on data column + Domain dataDomain = Domain.multipleValues(BIGINT, LongStream.range(0, 10_000) + .boxed() + .collect(toImmutableList())); + assertThat(pageSourceProvider.prunePredicate( + connectorSession, + split, + connectorTableHandle, + TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, dataDomain)))) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, dataDomain))); + + if (supportsDataColumnPruning()) { + SplitSource splitSourceWithDfOnDataColumn = coordinator.getSplitManager().getSplits( + transactionSession, + Span.current(), + handle, + getDynamicFilter(TupleDomain.withColumnDomains(ImmutableMap.of( + dataColumn, + Domain.create(ValueSet.ofRanges(Range.lessThan(BIGINT, 1_000_000L)), false)))), + alwaysTrue()); + ConnectorSplit splitWithDfOnDataColumn = getFutureValue(splitSourceWithDfOnDataColumn.getNextBatch(1000)).getSplits().get(0).getConnectorSplit(); + // getUnenforcedPredicate and prunePredicate should prune data column if there is dynamic filter on that column + Domain containingRange = Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 60_000L)), false); + assertThat(pageSourceProvider.getUnenforcedPredicate( + connectorSession, + splitWithDfOnDataColumn, + connectorTableHandle, + TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, containingRange)))) + .isEqualTo(TupleDomain.all()); + assertThat(pageSourceProvider.prunePredicate( + connectorSession, + splitWithDfOnDataColumn, + connectorTableHandle, + TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, containingRange)))) + .isEqualTo(TupleDomain.all()); + } + + if (isDynamicRowFilteringEnabled || getUnenforcedPredicateIsPrune()) { + // getUnenforcedPredicate should not prune or simplify data column + assertThat(getUnenforcedPredicate( + new PageSourceProviderInstance(pageSourceProvider), + isDynamicRowFilteringEnabled, + session, + new Split(catalogHandle, split), + new TableHandle(catalogHandle, connectorTableHandle, catalogTransaction), + TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, dataDomain)))) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, dataDomain))); + } + else { + // getUnenforcedPredicate should not prune but simplify data column + assertThat(pageSourceProvider.getUnenforcedPredicate( + connectorSession, + split, + connectorTableHandle, + TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, dataDomain)))) + .isEqualTo(TupleDomain.withColumnDomains(ImmutableMap.of(dataColumn, Domain.create(ValueSet.ofRanges(range(BIGINT, 0L, true, 9_999L, true)), false)))); + } + }); + assertUpdate("drop table " + tableName); + } + + @Test + public void testEffectivePredicateReturnedPerSplit() + { + if (!effectivePredicateReturnedPerSplit()) { + abort("Effective predicate is not returned per split"); + } + + DistributedQueryRunner runner = getDistributedQueryRunner(); + transaction(runner.getTransactionManager(), runner.getPlannerContext().getMetadata(), runner.getAccessControl()) + .singleStatement() + .execute(getSession(), transactionSession -> { + TestingTrinoServer coordinator = runner.getCoordinator(); + TestingTrinoServer worker = runner.getServers().get(0); + checkState(!worker.isCoordinator()); + String catalog = transactionSession.getCatalog().orElseThrow(); + String schema = transactionSession.getSchema().orElseThrow(); + Metadata metadata = coordinator.getPlannerContext().getMetadata(); + TableHandle handle = metadata.getTableHandle( + transactionSession, + new QualifiedObjectName(catalog, schema, "lineitem")).orElseThrow(); + ConnectorTableHandle connectorTableHandle = handle.connectorHandle(); + ColumnHandle orderKeyColumn = metadata.getColumnHandles(transactionSession, handle).get("orderkey"); + + // get table handle with filter applied + TupleDomain effectivePredicate = TupleDomain.withColumnDomains(ImmutableMap.of( + orderKeyColumn, Domain.singleValue(BIGINT, 17125L))); + Optional> filterResult = metadata.applyFilter( + transactionSession, + handle, + new Constraint(effectivePredicate)); + assertThat(filterResult).isPresent(); + TableHandle handleWithFilter = filterResult.get().getHandle(); + ConnectorTableHandle connectorTableHandleWithFilter = handleWithFilter.connectorHandle(); + + // make sure cache table ids are same for both table handles + CacheMetadata cacheMetadata = runner.getCacheMetadata(); + assertThat(cacheMetadata.getCacheTableId(transactionSession, handle)).isEqualTo(cacheMetadata.getCacheTableId(transactionSession, handleWithFilter)); + + // make sure effective predicate is propagated as part of split id + SplitSource splitSource = coordinator.getSplitManager().getSplits(transactionSession, Span.current(), handle, DynamicFilter.EMPTY, alwaysTrue()); + ConnectorSplit split = getFutureValue(splitSource.getNextBatch(1000)).getSplits().get(0).getConnectorSplit(); + + SplitSource splitSourceWithFilter = coordinator.getSplitManager().getSplits(transactionSession, Span.current(), handleWithFilter, DynamicFilter.EMPTY, alwaysTrue()); + ConnectorSplit splitWithFilter = getFutureValue(splitSourceWithFilter.getNextBatch(1000)).getSplits().get(0).getConnectorSplit(); + + ConnectorPageSourceProvider pageSourceProvider = worker.getConnector(coordinator.getCatalogHandle(catalog)).getPageSourceProviderFactory().createPageSourceProvider(); + ConnectorSession connectorSession = transactionSession.toConnectorSession(metadata.getCatalogHandle(transactionSession, catalog).orElseThrow()); + + // split for original table handle doesn't propagate any effective predicate + assertThat(pageSourceProvider.getUnenforcedPredicate(connectorSession, split, connectorTableHandle, TupleDomain.all())) + .isEqualTo(TupleDomain.all()); + // split for filtered table handle should propagate effective predicate + assertThat(pageSourceProvider.getUnenforcedPredicate(connectorSession, splitWithFilter, connectorTableHandleWithFilter, TupleDomain.all())) + .isEqualTo(effectivePredicate); + + if (supportsDataColumnPruning()) { + // make sure prunePredicate removes predicates that evaluate to ALL for a split + assertThat(pageSourceProvider.prunePredicate( + connectorSession, + splitWithFilter, + connectorTableHandleWithFilter, + TupleDomain.withColumnDomains(ImmutableMap.of( + orderKeyColumn, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(BIGINT, 60_000L)), false))))) + .isEqualTo(TupleDomain.all()); + } + }); + } + + private TupleDomain getUnenforcedPredicate( + PageSourceProvider pageSourceProvider, + boolean isDynamicRowFilteringEnabled, + Session session, + Split split, + TableHandle table, + TupleDomain predicate) + { + if (isDynamicRowFilteringEnabled) { + return getDynamicRowFilteringUnenforcedPredicate(pageSourceProvider, session, split, table, predicate); + } + return pageSourceProvider.getUnenforcedPredicate(session, split, table, predicate); + } + + protected CacheColumnId getCacheColumnId(Session session, String tableName, String columnName) + { + QueryRunner runner = getQueryRunner(); + QualifiedObjectName table = new QualifiedObjectName(session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), tableName); + return transaction(runner.getTransactionManager(), runner.getPlannerContext().getMetadata(), runner.getAccessControl()) + .singleStatement() + .execute(session, transactionSession -> { + Metadata metadata = runner.getPlannerContext().getMetadata(); + CacheMetadata cacheMetadata = runner.getCacheMetadata(); + TableHandle tableHandle = metadata.getTableHandle(transactionSession, table).get(); + return new CacheColumnId("[" + cacheMetadata.getCacheColumnId(transactionSession, tableHandle, metadata.getColumnHandles(transactionSession, tableHandle).get(columnName)).get() + "]"); + }); + } + + protected boolean effectivePredicateReturnedPerSplit() + { + return true; + } + + protected boolean supportsDataColumnPruning() + { + return true; + } + + protected boolean getUnenforcedPredicateIsPrune() + { + return false; + } + + protected CacheTableId getCacheTableId(Session session, String tableName) + { + QueryRunner runner = getQueryRunner(); + QualifiedObjectName table = new QualifiedObjectName(session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), tableName); + return transaction(runner.getTransactionManager(), runner.getPlannerContext().getMetadata(), runner.getAccessControl()) + .singleStatement() + .execute(session, transactionSession -> { + Metadata metadata = runner.getPlannerContext().getMetadata(); + CacheMetadata cacheMetadata = runner.getCacheMetadata(); + TableHandle tableHandle = metadata.getTableHandle(transactionSession, table).get(); + return cacheMetadata.getCacheTableId(transactionSession, tableHandle).get(); + }); + } + + protected void assertPlan(Session session, Plan plan, PlanMatchPattern pattern) + { + QueryRunner runner = getQueryRunner(); + transaction(runner.getTransactionManager(), runner.getPlannerContext().getMetadata(), runner.getAccessControl()) + .singleStatement() + .execute(session, transactionSession -> { + runner.getTransactionManager().getCatalogHandle(transactionSession.getTransactionId().get(), transactionSession.getCatalog().orElseThrow()); + PlanAssert.assertPlan(transactionSession, getQueryRunner().getPlannerContext().getMetadata(), createTestingFunctionManager(), noopStatsCalculator(), plan, pattern); + }); + } + + protected T withTransaction(Function transactionSessionConsumer) + { + return newTransaction().execute(getSession(), transactionSessionConsumer); + } + + protected MaterializedResultWithPlan executeWithPlan(Session session, @Language("SQL") String sql) + { + return getDistributedQueryRunner().executeWithPlan(session, sql); + } + + protected Long getScanSplitsWithDynamicFiltersApplied(QueryId queryId) + { + return getOperatorStats(queryId, TableScanOperator.class.getSimpleName(), ScanFilterAndProjectOperator.class.getSimpleName()) + .map(OperatorStats::getDynamicFilterSplitsProcessed) + .mapToLong(Long::valueOf) + .sum(); + } + + protected Long getScanOperatorInputPositions(QueryId queryId) + { + return getOperatorInputPositions(queryId, TableScanOperator.class.getSimpleName(), ScanFilterAndProjectOperator.class.getSimpleName()); + } + + protected Long getCacheDataOperatorInputPositions(QueryId queryId) + { + return getOperatorInputPositions(queryId, CacheDataOperator.class.getSimpleName()); + } + + protected Long getLoadCachedDataOperatorInputPositions(QueryId queryId) + { + return getOperatorInputPositions(queryId, LoadCachedDataOperator.class.getSimpleName()); + } + + protected Long getOperatorInputPositions(QueryId queryId, String... operatorType) + { + return getOperatorStats(queryId, operatorType) + .map(OperatorStats::getInputPositions) + .mapToLong(Long::valueOf) + .sum(); + } + + protected Stream getOperatorStats(QueryId queryId, String... operatorType) + { + ImmutableSet operatorTypes = ImmutableSet.copyOf(operatorType); + return getDistributedQueryRunner().getCoordinator() + .getQueryManager() + .getFullQueryInfo(queryId) + .getQueryStats() + .getOperatorSummaries() + .stream() + .filter(summary -> operatorTypes.contains(summary.getOperatorType())); + } + + protected Session withCacheEnabled() + { + return Session.builder(getSession()) + .setSystemProperty(ENABLE_LARGE_DYNAMIC_FILTERS, "false") + .setSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, "true") + .setSystemProperty(CACHE_AGGREGATIONS_ENABLED, "true") + .setSystemProperty(CACHE_PROJECTIONS_ENABLED, "true") + .build(); + } + + protected Session withCacheDisabled() + { + return Session.builder(getSession()) + .setSystemProperty(ENABLE_LARGE_DYNAMIC_FILTERS, "false") + .setSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, "false") + .setSystemProperty(CACHE_AGGREGATIONS_ENABLED, "false") + .setSystemProperty(CACHE_PROJECTIONS_ENABLED, "false") + .build(); + } + + protected Session withCommonSubqueryCacheEnabled() + { + return Session.builder(getSession()) + .setSystemProperty(ENABLE_LARGE_DYNAMIC_FILTERS, "false") + .setSystemProperty(CACHE_COMMON_SUBQUERIES_ENABLED, "true") + .setSystemProperty(CACHE_AGGREGATIONS_ENABLED, "false") + .setSystemProperty(CACHE_PROJECTIONS_ENABLED, "false") + .build(); + } + + protected Session withDynamicRowFiltering(Session baseSession, boolean enabled) + { + return Session.builder(baseSession) + .setSystemProperty(ENABLE_DYNAMIC_ROW_FILTERING, String.valueOf(enabled)) + .build(); + } + + protected Session withBroadcastJoin(Session baseSession) + { + return Session.builder(baseSession) + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, BROADCAST.name()) + .setSystemProperty(JOIN_REORDERING_STRATEGY, NONE.name()) + .build(); + } + + abstract protected void createPartitionedTableAsSelect(String tableName, List partitionColumns, String asSelect); + + protected Session withProjectionPushdownEnabled(Session session, boolean projectionPushdownEnabled) + { + return session; + } + + private static DynamicFilter getDynamicFilter(TupleDomain tupleDomain) + { + return new DynamicFilter() + { + @Override + public Set getColumnsCovered() + { + return tupleDomain.getDomains().map(Map::keySet) + .orElseGet(ImmutableSet::of); + } + + @Override + public CompletableFuture isBlocked() + { + return completedFuture(null); + } + + @Override + public boolean isComplete() + { + return true; + } + + @Override + public boolean isAwaitable() + { + return false; + } + + @Override + public TupleDomain getCurrentPredicate() + { + return tupleDomain; + } + }; + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index df29a287aa8e..a1fc2ff88b98 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -29,6 +29,7 @@ import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; import io.trino.Session; import io.trino.Session.SessionBuilder; +import io.trino.cache.CacheMetadata; import io.trino.client.ClientSession; import io.trino.client.StatementClient; import io.trino.connector.CoordinatorDynamicCatalogManager; @@ -425,6 +426,12 @@ public TransactionManager getTransactionManager() return coordinator.getTransactionManager(); } + @Override + public CacheMetadata getCacheMetadata() + { + return coordinator.getCacheMetadata(); + } + @Override public PlannerContext getPlannerContext() { @@ -994,7 +1001,8 @@ public interface TestingTrinoClientFactory private static TestingTrinoClient createClient(TestingTrinoServer testingTrinoServer, Session session, String encoding) { - return new TestingTrinoClient(testingTrinoServer, new TestingStatementClientFactory() { + return new TestingTrinoClient(testingTrinoServer, new TestingStatementClientFactory() + { @Override public StatementClient create(OkHttpClient httpClient, Session session, ClientSession clientSession, String query) { diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestCacheDynamicFiltering.java b/testing/trino-tests/src/test/java/io/trino/execution/TestCacheDynamicFiltering.java new file mode 100644 index 000000000000..2d7282e22af8 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestCacheDynamicFiltering.java @@ -0,0 +1,321 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.plugin.tpch.TpchPlugin; +import io.trino.spi.Plugin; +import io.trino.spi.cache.CacheColumnId; +import io.trino.spi.cache.CacheSplitId; +import io.trino.spi.cache.CacheTableId; +import io.trino.spi.cache.ConnectorCacheMetadata; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.Connector; +import io.trino.spi.connector.ConnectorContext; +import io.trino.spi.connector.ConnectorFactory; +import io.trino.spi.connector.ConnectorMetadata; +import io.trino.spi.connector.ConnectorPageSinkProvider; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.ConnectorPageSourceProvider; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.connector.ConnectorSplitManager; +import io.trino.spi.connector.ConnectorSplitSource; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.connector.Constraint; +import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.EmptyPageSource; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.transaction.IsolationLevel; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingMetadata; +import io.trino.testing.TestingMetadata.TestingColumnHandle; +import io.trino.testing.TestingMetadata.TestingTableHandle; +import io.trino.testing.TestingPageSinkProvider; +import io.trino.testing.TestingTransactionHandle; +import org.intellij.lang.annotations.Language; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Predicate; + +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.TestingSplit.createRemoteSplit; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCacheDynamicFiltering + extends AbstractTestQueryFramework +{ + private volatile Consumer> expectedCoordinatorDynamicFilterAssertion; + private volatile Predicate> expectedTableScanDynamicFilter; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("test") + .setSchema("default") + .build(); + return DistributedQueryRunner.builder(session) + .setExtraProperties(ImmutableMap.of( + "cache.enabled", "true", + "query.schedule-split-batch-size", "1")) + .build(); + } + + @BeforeAll + public void setup() + { + getQueryRunner().installPlugin(new TestingPlugin()); + getQueryRunner().installPlugin(new TpchPlugin()); + getQueryRunner().createCatalog("test", "test", ImmutableMap.of()); + getQueryRunner().createCatalog("tpch", "tpch", ImmutableMap.of()); + computeActual("CREATE TABLE orders AS SELECT * FROM tpch.tiny.orders"); + } + + @Test + public void testCacheDynamicFiltering() + { + @Language("SQL") String query = """ + select count(orderkey) from orders o join (select * from (values 0, 1) t(custkey)) t on o.custkey = t.custkey + union all + select count(orderkey) from orders o join (select * from (values 0, 2) t(custkey)) t on o.custkey = t.custkey + """; + TupleDomain firstScanDomain = TupleDomain.withColumnDomains(ImmutableMap.of( + new TestingColumnHandle("custkey", OptionalInt.of(1), Optional.of(BIGINT)), Domain.create(ValueSet.of(BIGINT, 0L, 1L), false))); + TupleDomain secondScanDomain = TupleDomain.withColumnDomains(ImmutableMap.of( + new TestingColumnHandle("custkey", OptionalInt.of(1), Optional.of(BIGINT)), Domain.create(ValueSet.of(BIGINT, 0L, 2L), false))); + // DF on worker nodes should eventually be union of DFs from first and second orders table scan + expectedTableScanDynamicFilter = tuple -> tuple.equals(TupleDomain.columnWiseUnion(firstScanDomain, secondScanDomain)); + // Coordinator should only use original DF for split enumeration + expectedCoordinatorDynamicFilterAssertion = tuple -> assertThat(tuple).isIn(firstScanDomain, secondScanDomain); + computeActual(query); + } + + private class TestingPlugin + implements Plugin + { + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ConnectorFactory() + { + private final ConnectorCacheMetadata metadata = new ConnectorCacheMetadata() + { + @Override + public Optional getCacheTableId(ConnectorTableHandle tableHandle) + { + return Optional.of(new CacheTableId(((TestingTableHandle) tableHandle).getTableName().getTableName())); + } + + @Override + public Optional getCacheColumnId(ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return Optional.of(new CacheColumnId(((TestingColumnHandle) columnHandle).getName())); + } + + @Override + public ConnectorTableHandle getCanonicalTableHandle(ConnectorTableHandle handle) + { + return handle; + } + }; + + @Override + public String getName() + { + return "test"; + } + + @Override + public Connector create(String catalogName, Map config, ConnectorContext context) + { + return new TestConnector(new TestingMetadata(), metadata); + } + }); + } + } + + private class TestConnector + implements Connector + { + private final ConnectorMetadata metadata; + private final ConnectorCacheMetadata cacheMetadata; + private final AtomicLong splitCount = new AtomicLong(); + private volatile boolean finished; + + private TestConnector(ConnectorMetadata metadata, ConnectorCacheMetadata cacheMetadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.cacheMetadata = requireNonNull(cacheMetadata, "cacheMetadata is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly, boolean autoCommit) + { + return TestingTransactionHandle.create(); + } + + @Override + public ConnectorMetadata getMetadata(ConnectorSession session, ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorCacheMetadata getCacheMetadata() + { + return cacheMetadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new ConnectorSplitManager() + { + @Override + public ConnectorSplitSource getSplits( + ConnectorTransactionHandle transaction, + ConnectorSession session, + ConnectorTableHandle table, + DynamicFilter dynamicFilter, + Constraint constraint) + { + return new ConnectorSplitSource() + { + @Override + public CompletableFuture getNextBatch(int maxSize) + { + CompletableFuture blocked = dynamicFilter.isBlocked(); + + if (blocked.isDone()) { + // prevent active looping + try { + Thread.sleep(100); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return completedFuture(new ConnectorSplitBatch(ImmutableList.of(createRemoteSplit()), isFinished())); + } + + return blocked.thenApply(ignored -> { + // yield until dynamic filter is fully loaded + return new ConnectorSplitBatch(ImmutableList.of(), false); + }); + } + + @Override + public void close() + { + } + + @Override + public boolean isFinished() + { + if (!finished) { + return false; + } + + expectedCoordinatorDynamicFilterAssertion.accept(dynamicFilter.getCurrentPredicate()); + return true; + } + }; + } + + @Override + public Optional getCacheSplitId(ConnectorSplit split) + { + return Optional.of(new CacheSplitId(Long.toString(splitCount.incrementAndGet()))); + } + }; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return new ConnectorPageSourceProvider() + { + @Override + public ConnectorPageSource createPageSource( + ConnectorTransactionHandle transaction, + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + List columns, + DynamicFilter dynamicFilter) + { + return new EmptyPageSource() + { + @Override + public boolean isFinished() + { + // cache DF on worker should not block + assertThat(dynamicFilter.isBlocked()).isDone(); + if (expectedTableScanDynamicFilter.test(dynamicFilter.getCurrentPredicate())) { + finished = true; + } + + return true; + } + }; + } + + @Override + public TupleDomain getUnenforcedPredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain dynamicFilter) + { + return dynamicFilter; + } + + @Override + public TupleDomain prunePredicate( + ConnectorSession session, + ConnectorSplit split, + ConnectorTableHandle table, + TupleDomain predicate) + { + return predicate; + } + }; + } + + @Override + public ConnectorPageSinkProvider getPageSinkProvider() + { + return new TestingPageSinkProvider(); + } + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchCacheSubqueriesTest.java b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchCacheSubqueriesTest.java new file mode 100644 index 000000000000..1b27c2d172ea --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchCacheSubqueriesTest.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests.tpch; + +import io.trino.testing.BaseCacheSubqueriesTest; +import io.trino.testing.QueryRunner; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; +import java.util.Map; + +import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_PARTITIONING_ENABLED; +import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_SPLITS_PER_NODE; +import static org.junit.jupiter.api.Assumptions.abort; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +@Execution(SAME_THREAD) +public class TestTpchCacheSubqueriesTest + extends BaseCacheSubqueriesTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return TpchQueryRunner.builder() + .addExtraProperties(EXTRA_PROPERTIES) + // cache doesn't support table partitioning yet + // create enough splits for caching to be effective + .withConnectorProperties(Map.of(TPCH_PARTITIONING_ENABLED, "false", TPCH_SPLITS_PER_NODE, "100")) + .build(); + } + + @Test + @Override + public void testCacheWhenProjectionsWerePushedDown() + { + abort("tpch does not support for pushing down projections"); + } + + @Override + @ParameterizedTest + @MethodSource("isDynamicRowFilteringEnabled") + public void testDynamicFilterCache(boolean isDynamicRowFilteringEnabled) + { + abort("tpch does not support for partitioned tables"); + } + + @Override + @Test + public void testPredicateOnPartitioningColumnThatWasNotFullyPushed() + { + abort("tpch does not support for partitioned tables"); + } + + @Override + @Test + public void testPartitionedQueryCache() + { + abort("tpch does not support for partitioned tables"); + } + + @Override + @Test + public void testCommonSubqueryCacheSplitByIntersectionOfEnforcedConstraint() + { + abort("tpch does not support for partitioned tables"); + } + + @Override + @ParameterizedTest + @MethodSource("isDynamicRowFilteringEnabled") + public void testGetUnenforcedPredicateAndPrunePredicate(boolean isDynamicRowFilteringEnabled) + { + abort("tpch does not support for partitioned tables"); + } + + @Override + protected void createPartitionedTableAsSelect(String tableName, List partitionColumns, String asSelect) + { + throw new UnsupportedOperationException("tpch does not support for partitioned tables"); + } + + @Override + protected boolean supportsDataColumnPruning() + { + return false; + } + + @Override + protected boolean effectivePredicateReturnedPerSplit() + { + return false; + } +}