diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index c0f47c1f8576..bc71a6b3efe7 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -144,6 +144,10 @@ public final class SystemSessionProperties public static final String TIME_ZONE_ID = "time_zone_id"; public static final String LEGACY_CATALOG_ROLES = "legacy_catalog_roles"; public static final String INCREMENTAL_HASH_ARRAY_LOAD_FACTOR_ENABLED = "incremental_hash_array_load_factor_enabled"; + public static final String ENABLE_ADAPTIVE_REMOTE_TASK_REQUEST_SIZE = "enable_adaptive_remote_task_request_size"; + public static final String MAX_REMOTE_TASK_REQUEST_SIZE = "max_remote_task_request_size"; + public static final String REMOTE_TASK_REQUEST_SIZE_HEADROOM = "remote_task_request_size_headroom"; + public static final String REMOTE_TASK_GUARANTEED_SPLITS_PER_REQUEST = "remote_task_guaranteed_splits_per_request"; private final List> sessionProperties; @@ -666,6 +670,26 @@ public SystemSessionProperties( INCREMENTAL_HASH_ARRAY_LOAD_FACTOR_ENABLED, "Use smaller load factor for small hash arrays in order to improve performance", featuresConfig.isIncrementalHashArrayLoadFactorEnabled(), + false), + booleanProperty( + ENABLE_ADAPTIVE_REMOTE_TASK_REQUEST_SIZE, + "Experimental: Enable adaptive adjustment for size of remote task update request", + queryManagerConfig.isEnabledAdaptiveTaskRequestSize(), + false), + dataSizeProperty( + MAX_REMOTE_TASK_REQUEST_SIZE, + "Experimental: Max size of remote task update request", + queryManagerConfig.getMaxRemoteTaskRequestSize(), + false), + dataSizeProperty( + REMOTE_TASK_REQUEST_SIZE_HEADROOM, + "Experimental: Headroom for size of remote task update request", + queryManagerConfig.getRemoteTaskRequestSizeHeadroom(), + false), + integerProperty( + REMOTE_TASK_GUARANTEED_SPLITS_PER_REQUEST, + "Guaranteed splits per remote task request", + queryManagerConfig.getRemoteTaskGuaranteedSplitPerTask(), false)); } @@ -1184,4 +1208,24 @@ public static boolean isIncrementalHashArrayLoadFactorEnabled(Session session) { return session.getSystemProperty(INCREMENTAL_HASH_ARRAY_LOAD_FACTOR_ENABLED, Boolean.class); } + + public static boolean isEnableAdaptiveTaskRequestSize(Session session) + { + return session.getSystemProperty(ENABLE_ADAPTIVE_REMOTE_TASK_REQUEST_SIZE, Boolean.class); + } + + public static DataSize getMaxRemoteTaskRequestSize(Session session) + { + return session.getSystemProperty(MAX_REMOTE_TASK_REQUEST_SIZE, DataSize.class); + } + + public static DataSize getRemoteTaskRequestSizeHeadroom(Session session) + { + return session.getSystemProperty(REMOTE_TASK_REQUEST_SIZE_HEADROOM, DataSize.class); + } + + public static int getRemoteTaskGuaranteedSplitsPerRequest(Session session) + { + return session.getSystemProperty(REMOTE_TASK_GUARANTEED_SPLITS_PER_REQUEST, Integer.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index 7a3468663fd2..d85a533ef0fb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -58,6 +58,11 @@ public class QueryManagerConfig private Duration remoteTaskMaxErrorDuration = new Duration(5, TimeUnit.MINUTES); private int remoteTaskMaxCallbackThreads = 1000; + private boolean enabledAdaptiveTaskRequestSize; + private DataSize maxRemoteTaskRequestSize = DataSize.of(8, DataSize.Unit.MEGABYTE); + private DataSize remoteTaskRequestSizeHeadroom = DataSize.of(2, DataSize.Unit.MEGABYTE); + private int remoteTaskGuaranteedSplitPerTask = 3; + private String queryExecutionPolicy = "all-at-once"; private Duration queryMaxRunTime = new Duration(100, TimeUnit.DAYS); private Duration queryMaxExecutionTime = new Duration(100, TimeUnit.DAYS); @@ -338,6 +343,57 @@ public QueryManagerConfig setRemoteTaskMaxCallbackThreads(int remoteTaskMaxCallb return this; } + public boolean isEnabledAdaptiveTaskRequestSize() + { + return enabledAdaptiveTaskRequestSize; + } + + @Config("query.remote-task.enable-adaptive-request-size") + public QueryManagerConfig setEnabledAdaptiveTaskRequestSize(boolean enabledAdaptiveTaskRequestSize) + { + this.enabledAdaptiveTaskRequestSize = enabledAdaptiveTaskRequestSize; + return this; + } + + @NotNull + public DataSize getMaxRemoteTaskRequestSize() + { + return maxRemoteTaskRequestSize; + } + + @Config("query.remote-task.max-request-size") + public QueryManagerConfig setMaxRemoteTaskRequestSize(DataSize maxRemoteTaskRequestSize) + { + this.maxRemoteTaskRequestSize = maxRemoteTaskRequestSize; + return this; + } + + @NotNull + public DataSize getRemoteTaskRequestSizeHeadroom() + { + return remoteTaskRequestSizeHeadroom; + } + + @Config("query.remote-task.request-size-headroom") + public QueryManagerConfig setRemoteTaskRequestSizeHeadroom(DataSize remoteTaskRequestSizeHeadroom) + { + this.remoteTaskRequestSizeHeadroom = remoteTaskRequestSizeHeadroom; + return this; + } + + @Min(1) + public int getRemoteTaskGuaranteedSplitPerTask() + { + return remoteTaskGuaranteedSplitPerTask; + } + + @Config("query.remote-task.guaranteed-splits-per-task") + public QueryManagerConfig setRemoteTaskGuaranteedSplitPerTask(int remoteTaskGuaranteedSplitPerTask) + { + this.remoteTaskGuaranteedSplitPerTask = remoteTaskGuaranteedSplitPerTask; + return this; + } + @NotNull public String getQueryExecutionPolicy() { diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 15529df5aa74..8622ce4230b0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -133,7 +133,7 @@ public class SqlTaskExecution private final ConcurrentMap unpartitionedSources = new ConcurrentHashMap<>(); @GuardedBy("this") - private long maxAcknowledgedSplit = Long.MIN_VALUE; + private final Map maxAcknowledgedSplitByPlanNode = new HashMap<>(); @GuardedBy("this") private final SchedulingLifespanManager schedulingLifespanManager; @@ -325,12 +325,11 @@ private synchronized Map updateSources(List Map updatedUnpartitionedSources = new HashMap<>(); // first remove any split that was already acknowledged - long currentMaxAcknowledgedSplit = this.maxAcknowledgedSplit; sources = sources.stream() .map(source -> new TaskSource( source.getPlanNodeId(), source.getSplits().stream() - .filter(scheduledSplit -> scheduledSplit.getSequenceId() > currentMaxAcknowledgedSplit) + .filter(scheduledSplit -> scheduledSplit.getSequenceId() > maxAcknowledgedSplitByPlanNode.getOrDefault(source.getPlanNodeId(), Long.MIN_VALUE)) .collect(Collectors.toSet()), // Like splits, noMoreSplitsForLifespan could be pruned so that only new items will be processed. // This is not happening here because correctness won't be compromised due to duplicate events for noMoreSplitsForLifespan. @@ -354,11 +353,18 @@ private synchronized Map updateSources(List } // update maxAcknowledgedSplit - maxAcknowledgedSplit = sources.stream() - .flatMap(source -> source.getSplits().stream()) - .mapToLong(ScheduledSplit::getSequenceId) - .max() - .orElse(maxAcknowledgedSplit); + for (TaskSource taskSource : sources) { + long maxAcknowledgedSplit = taskSource.getSplits().stream() + .mapToLong(ScheduledSplit::getSequenceId) + .max() + .orElse(Long.MIN_VALUE); + PlanNodeId planNodeId = taskSource.getPlanNodeId(); + + if (!maxAcknowledgedSplitByPlanNode.containsKey(planNodeId)) { + maxAcknowledgedSplitByPlanNode.put(planNodeId, Long.MIN_VALUE); + } + maxAcknowledgedSplitByPlanNode.computeIfPresent(planNodeId, (key, val) -> maxAcknowledgedSplit > val ? maxAcknowledgedSplit : val); + } return updatedUnpartitionedSources; } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index 7b28f4bab16d..a70df50218f0 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -15,6 +15,7 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; import com.google.common.collect.SetMultimap; import com.google.common.net.HttpHeaders; @@ -64,6 +65,7 @@ import java.net.URI; import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,6 +81,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import java.util.stream.Stream; import static com.google.common.base.MoreObjects.toStringHelper; @@ -91,7 +94,11 @@ import static io.airlift.http.client.Request.Builder.prepareDelete; import static io.airlift.http.client.Request.Builder.preparePost; import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; +import static io.trino.SystemSessionProperties.getMaxRemoteTaskRequestSize; import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask; +import static io.trino.SystemSessionProperties.getRemoteTaskGuaranteedSplitsPerRequest; +import static io.trino.SystemSessionProperties.getRemoteTaskRequestSizeHeadroom; +import static io.trino.SystemSessionProperties.isEnableAdaptiveTaskRequestSize; import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTERS_VERSION; import static io.trino.execution.TaskInfo.createInitialTask; import static io.trino.execution.TaskState.ABORTED; @@ -172,6 +179,14 @@ public final class HttpRemoteTask private final AtomicBoolean started = new AtomicBoolean(false); private final AtomicBoolean aborting = new AtomicBoolean(false); + @GuardedBy("this") + private int splitBatchSize; + + private final int guaranteedSplitsPerRequest; + private final long maxRequestSize; + private final long requestSizeHeadroom; + private final boolean enableAdaptiveTaskRequestSize; + public HttpRemoteTask( Session session, TaskId taskId, @@ -235,6 +250,12 @@ public HttpRemoteTask( } maxUnacknowledgedSplits = getMaxUnacknowledgedSplitsPerTask(session); + this.guaranteedSplitsPerRequest = getRemoteTaskGuaranteedSplitsPerRequest(session); + this.maxRequestSize = getMaxRemoteTaskRequestSize(session).toBytes(); + this.requestSizeHeadroom = getRemoteTaskRequestSizeHeadroom(session).toBytes(); + this.splitBatchSize = maxUnacknowledgedSplits; + this.enableAdaptiveTaskRequestSize = isEnableAdaptiveTaskRequestSize(session); + int pendingSourceSplitCount = 0; long pendingSourceSplitsWeight = 0; for (PlanNodeId planNodeId : planFragment.getPartitionedSources()) { @@ -557,6 +578,10 @@ private synchronized void processTaskUpdate(TaskInfo newValue, List pendingSourceSplitsWeight -= removedWeight; } } + // set needsUpdate to true when there are sill pending splits + if (pendingSplits.size() > 0) { + needsUpdate.set(true); + } // Update node level split tracker before split queue space to ensure it's up to date before waking up the scheduler partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); updateSplitQueueSpace(); @@ -580,6 +605,27 @@ private synchronized void triggerUpdate() sendUpdate(); } + /** + * Adaptively adjust batch size to meet expected request size: + * If requestSize is not equal to expectedSize, this function will try to estimate and adjust the batch size proportionally based on + * current nums of splits and size of request. + */ + private synchronized void adjustSplitBatchSize(List sources, long requestSize, long expectedSize) + { + int numSplits = 0; + for (TaskSource taskSource : sources) { + numSplits = Math.max(numSplits, taskSource.getSplits().size()); + } + if (requestSize <= 0 || numSplits == 0) { + return; + } + if ((requestSize > expectedSize && splitBatchSize > guaranteedSplitsPerRequest) || (requestSize < expectedSize && splitBatchSize < maxUnacknowledgedSplits)) { + int newSplitBatchSize = (int) (numSplits * ((double) expectedSize / requestSize)); + newSplitBatchSize = Math.max(guaranteedSplitsPerRequest, Math.min(maxUnacknowledgedSplits, newSplitBatchSize)); + splitBatchSize = newSplitBatchSize; + } + } + private synchronized void sendUpdate() { TaskStatus taskStatus = getTaskStatus(); @@ -614,6 +660,20 @@ private synchronized void sendUpdate() outputBuffers.get(), dynamicFilterDomains.getDynamicFilterDomains()); byte[] taskUpdateRequestJson = taskUpdateRequestCodec.toJsonBytes(updateRequest); + + if (enableAdaptiveTaskRequestSize) { + int oldSplitBatchSize = splitBatchSize; + // try to adjust batch size to meet expected request size: (requestSizeLimit - requestSizeLimitHeadroom) + adjustSplitBatchSize(sources, taskUpdateRequestJson.length, maxRequestSize - requestSizeHeadroom); + // abandon current request and reschedule update if size of request body exceeds requestSizeLimit + // and splitBatchSize is updated + if (taskUpdateRequestJson.length > maxRequestSize && splitBatchSize < oldSplitBatchSize) { + log.debug("%s - current taskUpdateRequestJson exceeded limit: %d, abandon.", taskId, taskUpdateRequestJson.length); + scheduleUpdate(); + return; + } + } + if (fragment.isPresent()) { stats.updateWithPlanBytes(taskUpdateRequestJson.length); } @@ -646,21 +706,36 @@ private synchronized void sendUpdate() private synchronized List getSources() { - return Stream.concat(planFragment.getPartitionedSourceNodes().stream(), planFragment.getRemoteSourceNodes().stream()) - .filter(Objects::nonNull) - .map(PlanNode::getId) - .map(this::getSource) - .filter(Objects::nonNull) - .collect(toImmutableList()); + return Stream.concat( + planFragment.getPartitionedSourceNodes().stream() + .filter(Objects::nonNull) + .map(PlanNode::getId) + .map(planNodeId -> getSource(planNodeId, true)), + planFragment.getRemoteSourceNodes().stream() + .filter(Objects::nonNull) + .map(PlanNode::getId) + .map(planNodeId -> getSource(planNodeId, false)) + ).filter(Objects::nonNull).collect(toImmutableList()); } - private synchronized TaskSource getSource(PlanNodeId planNodeId) + private synchronized TaskSource getSource(PlanNodeId planNodeId, boolean isPartitionedSource) { Set splits = pendingSplits.get(planNodeId); boolean pendingNoMoreSplits = Boolean.TRUE.equals(this.noMoreSplits.get(planNodeId)); boolean noMoreSplits = this.noMoreSplits.containsKey(planNodeId); Set noMoreSplitsForLifespan = pendingNoMoreSplitsForLifespan.get(planNodeId); + // only apply batchSize to partitioned sources + if (isPartitionedSource && splitBatchSize < splits.size()) { + splits = splits.stream() + .sorted(Comparator.comparingLong(ScheduledSplit::getSequenceId)) + .limit(splitBatchSize) + .collect(Collectors.toSet()); + // if not last batch, we need to defer setting no more splits + noMoreSplits = false; + noMoreSplitsForLifespan = ImmutableSet.of(); + } + TaskSource element = null; if (!splits.isEmpty() || !noMoreSplitsForLifespan.isEmpty() || pendingNoMoreSplits) { element = new TaskSource(planNodeId, splits, noMoreSplitsForLifespan, noMoreSplits); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index 15a26cee4484..36cac2f0af28 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -57,7 +57,11 @@ public void testDefaults() .setQueryMaxCpuTime(new Duration(1_000_000_000, DAYS)) .setQueryMaxScanPhysicalBytes(null) .setRequiredWorkers(1) - .setRequiredWorkersMaxWait(new Duration(5, MINUTES))); + .setRequiredWorkersMaxWait(new Duration(5, MINUTES)) + .setEnabledAdaptiveTaskRequestSize(false) + .setMaxRemoteTaskRequestSize(DataSize.of(8, DataSize.Unit.MEGABYTE)) + .setRemoteTaskRequestSizeHeadroom(DataSize.of(2, DataSize.Unit.MEGABYTE)) + .setRemoteTaskGuaranteedSplitPerTask(3)); } @Test @@ -87,6 +91,10 @@ public void testExplicitPropertyMappings() .put("query.max-scan-physical-bytes", "1kB") .put("query-manager.required-workers", "333") .put("query-manager.required-workers-max-wait", "33m") + .put("query.remote-task.enable-adaptive-request-size", "true") + .put("query.remote-task.max-request-size", "10MB") + .put("query.remote-task.request-size-headroom", "1MB") + .put("query.remote-task.guaranteed-splits-per-task", "5") .build(); QueryManagerConfig expected = new QueryManagerConfig() @@ -112,7 +120,11 @@ public void testExplicitPropertyMappings() .setQueryMaxCpuTime(new Duration(2, DAYS)) .setQueryMaxScanPhysicalBytes(DataSize.of(1, KILOBYTE)) .setRequiredWorkers(333) - .setRequiredWorkersMaxWait(new Duration(33, MINUTES)); + .setRequiredWorkersMaxWait(new Duration(33, MINUTES)) + .setEnabledAdaptiveTaskRequestSize(true) + .setMaxRemoteTaskRequestSize(DataSize.of(10, DataSize.Unit.MEGABYTE)) + .setRemoteTaskRequestSizeHeadroom(DataSize.of(1, DataSize.Unit.MEGABYTE)) + .setRemoteTaskGuaranteedSplitPerTask(5); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index a259b8bec336..53555625be74 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -13,10 +13,12 @@ */ package io.trino.server.remotetask; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; import com.google.inject.Binder; import com.google.inject.Injector; import com.google.inject.Module; @@ -28,6 +30,7 @@ import io.airlift.json.JsonCodec; import io.airlift.json.JsonModule; import io.airlift.units.Duration; +import io.trino.Session; import io.trino.block.BlockJsonSerde; import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; @@ -112,15 +115,21 @@ import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.SystemSessionProperties.ENABLE_ADAPTIVE_REMOTE_TASK_REQUEST_SIZE; +import static io.trino.SystemSessionProperties.MAX_REMOTE_TASK_REQUEST_SIZE; +import static io.trino.SystemSessionProperties.REMOTE_TASK_GUARANTEED_SPLITS_PER_REQUEST; +import static io.trino.SystemSessionProperties.REMOTE_TASK_REQUEST_SIZE_HEADROOM; import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTERS_VERSION; import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.metadata.MetadataManager.createTestMetadataManager; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static io.trino.server.InternalHeaders.TRINO_CURRENT_VERSION; import static io.trino.server.InternalHeaders.TRINO_MAX_WAIT; import static io.trino.spi.StandardErrorCode.REMOTE_TASK_ERROR; import static io.trino.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.testing.assertions.Assert.assertEquals; import static io.trino.testing.assertions.Assert.assertEventually; import static java.lang.Math.min; @@ -365,6 +374,52 @@ public void testOutboundDynamicFilters() dynamicFilterService.stop(); } + @Test(timeOut = 300000) + public void testAdaptiveRemoteTaskRequestSize() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, FailureScenario.NO_FAILURE); + + Session session = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty(ENABLE_ADAPTIVE_REMOTE_TASK_REQUEST_SIZE, "true") + .setSystemProperty(MAX_REMOTE_TASK_REQUEST_SIZE, "8kB") + .setSystemProperty(REMOTE_TASK_REQUEST_SIZE_HEADROOM, "2kB") + .setSystemProperty(REMOTE_TASK_GUARANTEED_SPLITS_PER_REQUEST, "2") + .build(); + HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource); + + RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory, ImmutableSet.of(), session); + + testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); + remoteTask.start(); + + Lifespan lifespan = Lifespan.driverGroup(3); + + Multimap splits = HashMultimap.create(); + for (int i = 0; i < 1000; i++) { + splits.put(TABLE_SCAN_NODE_ID, new Split(new CatalogName("test"), TestingSplit.createLocalSplit(), lifespan)); + } + remoteTask.addSplits(splits); + + poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID) != null); + poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).getSplits().size() == 1000); + + remoteTask.noMoreSplits(TABLE_SCAN_NODE_ID, lifespan); + poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).getNoMoreSplitsForLifespan().size() == 1); + + remoteTask.noMoreSplits(TABLE_SCAN_NODE_ID); + poll(() -> testingTaskResource.getTaskSource(TABLE_SCAN_NODE_ID).isNoMoreSplits()); + + remoteTask.cancel(); + poll(() -> remoteTask.getTaskStatus().getState().isDone()); + poll(() -> remoteTask.getTaskInfo().getTaskStatus().getState().isDone()); + + httpRemoteTaskFactory.stop(); + } + private void runTest(FailureScenario failureScenario) throws Exception { @@ -409,9 +464,14 @@ private void addSplit(RemoteTask remoteTask, TestingTaskResource testingTaskReso } private RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory, Set outboundDynamicFilterIds) + { + return createRemoteTask(httpRemoteTaskFactory, outboundDynamicFilterIds, TEST_SESSION); + } + + private RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory, Set outboundDynamicFilterIds, Session session) { return httpRemoteTaskFactory.createRemoteTask( - TEST_SESSION, + session, new TaskId("test", 1, 2), new InternalNode("node-id", URI.create("http://fake.invalid/"), new NodeVersion("version"), false), TaskTestUtils.PLAN_FRAGMENT,