Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,18 @@ public Result optimizeAndMarkPlanChanges(PlanNode plan, Context context)
int runtimeAdaptivePartitioningPartitionCount = getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(context.session());
long runtimeAdaptivePartitioningMaxTaskSizeInBytes = getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(context.session()).toBytes();
RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider();
for (PlanFragment fragment : runtimeInfoProvider.getAllPlanFragments()) {
List<PlanFragment> fragments = runtimeInfoProvider.getAllPlanFragments();

// Skip if there are already some fragments with the maximum partition count. This is to avoid re-planning
// since currently we apply this rule on the entire plan. Once, we have a granular way of applying this rule,
// we can remove this check.
if (fragments.stream()
.anyMatch(fragment ->
fragment.getPartitionCount().orElse(maxPartitionCount) >= runtimeAdaptivePartitioningPartitionCount)) {
return new Result(plan, ImmutableSet.of());
}

for (PlanFragment fragment : fragments) {
// Skip if the stage is not consuming hash partitioned input or if the runtime stats are accurate which
// basically means that the stage can't be re-planned in the current implementation of AdaptivePlaner.
// TODO: We need add an ability to re-plan fragment whose stats are estimated by progress.
Expand All @@ -81,11 +92,6 @@ public Result optimizeAndMarkPlanChanges(PlanNode plan, Context context)
}

int partitionCount = fragment.getPartitionCount().orElse(maxPartitionCount);
// Skip if partition count is already at the maximum
if (partitionCount >= runtimeAdaptivePartitioningPartitionCount) {
continue;
}

// calculate (estimated) input data size to determine if we want to change number of partitions at runtime
List<Long> partitionedInputBytes = fragment.getRemoteSourceNodes().stream()
// skip for replicate exchange since it's assumed that broadcast join will be chosen by
Expand Down Expand Up @@ -159,8 +165,9 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext<Void> context)
.collect(toImmutableList());
PartitioningScheme partitioningScheme = node.getPartitioningScheme();

// for FTE it only makes sense to set partition count fot hash partitioned fragments
if (node.getPartitioningScheme().getPartitioning().getHandle() == FIXED_HASH_DISTRIBUTION) {
// for FTE it only makes sense to set partition count for hash partitioned fragments
if (node.getScope() == REMOTE
&& node.getPartitioningScheme().getPartitioning().getHandle() == FIXED_HASH_DISTRIBUTION) {
partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(partitionCount));
changedPlanIds.add(node.getId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ public void testJoinOrderSwitchRule()
ImmutableMap.of(
new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000),
new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500)),
matcher);
matcher,
false);
}

@Test
Expand Down Expand Up @@ -159,8 +160,8 @@ SELECT max(s.nationkey), sum(t.regionkey)
ImmutableMap.of(
new PlanFragmentId("3"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000),
new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500)),
matcher
);
matcher,
false);
}

@Test
Expand Down Expand Up @@ -199,7 +200,8 @@ public void testNoChangeToRootSubPlanIfStatsAreAccurate()
new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500),
// Since the runtime stats are accurate, adaptivePlanner will not change this subplan
new PlanFragmentId("0"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000)),
matcher);
matcher,
false);
}

@Test
Expand Down Expand Up @@ -267,8 +269,8 @@ SELECT max(s.nationkey), sum(t.regionkey)
new PlanFragmentId("3"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000),
new PlanFragmentId("4"), createRuntimeStats(ImmutableLongArray.of(10000L, 10000L, 10000L), 10000),
new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(200L, 2000L, 1000L), 500)),
matcher
);
matcher,
false);
}

private OutputStatsEstimateResult createRuntimeStats(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ protected SubPlan subplan(@Language("SQL") String sql, LogicalPlanner.Stage stag
}
}

protected void assertAdaptivePlan(@Language("SQL") String sql, Session session, Map<PlanFragmentId, OutputStatsEstimateResult> completeStageStats, SubPlanMatcher subPlanMatcher)
protected void assertAdaptivePlan(@Language("SQL") String sql, Session session, Map<PlanFragmentId, OutputStatsEstimateResult> completeStageStats, SubPlanMatcher subPlanMatcher, boolean checkIdempotence)
{
assertAdaptivePlan(sql, session, planTester.getAdaptivePlanOptimizers(), completeStageStats, subPlanMatcher);
assertAdaptivePlan(sql, session, planTester.getAdaptivePlanOptimizers(), completeStageStats, subPlanMatcher, checkIdempotence);
}

