diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index eef590c6c3af..fd56c8d89ae3 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -118,6 +118,7 @@ public final class SystemSessionProperties public static final String STATISTICS_PRECALCULATION_FOR_PUSHDOWN_ENABLED = "statistics_precalculation_for_pushdown_enabled"; public static final String COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES = "collect_plan_statistics_for_all_queries"; public static final String IGNORE_STATS_CALCULATOR_FAILURES = "ignore_stats_calculator_failures"; + public static final String MAX_DRIVERS_PER_QUERY = "max_drivers_per_query"; public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task"; public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled"; public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort"; @@ -519,6 +520,11 @@ public SystemSessionProperties( "Collect plan statistics for non-EXPLAIN queries", featuresConfig.isCollectPlanStatisticsForAllQueries(), false), + integerProperty( + MAX_DRIVERS_PER_QUERY, + "Maximum number of drivers per query", + taskManagerConfig.getMaxDriversPerQuery(), + false), new PropertyMetadata<>( MAX_DRIVERS_PER_TASK, "Maximum number of drivers per task", @@ -1034,6 +1040,11 @@ private static void validateHideInaccesibleColumns(boolean value, boolean defaul } } + public static int getMaxDriversPerQuery(Session session) + { + return session.getSystemProperty(MAX_DRIVERS_PER_QUERY, Integer.class); + } + public static OptionalInt getMaxDriversPerTask(Session session) { Integer value = session.getSystemProperty(MAX_DRIVERS_PER_TASK, Integer.class); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index c636b20aa1fb..c7c8d2b35e20 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -71,6 +71,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.trino.SystemSessionProperties.getInitialSplitsPerNode; +import static io.trino.SystemSessionProperties.getMaxDriversPerQuery; import static io.trino.SystemSessionProperties.getMaxDriversPerTask; import static io.trino.SystemSessionProperties.getSplitConcurrencyAdjustmentInterval; import static io.trino.execution.SqlTaskExecution.SplitsState.ADDING_SPLITS; @@ -263,7 +264,8 @@ private static TaskHandle createTaskHandle( outputBuffer::getUtilization, getInitialSplitsPerNode(taskContext.getSession()), getSplitConcurrencyAdjustmentInterval(taskContext.getSession()), - getMaxDriversPerTask(taskContext.getSession())); + getMaxDriversPerTask(taskContext.getSession()), + getMaxDriversPerQuery(taskContext.getSession())); taskStateMachine.addStateChangeListener(state -> { if (state.isDone()) { taskExecutor.removeTask(taskHandle); diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java index e19f460c8cba..e0fcc0f40158 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java @@ -55,6 +55,7 @@ public class TaskManagerConfig private boolean shareIndexLoading; private int maxWorkerThreads = Runtime.getRuntime().availableProcessors() * 2; private Integer minDrivers; + private Integer maxDriversPerQuery; private Integer initialSplitsPerNode; private int minDriversPerTask = 3; private int maxDriversPerTask = Integer.MAX_VALUE; @@ -287,6 +288,24 @@ public TaskManagerConfig setMinDrivers(int minDrivers) return this; } + @Min(1) + public int getMaxDriversPerQuery() + { + if (maxDriversPerQuery == null) { + // minDrivers has the higher priority over maxDriversPerQuery. + // That means maxDriversPerQuery is capped by minDrivers. + return getMinDrivers(); + } + return maxDriversPerQuery; + } + + @Config("task.max-drivers-per-query") + public TaskManagerConfig setMaxDriversPerQuery(int maxDrivers) + { + this.maxDriversPerQuery = maxDrivers; + return this; + } + @Min(1) public int getMaxDriversPerTask() { diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/DriverLimitPerQuery.java b/core/trino-main/src/main/java/io/trino/execution/executor/DriverLimitPerQuery.java new file mode 100644 index 000000000000..918ea9486cf5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/executor/DriverLimitPerQuery.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.executor; + +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +public class DriverLimitPerQuery +{ + private static final String ACCESS_TO_UNREFERENCED_INSTANCE_ERROR_MESSAGE = "Attempt to access unreferenced instance"; + + private final int maxDriversPerQuery; + private final AtomicInteger numberOfDrivers; + private final AtomicInteger referenceCount; + + public DriverLimitPerQuery(int maxDriversPerQuery) + { + this.maxDriversPerQuery = maxDriversPerQuery; + numberOfDrivers = new AtomicInteger(0); + referenceCount = new AtomicInteger(0); + } + + public void increase() + { + checkSanity(); + checkState(numberOfDrivers.getAndIncrement() >= 0, "numberOfDrivers is a negative number"); + } + + public void decrease() + { + checkSanity(); + checkState(numberOfDrivers.decrementAndGet() >= 0, "numberOfDrivers turned into a negative number"); + } + + public void subtract(int delta) + { + checkSanity(); + checkArgument(delta > 0, "delta is equal to or less than zero"); + checkState(numberOfDrivers.addAndGet(delta * -1) >= 0, "numberOfDrivers turned into a negative number"); + } + + public boolean isFull() + { + checkSanity(); + return numberOfDrivers.get() >= maxDriversPerQuery; + } + + public void addInitialReference() + { + checkState(referenceCount.getAndIncrement() == 0, "referenceCount is non-zero when initialize"); + } + + public void addReference() + { + checkState(referenceCount.getAndIncrement() > 0, ACCESS_TO_UNREFERENCED_INSTANCE_ERROR_MESSAGE); + } + + public boolean dereference() + { + int currentCount = referenceCount.decrementAndGet(); + if (currentCount < 0) { + throw new IllegalStateException(ACCESS_TO_UNREFERENCED_INSTANCE_ERROR_MESSAGE); + } + // is unreferenced? + return currentCount == 0; + } + + private void checkSanity() + { + checkState(referenceCount.get() > 0, ACCESS_TO_UNREFERENCED_INSTANCE_ERROR_MESSAGE); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java index c4a3b04119b4..ffc0a4b4d8f0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java @@ -27,6 +27,7 @@ import io.trino.execution.SplitRunner; import io.trino.execution.TaskId; import io.trino.execution.TaskManagerConfig; +import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.spi.VersionEmbedder; import org.weakref.jmx.Managed; @@ -39,6 +40,7 @@ import javax.inject.Inject; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; @@ -154,6 +156,9 @@ public class TaskExecutor private final TimeStat blockedQuantaWallTime = new TimeStat(MICROSECONDS); private final TimeStat unblockedQuantaWallTime = new TimeStat(MICROSECONDS); + @GuardedBy("this") + private final Map queryIdDriverLimitPerQueryMap = new HashMap<>(); + private volatile boolean closed; @Inject @@ -254,7 +259,8 @@ public synchronized TaskHandle addTask( DoubleSupplier utilizationSupplier, int initialSplitConcurrency, Duration splitConcurrencyAdjustFrequency, - OptionalInt maxDriversPerTask) + OptionalInt maxDriversPerTask, + int maxDriversPerQuery) { requireNonNull(taskId, "taskId is null"); requireNonNull(utilizationSupplier, "utilizationSupplier is null"); @@ -263,7 +269,24 @@ public synchronized TaskHandle addTask( log.debug("Task scheduled %s", taskId); - TaskHandle taskHandle = new TaskHandle(taskId, waitingSplits, utilizationSupplier, initialSplitConcurrency, splitConcurrencyAdjustFrequency, maxDriversPerTask); + QueryId queryId = taskId.getQueryId(); + DriverLimitPerQuery driverLimitPerQuery = queryIdDriverLimitPerQueryMap.get(queryId); + if (driverLimitPerQuery == null) { + driverLimitPerQuery = new DriverLimitPerQuery(maxDriversPerQuery); + driverLimitPerQuery.addInitialReference(); + queryIdDriverLimitPerQueryMap.put(queryId, driverLimitPerQuery); + } + else { + driverLimitPerQuery.addReference(); + } + TaskHandle taskHandle = new TaskHandle( + taskId, + waitingSplits, + utilizationSupplier, + initialSplitConcurrency, + splitConcurrencyAdjustFrequency, + maxDriversPerTask, + driverLimitPerQuery); tasks.add(taskHandle); return taskHandle; @@ -285,6 +308,9 @@ private void doRemoveTask(TaskHandle taskHandle) synchronized (this) { tasks.remove(taskHandle); splits = taskHandle.destroy(); + if (taskHandle.dereferenceFromDriverLimitPerQuery()) { + queryIdDriverLimitPerQueryMap.remove(taskHandle.getTaskId().getQueryId()); + } // stop tracking splits (especially blocked splits which may never unblock) allSplits.removeAll(splits); @@ -443,6 +469,10 @@ private synchronized PrioritizedSplitRunner pollNextSplitWorker() if (task.getRunningLeafSplits() >= task.getMaxDriversPerTask().orElse(maximumNumberOfDriversPerTask)) { continue; } + // skip tasks whose max number of drivers per query value is equal to or less than the current total running leaf splits + if (task.isDriverLimitPerQueryExceeded()) { + continue; + } PrioritizedSplitRunner split = task.pollNextSplit(); if (split != null) { // move task to end of list diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/TaskHandle.java b/core/trino-main/src/main/java/io/trino/execution/executor/TaskHandle.java index b07e21067775..76d0fc538af0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/TaskHandle.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/TaskHandle.java @@ -58,6 +58,7 @@ public class TaskHandle protected final AtomicReference priority = new AtomicReference<>(new Priority(0, 0)); private final MultilevelSplitQueue splitQueue; private final OptionalInt maxDriversPerTask; + private final DriverLimitPerQuery driverLimitPerQuery; public TaskHandle( TaskId taskId, @@ -65,12 +66,14 @@ public TaskHandle( DoubleSupplier utilizationSupplier, int initialSplitConcurrency, Duration splitConcurrencyAdjustFrequency, - OptionalInt maxDriversPerTask) + OptionalInt maxDriversPerTask, + DriverLimitPerQuery driverLimitPerQuery) { this.taskId = requireNonNull(taskId, "taskId is null"); this.splitQueue = requireNonNull(splitQueue, "splitQueue is null"); this.utilizationSupplier = requireNonNull(utilizationSupplier, "utilizationSupplier is null"); this.maxDriversPerTask = requireNonNull(maxDriversPerTask, "maxDriversPerTask is null"); + this.driverLimitPerQuery = requireNonNull(driverLimitPerQuery, "driverLimitPerQuery is null"); this.concurrencyController = new SplitConcurrencyController( initialSplitConcurrency, requireNonNull(splitConcurrencyAdjustFrequency, "splitConcurrencyAdjustFrequency is null")); @@ -114,6 +117,16 @@ public TaskId getTaskId() return taskId; } + public boolean isDriverLimitPerQueryExceeded() + { + return driverLimitPerQuery.isFull(); + } + + public boolean dereferenceFromDriverLimitPerQuery() + { + return driverLimitPerQuery.dereference(); + } + public OptionalInt getMaxDriversPerTask() { return maxDriversPerTask; @@ -123,6 +136,10 @@ public OptionalInt getMaxDriversPerTask() public synchronized List destroy() { destroyed = true; + int runningLeafSplitsSize = runningLeafSplits.size(); + if (runningLeafSplitsSize > 0) { + driverLimitPerQuery.subtract(runningLeafSplitsSize); + } ImmutableList.Builder builder = ImmutableList.builder(); builder.addAll(runningIntermediateSplits); @@ -169,6 +186,7 @@ public synchronized PrioritizedSplitRunner pollNextSplit() PrioritizedSplitRunner split = queuedLeafSplits.poll(); if (split != null) { runningLeafSplits.add(split); + driverLimitPerQuery.increase(); } return split; } @@ -177,7 +195,9 @@ public synchronized void splitComplete(PrioritizedSplitRunner split) { concurrencyController.splitFinished(split.getScheduledNanos(), utilizationSupplier.getAsDouble(), runningLeafSplits.size()); runningIntermediateSplits.remove(split); - runningLeafSplits.remove(split); + if (runningLeafSplits.remove(split)) { + driverLimitPerQuery.decrease(); + } } public int getNextSplitId() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java index 2061d6d7600b..32d19f581692 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java @@ -46,6 +46,7 @@ public void testDefaults() .setTaskCpuTimerEnabled(true) .setMaxWorkerThreads(Runtime.getRuntime().availableProcessors() * 2) .setMinDrivers(Runtime.getRuntime().availableProcessors() * 2 * 2) + .setMaxDriversPerQuery(Runtime.getRuntime().availableProcessors() * 2 * 2) .setMinDriversPerTask(3) .setMaxDriversPerTask(Integer.MAX_VALUE) .setInfoMaxAge(new Duration(15, TimeUnit.MINUTES)) @@ -86,6 +87,7 @@ public void testExplicitPropertyMappings() .put("task.max-local-exchange-buffer-size", "33MB") .put("task.max-worker-threads", "3") .put("task.min-drivers", "2") + .put("task.max-drivers-per-query", "2") .put("task.min-drivers-per-task", "5") .put("task.max-drivers-per-task", "13") .put("task.info.max-age", "22m") @@ -117,6 +119,7 @@ public void testExplicitPropertyMappings() .setMaxLocalExchangeBufferSize(DataSize.of(33, Unit.MEGABYTE)) .setMaxWorkerThreads(3) .setMinDrivers(2) + .setMaxDriversPerQuery(2) .setMinDriversPerTask(5) .setMaxDriversPerTask(13) .setInfoMaxAge(new Duration(22, TimeUnit.MINUTES)) diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/SimulationTask.java b/core/trino-main/src/test/java/io/trino/execution/executor/SimulationTask.java index 58a3a640283a..02574c957140 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/SimulationTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/SimulationTask.java @@ -40,7 +40,7 @@ public SimulationTask(TaskExecutor taskExecutor, TaskSpecification specification { this.specification = specification; this.taskId = taskId; - taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, SECONDS), OptionalInt.empty()); + taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, SECONDS), OptionalInt.empty(), Integer.MAX_VALUE); } public void setKilled() diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java index 47465890f76f..4c9287ec1894 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java @@ -58,7 +58,7 @@ public void testTasksComplete() try { TaskId taskId = new TaskId(new StageId("test", 0), 0, 0); - TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE); Phaser beginPhase = new Phaser(); beginPhase.register(); @@ -151,8 +151,8 @@ public void testQuantaFairness() ticker.increment(20, MILLISECONDS); try { - TaskHandle shortQuantaTaskHandle = taskExecutor.addTask(new TaskId(new StageId("short_quanta", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); - TaskHandle longQuantaTaskHandle = taskExecutor.addTask(new TaskId(new StageId("long_quanta", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle shortQuantaTaskHandle = taskExecutor.addTask(new TaskId(new StageId("short_quanta", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE); + TaskHandle longQuantaTaskHandle = taskExecutor.addTask(new TaskId(new StageId("long_quanta", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE); Phaser endQuantaPhaser = new Phaser(); @@ -185,7 +185,7 @@ public void testLevelMovement() ticker.increment(20, MILLISECONDS); try { - TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE); Phaser globalPhaser = new Phaser(); globalPhaser.bulkRegister(3); // 2 taskExecutor threads + test thread @@ -226,9 +226,9 @@ public void testLevelMultipliers() try { for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { TaskHandle[] taskHandles = { - taskExecutor.addTask(new TaskId(new StageId("test1", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()), - taskExecutor.addTask(new TaskId(new StageId("test2", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()), - taskExecutor.addTask(new TaskId(new StageId("test3", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()) + taskExecutor.addTask(new TaskId(new StageId("test1", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE), + taskExecutor.addTask(new TaskId(new StageId("test2", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE), + taskExecutor.addTask(new TaskId(new StageId("test3", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE) }; // move task 0 to next level @@ -302,7 +302,7 @@ public void testTaskHandle() try { TaskId taskId = new TaskId(new StageId("test", 0), 0, 0); - TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle taskHandle = taskExecutor.addTask(taskId, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE); Phaser beginPhase = new Phaser(); beginPhase.register(); @@ -333,8 +333,8 @@ public void testTaskHandle() public void testLevelContributionCap() { MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); - TaskHandle handle0 = new TaskHandle(new TaskId(new StageId("test0", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); - TaskHandle handle1 = new TaskHandle(new TaskId(new StageId("test1", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); + TaskHandle handle0 = new TaskHandle(new TaskId(new StageId("test0", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty(), new DriverLimitPerQuery(Integer.MAX_VALUE)); + TaskHandle handle1 = new TaskHandle(new TaskId(new StageId("test1", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty(), new DriverLimitPerQuery(Integer.MAX_VALUE)); for (int i = 0; i < (LEVEL_THRESHOLD_SECONDS.length - 1); i++) { long levelAdvanceTime = SECONDS.toNanos(LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]); @@ -353,7 +353,7 @@ public void testLevelContributionCap() public void testUpdateLevelWithCap() { MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); - TaskHandle handle0 = new TaskHandle(new TaskId(new StageId("test0", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty()); + TaskHandle handle0 = new TaskHandle(new TaskId(new StageId("test0", 0), 0, 0), splitQueue, () -> 1, 1, new Duration(1, SECONDS), OptionalInt.empty(), new DriverLimitPerQuery(Integer.MAX_VALUE)); long quantaNanos = MINUTES.toNanos(10); handle0.addScheduledNanos(quantaNanos); @@ -375,7 +375,7 @@ public void testMinMaxDriversPerTask() TaskExecutor taskExecutor = new TaskExecutor(4, 16, 1, maxDriversPerTask, splitQueue, ticker); taskExecutor.start(); try { - TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), Integer.MAX_VALUE); // enqueue all batches of splits int batchCount = 4; @@ -416,7 +416,7 @@ public void testUserSpecifiedMaxDriversPerTask() taskExecutor.start(); try { // overwrite the max drivers per task to be 1 - TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.of(1)); + TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.of(1), Integer.MAX_VALUE); // enqueue all batches of splits int batchCount = 4; @@ -461,7 +461,8 @@ public void testMinDriversPerTaskWhenTargetConcurrencyIncreases() () -> 0, 1, new Duration(1, MILLISECONDS), - OptionalInt.of(2)); + OptionalInt.of(2), + Integer.MAX_VALUE); // create 3 splits int batchCount = 3; @@ -491,6 +492,55 @@ public void testMinDriversPerTaskWhenTargetConcurrencyIncreases() } } + @Test + public void testMaxDriversPerQuery() + { + TestingTicker ticker = new TestingTicker(); + TaskExecutor taskExecutor = new TaskExecutor(4, 8, 1, 10, ticker); + taskExecutor.start(); + + try { + TaskId taskId1 = new TaskId(new StageId("test1", 0), 0, 0); + TaskId taskId2 = new TaskId(new StageId("test2", 0), 0, 0); + TaskHandle taskHandle1 = taskExecutor.addTask(taskId1, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), 2); + TaskHandle taskHandle2 = taskExecutor.addTask(taskId2, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), 1); + TaskHandle taskHandle3 = taskExecutor.addTask(taskId1, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), 5); + TaskHandle taskHandle4 = taskExecutor.addTask(taskId2, () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty(), 5); + + Phaser beginPhase = new Phaser(); + beginPhase.register(); + Phaser verificationComplete = new Phaser(); + verificationComplete.register(); + + TestingJob driver1 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver2 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver3 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver4 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver5 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + TestingJob driver6 = new TestingJob(ticker, new Phaser(), beginPhase, verificationComplete, 10, 0); + + taskExecutor.enqueueSplits(taskHandle1, false, ImmutableList.of(driver1)); + assertEquals(taskHandle1.getRunningLeafSplits(), 1); + taskExecutor.enqueueSplits(taskHandle2, false, ImmutableList.of(driver2)); + assertEquals(taskHandle2.getRunningLeafSplits(), 1); + taskExecutor.enqueueSplits(taskHandle3, false, ImmutableList.of(driver3)); + assertEquals(taskHandle3.getRunningLeafSplits(), 1); + taskExecutor.enqueueSplits(taskHandle3, false, ImmutableList.of(driver5)); + assertEquals(taskHandle3.getRunningLeafSplits(), 1); // Query "test1" exceeds maxDriversPerQuery + taskExecutor.enqueueSplits(taskHandle4, false, ImmutableList.of(driver4)); + assertEquals(taskHandle4.getRunningLeafSplits(), 1); // guaranteedNumberOfDriversPerTask has a higher priority than maxDriversPerQuery + taskExecutor.enqueueSplits(taskHandle4, false, ImmutableList.of(driver6)); + assertEquals(taskHandle4.getRunningLeafSplits(), 1); // Query "test2" exceeds maxDriversPerQuery + + // let the split continue to run + beginPhase.arriveAndDeregister(); + verificationComplete.arriveAndDeregister(); + } + finally { + taskExecutor.stop(); + } + } + private void assertSplitStates(int endIndex, TestingJob[] splits) { // assert that splits up to and including endIndex are all started