diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java index e19f460c8cba..888eb7ebfa14 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java @@ -81,6 +81,11 @@ public class TaskManagerConfig private BigDecimal levelTimeMultiplier = new BigDecimal(2.0); + private Duration longRunningSplitWarningThreshold = new Duration(10, TimeUnit.MINUTES); + private boolean enableInterruptStuckSplits = true; + private Duration interruptStuckSplitsTimeout = new Duration(10, TimeUnit.MINUTES); + private Duration stuckSplitDetectionInterval = new Duration(1, TimeUnit.MINUTES); + @MinDuration("1ms") @MaxDuration("10s") @NotNull @@ -463,4 +468,58 @@ public TaskManagerConfig setTaskYieldThreads(int taskYieldThreads) this.taskYieldThreads = taskYieldThreads; return this; } + + @MinDuration("1ms") + public Duration getLongRunningSplitWarningThreshold() + { + return longRunningSplitWarningThreshold; + } + + @Config("task.long-running-split-warning-threshold") + @ConfigDescription("Print out split call stack when it runs longer than this threshold") + public TaskManagerConfig setLongRunningSplitWarningThreshold(Duration longRunningSplitWarningThreshold) + { + this.longRunningSplitWarningThreshold = longRunningSplitWarningThreshold; + return this; + } + + public boolean isEnableInterruptStuckSplits() + { + return enableInterruptStuckSplits; + } + + @Config("task.enable-interrupt-stuck-splits") + public TaskManagerConfig setEnableInterruptStuckSplits(boolean enableInterruptStuckSplits) + { + this.enableInterruptStuckSplits = enableInterruptStuckSplits; + return this; + } + + @MinDuration("1s") + public Duration getInterruptStuckSplitsTimeout() + { + return interruptStuckSplitsTimeout; + } + + @Config("task.interrupt-stuck-splits-timeout") + @ConfigDescription("Interrupt task processing thread after this timeout if the thread is stuck in certain external libraries used by Trino functions") + public TaskManagerConfig setInterruptStuckSplitsTimeout(Duration interruptStuckSplitsTimeout) + { + this.interruptStuckSplitsTimeout = interruptStuckSplitsTimeout; + return this; + } + + @MinDuration("1ms") + public Duration getStuckSplitDetectionInterval() + { + return stuckSplitDetectionInterval; + } + + @Config("task.stuck-split-detection-interval") + @ConfigDescription("Interval between detecting stuck split") + public TaskManagerConfig setStuckSplitDetectionInterval(Duration stuckSplitDetectionInterval) + { + this.stuckSplitDetectionInterval = stuckSplitDetectionInterval; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java index c4a3b04119b4..33ba0b6f4e8f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java @@ -27,6 +27,8 @@ import io.trino.execution.SplitRunner; import io.trino.execution.TaskId; import io.trino.execution.TaskManagerConfig; +import io.trino.operator.scalar.JoniRegexpFunctions; +import io.trino.operator.scalar.JoniRegexpReplaceLambdaFunction; import io.trino.spi.TrinoException; import io.trino.spi.VersionEmbedder; import org.weakref.jmx.Managed; @@ -44,6 +46,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.SortedSet; @@ -54,10 +57,11 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongArray; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.DoubleSupplier; +import java.util.function.Predicate; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -66,23 +70,27 @@ import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.concurrent.Threads.threadsNamed; import static io.trino.execution.executor.MultilevelSplitQueue.computeLevel; +import static io.trino.execution.executor.PrioritizedSplitRunner.SPLIT_RUN_QUANTA; import static io.trino.version.EmbedVersion.testingVersionEmbedder; import static java.lang.Math.min; import static java.lang.String.format; +import static java.lang.System.lineSeparator; +import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; import static java.util.concurrent.TimeUnit.MICROSECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static java.util.stream.Collectors.joining; @ThreadSafe public class TaskExecutor { private static final Logger log = Logger.get(TaskExecutor.class); - // print out split call stack if it has been running for a certain amount of time - private static final Duration LONG_SPLIT_WARNING_THRESHOLD = new Duration(600, TimeUnit.SECONDS); - + private static final Predicate> POTENTIALLY_STUCK_SPLIT_PREDICATE = + elements -> elements.stream().anyMatch(TaskExecutor::usedJoniRegexpFunctions); private static final AtomicLong NEXT_RUNNER_ID = new AtomicLong(); private final ExecutorService executor; @@ -92,10 +100,13 @@ public class TaskExecutor private final int minimumNumberOfDrivers; private final int guaranteedNumberOfDriversPerTask; private final int maximumNumberOfDriversPerTask; + private final Duration longRunningSplitWarningThreshold; private final VersionEmbedder versionEmbedder; private final Ticker ticker; + private final Optional stuckSplitInterrupter; + private final ScheduledExecutorService splitMonitorExecutor = newSingleThreadScheduledExecutor(daemonThreadsNamed("TaskExecutor")); private final SortedSet runningSplitInfos = new ConcurrentSkipListSet<>(); @@ -163,21 +174,59 @@ public TaskExecutor(TaskManagerConfig config, VersionEmbedder versionEmbedder, M config.getMinDrivers(), config.getMinDriversPerTask(), config.getMaxDriversPerTask(), + config.getLongRunningSplitWarningThreshold(), + createStuckSplitInterrupter( + config.getLongRunningSplitWarningThreshold(), + config.isEnableInterruptStuckSplits(), + config.getInterruptStuckSplitsTimeout(), + config.getStuckSplitDetectionInterval(), + POTENTIALLY_STUCK_SPLIT_PREDICATE), versionEmbedder, splitQueue, Ticker.systemTicker()); } @VisibleForTesting - public TaskExecutor(int runnerThreads, int minDrivers, int guaranteedNumberOfDriversPerTask, int maximumNumberOfDriversPerTask, Ticker ticker) + public TaskExecutor( + int runnerThreads, + int minDrivers, + int guaranteedNumberOfDriversPerTask, + int maximumNumberOfDriversPerTask, + Duration longRunningSplitWarningThreshold, + Optional stuckSplitInterrupter, + Ticker ticker) { - this(runnerThreads, minDrivers, guaranteedNumberOfDriversPerTask, maximumNumberOfDriversPerTask, testingVersionEmbedder(), new MultilevelSplitQueue(2), ticker); + this(runnerThreads, + minDrivers, + guaranteedNumberOfDriversPerTask, + maximumNumberOfDriversPerTask, + longRunningSplitWarningThreshold, + stuckSplitInterrupter, + testingVersionEmbedder(), + new MultilevelSplitQueue(2), + ticker); } @VisibleForTesting - public TaskExecutor(int runnerThreads, int minDrivers, int guaranteedNumberOfDriversPerTask, int maximumNumberOfDriversPerTask, MultilevelSplitQueue splitQueue, Ticker ticker) + public TaskExecutor( + int runnerThreads, + int minDrivers, + int guaranteedNumberOfDriversPerTask, + int maximumNumberOfDriversPerTask, + Duration longRunningSplitWarningThreshold, + Optional stuckSplitInterrupter, + MultilevelSplitQueue splitQueue, + Ticker ticker) { - this(runnerThreads, minDrivers, guaranteedNumberOfDriversPerTask, maximumNumberOfDriversPerTask, testingVersionEmbedder(), splitQueue, ticker); + this(runnerThreads, + minDrivers, + guaranteedNumberOfDriversPerTask, + maximumNumberOfDriversPerTask, + longRunningSplitWarningThreshold, + stuckSplitInterrupter, + testingVersionEmbedder(), + splitQueue, + ticker); } @VisibleForTesting @@ -186,6 +235,8 @@ public TaskExecutor( int minDrivers, int guaranteedNumberOfDriversPerTask, int maximumNumberOfDriversPerTask, + Duration longRunningSplitWarningThreshold, + Optional stuckSplitInterrupter, VersionEmbedder versionEmbedder, MultilevelSplitQueue splitQueue, Ticker ticker) @@ -199,13 +250,17 @@ public TaskExecutor( this.executor = newCachedThreadPool(threadsNamed("task-processor-%s")); this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor); this.runnerThreads = runnerThreads; + this.versionEmbedder = requireNonNull(versionEmbedder, "versionEmbedder is null"); + this.stuckSplitInterrupter = requireNonNull(stuckSplitInterrupter, "stuckSplitInterrupter is null"); + this.ticker = requireNonNull(ticker, "ticker is null"); this.minimumNumberOfDrivers = minDrivers; this.guaranteedNumberOfDriversPerTask = guaranteedNumberOfDriversPerTask; this.maximumNumberOfDriversPerTask = maximumNumberOfDriversPerTask; + this.longRunningSplitWarningThreshold = longRunningSplitWarningThreshold; this.waitingSplits = requireNonNull(splitQueue, "splitQueue is null"); this.tasks = new LinkedList<>(); } @@ -217,6 +272,10 @@ public synchronized void start() for (int i = 0; i < runnerThreads; i++) { addRunnerThread(); } + stuckSplitInterrupter.ifPresent(interrupter -> { + long intervalSeconds = (long) interrupter.getStuckSplitDetectionInterval().getValue(SECONDS); + splitMonitorExecutor.scheduleAtFixedRate(() -> interrupter.interruptStuckSplits(runningSplitInfos, ticker), intervalSeconds, intervalSeconds, SECONDS); + }); } @PreDestroy @@ -457,10 +516,90 @@ private synchronized PrioritizedSplitRunner pollNextSplitWorker() return null; } + private static String toStackString(List stackTraceElements) + { + String stackString = stackTraceElements.stream() + .map(Object::toString) + .collect(joining(lineSeparator())); + + return stackString; + } + + private static boolean usedJoniRegexpFunctions(StackTraceElement stackTraceElement) + { + String className = stackTraceElement.getClassName(); + return JoniRegexpFunctions.class.getName().equals(className) + || JoniRegexpReplaceLambdaFunction.class.getName().equals(className); + } + + public static Optional createStuckSplitInterrupter( + Duration longRunningSplitWarningThreshold, + boolean enableInterruptStuckSplits, + Duration interruptStuckSplitsTimeout, + Duration stuckSplitDetectionInterval, + Predicate> potentiallyStuckSplitPredicate) + { + if (!enableInterruptStuckSplits) { + return Optional.empty(); + } + + return Optional.of(new StuckSplitInterrupter(longRunningSplitWarningThreshold, interruptStuckSplitsTimeout, stuckSplitDetectionInterval, potentiallyStuckSplitPredicate)); + } + + private static class StuckSplitInterrupter + { + private final Duration interruptStuckSplitsTimeout; + private final Duration stuckSplitDetectionInterval; + private final Predicate> potentiallyStuckSplitPredicate; + + public StuckSplitInterrupter(Duration longRunningSplitWarningThreshold, Duration interruptStuckSplitsTimeout, Duration stuckSplitDetectionInterval, Predicate> potentiallyStuckSplitPredicate) + { + requireNonNull(longRunningSplitWarningThreshold, "longRunningSplitWarningThreshold is null"); + requireNonNull(interruptStuckSplitsTimeout, "interruptStuckSplitsTimeout is null"); + requireNonNull(stuckSplitDetectionInterval, "stuckSplitDetectionInterval is null"); + requireNonNull(potentiallyStuckSplitPredicate, "potentiallyStuckSplitPredicate is null"); + checkArgument(interruptStuckSplitsTimeout.compareTo(SPLIT_RUN_QUANTA) >= 0, "interruptStuckSplitsTimeout must be at least %s", SPLIT_RUN_QUANTA); + checkArgument(longRunningSplitWarningThreshold.compareTo(interruptStuckSplitsTimeout) <= 0, "longRunningSplitWarningThreshold cannot be greater than interruptStuckSplitsTimeout"); + + this.interruptStuckSplitsTimeout = interruptStuckSplitsTimeout; + this.stuckSplitDetectionInterval = stuckSplitDetectionInterval; + this.potentiallyStuckSplitPredicate = potentiallyStuckSplitPredicate; + } + + public Duration getStuckSplitDetectionInterval() + { + return stuckSplitDetectionInterval; + } + + private void interruptStuckSplits(SortedSet runningSplitInfos, Ticker ticker) + { + for (RunningSplitInfo splitInfo : runningSplitInfos) { + Duration duration = Duration.succinctNanos(ticker.read() - splitInfo.getStartTime()); + + if (duration.compareTo(interruptStuckSplitsTimeout) > 0) { + List stackTraceElements = asList(splitInfo.getThread().getStackTrace()); + + if (!splitInfo.isPrinted()) { + splitInfo.setPrinted(); + + log.warn("Split %s has been long running %n%s", splitInfo.getSplitInfo(), toStackString(stackTraceElements)); + } + + if (potentiallyStuckSplitPredicate.test(stackTraceElements)) { + log.warn("Potentially interrupting stuck split matching the predicate %s", splitInfo.getSplitInfo()); + + splitInfo.getTaskRunner().interruptSplit(splitInfo); + } + } + } + } + } + private class TaskRunner implements Runnable { private final long runnerId = NEXT_RUNNER_ID.getAndIncrement(); + private final AtomicReference currentThreadSplit = new AtomicReference<>(); @Override public void run() @@ -478,10 +617,15 @@ public void run() } String threadId = split.getTaskHandle().getTaskId() + "-" + split.getSplitId(); + try (SetThreadName splitName = new SetThreadName(threadId)) { - RunningSplitInfo splitInfo = new RunningSplitInfo(ticker.read(), threadId, Thread.currentThread()); - runningSplitInfos.add(splitInfo); + RunningSplitInfo splitInfo = new RunningSplitInfo(ticker.read(), threadId, Thread.currentThread(), split, this); + + synchronized (this) { + currentThreadSplit.set(splitInfo); + } runningSplits.add(split); + runningSplitInfos.add(splitInfo); ListenableFuture blocked; try { @@ -524,6 +668,11 @@ public void run() } splitFinished(split); } + finally { + synchronized (this) { + currentThreadSplit.set(null); + } + } } } finally { @@ -533,6 +682,14 @@ public void run() } } } + + private synchronized void interruptSplit(RunningSplitInfo stuckSplit) + { + RunningSplitInfo runningSplit = stuckSplit.getTaskRunner().currentThreadSplit.get(); + if (runningSplit == stuckSplit) { + stuckSplit.getThread().interrupt(); + } + } } // @@ -786,7 +943,7 @@ private synchronized int getRunningTasksForLevel(int level) return count; } - public String getMaxActiveSplitsInfo() + public String getStuckSplitsInfo() { // Sample output: // @@ -804,12 +961,12 @@ public String getMaxActiveSplitsInfo() // at java.util.Formatter.format(Formatter.java:2501) // at ... (more line of stacktrace) StringBuilder stackTrace = new StringBuilder(); - int maxActiveSplitCount = 0; + int stuckSplitsCount = 0; String message = "%s splits have been continuously active for more than %s seconds\n"; for (RunningSplitInfo splitInfo : runningSplitInfos) { Duration duration = Duration.succinctNanos(ticker.read() - splitInfo.getStartTime()); - if (duration.compareTo(LONG_SPLIT_WARNING_THRESHOLD) >= 0) { - maxActiveSplitCount++; + if (duration.compareTo(longRunningSplitWarningThreshold) >= 0) { + stuckSplitsCount++; stackTrace.append("\n"); stackTrace.append(format("\"%s\" tid=%s", splitInfo.getThreadId(), splitInfo.getThread().getId())).append("\n"); for (StackTraceElement traceElement : splitInfo.getThread().getStackTrace()) { @@ -818,16 +975,16 @@ public String getMaxActiveSplitsInfo() } } - return format(message, maxActiveSplitCount, LONG_SPLIT_WARNING_THRESHOLD).concat(stackTrace.toString()); + return format(message, stuckSplitsCount, longRunningSplitWarningThreshold).concat(stackTrace.toString()); } @Managed - public long getRunAwaySplitCount() + public long getStuckSplitsCount() { int count = 0; for (RunningSplitInfo splitInfo : runningSplitInfos) { Duration duration = Duration.succinctNanos(ticker.read() - splitInfo.getStartTime()); - if (duration.compareTo(LONG_SPLIT_WARNING_THRESHOLD) > 0) { + if (duration.compareTo(longRunningSplitWarningThreshold) > 0) { count++; } } @@ -840,14 +997,23 @@ private static class RunningSplitInfo private final long startTime; private final String threadId; private final Thread thread; + private final PrioritizedSplitRunner split; private boolean printed; + private final TaskRunner taskRunner; - public RunningSplitInfo(long startTime, String threadId, Thread thread) + public RunningSplitInfo(long startTime, String threadId, Thread thread, PrioritizedSplitRunner split, TaskRunner taskRunner) { this.startTime = startTime; this.threadId = threadId; this.thread = thread; + this.split = split; this.printed = false; + this.taskRunner = taskRunner; + } + + public TaskRunner getTaskRunner() + { + return taskRunner; } public long getStartTime() @@ -875,6 +1041,11 @@ public void setPrinted() printed = true; } + public String getSplitInfo() + { + return split.getInfo(); + } + @Override public int compareTo(RunningSplitInfo o) { diff --git a/core/trino-main/src/main/java/io/trino/server/TaskExecutorResource.java b/core/trino-main/src/main/java/io/trino/server/TaskExecutorResource.java index 11ff77eafe01..c5e077d58373 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskExecutorResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskExecutorResource.java @@ -25,7 +25,7 @@ import static io.trino.server.security.ResourceSecurity.AccessType.MANAGEMENT_READ; import static java.util.Objects.requireNonNull; -@Path("/v1/maxActiveSplits") +@Path("/v1/stuckSplits") public class TaskExecutorResource { private final TaskExecutor taskExecutor; @@ -40,8 +40,8 @@ public TaskExecutorResource( @ResourceSecurity(MANAGEMENT_READ) @GET @Produces(MediaType.TEXT_PLAIN) - public String getMaxActiveSplit() + public String getStuckSplits() { - return taskExecutor.getMaxActiveSplitsInfo(); + return taskExecutor.getStuckSplitsInfo(); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java index 31932e6c88e2..a2c26baa5118 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java @@ -21,6 +21,7 @@ import io.airlift.stats.CounterStat; import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; +import io.airlift.units.Duration; import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.executor.TaskExecutor; import io.trino.memory.MemoryPool; @@ -44,6 +45,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ScheduledExecutorService; @@ -61,6 +63,7 @@ import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static java.util.Collections.singletonList; import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.TimeUnit.MINUTES; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -68,7 +71,7 @@ @Test(singleThreaded = true) public class TestMemoryRevokingScheduler { - private final AtomicInteger idGeneator = new AtomicInteger(); + private final AtomicInteger idGenerator = new AtomicInteger(); private final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(DataSize.of(10, GIGABYTE)); private final Map queryContexts = new HashMap<>(); @@ -84,7 +87,7 @@ public void setUp() { memoryPool = new MemoryPool(DataSize.ofBytes(10)); - TaskExecutor taskExecutor = new TaskExecutor(8, 16, 3, 4, Ticker.systemTicker()); + TaskExecutor taskExecutor = new TaskExecutor(8, 16, 3, 4, new Duration(10, MINUTES), Optional.empty(), Ticker.systemTicker()); taskExecutor.start(); // Must be single threaded @@ -219,9 +222,8 @@ private OperatorContext createContexts(SqlTask sqlTask) TaskContext taskContext = getOrCreateTaskContext(sqlTask); PipelineContext pipelineContext = taskContext.addPipelineContext(0, false, false, false); DriverContext driverContext = pipelineContext.addDriverContext(); - OperatorContext operatorContext = driverContext.addOperatorContext(1, new PlanNodeId("na"), "na"); - return operatorContext; + return driverContext.addOperatorContext(1, new PlanNodeId("na"), "na"); } private void requestMemoryRevoking(MemoryRevokingScheduler scheduler) @@ -256,7 +258,7 @@ private SqlTask newSqlTask(QueryId queryId) { QueryContext queryContext = getOrCreateQueryContext(queryId); - TaskId taskId = new TaskId(new StageId(queryId.getId(), 0), idGeneator.incrementAndGet(), 0); + TaskId taskId = new TaskId(new StageId(queryId.getId(), 0), idGenerator.incrementAndGet(), 0); URI location = URI.create("fake://task/" + taskId); return createSqlTask( diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index d300aa7e516a..f57b57cfd4fa 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -21,6 +21,7 @@ import io.airlift.stats.CounterStat; import io.airlift.stats.TestingGcMonitor; import io.airlift.units.DataSize; +import io.airlift.units.Duration; import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.buffer.BufferResult; import io.trino.execution.buffer.BufferState; @@ -65,6 +66,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.util.concurrent.Executors.newScheduledThreadPool; +import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; @@ -86,7 +88,7 @@ public class TestSqlTask public TestSqlTask() { - taskExecutor = new TaskExecutor(8, 16, 3, 4, Ticker.systemTicker()); + taskExecutor = new TaskExecutor(8, 16, 3, 4, new Duration(10, MINUTES), Optional.empty(), Ticker.systemTicker()); taskExecutor.start(); taskNotificationExecutor = newScheduledThreadPool(10, threadsNamed("task-notification-%s")); @@ -282,7 +284,7 @@ public void testBufferCloseOnCancel() sqlTask.cancel(); assertEquals(sqlTask.getTaskInfo().getTaskStatus().getState(), TaskState.CANCELED); - // buffer future will complete.. the event is async so wait a bit for event to propagate + // buffer future will complete, the event is async so wait a bit for event to propagate bufferResult.get(1, SECONDS); bufferResult = sqlTask.getTaskResults(OUT, 0, DataSize.of(1, MEGABYTE)); @@ -332,7 +334,7 @@ public void testDynamicFilters() ListenableFuture future = sqlTask.getTaskStatus(STARTING_VERSION); assertFalse(future.isDone()); - // make sure future gets unblocked when dynamic filters version is updated + // make sure future gets unblocked when dynamic filter's version is updated taskContext.updateDomains(ImmutableMap.of(new DynamicFilterId("filter"), Domain.none(BIGINT))); assertEquals(sqlTask.getTaskStatus().getVersion(), STARTING_VERSION + 1); assertEquals(sqlTask.getTaskStatus().getDynamicFiltersVersion(), INITIAL_DYNAMIC_FILTERS_VERSION + 1); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index 85011910961a..939efcd12596 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -50,7 +50,6 @@ import io.trino.spi.block.TestingBlockEncodingSerde; import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.UpdatablePageSource; -import io.trino.spi.type.Type; import io.trino.spiller.SpillSpaceTracker; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.plan.PlanNodeId; @@ -82,7 +81,6 @@ import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static io.trino.spi.type.VarcharType.VARCHAR; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.HOURS; @@ -104,7 +102,7 @@ public void testSimple() { ScheduledExecutorService taskNotificationExecutor = newScheduledThreadPool(10, threadsNamed("task-notification-%s")); ScheduledExecutorService driverYieldExecutor = newScheduledThreadPool(2, threadsNamed("driver-yield-%s")); - TaskExecutor taskExecutor = new TaskExecutor(5, 10, 3, 4, Ticker.systemTicker()); + TaskExecutor taskExecutor = new TaskExecutor(5, 10, 3, 4, new Duration(10, TimeUnit.MINUTES), Optional.empty(), Ticker.systemTicker()); taskExecutor.start(); try { @@ -122,7 +120,7 @@ public void testSimple() // | // Scan // - TestingScanOperatorFactory testingScanOperatorFactory = new TestingScanOperatorFactory(0, TABLE_SCAN_NODE_ID, ImmutableList.of(VARCHAR)); + TestingScanOperatorFactory testingScanOperatorFactory = new TestingScanOperatorFactory(0, TABLE_SCAN_NODE_ID); TaskOutputOperatorFactory taskOutputOperatorFactory = new TaskOutputOperatorFactory( 1, TABLE_SCAN_NODE_ID, @@ -337,8 +335,7 @@ public static class TestingScanOperatorFactory public TestingScanOperatorFactory( int operatorId, - PlanNodeId sourceId, - List types) + PlanNodeId sourceId) { this.operatorId = operatorId; this.sourceId = requireNonNull(sourceId, "sourceId is null"); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java index c8eadb64075f..ad5b0f90a7f0 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java @@ -82,7 +82,7 @@ public TestSqlTaskManager() { localMemoryManager = new LocalMemoryManager(new NodeMemoryConfig()); localSpillManager = new LocalSpillManager(new NodeSpillConfig()); - taskExecutor = new TaskExecutor(8, 16, 3, 4, Ticker.systemTicker()); + taskExecutor = new TaskExecutor(8, 16, 3, 4, new Duration(10, TimeUnit.MINUTES), Optional.empty(), Ticker.systemTicker()); taskExecutor.start(); taskManagementExecutor = new TaskManagementExecutor(); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java index 07e0de14fdf2..83331aa1a1ba 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java @@ -65,7 +65,11 @@ public void testDefaults() .setTaskNotificationThreads(5) .setTaskYieldThreads(3) .setLevelTimeMultiplier(new BigDecimal("2")) - .setStatisticsCpuTimerEnabled(true)); + .setStatisticsCpuTimerEnabled(true) + .setLongRunningSplitWarningThreshold(new Duration(10, TimeUnit.MINUTES)) + .setEnableInterruptStuckSplits(true) + .setInterruptStuckSplitsTimeout(new Duration(10, TimeUnit.MINUTES)) + .setStuckSplitDetectionInterval(new Duration(1, TimeUnit.MINUTES))); } @Test @@ -101,6 +105,10 @@ public void testExplicitPropertyMappings() .put("task.task-yield-threads", "8") .put("task.level-time-multiplier", "2.1") .put("task.statistics-cpu-timer-enabled", "false") + .put("task.long-running-split-warning-threshold", "9m") + .put("task.enable-interrupt-stuck-splits", "false") + .put("task.interrupt-stuck-splits-timeout", "9m") + .put("task.stuck-split-detection-interval", "1s") .buildOrThrow(); TaskManagerConfig expected = new TaskManagerConfig() @@ -131,7 +139,11 @@ public void testExplicitPropertyMappings() .setTaskNotificationThreads(13) .setTaskYieldThreads(8) .setLevelTimeMultiplier(new BigDecimal("2.1")) - .setStatisticsCpuTimerEnabled(false); + .setStatisticsCpuTimerEnabled(false) + .setLongRunningSplitWarningThreshold(new Duration(9, TimeUnit.MINUTES)) + .setEnableInterruptStuckSplits(false) + .setInterruptStuckSplitsTimeout(new Duration(9, TimeUnit.MINUTES)) + .setStuckSplitDetectionInterval(new Duration(1, TimeUnit.SECONDS)); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/TaskExecutorSimulator.java b/core/trino-main/src/test/java/io/trino/execution/executor/TaskExecutorSimulator.java index 23649bc8445b..a01b3f0cae31 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/TaskExecutorSimulator.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/TaskExecutorSimulator.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.LongSummaryStatistics; import java.util.Map; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -79,7 +80,7 @@ public static void main(String[] args) private TaskExecutorSimulator() { splitQueue = new MultilevelSplitQueue(2); - taskExecutor = new TaskExecutor(36, 72, 3, 8, splitQueue, Ticker.systemTicker()); + taskExecutor = new TaskExecutor(36, 72, 3, 8, new Duration(10, MINUTES), Optional.empty(), splitQueue, Ticker.systemTicker()); taskExecutor.start(); } @@ -154,7 +155,7 @@ private void runExperimentOverloadedCluster(SimulationController controller) SECONDS.sleep(30); - // this gets the executor into a more realistic point-in-time state, where long running tasks start to make progress + // this gets the executor into a more realistic point-in-time state, where long-running tasks start to make progress for (int i = 0; i < 20; i++) { controller.clearPendingQueue(); MINUTES.sleep(1); diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java index 47465890f76f..d44fa851dcad 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/TestTaskExecutor.java @@ -13,6 +13,7 @@ */ package io.trino.execution.executor; +import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; @@ -26,6 +27,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.Future; import java.util.concurrent.Phaser; @@ -38,6 +40,8 @@ import static io.airlift.testing.Assertions.assertLessThan; import static io.trino.execution.executor.MultilevelSplitQueue.LEVEL_CONTRIBUTION_CAP; import static io.trino.execution.executor.MultilevelSplitQueue.LEVEL_THRESHOLD_SECONDS; +import static io.trino.execution.executor.TaskExecutor.createStuckSplitInterrupter; +import static io.trino.version.EmbedVersion.testingVersionEmbedder; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -52,7 +56,7 @@ public void testTasksComplete() throws Exception { TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(4, 8, 3, 4, ticker); + TaskExecutor taskExecutor = new TaskExecutor(4, 8, 3, 4, new Duration(10, MINUTES), Optional.empty(), ticker); taskExecutor.start(); ticker.increment(20, MILLISECONDS); @@ -78,9 +82,9 @@ public void testTasksComplete() assertEquals(driver1.getCompletedPhases(), 0); assertEquals(driver2.getCompletedPhases(), 0); ticker.increment(60, SECONDS); - assertEquals(taskExecutor.getRunAwaySplitCount(), 0); + assertEquals(taskExecutor.getStuckSplitsCount(), 0); ticker.increment(600, SECONDS); - assertEquals(taskExecutor.getRunAwaySplitCount(), 2); + assertEquals(taskExecutor.getStuckSplitsCount(), 2); verificationComplete.arriveAndAwaitAdvance(); // advance one phase and verify @@ -135,7 +139,7 @@ public void testTasksComplete() // no splits remaining ticker.increment(610, SECONDS); - assertEquals(taskExecutor.getRunAwaySplitCount(), 0); + assertEquals(taskExecutor.getStuckSplitsCount(), 0); } finally { taskExecutor.stop(); @@ -146,7 +150,7 @@ public void testTasksComplete() public void testQuantaFairness() { TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(1, 2, 3, 4, ticker); + TaskExecutor taskExecutor = new TaskExecutor(1, 2, 3, 4, new Duration(10, MINUTES), Optional.empty(), ticker); taskExecutor.start(); ticker.increment(20, MILLISECONDS); @@ -180,7 +184,7 @@ public void testQuantaFairness() public void testLevelMovement() { TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(2, 2, 3, 4, ticker); + TaskExecutor taskExecutor = new TaskExecutor(2, 2, 3, 4, new Duration(10, MINUTES), Optional.empty(), ticker); taskExecutor.start(); ticker.increment(20, MILLISECONDS); @@ -219,7 +223,7 @@ public void testLevelMultipliers() throws Exception { TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(6, 3, 3, 4, new MultilevelSplitQueue(2), ticker); + TaskExecutor taskExecutor = new TaskExecutor(6, 3, 3, 4, new Duration(10, MINUTES), Optional.empty(), new MultilevelSplitQueue(2), ticker); taskExecutor.start(); ticker.increment(20, MILLISECONDS); @@ -297,7 +301,7 @@ public void testLevelMultipliers() public void testTaskHandle() { TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(4, 8, 3, 4, ticker); + TaskExecutor taskExecutor = new TaskExecutor(4, 8, 3, 4, new Duration(10, MINUTES), Optional.empty(), ticker); taskExecutor.start(); try { @@ -372,7 +376,7 @@ public void testMinMaxDriversPerTask() int maxDriversPerTask = 2; MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); TestingTicker ticker = new TestingTicker(); - TaskExecutor taskExecutor = new TaskExecutor(4, 16, 1, maxDriversPerTask, splitQueue, ticker); + TaskExecutor taskExecutor = new TaskExecutor(4, 16, 1, maxDriversPerTask, new Duration(10, MINUTES), Optional.empty(), splitQueue, ticker); taskExecutor.start(); try { TaskHandle testTaskHandle = taskExecutor.addTask(new TaskId(new StageId("test", 0), 0, 0), () -> 0, 10, new Duration(1, MILLISECONDS), OptionalInt.empty()); @@ -412,7 +416,7 @@ public void testUserSpecifiedMaxDriversPerTask() MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); TestingTicker ticker = new TestingTicker(); // create a task executor with min/max drivers per task to be 2 and 4 - TaskExecutor taskExecutor = new TaskExecutor(4, 16, 2, 4, splitQueue, ticker); + TaskExecutor taskExecutor = new TaskExecutor(4, 16, 2, 4, new Duration(10, MINUTES), Optional.empty(), splitQueue, ticker); taskExecutor.start(); try { // overwrite the max drivers per task to be 1 @@ -451,7 +455,7 @@ public void testMinDriversPerTaskWhenTargetConcurrencyIncreases() MultilevelSplitQueue splitQueue = new MultilevelSplitQueue(2); TestingTicker ticker = new TestingTicker(); // create a task executor with min/max drivers per task to be 2 - TaskExecutor taskExecutor = new TaskExecutor(4, 1, 2, 2, splitQueue, ticker); + TaskExecutor taskExecutor = new TaskExecutor(4, 1, 2, 2, new Duration(10, MINUTES), Optional.empty(), splitQueue, ticker); taskExecutor.start(); try { @@ -491,6 +495,39 @@ public void testMinDriversPerTaskWhenTargetConcurrencyIncreases() } } + @Test + public void testTaskExecutorStuckSplitInterrupt() + throws Exception + { + TaskExecutor taskExecutor = new TaskExecutor( + 8, + 16, + 3, + 4, + new Duration(1, SECONDS), + createStuckSplitInterrupter(new Duration(1, SECONDS), true, new Duration(1, SECONDS), new Duration(1, SECONDS), elements -> elements.stream().anyMatch(element -> element.getFileName().equals("TestTaskExecutor.java"))), + testingVersionEmbedder(), + new MultilevelSplitQueue(2), + Ticker.systemTicker()); + taskExecutor.start(); + + try { + TaskId taskId = new TaskId(new StageId("foo", 0), 0, 0); + TaskHandle taskHandle = taskExecutor.addTask( + taskId, + () -> 1.0, + 1, + new Duration(1, SECONDS), + OptionalInt.of(1)); + MockSplitRunner mockSplitRunner = new MockSplitRunner(); + taskExecutor.enqueueSplits(taskHandle, false, ImmutableList.of(mockSplitRunner)); + mockSplitRunner.interrupted.get(60, SECONDS); + } + finally { + taskExecutor.stop(); + } + } + private void assertSplitStates(int endIndex, TestingJob[] splits) { // assert that splits up to and including endIndex are all started @@ -517,6 +554,44 @@ private static void waitUntilSplitsStart(List splits) } } + private static class MockSplitRunner + implements SplitRunner + { + private SettableFuture interrupted = SettableFuture.create(); + + @Override + public boolean isFinished() + { + return interrupted.isDone(); + } + + @Override + public ListenableFuture processFor(Duration duration) + { + while (true) { + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + break; + } + } + interrupted.set(true); + return immediateVoidFuture(); + } + + @Override + public String getInfo() + { + return ""; + } + + @Override + public void close() + { + } + } + private static class TestingJob implements SplitRunner { diff --git a/docs/src/main/sphinx/admin/properties-task.rst b/docs/src/main/sphinx/admin/properties-task.rst index f93fd3af2676..b45f17a4bc37 100644 --- a/docs/src/main/sphinx/admin/properties-task.rst +++ b/docs/src/main/sphinx/admin/properties-task.rst @@ -124,3 +124,23 @@ of additional CPU for parallel writes. Some connectors can be bottlenecked on CP writing due to compression or other factors. Setting this too high may cause the cluster to become overloaded due to excessive resource utilization. This can also be specified on a per-query basis using the ``task_writer_count`` session property. + +``task.long-running-split-warning-threshold`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** :ref:`prop-type-duration` +* **Minimum value:** ``1ms`` +* **Default value:** ``10m`` + +When split runs longer than this threshold, we can get the call stack via +``/v1/stuckSplits`` endpoint on coordinator. + +``task.interrupt-stuck-splits-timeout`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** :ref:`prop-type-duration` +* **Minimum value:** ``1s`` +* **Default value:** ``10m`` + +The length of time Trino waits for a blocked split processing thread before interrupting the thread. +Only applies to threads that are blocked by the third-party Joni regular expression library.