protected void assertAdaptivePlan(@Language("SQL") String sql, Session session, List<AdaptivePlanOptimizer> optimizers, Map<PlanFragmentId, OutputStatsEstimateResult> completeStageStats, SubPlanMatcher subPlanMatcher)
protected void assertAdaptivePlan(@Language("SQL") String sql, Session session, List<AdaptivePlanOptimizer> optimizers, Map<PlanFragmentId, OutputStatsEstimateResult> completeStageStats, SubPlanMatcher subPlanMatcher, boolean checkIdempotence)
{
try {
planTester.inTransaction(session, transactionSession -> {
Expand All @@ -300,6 +300,16 @@ protected void assertAdaptivePlan(@Language("SQL") String sql, Session session,
subPlanMatcher,
formattedPlan));
}
if (checkIdempotence) {
SubPlan idempotentPlan = planTester.createAdaptivePlan(transactionSession, adaptivePlan, optimizers, WarningCollector.NOOP, createPlanOptimizersStatsCollector(), createRuntimeInfoProvider(adaptivePlan, completeStageStats));
String formattedIdempotentPlan = textDistributedPlan(idempotentPlan, planTester.getPlannerContext().getMetadata(), planTester.getPlannerContext().getFunctionManager(), transactionSession, false, UNKNOWN);
if (!subPlanMatcher.matches(idempotentPlan, planTester.getStatsCalculator(), transactionSession, planTester.getPlannerContext().getMetadata())) {
throw new AssertionError(format(
"Adaptive plan is not idempotent, expected [\n\n%s\n] but found [\n\n%s\n]",
subPlanMatcher,
formattedIdempotentPlan));
}
}
return null;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,43 @@
*/
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.ImmutableLongArray;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.SubPlanMatcher;
import io.trino.sql.planner.plan.AdaptivePlanNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import org.junit.jupiter.api.Test;

import java.util.Optional;

import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.trino.SystemSessionProperties.FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT;
import static io.trino.SystemSessionProperties.FAULT_TOLERANT_EXECUTION_MIN_PARTITION_COUNT;
import static io.trino.SystemSessionProperties.FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_ENABLED;
import static io.trino.SystemSessionProperties.FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_MAX_TASK_SIZE;
import static io.trino.SystemSessionProperties.FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_PARTITION_COUNT;
import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static io.trino.SystemSessionProperties.JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT;
import static io.trino.SystemSessionProperties.JOIN_REORDERING_STRATEGY;
import static io.trino.SystemSessionProperties.RETRY_POLICY;
import static io.trino.SystemSessionProperties.TASK_CONCURRENCY;
import static io.trino.operator.RetryPolicy.TASK;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.join;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.output;
import static io.trino.sql.planner.assertions.PlanMatchPattern.remoteSource;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.JoinType.LEFT;
import static io.trino.type.UnknownType.UNKNOWN;

public class TestAdaptivePartitioning
extends BasePlanTest
Expand Down Expand Up @@ -78,7 +94,49 @@ public void testCreateTableAs()
new PlanFragmentId("3"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB * 2, ONE_MB), 10000),
new PlanFragmentId("4"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500),
new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500)),
matcher);
matcher,
true);
}

@Test
public void testNoPartitionCountInLocalExchange()
{
SubPlanMatcher matcher = SubPlanMatcher.builder()
.fragmentMatcher(fm -> fm
.fragmentId(3)
.inputPartitionCount(10)
.planPattern(
output(
join(LEFT, builder -> builder
.equiCriteria(ImmutableList.of(aliases ->
new JoinNode.EquiJoinClause(
new Symbol(UNKNOWN, "suppkey"),
new Symbol(UNKNOWN, "nationkey"))))
.left(node(AdaptivePlanNode.class,
remoteSource(ImmutableList.of(new PlanFragmentId("4")))))
// validate no partitionCount in local exchange
.right(exchange(LOCAL, Optional.empty(),
node(AdaptivePlanNode.class,
remoteSource(ImmutableList.of(new PlanFragmentId("5"))))))))))
.children(
sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(4).outputPartitionCount(10).inputPartitionCount(1))
.children(sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(1).outputPartitionCount(1))),
sb1 -> sb1.fragmentMatcher(fm -> fm.fragmentId(5).outputPartitionCount(10).inputPartitionCount(1))
.children(sb2 -> sb2.fragmentMatcher(fm -> fm.fragmentId(2).outputPartitionCount(1))))
.build();

assertAdaptivePlan(
"""
SELECT l.* FROM lineitem l
LEFT JOIN nation n
ON l.suppkey = n.nationkey
""",
getSession(),
ImmutableMap.of(
new PlanFragmentId("1"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500),
new PlanFragmentId("2"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB * 2, ONE_MB), 10000)),
matcher,
true);
}

@Test
Expand Down Expand Up @@ -156,7 +214,8 @@ public void testSkipBroadcastSubtree()
new PlanFragmentId("10"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500),
new PlanFragmentId("11"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500),
new PlanFragmentId("12"), createRuntimeStats(ImmutableLongArray.of(ONE_MB, ONE_MB, ONE_MB), 500)),
matcher);
matcher,
true);
}


Expand All @@ -165,6 +224,8 @@ private Session getSession()
{
return Session.builder(getPlanTester().getDefaultSession())
.setSystemProperty(RETRY_POLICY, TASK.name())
.setSystemProperty(TASK_CONCURRENCY, "4")
.setSystemProperty(JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT, "0")
.setSystemProperty(JOIN_REORDERING_STRATEGY, OptimizerConfig.JoinReorderingStrategy.NONE.name())
.setSystemProperty(JOIN_DISTRIBUTION_TYPE, OptimizerConfig.JoinDistributionType.PARTITIONED.name())
.setSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_ENABLED, "true")
Expand Down